510 lines
20 KiB
Python
510 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Simple MCP client that uses Ollama models for inference.
|
|
"""
|
|
|
|
from fastmcp.client.transports import NodeStdioTransport, PythonStdioTransport, SSETransport, StreamableHttpTransport
|
|
|
|
|
|
import json
|
|
import sys
|
|
import os
|
|
import asyncio
|
|
from typing import Optional, Dict, Any, List
|
|
import requests
|
|
from fastmcp import Client as FastMcpClient
|
|
|
|
|
|
class OllamaClient:
|
|
"""Client for interacting with Ollama API."""
|
|
|
|
def __init__(self, baseUrl: str = "http://localhost:11434", model: str = "gpt-oss:20b"):
|
|
self.baseUrl = baseUrl
|
|
self.model = model
|
|
|
|
def listModels(self) -> List[str]:
|
|
"""List available Ollama models."""
|
|
try:
|
|
response = requests.get(f"{self.baseUrl}/api/tags", timeout=10)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return [model["name"] for model in data.get("models", [])]
|
|
except requests.RequestException as e:
|
|
print(f"Error listing models: {e}", file=sys.stderr)
|
|
return []
|
|
|
|
def chat(self, messages: List[Dict[str, str]], options: Optional[Dict[str, Any]] = None) -> str:
|
|
"""Send chat messages to Ollama and get response."""
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"stream": False,
|
|
}
|
|
|
|
if options:
|
|
payload["options"] = options
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{self.baseUrl}/api/chat",
|
|
json=payload,
|
|
timeout=60*60
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data.get("message", {}).get("content", "")
|
|
except requests.RequestException as e:
|
|
print(f"Error in chat request: {e}", file=sys.stderr)
|
|
raise
|
|
|
|
def generate(self, prompt: str, options: Optional[Dict[str, Any]] = None) -> str:
|
|
"""Generate text from a prompt using Ollama."""
|
|
payload = {
|
|
"model": self.model,
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
}
|
|
|
|
if options:
|
|
payload["options"] = options
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{self.baseUrl}/api/generate",
|
|
json=payload,
|
|
timeout=120
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data.get("response", "")
|
|
except requests.RequestException as e:
|
|
print(f"Error in generate request: {e}", file=sys.stderr)
|
|
raise
|
|
|
|
def checkHealth(self) -> bool:
|
|
"""Check if Ollama server is accessible."""
|
|
try:
|
|
response = requests.get(f"{self.baseUrl}/api/tags", timeout=5)
|
|
return response.status_code == 200
|
|
except requests.RequestException:
|
|
return False
|
|
|
|
|
|
class McpServerWrapper:
|
|
"""Wrapper around FastMCP Client for easier use."""
|
|
|
|
def __init__(self, httpUrl: str, headers: Optional[Dict[str, str]] = None):
|
|
self.httpUrl = httpUrl.rstrip("/")
|
|
self.headers = headers or {}
|
|
self.client: Optional[FastMcpClient] = None
|
|
self.serverTools: List[Dict[str, Any]] = []
|
|
|
|
async def connect(self) -> bool:
|
|
"""Connect and initialize with MCP server via HTTP."""
|
|
try:
|
|
# FastMcpClient doesn't support headers parameter directly
|
|
# Headers would need to be passed via custom transport or auth
|
|
# For now, we initialize without headers
|
|
self.client = FastMcpClient(self.httpUrl)
|
|
await self.client.__aenter__()
|
|
# Load tools after connection
|
|
tools = await self.listServerTools()
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error connecting to MCP server: {e}", file=sys.stderr)
|
|
return False
|
|
|
|
async def disconnect(self):
|
|
"""Disconnect from MCP server."""
|
|
if self.client:
|
|
await self.client.__aexit__(None, None, None)
|
|
self.client = None
|
|
|
|
async def listServerTools(self) -> List[Dict[str, Any]]:
|
|
"""List tools available from MCP server."""
|
|
if not self.client:
|
|
return []
|
|
|
|
try:
|
|
tools = await self.client.list_tools()
|
|
self.serverTools = tools
|
|
return tools
|
|
except Exception as e:
|
|
print(f"Error listing tools: {e}", file=sys.stderr)
|
|
return []
|
|
|
|
async def callServerTool(self, name: str, arguments: Dict[str, Any]) -> Any:
|
|
"""Call a tool on the MCP server."""
|
|
if not self.client:
|
|
raise RuntimeError("Not connected to MCP server")
|
|
|
|
try:
|
|
result = await self.client.call_tool(name, arguments)
|
|
# FastMCP call_tool returns a result object with .content
|
|
if hasattr(result, 'content'):
|
|
content = result.content
|
|
# If content is a list, return it as is (will be serialized later)
|
|
if isinstance(content, list):
|
|
return content
|
|
return content
|
|
elif isinstance(result, list):
|
|
# Handle list of results
|
|
if len(result) > 0:
|
|
# Extract content from each item if it exists
|
|
contents = []
|
|
for item in result:
|
|
if hasattr(item, 'content'):
|
|
contents.append(item.content)
|
|
else:
|
|
contents.append(item)
|
|
return contents if len(contents) > 1 else contents[0] if contents else None
|
|
return result
|
|
return result
|
|
except Exception as e:
|
|
raise RuntimeError(f"Tool call failed: {str(e)}")
|
|
|
|
async def listServerResources(self) -> List[Dict[str, Any]]:
|
|
"""List resources available from MCP server."""
|
|
if not self.client:
|
|
return []
|
|
|
|
try:
|
|
resources = await self.client.list_resources()
|
|
return resources
|
|
except Exception as e:
|
|
print(f"Error listing resources: {e}", file=sys.stderr)
|
|
return []
|
|
|
|
|
|
class OllamaMcpClient:
|
|
"""Simple MCP client that uses Ollama for inference."""
|
|
|
|
def __init__(self, ollamaClient: OllamaClient, mcpServer: Optional[McpServerWrapper] = None):
|
|
self.ollamaClient = ollamaClient
|
|
self.mcpServer = mcpServer
|
|
self.tools: List[Dict[str, Any]] = []
|
|
self.resources: List[Dict[str, Any]] = []
|
|
|
|
def _serializeToolResult(self, result: Any) -> Any:
|
|
"""Serialize tool result to JSON-serializable format."""
|
|
# Handle TextContent and other content objects
|
|
if hasattr(result, 'text'):
|
|
return result.text
|
|
if hasattr(result, 'content'):
|
|
content = result.content
|
|
if hasattr(content, 'text'):
|
|
return content.text
|
|
return content
|
|
# Handle lists of content objects
|
|
if isinstance(result, list):
|
|
return [self._serializeToolResult(item) for item in result]
|
|
# Handle dicts
|
|
if isinstance(result, dict):
|
|
return {k: self._serializeToolResult(v) for k, v in result.items()}
|
|
# Already serializable (str, int, float, bool, None)
|
|
return result
|
|
|
|
async def _loadServerTools(self):
|
|
"""Load tools from connected MCP server."""
|
|
if self.mcpServer:
|
|
serverTools = await self.mcpServer.listServerTools()
|
|
for tool in serverTools:
|
|
# Handle both Pydantic Tool objects and dicts
|
|
if hasattr(tool, "name"):
|
|
# Pydantic Tool object - access attributes directly
|
|
name = getattr(tool, "name", "")
|
|
description = getattr(tool, "description", "")
|
|
# Try both camelCase and snake_case for inputSchema
|
|
inputSchema = getattr(tool, "inputSchema", getattr(tool, "input_schema", {}))
|
|
else:
|
|
# Dict - use .get()
|
|
name = tool.get("name", "")
|
|
description = tool.get("description", "")
|
|
inputSchema = tool.get("inputSchema", tool.get("input_schema", {}))
|
|
|
|
self.tools.append({
|
|
"name": name,
|
|
"description": description,
|
|
"inputSchema": inputSchema
|
|
})
|
|
|
|
def registerTool(self, name: str, description: str, parameters: Dict[str, Any]):
|
|
"""Register a tool that can be used by the model."""
|
|
self.tools.append({
|
|
"name": name,
|
|
"description": description,
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": parameters,
|
|
"required": list(parameters.keys())
|
|
}
|
|
})
|
|
|
|
async def processRequest(self, prompt: str, context: Optional[List[str]] = None, maxIterations: int = 5) -> str:
|
|
"""Process a request using Ollama with optional context and tool support."""
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": """Sei un Crypto Solver Agent specializzato in sfide CTF (Capture The Flag). Il tuo obiettivo primario è identificare, analizzare e risolvere sfide crittografiche memorizzate nella directory /tmp per recuperare la flag. REGOLE OPERATIVE: Esplorazione: Inizia sempre elencando i file presenti in /tmp. Identifica i file rilevanti come sorgenti Python (.py), output di testo (.txt), file cifrati o chiavi pubbliche/private (.pem, .pub). Analisi: Leggi i file trovati. Determina il tipo di crittografia coinvolta. Casi comuni: RSA: analizza parametri come n, e, c. Verifica se n è piccolo (fattorizzabile), se e è basso (attacco radice e-esima) o se ci sono vulnerabilità note (Wiener, Hastad, moduli comuni). Simmetrica (AES/DES): cerca la modalità (ECB, CBC), vulnerabilità nel IV, o riutilizzo della chiave. XOR/Cifrari Classici: esegui analisi delle frequenze o attacchi a chiave fissa. Encoding: gestisci correttamente Base64, Hex, Big-Endian/Little-Endian. Esecuzione: Scrivi ed esegui script Python per risolvere la sfida. Utilizza librerie come pycryptodome, gmpy2 o sympy se disponibili nell'ambiente. Non limitarti a spiegare la teoria: scrivi il codice necessario a produrre il plaintext. Validazione: Una volta decifrato il contenuto, cerca stringhe nel formato flag{...}. Se il risultato non è leggibile, rivaluta l'approccio e prova una strategia alternativa. REQUISITI DI OUTPUT: Fornisci una breve spiegazione della vulnerabilità trovata. Mostra il codice Python risolutivo che hai generato. Restituisci la flag finale in modo chiaramente visibile. LIMITI: Opera esclusivamente all'interno della directory /tmp. Non tentare di forzare la password di sistema; concentrati sulla logica crittografica. Se mancano dati (es. un file citato nel codice non è presente), chiedi esplicitamente o cercalo nelle sottocartelle di /tmp. Inizia ora analizzando il contenuto di /tmp."""
|
|
}
|
|
]
|
|
|
|
if context:
|
|
messages.append({
|
|
"role": "system",
|
|
"content": f"Context:\n{'\n\n'.join(context)}"
|
|
})
|
|
|
|
if self.tools:
|
|
toolDescriptions = json.dumps(self.tools, indent=2)
|
|
messages.append({
|
|
"role": "system",
|
|
"content": f"Available tools:\n{toolDescriptions}\n\nTo use a tool, respond with JSON: {{\"tool_name\": \"name\", \"tool_args\": {{...}}}}"
|
|
})
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": prompt
|
|
})
|
|
|
|
iteration = 0
|
|
while iteration < maxIterations:
|
|
response = self.ollamaClient.chat(messages)
|
|
|
|
# Check if response contains tool call
|
|
toolCall = self._parseToolCall(response)
|
|
if toolCall:
|
|
toolName = toolCall.get("tool_name")
|
|
toolArgs = toolCall.get("tool_args", {})
|
|
|
|
# Print agent intent (response before tool call)
|
|
print(f"\n[Agent Intent]: {response}", file=sys.stderr)
|
|
print(f"[Tool Call Detected]: {toolName} with arguments: {toolArgs}", file=sys.stderr)
|
|
|
|
# Try to call the tool
|
|
try:
|
|
print(f"[Executing Tool]: {toolName} with arguments: {toolArgs}", file=sys.stderr)
|
|
toolResult = await self._executeTool(toolName, toolArgs)
|
|
# Serialize tool result to JSON-serializable format
|
|
serializedResult = self._serializeToolResult(toolResult)
|
|
print(f"[Tool Output]: {json.dumps(serializedResult, indent=2)}", file=sys.stderr)
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": response
|
|
})
|
|
messages.append({
|
|
"role": "user",
|
|
"content": f"Tool result: {json.dumps(serializedResult)}"
|
|
})
|
|
iteration += 1
|
|
continue
|
|
except Exception as e:
|
|
print(f"[Tool Error]: {str(e)}", file=sys.stderr)
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": response
|
|
})
|
|
messages.append({
|
|
"role": "user",
|
|
"content": f"Tool error: {str(e)}"
|
|
})
|
|
iteration += 1
|
|
continue
|
|
|
|
# No tool call, return response
|
|
print(f"\n[Agent Response (Final)]: {response}", file=sys.stderr)
|
|
return response
|
|
|
|
return messages[-1].get("content", "Max iterations reached")
|
|
|
|
def _parseToolCall(self, response: str) -> Optional[Dict[str, Any]]:
|
|
"""Try to parse tool call from response."""
|
|
# Try to find JSON object in response
|
|
try:
|
|
# Look for JSON in response
|
|
startIdx = response.find("{")
|
|
endIdx = response.rfind("}") + 1
|
|
if startIdx >= 0 and endIdx > startIdx:
|
|
jsonStr = response[startIdx:endIdx]
|
|
parsed = json.loads(jsonStr)
|
|
if "tool_name" in parsed:
|
|
return parsed
|
|
except:
|
|
pass
|
|
return None
|
|
|
|
async def _executeTool(self, toolName: str, toolArgs: Dict[str, Any]) -> Any:
|
|
"""Execute a tool - either from server or local."""
|
|
# First check if it's a server tool
|
|
if self.mcpServer:
|
|
# Check if tool exists in server tools
|
|
for tool in self.mcpServer.serverTools:
|
|
# Handle both Pydantic Tool objects and dicts
|
|
tool_name = getattr(tool, "name", None) if hasattr(tool, "name") else tool.get("name") if isinstance(tool, dict) else None
|
|
if tool_name == toolName:
|
|
return await self.mcpServer.callServerTool(toolName, toolArgs)
|
|
|
|
# Check local tools
|
|
if toolName == "get_time":
|
|
from datetime import datetime
|
|
return datetime.now().isoformat()
|
|
elif toolName == "count_words":
|
|
text = toolArgs.get("text", "")
|
|
return len(text.split())
|
|
|
|
raise ValueError(f"Tool '{toolName}' not found")
|
|
|
|
def listTools(self) -> List[Dict[str, Any]]:
|
|
"""List all registered tools."""
|
|
return self.tools
|
|
|
|
def listResources(self) -> List[Dict[str, Any]]:
|
|
"""List all available resources."""
|
|
return self.resources
|
|
|
|
|
|
async def async_main(args, ollamaClient: OllamaClient):
|
|
"""Async main function."""
|
|
# Connect to MCP server if specified
|
|
mcpServerWrapper = None
|
|
if args.mcp_server:
|
|
headers = {}
|
|
if args.mcp_headers:
|
|
try:
|
|
headers = json.loads(args.mcp_headers)
|
|
except json.JSONDecodeError:
|
|
print("Warning: Invalid JSON in --mcp-headers, ignoring", file=sys.stderr)
|
|
|
|
mcpServerWrapper = McpServerWrapper(httpUrl=args.mcp_server, headers=headers)
|
|
if not await mcpServerWrapper.connect():
|
|
print("Error: Failed to connect to MCP server", file=sys.stderr)
|
|
sys.exit(1)
|
|
print("Connected to MCP server via streamable HTTP", file=sys.stderr)
|
|
|
|
# Initialize MCP client
|
|
mcpClient = OllamaMcpClient(ollamaClient, mcpServerWrapper)
|
|
|
|
# Load server tools
|
|
if mcpServerWrapper:
|
|
await mcpClient._loadServerTools()
|
|
serverTools = await mcpServerWrapper.listServerTools()
|
|
if serverTools:
|
|
# Handle both Pydantic Tool objects and dicts
|
|
tool_names = [
|
|
getattr(t, "name", "") if hasattr(t, "name") else t.get("name", "") if isinstance(t, dict) else ""
|
|
for t in serverTools
|
|
]
|
|
print(f"Available MCP server tools: {tool_names}", file=sys.stderr)
|
|
|
|
# Register some example tools
|
|
mcpClient.registerTool(
|
|
name="get_time",
|
|
description="Get the current time",
|
|
parameters={}
|
|
)
|
|
mcpClient.registerTool(
|
|
name="count_words",
|
|
description="Count words in a text",
|
|
parameters={
|
|
"text": {
|
|
"type": "string",
|
|
"description": "The text to count words in"
|
|
}
|
|
}
|
|
)
|
|
|
|
# Process prompt or run interactively
|
|
if args.prompt:
|
|
response = await mcpClient.processRequest(args.prompt)
|
|
print(response)
|
|
elif args.interactive:
|
|
print("MCP Client with Ollama - Interactive Mode")
|
|
print("Type 'quit' or 'exit' to exit\n")
|
|
while True:
|
|
try:
|
|
prompt = input("You: ").strip()
|
|
if prompt.lower() in ["quit", "exit"]:
|
|
break
|
|
if not prompt:
|
|
continue
|
|
response = await mcpClient.processRequest(prompt)
|
|
print(f"Assistant: {response}\n")
|
|
except KeyboardInterrupt:
|
|
print("\nGoodbye!")
|
|
break
|
|
except Exception as e:
|
|
print(f"Error: {e}", file=sys.stderr)
|
|
|
|
# Cleanup
|
|
if mcpServerWrapper:
|
|
await mcpServerWrapper.disconnect()
|
|
|
|
|
|
def main():
|
|
"""Main function to run the MCP client."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="MCP client using Ollama")
|
|
parser.add_argument(
|
|
"--base-url",
|
|
default="http://localhost:11434",
|
|
help="Ollama base URL (default: http://localhost:11434)"
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
default="ministral-3",
|
|
help="Ollama model to use (default: ministral-3)"
|
|
)
|
|
parser.add_argument(
|
|
"--list-models",
|
|
action="store_true",
|
|
help="List available Ollama models and exit"
|
|
)
|
|
parser.add_argument(
|
|
"--prompt",
|
|
help="Prompt to send to the model"
|
|
)
|
|
parser.add_argument(
|
|
"--interactive",
|
|
"-i",
|
|
action="store_true",
|
|
help="Run in interactive mode"
|
|
)
|
|
parser.add_argument(
|
|
"--mcp-server",
|
|
help="HTTP URL for MCP server (e.g., 'http://localhost:8000/mcp')",
|
|
default="http://localhost:8000/mcp"
|
|
)
|
|
parser.add_argument(
|
|
"--mcp-headers",
|
|
help="Additional headers for MCP server as JSON string (e.g., '{\"Authorization\": \"Bearer token\"}')"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Initialize Ollama client
|
|
ollamaClient = OllamaClient(baseUrl=args.base_url, model=args.model)
|
|
|
|
# Check health
|
|
if not ollamaClient.checkHealth():
|
|
print(f"Error: Cannot connect to Ollama at {args.base_url}", file=sys.stderr)
|
|
print("Make sure Ollama is running and accessible.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# List models if requested
|
|
if args.list_models:
|
|
models = ollamaClient.listModels()
|
|
print("Available models:")
|
|
for model in models:
|
|
print(f" - {model}")
|
|
sys.exit(0)
|
|
|
|
# Run async main
|
|
asyncio.run(async_main(args, ollamaClient))
|
|
|
|
if not args.prompt and not args.interactive:
|
|
parser.print_help()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |