#!/usr/bin/env python3 """ MCP client that uses Ollama for inference and LangChain create_agent with runtime-registered MCP tools (see https://docs.langchain.com/oss/python/langchain/agents#runtime-tool-registration). """ import json import sys import os import asyncio from pathlib import Path from typing import Optional, Dict, Any, List, Callable, Awaitable import requests from fastmcp import Client as FastMcpClient from ollama import ResponseError as OllamaResponseError from pydantic import BaseModel, ConfigDict, Field, create_model # LangChain agent and middleware try: from langchain.agents import create_agent from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse, ToolCallRequest from langchain_core.tools import StructuredTool, tool from langchain_ollama import ChatOllama from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage except ImportError as e: print(f"Missing dependency: {e}. Install with: pip install langchain langgraph langchain-community langchain-ollama", file=sys.stderr) sys.exit(1) @tool def getTime() -> str: """Get the current time in ISO format.""" from datetime import datetime return datetime.now().isoformat() @tool def countWords(text: str) -> int: """Count the number of words in a text.""" return len(text.split()) def loadMcpConfig(configPath: Optional[str] = None) -> Dict[str, str]: """Load MCP server URLs from mcp.json. Returns dict serverName -> url.""" if configPath is None: # Default: mcpServer/mcp.json relative to project root or cwd base = Path(__file__).resolve().parent.parent configPath = str(base / "mcpServer" / "mcp.json") path = Path(configPath) if not path.exists(): return {} try: with open(path, "r", encoding="utf-8") as f: data = json.load(f) except (json.JSONDecodeError, OSError) as e: print(f"Warning: Could not load MCP config from {path}: {e}", file=sys.stderr) return {} servers = data.get("mcpServers") or data.get("mcp_servers") or {} return {name: info.get("url", "") for name, info in servers.items() if isinstance(info, dict) and info.get("url")} class GenericToolArgs(BaseModel): """Accept any keyword arguments for MCP tool calls (fallback when schema is missing).""" model_config = ConfigDict(extra="allow") def _jsonSchemaTypeToPython(jsonType: str) -> type: """Map JSON schema type to Python type.""" return {"string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict}.get(jsonType, str) def _defaultForJsonType(jsonType: str) -> Any: """Sensible default for optional MCP params so server does not receive null.""" return {"string": "", "integer": 0, "number": 0.0, "boolean": False, "array": [], "object": {}}.get(jsonType, "") def _defaultsFromInputSchema(inputSchema: Dict[str, Any]) -> Dict[str, Any]: """Build default values for all params so we never send null to the MCP server (LLM may omit required params).""" if not inputSchema: return {} properties = inputSchema.get("properties") or {} out: Dict[str, Any] = {} for name, spec in properties.items(): if not isinstance(spec, dict): continue if "default" in spec: out[name] = spec["default"] else: out[name] = _defaultForJsonType(spec.get("type", "string")) return out def buildArgsSchemaFromMcpInputSchema(toolName: str, inputSchema: Dict[str, Any]) -> type[BaseModel]: """Build a Pydantic model from MCP tool inputSchema so the LLM gets exact parameter names (path, content, etc.).""" if not inputSchema: return GenericToolArgs properties = inputSchema.get("properties") or {} required = set(inputSchema.get("required") or []) if not properties: return GenericToolArgs fields: Dict[str, Any] = {} for name, spec in properties.items(): if not isinstance(spec, dict): continue desc = spec.get("description", "") jsonType = spec.get("type", "string") pyType = _jsonSchemaTypeToPython(jsonType) if name in required: fields[name] = (pyType, Field(..., description=desc)) else: fields[name] = (Optional[pyType], Field(None, description=desc)) if not fields: return GenericToolArgs return create_model(f"McpArgs_{toolName}", **fields) class OllamaClient: """Client for interacting with Ollama API.""" def __init__(self, baseUrl: str = "http://localhost:11434", model: str = "gpt-oss:20b"): self.baseUrl = baseUrl self.model = model def listModels(self) -> List[str]: """List available Ollama models.""" try: response = requests.get(f"{self.baseUrl}/api/tags", timeout=10) response.raise_for_status() data = response.json() return [model["name"] for model in data.get("models", [])] except requests.RequestException as e: print(f"Error listing models: {e}", file=sys.stderr) return [] def chat(self, messages: List[Dict[str, str]], options: Optional[Dict[str, Any]] = None) -> str: """Send chat messages to Ollama and get response.""" payload = { "model": self.model, "messages": messages, "stream": False, } if options: payload["options"] = options try: response = requests.post( f"{self.baseUrl}/api/chat", json=payload, timeout=60*60*60 ) response.raise_for_status() data = response.json() return data.get("message", {}).get("content", "") except requests.RequestException as e: print(f"Error in chat request: {e}", file=sys.stderr) raise def generate(self, prompt: str, options: Optional[Dict[str, Any]] = None) -> str: """Generate text from a prompt using Ollama.""" payload = { "model": self.model, "prompt": prompt, "stream": False, } if options: payload["options"] = options try: response = requests.post( f"{self.baseUrl}/api/generate", json=payload, timeout=120 ) response.raise_for_status() data = response.json() return data.get("response", "") except requests.RequestException as e: print(f"Error in generate request: {e}", file=sys.stderr) raise def checkHealth(self) -> bool: """Check if Ollama server is accessible.""" try: response = requests.get(f"{self.baseUrl}/api/tags", timeout=5) return response.status_code == 200 except requests.RequestException: return False class McpServerWrapper: """Wrapper around FastMCP Client for easier use.""" def __init__(self, httpUrl: str, headers: Optional[Dict[str, str]] = None): self.httpUrl = httpUrl.rstrip("/") self.headers = headers or {} self.client: Optional[FastMcpClient] = None self.serverTools: List[Dict[str, Any]] = [] async def connect(self) -> bool: """Connect and initialize with MCP server via HTTP.""" try: # FastMcpClient doesn't support headers parameter directly # Headers would need to be passed via custom transport or auth # For now, we initialize without headers self.client = FastMcpClient(self.httpUrl) await self.client.__aenter__() # Load tools after connection tools = await self.listServerTools() return True except Exception as e: print(f"Error connecting to MCP server: {e}", file=sys.stderr) return False async def disconnect(self): """Disconnect from MCP server.""" if self.client: await self.client.__aexit__(None, None, None) self.client = None async def listServerTools(self) -> List[Dict[str, Any]]: """List tools available from MCP server.""" if not self.client: return [] try: tools = await self.client.list_tools() self.serverTools = tools return tools except Exception as e: print(f"Error listing tools: {e}", file=sys.stderr) return [] async def callServerTool(self, name: str, arguments: Dict[str, Any]) -> Any: """Call a tool on the MCP server.""" if not self.client: raise RuntimeError("Not connected to MCP server") try: result = await self.client.call_tool(name, arguments) # FastMCP call_tool returns a result object with .content if hasattr(result, 'content'): content = result.content # If content is a list, return it as is (will be serialized later) if isinstance(content, list): return content return content elif isinstance(result, list): # Handle list of results if len(result) > 0: # Extract content from each item if it exists contents = [] for item in result: if hasattr(item, 'content'): contents.append(item.content) else: contents.append(item) return contents if len(contents) > 1 else contents[0] if contents else None return result return result except Exception as e: raise RuntimeError(f"Tool call failed: {str(e)}") async def listServerResources(self) -> List[Dict[str, Any]]: """List resources available from MCP server.""" if not self.client: return [] try: resources = await self.client.list_resources() return resources except Exception as e: print(f"Error listing resources: {e}", file=sys.stderr) return [] def _serializeToolResult(result: Any) -> Any: """Serialize tool result to JSON-serializable format.""" if hasattr(result, "text"): return result.text if hasattr(result, "content"): content = result.content if hasattr(content, "text"): return content.text return content if isinstance(result, list): return [_serializeToolResult(item) for item in result] if isinstance(result, dict): return {k: _serializeToolResult(v) for k, v in result.items()} return result def _makeMcpToolCoroutine( toolName: str, server: McpServerWrapper, defaultArgs: Dict[str, Any], toolTimeout: Optional[float] = None, ) -> Callable[..., Awaitable[Any]]: async def _invoke(**kwargs: Any) -> Any: merged = {**defaultArgs, **kwargs} # Strip None values - MCP server Zod schemas often reject null for optional params (expect number | undefined, not number | null) merged = {k: v for k, v in merged.items() if v is not None} try: if toolTimeout is not None and toolTimeout > 0: result = await asyncio.wait_for( server.callServerTool(toolName, merged), timeout=toolTimeout, ) else: result = await server.callServerTool(toolName, merged) except asyncio.TimeoutError: return ( f"[Tool timeout] '{toolName}' exceeded {toolTimeout}s. " "The operation may have hung (e.g. command not found, subprocess blocking). " "Try an alternative (e.g. 'python' instead of 'python3') or increase --tool-timeout." ) return _serializeToolResult(result) return _invoke async def buildMcpLangChainTools( mcpServers: List[McpServerWrapper], toolTimeout: Optional[float] = None, ) -> List[StructuredTool]: """Build LangChain StructuredTools from connected MCP servers (runtime tool registration).""" tools: List[StructuredTool] = [] for server in mcpServers: rawTools = await server.listServerTools() for raw in rawTools: name = getattr(raw, "name", None) or (raw.get("name") if isinstance(raw, dict) else None) description = getattr(raw, "description", None) or (raw.get("description", "") if isinstance(raw, dict) else "") inputSchema = getattr(raw, "inputSchema", None) or getattr(raw, "input_schema", None) or (raw.get("inputSchema") or raw.get("input_schema") if isinstance(raw, dict) else None) if not name: continue description = description or f"MCP tool: {name}" schemaDict = inputSchema or {} argsSchema = buildArgsSchemaFromMcpInputSchema(name, schemaDict) defaultArgs = _defaultsFromInputSchema(schemaDict) tool = StructuredTool.from_function( name=name, description=description, args_schema=argsSchema, coroutine=_makeMcpToolCoroutine(name, server, defaultArgs, toolTimeout), ) tools.append(tool) return tools class LogToolCallsMiddleware(AgentMiddleware): """Middleware that logs every tool call (name and args).""" def wrap_tool_call(self, request: ToolCallRequest, handler: Callable): _logToolCallRequest(request) return handler(request) async def awrap_tool_call(self, request: ToolCallRequest, handler: Callable): _logToolCallRequest(request) return await handler(request) def _extractTextFromAIMessageContent(content: Any) -> str: """Extract plain text from AIMessage.content (str or list of content blocks).""" if content is None: return "" if isinstance(content, str): return content.strip() if isinstance(content, list): parts: List[str] = [] for block in content: if isinstance(block, dict) and "text" in block: parts.append(str(block["text"])) elif isinstance(block, str): parts.append(block) return "\n".join(parts).strip() if parts else "" return str(content).strip() def _extractFinalResponse(result: Dict[str, Any]) -> str: """Extract the final assistant text from agent result; handle recursion limit / no final message.""" messages = result.get("messages") or [] for msg in reversed(messages): if isinstance(msg, AIMessage) and hasattr(msg, "content"): text = _extractTextFromAIMessageContent(msg.content) if text: return text return ( "Agent stopped without a final text response (e.g. hit step limit after tool calls). " "Try again or increase --recursion-limit." ) def _logToolCallRequest(request: ToolCallRequest) -> None: tc = request.tool_call name = tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", None) args = tc.get("args", tc.get("arguments", {})) if isinstance(tc, dict) else getattr(tc, "args", getattr(tc, "arguments", {})) argsStr = json.dumps(args, ensure_ascii=False) if len(argsStr) > 500: argsStr = argsStr[:497] + "..." print(f"[Tool Call] {name} args={argsStr}", file=sys.stderr) class McpToolsMiddleware(AgentMiddleware): """Middleware that adds MCP tools at runtime and handles their execution (runtime tool registration).""" def __init__(self, mcpTools: List[StructuredTool], staticToolNames: Optional[List[str]] = None): self.mcpTools = mcpTools self.mcpToolsByName = {t.name: t for t in mcpTools} staticNames = set(staticToolNames or []) self.validToolNames = staticNames | set(self.mcpToolsByName.keys()) def wrap_model_call(self, request: ModelRequest, handler: Callable) -> ModelResponse: updated = request.override(tools=[*request.tools, *self.mcpTools]) return handler(updated) async def awrap_model_call(self, request: ModelRequest, handler: Callable): updated = request.override(tools=[*request.tools, *self.mcpTools]) return await handler(updated) def _toolExists(self, name: Optional[str]) -> bool: return bool(name and name in self.validToolNames) def _unknownToolErrorToolMessage(self, request: ToolCallRequest, name: str) -> ToolMessage: available = ", ".join(sorted(self.validToolNames)) content = ( f"[Error] Tool '{name}' does not exist. " f"Only the following tools are available: {available}. " "Do not call tools that are not in this list." ) tc = request.tool_call toolCallId = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) return ToolMessage( content=content, tool_call_id=toolCallId or "unknown", name=name or "unknown", status="error", ) def wrap_tool_call(self, request: ToolCallRequest, handler: Callable): name = request.tool_call.get("name") if isinstance(request.tool_call, dict) else getattr(request.tool_call, "name", None) if not self._toolExists(name): return self._unknownToolErrorToolMessage(request, name or "") if name and name in self.mcpToolsByName: return handler(request.override(tool=self.mcpToolsByName[name])) return handler(request) async def awrap_tool_call(self, request: ToolCallRequest, handler: Callable): name = request.tool_call.get("name") if isinstance(request.tool_call, dict) else getattr(request.tool_call, "name", None) if not self._toolExists(name): return self._unknownToolErrorToolMessage(request, name or "") if name and name in self.mcpToolsByName: return await handler(request.override(tool=self.mcpToolsByName[name])) return await handler(request) ''' TODO Use this if you want sequential thinking SYSTEM_PROMPT = """ ROLE: Sei un esperto Analista di Cybersecurity specializzato in CTF (Capture The Flag) e analisi di vulnerabilità. Operi in un ambiente Linux sandbox dove la tua unica area di lavoro è la directory /tmp. WORKSPACE CONSTRAINT: IL "SINGLE SOURCE OF TRUTH" - Obbligo Assoluto: Tutte le operazioni di lettura, scrittura, download e analisi devono avvenire esclusivamente all'interno di /tmp. - Percorsi: Ogni file deve essere referenziato con il percorso assoluto (es. /tmp/binary.bin). Non usare mai directory come ~/, /home o altre al di fuori di /tmp. - Condivisione: Ricorda che /tmp è montata su tutti i container MCP (fetch, filesystem, ecc.). Se scarichi un file con fetch in /tmp, il tool filesystem lo troverà immediatamente lì. TOOLSET & WORKFLOW: Utilizza i tuoi tool secondo questa logica: 1. sequentialthinking (Pianificazione): Usa questo tool PRIMA di ogni azione complessa. Suddividi la sfida in step logici (es. 1. Download, 2. Analisi Header, 3. Estrazione Flag). Ti aiuta a non perdere il filo durante task lunghi. 2. fetch (Ingestion): Usalo per recuperare binari, exploit o dati remoti. Salva l'output sempre in /tmp. 3. filesystem (Manipolazione): Usalo per ispezionare i file scaricati, creare script di exploit o leggere file di log e flag direttamente in /tmp. 4. memory (Stato): Utilizza questo tool per memorizzare scoperte chiave, indirizzi di memoria, offset o password trovate durante la sfida. Ti serve per mantenere il contesto tra diverse fasi del ragionamento. METODOLOGIA DI ANALISI: - Ipotesi e Test: Prima di agire, formula un'ipotesi basata sui dati presenti in /tmp. - Verifica Intermedia: Dopo ogni comando o modifica ai file, verifica il risultato usando il tool filesystem. Non dare mai per scontato che un'operazione sia riuscita senza controllare. - Pulizia Mentale: Se una strategia fallisce, usa sequentialthinking per rivedere il piano e aggiorna il tool memory con il motivo del fallimento per non ripetere lo stesso errore. REGOLE DI COMUNICAZIONE: - Sii estremamente tecnico, sintetico e preciso. - Se un file non è presente in /tmp, non provare a indovinarne il contenuto; usa fetch per ottenerlo o filesystem per cercarlo. - Rispondi con l'output delle tue analisi e l'eventuale flag trovata nel formato richiesto dalla sfida. """ ''' SYSTEM_PROMPT = "ROLE:\nSei un esperto Analista di Cybersecurity specializzato in CTF (Capture The Flag) e analisi di vulnerabilità. Operi in un ambiente Linux sandbox dove la tua unica area di lavoro è la directory /tmp.\n\nWORKSPACE CONSTRAINT: IL \"SINGLE SOURCE OF TRUTH\"\n- Obbligo Assoluto: Tutte le operazioni di lettura, scrittura e analisi devono avvenire esclusivamente all'interno di /tmp.\n- Percorsi: Ogni file deve essere referenziato con il percorso assoluto (es. /tmp/binary.bin). Non usare mai directory esterne a /tmp.\n- Condivisione: /tmp è montata su tutti i container MCP. I file creati o modificati da un tool sono immediatamente visibili agli altri.\n\nSTRETTO DIVIETO DI ALLUCINAZIONE TOOL:\n- USA ESCLUSIVAMENTE I TOOL MCP FORNITI: 'memory', 'filesystem'.\n\nREGOLE DI COMUNICAZIONE:\n- Sii estremamente tecnico, sintetico e preciso.\n- Non fare mai riferimento a tool che non siano 'memory' o 'filesystem'." class OllamaMcpClient: """MCP client that uses Ollama and LangChain create_agent with optional runtime MCP tools.""" def __init__( self, ollamaClient: OllamaClient, mcpTools: Optional[List[StructuredTool]] = None, systemPrompt: Optional[str] = None, ): self.ollamaClient = ollamaClient self.mcpTools = mcpTools or [] self.systemPrompt = systemPrompt or SYSTEM_PROMPT staticTools: List[Any] = [getTime, countWords] staticToolNames = [getTime.name, countWords.name] middleware: List[AgentMiddleware] = [LogToolCallsMiddleware()] if self.mcpTools: middleware.append(McpToolsMiddleware(self.mcpTools, staticToolNames=staticToolNames)) model = ChatOllama( base_url=ollamaClient.baseUrl, model=ollamaClient.model, temperature=0.1, ) self.agent = create_agent( model, tools=staticTools, middleware=middleware, system_prompt=self.systemPrompt, ) async def processRequest(self, prompt: str, context: Optional[List[str]] = None, recursionLimit: int = 50) -> str: """Process a request using the LangChain agent (ReAct loop with tools).""" messages: List[Any] = [HumanMessage(content=prompt)] if context: messages.insert(0, SystemMessage(content=f"Context:\n{chr(10).join(context)}")) config: Dict[str, Any] = {"recursion_limit": recursionLimit} toolParseRetryPrompt = ( "ATTENZIONE: Una chiamata write_file ha prodotto JSON non valido. " "Quando scrivi file con codice Python: usa \\n per le newline nel JSON, escapa le virgolette con \\. " "Non aggiungere parametri extra (es. overwrite). Usa edit_file per modifiche incrementali se il contenuto è lungo." ) try: result = await self.agent.ainvoke({"messages": messages}, config=config) except OllamaResponseError as e: errStr = str(e) if "error parsing tool call" in errStr: print(f"[Agent Error]: Tool call parse error, retrying with guidance: {errStr[:200]}...", file=sys.stderr) retryMessages: List[Any] = [SystemMessage(content=toolParseRetryPrompt)] retryMessages.extend(messages) result = await self.agent.ainvoke({"messages": retryMessages}, config=config) else: print(f"[Agent Error]: {e}", file=sys.stderr) raise except Exception as e: print(f"[Agent Error]: {e}", file=sys.stderr) raise return _extractFinalResponse(result) def listTools(self) -> List[str]: """List tool names (static + MCP).""" names = [getTime.name, countWords.name] names.extend(t.name for t in self.mcpTools) return names async def async_main(args, ollamaClient: OllamaClient): """Async main: MCP tools come only from mcp.json (Docker containers exposing SSE). Ollama is used only as LLM.""" mcpTools: List[StructuredTool] = [] mcpServers: List[McpServerWrapper] = [] # MCP servers from config file (mcp.json) – Docker containers with SSE endpoints serverUrls: Dict[str, str] = loadMcpConfig(args.mcp_config) if args.mcp_server: serverUrls["default"] = args.mcp_server.rstrip("/") # Which servers to use: default = all from mcp.json; or --mcp-tools fetch,filesystem to pick a subset wantServers = [s.strip() for s in (args.mcp_tools or "").split(",") if s.strip()] if not wantServers and serverUrls: wantServers = list(serverUrls.keys()) print(f"MCP tools from config (all SSE servers): {wantServers}", file=sys.stderr) for name in wantServers: url = serverUrls.get(name) if not url: print(f"Warning: MCP server '{name}' not in config (known: {list(serverUrls.keys())})", file=sys.stderr) continue wrapper = McpServerWrapper(httpUrl=url) if await wrapper.connect(): mcpServers.append(wrapper) print(f"Connected to MCP server '{name}' at {url}", file=sys.stderr) else: print(f"Error: Failed to connect to MCP server '{name}' at {url}", file=sys.stderr) if mcpServers: mcpTools = await buildMcpLangChainTools(mcpServers, toolTimeout=getattr(args, "tool_timeout", None)) #print(f"Loaded {len(mcpTools)} MCP tools: {[t.name for t in mcpTools]}", file=sys.stderr) mcpClient = OllamaMcpClient(ollamaClient, mcpTools=mcpTools) print(f"Agent tools: {mcpClient.listTools()}", file=sys.stderr) if args.prompt: response = await mcpClient.processRequest(args.prompt, recursionLimit=args.recursion_limit) print(response) elif args.interactive: print("MCP Client with Ollama (LangChain agent) - Interactive Mode") print("Type 'quit' or 'exit' to exit\n") while True: try: prompt = input("You: ").strip() if prompt.lower() in ["quit", "exit"]: break if not prompt: continue response = await mcpClient.processRequest(prompt, recursionLimit=args.recursion_limit) print(f"Assistant: {response}\n") except KeyboardInterrupt: print("\nGoodbye!") break except Exception as e: print(f"Error: {e}", file=sys.stderr) for wrapper in mcpServers: await wrapper.disconnect() def main(): """Main function to run the MCP client.""" import argparse parser = argparse.ArgumentParser(description="MCP client using Ollama") parser.add_argument( "--base-url", default="http://localhost:11434", help="Ollama base URL (default: http://localhost:11434)" ) parser.add_argument( "--model", default="gpt-oss:20b", help="Ollama model to use (default: ministral-3)" ) parser.add_argument( "--list-models", action="store_true", help="List available Ollama models and exit" ) parser.add_argument( "--prompt", help="Prompt to send to the model" ) parser.add_argument( "--interactive", "-i", action="store_true", help="Run in interactive mode" ) parser.add_argument( "--mcp-config", default=None, help="Path to mcp.json (default: mcpServer/mcp.json relative to project)" ) parser.add_argument( "--mcp-tools", default="", help="Comma-separated MCP server names from mcp.json (default: all servers in config). E.g. fetch,filesystem" ) parser.add_argument( "--mcp-server", help="Override: single MCP SSE URL (e.g. http://localhost:3000/sse). Added as server 'default' in addition to mcp.json." ) parser.add_argument( "--mcp-headers", help="Additional headers for MCP server as JSON string (e.g. '{\"Authorization\": \"Bearer token\"}')" ) parser.add_argument( "--recursion-limit", type=int, default=5000, help="Max agent steps (model + tool calls) before stopping (default: 50)" ) parser.add_argument( "--tool-timeout", type=float, default=60, help="Timeout in seconds for each MCP tool call. Prevents agent from freezing when a tool hangs (e.g. run with missing executable). Default: 60" ) args = parser.parse_args() # Initialize Ollama client ollamaClient = OllamaClient(baseUrl=args.base_url, model=args.model) # Check health if not ollamaClient.checkHealth(): print(f"Error: Cannot connect to Ollama at {args.base_url}", file=sys.stderr) print("Make sure Ollama is running and accessible.", file=sys.stderr) sys.exit(1) # List models if requested if args.list_models: models = ollamaClient.listModels() print("Available models:") for model in models: print(f" - {model}") sys.exit(0) # Run async main asyncio.run(async_main(args, ollamaClient)) if not args.prompt and not args.interactive: parser.print_help() if __name__ == "__main__": main()