From fe21c6b79026f3a553bc493eaf989fcfd7314881 Mon Sep 17 00:00:00 2001 From: Schrody Date: Thu, 12 Feb 2026 22:00:08 +0100 Subject: [PATCH] Modified client --- mcpClient/mcpClient.py | 664 ++++++++++++++++++++++++------------- mcpClient/requirements.txt | 8 +- mcpServer/requirements.txt | 1 - 3 files changed, 433 insertions(+), 240 deletions(-) delete mode 100644 mcpServer/requirements.txt diff --git a/mcpClient/mcpClient.py b/mcpClient/mcpClient.py index cbe4565..05f35da 100644 --- a/mcpClient/mcpClient.py +++ b/mcpClient/mcpClient.py @@ -1,18 +1,118 @@ #!/usr/bin/env python3 """ -Simple MCP client that uses Ollama models for inference. +MCP client that uses Ollama for inference and LangChain create_agent with +runtime-registered MCP tools (see https://docs.langchain.com/oss/python/langchain/agents#runtime-tool-registration). """ -from fastmcp.client.transports import NodeStdioTransport, PythonStdioTransport, SSETransport, StreamableHttpTransport - - import json import sys import os import asyncio -from typing import Optional, Dict, Any, List +from pathlib import Path +from typing import Optional, Dict, Any, List, Callable, Awaitable + import requests from fastmcp import Client as FastMcpClient +from ollama import ResponseError as OllamaResponseError +from pydantic import BaseModel, ConfigDict, Field, create_model + +# LangChain agent and middleware +try: + from langchain.agents import create_agent + from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse, ToolCallRequest + from langchain_core.tools import StructuredTool, tool + from langchain_ollama import ChatOllama + from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage +except ImportError as e: + print(f"Missing dependency: {e}. Install with: pip install langchain langgraph langchain-community langchain-ollama", file=sys.stderr) + sys.exit(1) + + +@tool +def getTime() -> str: + """Get the current time in ISO format.""" + from datetime import datetime + return datetime.now().isoformat() + + +@tool +def countWords(text: str) -> int: + """Count the number of words in a text.""" + return len(text.split()) + + +def loadMcpConfig(configPath: Optional[str] = None) -> Dict[str, str]: + """Load MCP server URLs from mcp.json. Returns dict serverName -> url.""" + if configPath is None: + # Default: mcpServer/mcp.json relative to project root or cwd + base = Path(__file__).resolve().parent.parent + configPath = str(base / "mcpServer" / "mcp.json") + path = Path(configPath) + if not path.exists(): + return {} + try: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + except (json.JSONDecodeError, OSError) as e: + print(f"Warning: Could not load MCP config from {path}: {e}", file=sys.stderr) + return {} + servers = data.get("mcpServers") or data.get("mcp_servers") or {} + return {name: info.get("url", "") for name, info in servers.items() if isinstance(info, dict) and info.get("url")} + + +class GenericToolArgs(BaseModel): + """Accept any keyword arguments for MCP tool calls (fallback when schema is missing).""" + model_config = ConfigDict(extra="allow") + + +def _jsonSchemaTypeToPython(jsonType: str) -> type: + """Map JSON schema type to Python type.""" + return {"string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict}.get(jsonType, str) + + +def _defaultForJsonType(jsonType: str) -> Any: + """Sensible default for optional MCP params so server does not receive null.""" + return {"string": "", "integer": 0, "number": 0.0, "boolean": False, "array": [], "object": {}}.get(jsonType, "") + + +def _defaultsFromInputSchema(inputSchema: Dict[str, Any]) -> Dict[str, Any]: + """Build default values for all params so we never send null to the MCP server (LLM may omit required params).""" + if not inputSchema: + return {} + properties = inputSchema.get("properties") or {} + out: Dict[str, Any] = {} + for name, spec in properties.items(): + if not isinstance(spec, dict): + continue + if "default" in spec: + out[name] = spec["default"] + else: + out[name] = _defaultForJsonType(spec.get("type", "string")) + return out + + +def buildArgsSchemaFromMcpInputSchema(toolName: str, inputSchema: Dict[str, Any]) -> type[BaseModel]: + """Build a Pydantic model from MCP tool inputSchema so the LLM gets exact parameter names (path, content, etc.).""" + if not inputSchema: + return GenericToolArgs + properties = inputSchema.get("properties") or {} + required = set(inputSchema.get("required") or []) + if not properties: + return GenericToolArgs + fields: Dict[str, Any] = {} + for name, spec in properties.items(): + if not isinstance(spec, dict): + continue + desc = spec.get("description", "") + jsonType = spec.get("type", "string") + pyType = _jsonSchemaTypeToPython(jsonType) + if name in required: + fields[name] = (pyType, Field(..., description=desc)) + else: + fields[name] = (Optional[pyType], Field(None, description=desc)) + if not fields: + return GenericToolArgs + return create_model(f"McpArgs_{toolName}", **fields) class OllamaClient: @@ -48,7 +148,7 @@ class OllamaClient: response = requests.post( f"{self.baseUrl}/api/chat", json=payload, - timeout=60*60 + timeout=60*60*60 ) response.raise_for_status() data = response.json() @@ -176,249 +276,317 @@ class McpServerWrapper: return [] -class OllamaMcpClient: - """Simple MCP client that uses Ollama for inference.""" +def _serializeToolResult(result: Any) -> Any: + """Serialize tool result to JSON-serializable format.""" + if hasattr(result, "text"): + return result.text + if hasattr(result, "content"): + content = result.content + if hasattr(content, "text"): + return content.text + return content + if isinstance(result, list): + return [_serializeToolResult(item) for item in result] + if isinstance(result, dict): + return {k: _serializeToolResult(v) for k, v in result.items()} + return result - def __init__(self, ollamaClient: OllamaClient, mcpServer: Optional[McpServerWrapper] = None): - self.ollamaClient = ollamaClient - self.mcpServer = mcpServer - self.tools: List[Dict[str, Any]] = [] - self.resources: List[Dict[str, Any]] = [] - def _serializeToolResult(self, result: Any) -> Any: - """Serialize tool result to JSON-serializable format.""" - # Handle TextContent and other content objects - if hasattr(result, 'text'): - return result.text - if hasattr(result, 'content'): - content = result.content - if hasattr(content, 'text'): - return content.text - return content - # Handle lists of content objects - if isinstance(result, list): - return [self._serializeToolResult(item) for item in result] - # Handle dicts - if isinstance(result, dict): - return {k: self._serializeToolResult(v) for k, v in result.items()} - # Already serializable (str, int, float, bool, None) - return result - - async def _loadServerTools(self): - """Load tools from connected MCP server.""" - if self.mcpServer: - serverTools = await self.mcpServer.listServerTools() - for tool in serverTools: - # Handle both Pydantic Tool objects and dicts - if hasattr(tool, "name"): - # Pydantic Tool object - access attributes directly - name = getattr(tool, "name", "") - description = getattr(tool, "description", "") - # Try both camelCase and snake_case for inputSchema - inputSchema = getattr(tool, "inputSchema", getattr(tool, "input_schema", {})) - else: - # Dict - use .get() - name = tool.get("name", "") - description = tool.get("description", "") - inputSchema = tool.get("inputSchema", tool.get("input_schema", {})) - - self.tools.append({ - "name": name, - "description": description, - "inputSchema": inputSchema - }) - - def registerTool(self, name: str, description: str, parameters: Dict[str, Any]): - """Register a tool that can be used by the model.""" - self.tools.append({ - "name": name, - "description": description, - "inputSchema": { - "type": "object", - "properties": parameters, - "required": list(parameters.keys()) - } - }) - - async def processRequest(self, prompt: str, context: Optional[List[str]] = None, maxIterations: int = 5) -> str: - """Process a request using Ollama with optional context and tool support.""" - messages = [ - { - "role": "system", - "content": """Sei un Crypto Solver Agent specializzato in sfide CTF (Capture The Flag). Il tuo obiettivo primario è identificare, analizzare e risolvere sfide crittografiche memorizzate nella directory /tmp per recuperare la flag. REGOLE OPERATIVE: Esplorazione: Inizia sempre elencando i file presenti in /tmp. Identifica i file rilevanti come sorgenti Python (.py), output di testo (.txt), file cifrati o chiavi pubbliche/private (.pem, .pub). Analisi: Leggi i file trovati. Determina il tipo di crittografia coinvolta. Casi comuni: RSA: analizza parametri come n, e, c. Verifica se n è piccolo (fattorizzabile), se e è basso (attacco radice e-esima) o se ci sono vulnerabilità note (Wiener, Hastad, moduli comuni). Simmetrica (AES/DES): cerca la modalità (ECB, CBC), vulnerabilità nel IV, o riutilizzo della chiave. XOR/Cifrari Classici: esegui analisi delle frequenze o attacchi a chiave fissa. Encoding: gestisci correttamente Base64, Hex, Big-Endian/Little-Endian. Esecuzione: Scrivi ed esegui script Python per risolvere la sfida. Utilizza librerie come pycryptodome, gmpy2 o sympy se disponibili nell'ambiente. Non limitarti a spiegare la teoria: scrivi il codice necessario a produrre il plaintext. Validazione: Una volta decifrato il contenuto, cerca stringhe nel formato flag{...}. Se il risultato non è leggibile, rivaluta l'approccio e prova una strategia alternativa. REQUISITI DI OUTPUT: Fornisci una breve spiegazione della vulnerabilità trovata. Mostra il codice Python risolutivo che hai generato. Restituisci la flag finale in modo chiaramente visibile. LIMITI: Opera esclusivamente all'interno della directory /tmp. Non tentare di forzare la password di sistema; concentrati sulla logica crittografica. Se mancano dati (es. un file citato nel codice non è presente), chiedi esplicitamente o cercalo nelle sottocartelle di /tmp. Inizia ora analizzando il contenuto di /tmp.""" - } - ] - - if context: - messages.append({ - "role": "system", - "content": f"Context:\n{'\n\n'.join(context)}" - }) - - if self.tools: - toolDescriptions = json.dumps(self.tools, indent=2) - messages.append({ - "role": "system", - "content": f"Available tools:\n{toolDescriptions}\n\nTo use a tool, respond with JSON: {{\"tool_name\": \"name\", \"tool_args\": {{...}}}}" - }) - - messages.append({ - "role": "user", - "content": prompt - }) - - iteration = 0 - while iteration < maxIterations: - response = self.ollamaClient.chat(messages) - - # Check if response contains tool call - toolCall = self._parseToolCall(response) - if toolCall: - toolName = toolCall.get("tool_name") - toolArgs = toolCall.get("tool_args", {}) - - # Print agent intent (response before tool call) - print(f"\n[Agent Intent]: {response}", file=sys.stderr) - print(f"[Tool Call Detected]: {toolName} with arguments: {toolArgs}", file=sys.stderr) - - # Try to call the tool - try: - print(f"[Executing Tool]: {toolName} with arguments: {toolArgs}", file=sys.stderr) - toolResult = await self._executeTool(toolName, toolArgs) - # Serialize tool result to JSON-serializable format - serializedResult = self._serializeToolResult(toolResult) - print(f"[Tool Output]: {json.dumps(serializedResult, indent=2)}", file=sys.stderr) - messages.append({ - "role": "assistant", - "content": response - }) - messages.append({ - "role": "user", - "content": f"Tool result: {json.dumps(serializedResult)}" - }) - iteration += 1 - continue - except Exception as e: - print(f"[Tool Error]: {str(e)}", file=sys.stderr) - messages.append({ - "role": "assistant", - "content": response - }) - messages.append({ - "role": "user", - "content": f"Tool error: {str(e)}" - }) - iteration += 1 - continue - - # No tool call, return response - print(f"\n[Agent Response (Final)]: {response}", file=sys.stderr) - return response - - return messages[-1].get("content", "Max iterations reached") - - def _parseToolCall(self, response: str) -> Optional[Dict[str, Any]]: - """Try to parse tool call from response.""" - # Try to find JSON object in response +def _makeMcpToolCoroutine( + toolName: str, + server: McpServerWrapper, + defaultArgs: Dict[str, Any], + toolTimeout: Optional[float] = None, +) -> Callable[..., Awaitable[Any]]: + async def _invoke(**kwargs: Any) -> Any: + merged = {**defaultArgs, **kwargs} + # Strip None values - MCP server Zod schemas often reject null for optional params (expect number | undefined, not number | null) + merged = {k: v for k, v in merged.items() if v is not None} try: - # Look for JSON in response - startIdx = response.find("{") - endIdx = response.rfind("}") + 1 - if startIdx >= 0 and endIdx > startIdx: - jsonStr = response[startIdx:endIdx] - parsed = json.loads(jsonStr) - if "tool_name" in parsed: - return parsed - except: - pass - return None + if toolTimeout is not None and toolTimeout > 0: + result = await asyncio.wait_for( + server.callServerTool(toolName, merged), + timeout=toolTimeout, + ) + else: + result = await server.callServerTool(toolName, merged) + except asyncio.TimeoutError: + return ( + f"[Tool timeout] '{toolName}' exceeded {toolTimeout}s. " + "The operation may have hung (e.g. command not found, subprocess blocking). " + "Try an alternative (e.g. 'python' instead of 'python3') or increase --tool-timeout." + ) + return _serializeToolResult(result) + return _invoke - async def _executeTool(self, toolName: str, toolArgs: Dict[str, Any]) -> Any: - """Execute a tool - either from server or local.""" - # First check if it's a server tool - if self.mcpServer: - # Check if tool exists in server tools - for tool in self.mcpServer.serverTools: - # Handle both Pydantic Tool objects and dicts - tool_name = getattr(tool, "name", None) if hasattr(tool, "name") else tool.get("name") if isinstance(tool, dict) else None - if tool_name == toolName: - return await self.mcpServer.callServerTool(toolName, toolArgs) - # Check local tools - if toolName == "get_time": - from datetime import datetime - return datetime.now().isoformat() - elif toolName == "count_words": - text = toolArgs.get("text", "") - return len(text.split()) +async def buildMcpLangChainTools( + mcpServers: List[McpServerWrapper], + toolTimeout: Optional[float] = None, +) -> List[StructuredTool]: + """Build LangChain StructuredTools from connected MCP servers (runtime tool registration).""" + tools: List[StructuredTool] = [] + for server in mcpServers: + rawTools = await server.listServerTools() + for raw in rawTools: + name = getattr(raw, "name", None) or (raw.get("name") if isinstance(raw, dict) else None) + description = getattr(raw, "description", None) or (raw.get("description", "") if isinstance(raw, dict) else "") + inputSchema = getattr(raw, "inputSchema", None) or getattr(raw, "input_schema", None) or (raw.get("inputSchema") or raw.get("input_schema") if isinstance(raw, dict) else None) + if not name: + continue + description = description or f"MCP tool: {name}" + schemaDict = inputSchema or {} + argsSchema = buildArgsSchemaFromMcpInputSchema(name, schemaDict) + defaultArgs = _defaultsFromInputSchema(schemaDict) + tool = StructuredTool.from_function( + name=name, + description=description, + args_schema=argsSchema, + coroutine=_makeMcpToolCoroutine(name, server, defaultArgs, toolTimeout), + ) + tools.append(tool) + return tools - raise ValueError(f"Tool '{toolName}' not found") - def listTools(self) -> List[Dict[str, Any]]: - """List all registered tools.""" - return self.tools +class LogToolCallsMiddleware(AgentMiddleware): + """Middleware that logs every tool call (name and args).""" - def listResources(self) -> List[Dict[str, Any]]: - """List all available resources.""" - return self.resources + def wrap_tool_call(self, request: ToolCallRequest, handler: Callable): + _logToolCallRequest(request) + return handler(request) + + async def awrap_tool_call(self, request: ToolCallRequest, handler: Callable): + _logToolCallRequest(request) + return await handler(request) + + +def _extractTextFromAIMessageContent(content: Any) -> str: + """Extract plain text from AIMessage.content (str or list of content blocks).""" + if content is None: + return "" + if isinstance(content, str): + return content.strip() + if isinstance(content, list): + parts: List[str] = [] + for block in content: + if isinstance(block, dict) and "text" in block: + parts.append(str(block["text"])) + elif isinstance(block, str): + parts.append(block) + return "\n".join(parts).strip() if parts else "" + return str(content).strip() + + +def _extractFinalResponse(result: Dict[str, Any]) -> str: + """Extract the final assistant text from agent result; handle recursion limit / no final message.""" + messages = result.get("messages") or [] + for msg in reversed(messages): + if isinstance(msg, AIMessage) and hasattr(msg, "content"): + text = _extractTextFromAIMessageContent(msg.content) + if text: + return text + return ( + "Agent stopped without a final text response (e.g. hit step limit after tool calls). " + "Try again or increase --recursion-limit." + ) + + +def _logToolCallRequest(request: ToolCallRequest) -> None: + tc = request.tool_call + name = tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", None) + args = tc.get("args", tc.get("arguments", {})) if isinstance(tc, dict) else getattr(tc, "args", getattr(tc, "arguments", {})) + argsStr = json.dumps(args, ensure_ascii=False) + if len(argsStr) > 500: + argsStr = argsStr[:497] + "..." + print(f"[Tool Call] {name} args={argsStr}", file=sys.stderr) + + +class McpToolsMiddleware(AgentMiddleware): + """Middleware that adds MCP tools at runtime and handles their execution (runtime tool registration).""" + + def __init__(self, mcpTools: List[StructuredTool], staticToolNames: Optional[List[str]] = None): + self.mcpTools = mcpTools + self.mcpToolsByName = {t.name: t for t in mcpTools} + staticNames = set(staticToolNames or []) + self.validToolNames = staticNames | set(self.mcpToolsByName.keys()) + + def wrap_model_call(self, request: ModelRequest, handler: Callable) -> ModelResponse: + updated = request.override(tools=[*request.tools, *self.mcpTools]) + return handler(updated) + + async def awrap_model_call(self, request: ModelRequest, handler: Callable): + updated = request.override(tools=[*request.tools, *self.mcpTools]) + return await handler(updated) + + def _toolExists(self, name: Optional[str]) -> bool: + return bool(name and name in self.validToolNames) + + def _unknownToolErrorToolMessage(self, request: ToolCallRequest, name: str) -> ToolMessage: + available = ", ".join(sorted(self.validToolNames)) + content = ( + f"[Error] Tool '{name}' does not exist. " + f"Only the following tools are available: {available}. " + "Do not call tools that are not in this list." + ) + tc = request.tool_call + toolCallId = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) + return ToolMessage( + content=content, + tool_call_id=toolCallId or "unknown", + name=name or "unknown", + status="error", + ) + + def wrap_tool_call(self, request: ToolCallRequest, handler: Callable): + name = request.tool_call.get("name") if isinstance(request.tool_call, dict) else getattr(request.tool_call, "name", None) + if not self._toolExists(name): + return self._unknownToolErrorToolMessage(request, name or "") + if name and name in self.mcpToolsByName: + return handler(request.override(tool=self.mcpToolsByName[name])) + return handler(request) + + async def awrap_tool_call(self, request: ToolCallRequest, handler: Callable): + name = request.tool_call.get("name") if isinstance(request.tool_call, dict) else getattr(request.tool_call, "name", None) + if not self._toolExists(name): + return self._unknownToolErrorToolMessage(request, name or "") + if name and name in self.mcpToolsByName: + return await handler(request.override(tool=self.mcpToolsByName[name])) + return await handler(request) + +''' TODO Use this if you want sequential thinking +SYSTEM_PROMPT = """ +ROLE: +Sei un esperto Analista di Cybersecurity specializzato in CTF (Capture The Flag) e analisi di vulnerabilità. Operi in un ambiente Linux sandbox dove la tua unica area di lavoro è la directory /tmp. + +WORKSPACE CONSTRAINT: IL "SINGLE SOURCE OF TRUTH" +- Obbligo Assoluto: Tutte le operazioni di lettura, scrittura, download e analisi devono avvenire esclusivamente all'interno di /tmp. +- Percorsi: Ogni file deve essere referenziato con il percorso assoluto (es. /tmp/binary.bin). Non usare mai directory come ~/, /home o altre al di fuori di /tmp. +- Condivisione: Ricorda che /tmp è montata su tutti i container MCP (fetch, filesystem, ecc.). Se scarichi un file con fetch in /tmp, il tool filesystem lo troverà immediatamente lì. + +TOOLSET & WORKFLOW: +Utilizza i tuoi tool secondo questa logica: +1. sequentialthinking (Pianificazione): Usa questo tool PRIMA di ogni azione complessa. Suddividi la sfida in step logici (es. 1. Download, 2. Analisi Header, 3. Estrazione Flag). Ti aiuta a non perdere il filo durante task lunghi. +2. fetch (Ingestion): Usalo per recuperare binari, exploit o dati remoti. Salva l'output sempre in /tmp. +3. filesystem (Manipolazione): Usalo per ispezionare i file scaricati, creare script di exploit o leggere file di log e flag direttamente in /tmp. +4. memory (Stato): Utilizza questo tool per memorizzare scoperte chiave, indirizzi di memoria, offset o password trovate durante la sfida. Ti serve per mantenere il contesto tra diverse fasi del ragionamento. + +METODOLOGIA DI ANALISI: +- Ipotesi e Test: Prima di agire, formula un'ipotesi basata sui dati presenti in /tmp. +- Verifica Intermedia: Dopo ogni comando o modifica ai file, verifica il risultato usando il tool filesystem. Non dare mai per scontato che un'operazione sia riuscita senza controllare. +- Pulizia Mentale: Se una strategia fallisce, usa sequentialthinking per rivedere il piano e aggiorna il tool memory con il motivo del fallimento per non ripetere lo stesso errore. + +REGOLE DI COMUNICAZIONE: +- Sii estremamente tecnico, sintetico e preciso. +- Se un file non è presente in /tmp, non provare a indovinarne il contenuto; usa fetch per ottenerlo o filesystem per cercarlo. +- Rispondi con l'output delle tue analisi e l'eventuale flag trovata nel formato richiesto dalla sfida. +""" +''' + +SYSTEM_PROMPT = "ROLE:\nSei un esperto Analista di Cybersecurity specializzato in CTF (Capture The Flag) e analisi di vulnerabilità. Operi in un ambiente Linux sandbox dove la tua unica area di lavoro è la directory /tmp.\n\nWORKSPACE CONSTRAINT: IL \"SINGLE SOURCE OF TRUTH\"\n- Obbligo Assoluto: Tutte le operazioni di lettura, scrittura e analisi devono avvenire esclusivamente all'interno di /tmp.\n- Percorsi: Ogni file deve essere referenziato con il percorso assoluto (es. /tmp/binary.bin). Non usare mai directory esterne a /tmp.\n- Condivisione: /tmp è montata su tutti i container MCP. I file creati o modificati da un tool sono immediatamente visibili agli altri.\n\nSTRETTO DIVIETO DI ALLUCINAZIONE TOOL:\n- USA ESCLUSIVAMENTE I TOOL MCP FORNITI: 'memory', 'filesystem'.\n- NON INVENTARE MAI TOOL INESISTENTI: È severamente vietato tentare di richiamare tool come \"run\", \"fetch\", \"execute_command\", \"shell\" o simili.\n- Se un tool non è in questa lista ('memory', 'filesystem'), NON esiste e non puoi usarlo.\n- Se senti la necessità di scaricare dati o eseguire comandi, ricorda che non hai tool per farlo; puoi solo operare sui file già presenti in /tmp tramite 'filesystem' o ragionare sugli stati tramite 'memory'.\n\nTOOLSET & WORKFLOW:\n1. memory (Pianificazione e Stato): È il tuo unico strumento di ragionamento e log. Usalo per definire il piano d'azione, suddividere la sfida in step e memorizzare scoperte (offset, password, indirizzi). Aggiorna la memoria prima di ogni azione.\n2. filesystem (Manipolazione): È il tuo unico strumento operativo. Usalo per ispezionare file esistenti, leggere contenuti, creare script o archiviare risultati esclusivamente in /tmp.\n\nMETODOLOGIA DI ANALISI:\n- Ragionamento Persistente: Documenta ogni ipotesi, passo logico e test nel tool memory.\n- Verifica Intermedia: Dopo ogni operazione sul filesystem, usa 'filesystem' per confermare che l'azione abbia prodotto il risultato atteso.\n- Gestione Errori: Se non trovi i file necessari in /tmp, segnalalo chiaramente senza provare a inventare tool per scaricarli o generarli.\n\nREGOLE DI COMUNICAZIONE:\n- Sii estremamente tecnico, sintetico e preciso.\n- Non fare mai riferimento a tool che non siano 'memory' o 'filesystem'." + +class OllamaMcpClient: + """MCP client that uses Ollama and LangChain create_agent with optional runtime MCP tools.""" + + def __init__( + self, + ollamaClient: OllamaClient, + mcpTools: Optional[List[StructuredTool]] = None, + systemPrompt: Optional[str] = None, + ): + self.ollamaClient = ollamaClient + self.mcpTools = mcpTools or [] + self.systemPrompt = systemPrompt or SYSTEM_PROMPT + staticTools: List[Any] = [getTime, countWords] + staticToolNames = [getTime.name, countWords.name] + middleware: List[AgentMiddleware] = [LogToolCallsMiddleware()] + if self.mcpTools: + middleware.append(McpToolsMiddleware(self.mcpTools, staticToolNames=staticToolNames)) + model = ChatOllama( + base_url=ollamaClient.baseUrl, + model=ollamaClient.model, + temperature=0.1, + ) + self.agent = create_agent( + model, + tools=staticTools, + middleware=middleware, + system_prompt=self.systemPrompt, + ) + + async def processRequest(self, prompt: str, context: Optional[List[str]] = None, recursionLimit: int = 50) -> str: + """Process a request using the LangChain agent (ReAct loop with tools).""" + messages: List[Any] = [HumanMessage(content=prompt)] + if context: + messages.insert(0, SystemMessage(content=f"Context:\n{chr(10).join(context)}")) + config: Dict[str, Any] = {"recursion_limit": recursionLimit} + toolParseRetryPrompt = ( + "ATTENZIONE: Una chiamata write_file ha prodotto JSON non valido. " + "Quando scrivi file con codice Python: usa \\n per le newline nel JSON, escapa le virgolette con \\. " + "Non aggiungere parametri extra (es. overwrite). Usa edit_file per modifiche incrementali se il contenuto è lungo." + ) + try: + result = await self.agent.ainvoke({"messages": messages}, config=config) + except OllamaResponseError as e: + errStr = str(e) + if "error parsing tool call" in errStr: + print(f"[Agent Error]: Tool call parse error, retrying with guidance: {errStr[:200]}...", file=sys.stderr) + retryMessages: List[Any] = [SystemMessage(content=toolParseRetryPrompt)] + retryMessages.extend(messages) + result = await self.agent.ainvoke({"messages": retryMessages}, config=config) + else: + print(f"[Agent Error]: {e}", file=sys.stderr) + raise + except Exception as e: + print(f"[Agent Error]: {e}", file=sys.stderr) + raise + return _extractFinalResponse(result) + + def listTools(self) -> List[str]: + """List tool names (static + MCP).""" + names = [getTime.name, countWords.name] + names.extend(t.name for t in self.mcpTools) + return names async def async_main(args, ollamaClient: OllamaClient): - """Async main function.""" - # Connect to MCP server if specified - mcpServerWrapper = None + """Async main: MCP tools come only from mcp.json (Docker containers exposing SSE). Ollama is used only as LLM.""" + mcpTools: List[StructuredTool] = [] + mcpServers: List[McpServerWrapper] = [] + + # MCP servers from config file (mcp.json) – Docker containers with SSE endpoints + serverUrls: Dict[str, str] = loadMcpConfig(args.mcp_config) if args.mcp_server: - headers = {} - if args.mcp_headers: - try: - headers = json.loads(args.mcp_headers) - except json.JSONDecodeError: - print("Warning: Invalid JSON in --mcp-headers, ignoring", file=sys.stderr) + serverUrls["default"] = args.mcp_server.rstrip("/") - mcpServerWrapper = McpServerWrapper(httpUrl=args.mcp_server, headers=headers) - if not await mcpServerWrapper.connect(): - print("Error: Failed to connect to MCP server", file=sys.stderr) - sys.exit(1) - print("Connected to MCP server via streamable HTTP", file=sys.stderr) + # Which servers to use: default = all from mcp.json; or --mcp-tools fetch,filesystem to pick a subset + wantServers = [s.strip() for s in (args.mcp_tools or "").split(",") if s.strip()] + if not wantServers and serverUrls: + wantServers = list(serverUrls.keys()) + print(f"MCP tools from config (all SSE servers): {wantServers}", file=sys.stderr) + for name in wantServers: + url = serverUrls.get(name) + if not url: + print(f"Warning: MCP server '{name}' not in config (known: {list(serverUrls.keys())})", file=sys.stderr) + continue + wrapper = McpServerWrapper(httpUrl=url) + if await wrapper.connect(): + mcpServers.append(wrapper) + print(f"Connected to MCP server '{name}' at {url}", file=sys.stderr) + else: + print(f"Error: Failed to connect to MCP server '{name}' at {url}", file=sys.stderr) - # Initialize MCP client - mcpClient = OllamaMcpClient(ollamaClient, mcpServerWrapper) + if mcpServers: + mcpTools = await buildMcpLangChainTools(mcpServers, toolTimeout=getattr(args, "tool_timeout", None)) + #print(f"Loaded {len(mcpTools)} MCP tools: {[t.name for t in mcpTools]}", file=sys.stderr) - # Load server tools - if mcpServerWrapper: - await mcpClient._loadServerTools() - serverTools = await mcpServerWrapper.listServerTools() - if serverTools: - # Handle both Pydantic Tool objects and dicts - tool_names = [ - getattr(t, "name", "") if hasattr(t, "name") else t.get("name", "") if isinstance(t, dict) else "" - for t in serverTools - ] - print(f"Available MCP server tools: {tool_names}", file=sys.stderr) + mcpClient = OllamaMcpClient(ollamaClient, mcpTools=mcpTools) + print(f"Agent tools: {mcpClient.listTools()}", file=sys.stderr) - # Register some example tools - mcpClient.registerTool( - name="get_time", - description="Get the current time", - parameters={} - ) - mcpClient.registerTool( - name="count_words", - description="Count words in a text", - parameters={ - "text": { - "type": "string", - "description": "The text to count words in" - } - } - ) - - # Process prompt or run interactively if args.prompt: - response = await mcpClient.processRequest(args.prompt) + response = await mcpClient.processRequest(args.prompt, recursionLimit=args.recursion_limit) print(response) elif args.interactive: - print("MCP Client with Ollama - Interactive Mode") + print("MCP Client with Ollama (LangChain agent) - Interactive Mode") print("Type 'quit' or 'exit' to exit\n") while True: try: @@ -427,7 +595,7 @@ async def async_main(args, ollamaClient: OllamaClient): break if not prompt: continue - response = await mcpClient.processRequest(prompt) + response = await mcpClient.processRequest(prompt, recursionLimit=args.recursion_limit) print(f"Assistant: {response}\n") except KeyboardInterrupt: print("\nGoodbye!") @@ -435,9 +603,8 @@ async def async_main(args, ollamaClient: OllamaClient): except Exception as e: print(f"Error: {e}", file=sys.stderr) - # Cleanup - if mcpServerWrapper: - await mcpServerWrapper.disconnect() + for wrapper in mcpServers: + await wrapper.disconnect() def main(): @@ -452,7 +619,7 @@ def main(): ) parser.add_argument( "--model", - default="ministral-3", + default="gpt-oss:20b", help="Ollama model to use (default: ministral-3)" ) parser.add_argument( @@ -470,14 +637,35 @@ def main(): action="store_true", help="Run in interactive mode" ) + parser.add_argument( + "--mcp-config", + default=None, + help="Path to mcp.json (default: mcpServer/mcp.json relative to project)" + ) + parser.add_argument( + "--mcp-tools", + default="", + help="Comma-separated MCP server names from mcp.json (default: all servers in config). E.g. fetch,filesystem" + ) parser.add_argument( "--mcp-server", - help="HTTP URL for MCP server (e.g., 'http://localhost:8000/mcp')", - default="http://localhost:8000/mcp" + help="Override: single MCP SSE URL (e.g. http://localhost:3000/sse). Added as server 'default' in addition to mcp.json." ) parser.add_argument( "--mcp-headers", - help="Additional headers for MCP server as JSON string (e.g., '{\"Authorization\": \"Bearer token\"}')" + help="Additional headers for MCP server as JSON string (e.g. '{\"Authorization\": \"Bearer token\"}')" + ) + parser.add_argument( + "--recursion-limit", + type=int, + default=5000, + help="Max agent steps (model + tool calls) before stopping (default: 50)" + ) + parser.add_argument( + "--tool-timeout", + type=float, + default=60, + help="Timeout in seconds for each MCP tool call. Prevents agent from freezing when a tool hangs (e.g. run with missing executable). Default: 60" ) args = parser.parse_args() diff --git a/mcpClient/requirements.txt b/mcpClient/requirements.txt index 34061a6..674527a 100644 --- a/mcpClient/requirements.txt +++ b/mcpClient/requirements.txt @@ -1,2 +1,8 @@ requests>=2.31.0 -fastmcp>=0.9.0 \ No newline at end of file +fastmcp>=0.9.0 +langchain>=0.3.0 +langchain-core>=0.3.0 +langgraph>=0.2.0 +langchain-community>=0.3.0 +langchain-ollama>=0.2.0 +pydantic>=2.0.0 \ No newline at end of file diff --git a/mcpServer/requirements.txt b/mcpServer/requirements.txt deleted file mode 100644 index 344ab7b..0000000 --- a/mcpServer/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -mcp[cli]>=1.25.0