698 lines
30 KiB
Python
698 lines
30 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
MCP client that uses Ollama for inference and LangChain create_agent with
|
||
runtime-registered MCP tools (see https://docs.langchain.com/oss/python/langchain/agents#runtime-tool-registration).
|
||
"""
|
||
|
||
import json
|
||
import sys
|
||
import os
|
||
import asyncio
|
||
from pathlib import Path
|
||
from typing import Optional, Dict, Any, List, Callable, Awaitable
|
||
|
||
import requests
|
||
from fastmcp import Client as FastMcpClient
|
||
from ollama import ResponseError as OllamaResponseError
|
||
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||
|
||
# LangChain agent and middleware
|
||
try:
|
||
from langchain.agents import create_agent
|
||
from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse, ToolCallRequest
|
||
from langchain_core.tools import StructuredTool, tool
|
||
from langchain_ollama import ChatOllama
|
||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
|
||
except ImportError as e:
|
||
print(f"Missing dependency: {e}. Install with: pip install langchain langgraph langchain-community langchain-ollama", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
|
||
@tool
|
||
def getTime() -> str:
|
||
"""Get the current time in ISO format."""
|
||
from datetime import datetime
|
||
return datetime.now().isoformat()
|
||
|
||
|
||
@tool
|
||
def countWords(text: str) -> int:
|
||
"""Count the number of words in a text."""
|
||
return len(text.split())
|
||
|
||
|
||
def loadMcpConfig(configPath: Optional[str] = None) -> Dict[str, str]:
|
||
"""Load MCP server URLs from mcp.json. Returns dict serverName -> url."""
|
||
if configPath is None:
|
||
# Default: mcpServer/mcp.json relative to project root or cwd
|
||
base = Path(__file__).resolve().parent.parent
|
||
configPath = str(base / "mcpServer" / "mcp.json")
|
||
path = Path(configPath)
|
||
if not path.exists():
|
||
return {}
|
||
try:
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
except (json.JSONDecodeError, OSError) as e:
|
||
print(f"Warning: Could not load MCP config from {path}: {e}", file=sys.stderr)
|
||
return {}
|
||
servers = data.get("mcpServers") or data.get("mcp_servers") or {}
|
||
return {name: info.get("url", "") for name, info in servers.items() if isinstance(info, dict) and info.get("url")}
|
||
|
||
|
||
class GenericToolArgs(BaseModel):
|
||
"""Accept any keyword arguments for MCP tool calls (fallback when schema is missing)."""
|
||
model_config = ConfigDict(extra="allow")
|
||
|
||
|
||
def _jsonSchemaTypeToPython(jsonType: str) -> type:
|
||
"""Map JSON schema type to Python type."""
|
||
return {"string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict}.get(jsonType, str)
|
||
|
||
|
||
def _defaultForJsonType(jsonType: str) -> Any:
|
||
"""Sensible default for optional MCP params so server does not receive null."""
|
||
return {"string": "", "integer": 0, "number": 0.0, "boolean": False, "array": [], "object": {}}.get(jsonType, "")
|
||
|
||
|
||
def _defaultsFromInputSchema(inputSchema: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Build default values for all params so we never send null to the MCP server (LLM may omit required params)."""
|
||
if not inputSchema:
|
||
return {}
|
||
properties = inputSchema.get("properties") or {}
|
||
out: Dict[str, Any] = {}
|
||
for name, spec in properties.items():
|
||
if not isinstance(spec, dict):
|
||
continue
|
||
if "default" in spec:
|
||
out[name] = spec["default"]
|
||
else:
|
||
out[name] = _defaultForJsonType(spec.get("type", "string"))
|
||
return out
|
||
|
||
|
||
def buildArgsSchemaFromMcpInputSchema(toolName: str, inputSchema: Dict[str, Any]) -> type[BaseModel]:
|
||
"""Build a Pydantic model from MCP tool inputSchema so the LLM gets exact parameter names (path, content, etc.)."""
|
||
if not inputSchema:
|
||
return GenericToolArgs
|
||
properties = inputSchema.get("properties") or {}
|
||
required = set(inputSchema.get("required") or [])
|
||
if not properties:
|
||
return GenericToolArgs
|
||
fields: Dict[str, Any] = {}
|
||
for name, spec in properties.items():
|
||
if not isinstance(spec, dict):
|
||
continue
|
||
desc = spec.get("description", "")
|
||
jsonType = spec.get("type", "string")
|
||
pyType = _jsonSchemaTypeToPython(jsonType)
|
||
if name in required:
|
||
fields[name] = (pyType, Field(..., description=desc))
|
||
else:
|
||
fields[name] = (Optional[pyType], Field(None, description=desc))
|
||
if not fields:
|
||
return GenericToolArgs
|
||
return create_model(f"McpArgs_{toolName}", **fields)
|
||
|
||
|
||
class OllamaClient:
|
||
"""Client for interacting with Ollama API."""
|
||
|
||
def __init__(self, baseUrl: str = "http://localhost:11434", model: str = "gpt-oss:20b"):
|
||
self.baseUrl = baseUrl
|
||
self.model = model
|
||
|
||
def listModels(self) -> List[str]:
|
||
"""List available Ollama models."""
|
||
try:
|
||
response = requests.get(f"{self.baseUrl}/api/tags", timeout=10)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
return [model["name"] for model in data.get("models", [])]
|
||
except requests.RequestException as e:
|
||
print(f"Error listing models: {e}", file=sys.stderr)
|
||
return []
|
||
|
||
def chat(self, messages: List[Dict[str, str]], options: Optional[Dict[str, Any]] = None) -> str:
|
||
"""Send chat messages to Ollama and get response."""
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": messages,
|
||
"stream": False,
|
||
}
|
||
|
||
if options:
|
||
payload["options"] = options
|
||
|
||
try:
|
||
response = requests.post(
|
||
f"{self.baseUrl}/api/chat",
|
||
json=payload,
|
||
timeout=60*60*60
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
return data.get("message", {}).get("content", "")
|
||
except requests.RequestException as e:
|
||
print(f"Error in chat request: {e}", file=sys.stderr)
|
||
raise
|
||
|
||
def generate(self, prompt: str, options: Optional[Dict[str, Any]] = None) -> str:
|
||
"""Generate text from a prompt using Ollama."""
|
||
payload = {
|
||
"model": self.model,
|
||
"prompt": prompt,
|
||
"stream": False,
|
||
}
|
||
|
||
if options:
|
||
payload["options"] = options
|
||
|
||
try:
|
||
response = requests.post(
|
||
f"{self.baseUrl}/api/generate",
|
||
json=payload,
|
||
timeout=120
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
return data.get("response", "")
|
||
except requests.RequestException as e:
|
||
print(f"Error in generate request: {e}", file=sys.stderr)
|
||
raise
|
||
|
||
def checkHealth(self) -> bool:
|
||
"""Check if Ollama server is accessible."""
|
||
try:
|
||
response = requests.get(f"{self.baseUrl}/api/tags", timeout=5)
|
||
return response.status_code == 200
|
||
except requests.RequestException:
|
||
return False
|
||
|
||
|
||
class McpServerWrapper:
|
||
"""Wrapper around FastMCP Client for easier use."""
|
||
|
||
def __init__(self, httpUrl: str, headers: Optional[Dict[str, str]] = None):
|
||
self.httpUrl = httpUrl.rstrip("/")
|
||
self.headers = headers or {}
|
||
self.client: Optional[FastMcpClient] = None
|
||
self.serverTools: List[Dict[str, Any]] = []
|
||
|
||
async def connect(self) -> bool:
|
||
"""Connect and initialize with MCP server via HTTP."""
|
||
try:
|
||
# FastMcpClient doesn't support headers parameter directly
|
||
# Headers would need to be passed via custom transport or auth
|
||
# For now, we initialize without headers
|
||
self.client = FastMcpClient(self.httpUrl)
|
||
await self.client.__aenter__()
|
||
# Load tools after connection
|
||
tools = await self.listServerTools()
|
||
return True
|
||
except Exception as e:
|
||
print(f"Error connecting to MCP server: {e}", file=sys.stderr)
|
||
return False
|
||
|
||
async def disconnect(self):
|
||
"""Disconnect from MCP server."""
|
||
if self.client:
|
||
await self.client.__aexit__(None, None, None)
|
||
self.client = None
|
||
|
||
async def listServerTools(self) -> List[Dict[str, Any]]:
|
||
"""List tools available from MCP server."""
|
||
if not self.client:
|
||
return []
|
||
|
||
try:
|
||
tools = await self.client.list_tools()
|
||
self.serverTools = tools
|
||
return tools
|
||
except Exception as e:
|
||
print(f"Error listing tools: {e}", file=sys.stderr)
|
||
return []
|
||
|
||
async def callServerTool(self, name: str, arguments: Dict[str, Any]) -> Any:
|
||
"""Call a tool on the MCP server."""
|
||
if not self.client:
|
||
raise RuntimeError("Not connected to MCP server")
|
||
|
||
try:
|
||
result = await self.client.call_tool(name, arguments)
|
||
# FastMCP call_tool returns a result object with .content
|
||
if hasattr(result, 'content'):
|
||
content = result.content
|
||
# If content is a list, return it as is (will be serialized later)
|
||
if isinstance(content, list):
|
||
return content
|
||
return content
|
||
elif isinstance(result, list):
|
||
# Handle list of results
|
||
if len(result) > 0:
|
||
# Extract content from each item if it exists
|
||
contents = []
|
||
for item in result:
|
||
if hasattr(item, 'content'):
|
||
contents.append(item.content)
|
||
else:
|
||
contents.append(item)
|
||
return contents if len(contents) > 1 else contents[0] if contents else None
|
||
return result
|
||
return result
|
||
except Exception as e:
|
||
raise RuntimeError(f"Tool call failed: {str(e)}")
|
||
|
||
async def listServerResources(self) -> List[Dict[str, Any]]:
|
||
"""List resources available from MCP server."""
|
||
if not self.client:
|
||
return []
|
||
|
||
try:
|
||
resources = await self.client.list_resources()
|
||
return resources
|
||
except Exception as e:
|
||
print(f"Error listing resources: {e}", file=sys.stderr)
|
||
return []
|
||
|
||
|
||
def _serializeToolResult(result: Any) -> Any:
|
||
"""Serialize tool result to JSON-serializable format."""
|
||
if hasattr(result, "text"):
|
||
return result.text
|
||
if hasattr(result, "content"):
|
||
content = result.content
|
||
if hasattr(content, "text"):
|
||
return content.text
|
||
return content
|
||
if isinstance(result, list):
|
||
return [_serializeToolResult(item) for item in result]
|
||
if isinstance(result, dict):
|
||
return {k: _serializeToolResult(v) for k, v in result.items()}
|
||
return result
|
||
|
||
|
||
def _makeMcpToolCoroutine(
|
||
toolName: str,
|
||
server: McpServerWrapper,
|
||
defaultArgs: Dict[str, Any],
|
||
toolTimeout: Optional[float] = None,
|
||
) -> Callable[..., Awaitable[Any]]:
|
||
async def _invoke(**kwargs: Any) -> Any:
|
||
merged = {**defaultArgs, **kwargs}
|
||
# Strip None values - MCP server Zod schemas often reject null for optional params (expect number | undefined, not number | null)
|
||
merged = {k: v for k, v in merged.items() if v is not None}
|
||
try:
|
||
if toolTimeout is not None and toolTimeout > 0:
|
||
result = await asyncio.wait_for(
|
||
server.callServerTool(toolName, merged),
|
||
timeout=toolTimeout,
|
||
)
|
||
else:
|
||
result = await server.callServerTool(toolName, merged)
|
||
except asyncio.TimeoutError:
|
||
return (
|
||
f"[Tool timeout] '{toolName}' exceeded {toolTimeout}s. "
|
||
"The operation may have hung (e.g. command not found, subprocess blocking). "
|
||
"Try an alternative (e.g. 'python' instead of 'python3') or increase --tool-timeout."
|
||
)
|
||
return _serializeToolResult(result)
|
||
return _invoke
|
||
|
||
|
||
async def buildMcpLangChainTools(
|
||
mcpServers: List[McpServerWrapper],
|
||
toolTimeout: Optional[float] = None,
|
||
) -> List[StructuredTool]:
|
||
"""Build LangChain StructuredTools from connected MCP servers (runtime tool registration)."""
|
||
tools: List[StructuredTool] = []
|
||
for server in mcpServers:
|
||
rawTools = await server.listServerTools()
|
||
for raw in rawTools:
|
||
name = getattr(raw, "name", None) or (raw.get("name") if isinstance(raw, dict) else None)
|
||
description = getattr(raw, "description", None) or (raw.get("description", "") if isinstance(raw, dict) else "")
|
||
inputSchema = getattr(raw, "inputSchema", None) or getattr(raw, "input_schema", None) or (raw.get("inputSchema") or raw.get("input_schema") if isinstance(raw, dict) else None)
|
||
if not name:
|
||
continue
|
||
description = description or f"MCP tool: {name}"
|
||
schemaDict = inputSchema or {}
|
||
argsSchema = buildArgsSchemaFromMcpInputSchema(name, schemaDict)
|
||
defaultArgs = _defaultsFromInputSchema(schemaDict)
|
||
tool = StructuredTool.from_function(
|
||
name=name,
|
||
description=description,
|
||
args_schema=argsSchema,
|
||
coroutine=_makeMcpToolCoroutine(name, server, defaultArgs, toolTimeout),
|
||
)
|
||
tools.append(tool)
|
||
return tools
|
||
|
||
|
||
class LogToolCallsMiddleware(AgentMiddleware):
|
||
"""Middleware that logs every tool call (name and args)."""
|
||
|
||
def wrap_tool_call(self, request: ToolCallRequest, handler: Callable):
|
||
_logToolCallRequest(request)
|
||
return handler(request)
|
||
|
||
async def awrap_tool_call(self, request: ToolCallRequest, handler: Callable):
|
||
_logToolCallRequest(request)
|
||
return await handler(request)
|
||
|
||
|
||
def _extractTextFromAIMessageContent(content: Any) -> str:
|
||
"""Extract plain text from AIMessage.content (str or list of content blocks)."""
|
||
if content is None:
|
||
return ""
|
||
if isinstance(content, str):
|
||
return content.strip()
|
||
if isinstance(content, list):
|
||
parts: List[str] = []
|
||
for block in content:
|
||
if isinstance(block, dict) and "text" in block:
|
||
parts.append(str(block["text"]))
|
||
elif isinstance(block, str):
|
||
parts.append(block)
|
||
return "\n".join(parts).strip() if parts else ""
|
||
return str(content).strip()
|
||
|
||
|
||
def _extractFinalResponse(result: Dict[str, Any]) -> str:
|
||
"""Extract the final assistant text from agent result; handle recursion limit / no final message."""
|
||
messages = result.get("messages") or []
|
||
for msg in reversed(messages):
|
||
if isinstance(msg, AIMessage) and hasattr(msg, "content"):
|
||
text = _extractTextFromAIMessageContent(msg.content)
|
||
if text:
|
||
return text
|
||
return (
|
||
"Agent stopped without a final text response (e.g. hit step limit after tool calls). "
|
||
"Try again or increase --recursion-limit."
|
||
)
|
||
|
||
|
||
def _logToolCallRequest(request: ToolCallRequest) -> None:
|
||
tc = request.tool_call
|
||
name = tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", None)
|
||
args = tc.get("args", tc.get("arguments", {})) if isinstance(tc, dict) else getattr(tc, "args", getattr(tc, "arguments", {}))
|
||
argsStr = json.dumps(args, ensure_ascii=False)
|
||
if len(argsStr) > 500:
|
||
argsStr = argsStr[:497] + "..."
|
||
print(f"[Tool Call] {name} args={argsStr}", file=sys.stderr)
|
||
|
||
|
||
class McpToolsMiddleware(AgentMiddleware):
|
||
"""Middleware that adds MCP tools at runtime and handles their execution (runtime tool registration)."""
|
||
|
||
def __init__(self, mcpTools: List[StructuredTool], staticToolNames: Optional[List[str]] = None):
|
||
self.mcpTools = mcpTools
|
||
self.mcpToolsByName = {t.name: t for t in mcpTools}
|
||
staticNames = set(staticToolNames or [])
|
||
self.validToolNames = staticNames | set(self.mcpToolsByName.keys())
|
||
|
||
def wrap_model_call(self, request: ModelRequest, handler: Callable) -> ModelResponse:
|
||
updated = request.override(tools=[*request.tools, *self.mcpTools])
|
||
return handler(updated)
|
||
|
||
async def awrap_model_call(self, request: ModelRequest, handler: Callable):
|
||
updated = request.override(tools=[*request.tools, *self.mcpTools])
|
||
return await handler(updated)
|
||
|
||
def _toolExists(self, name: Optional[str]) -> bool:
|
||
return bool(name and name in self.validToolNames)
|
||
|
||
def _unknownToolErrorToolMessage(self, request: ToolCallRequest, name: str) -> ToolMessage:
|
||
available = ", ".join(sorted(self.validToolNames))
|
||
content = (
|
||
f"[Error] Tool '{name}' does not exist. "
|
||
f"Only the following tools are available: {available}. "
|
||
"Do not call tools that are not in this list."
|
||
)
|
||
tc = request.tool_call
|
||
toolCallId = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
|
||
return ToolMessage(
|
||
content=content,
|
||
tool_call_id=toolCallId or "unknown",
|
||
name=name or "unknown",
|
||
status="error",
|
||
)
|
||
|
||
def wrap_tool_call(self, request: ToolCallRequest, handler: Callable):
|
||
name = request.tool_call.get("name") if isinstance(request.tool_call, dict) else getattr(request.tool_call, "name", None)
|
||
if not self._toolExists(name):
|
||
return self._unknownToolErrorToolMessage(request, name or "<unknown>")
|
||
if name and name in self.mcpToolsByName:
|
||
return handler(request.override(tool=self.mcpToolsByName[name]))
|
||
return handler(request)
|
||
|
||
async def awrap_tool_call(self, request: ToolCallRequest, handler: Callable):
|
||
name = request.tool_call.get("name") if isinstance(request.tool_call, dict) else getattr(request.tool_call, "name", None)
|
||
if not self._toolExists(name):
|
||
return self._unknownToolErrorToolMessage(request, name or "<unknown>")
|
||
if name and name in self.mcpToolsByName:
|
||
return await handler(request.override(tool=self.mcpToolsByName[name]))
|
||
return await handler(request)
|
||
|
||
''' TODO Use this if you want sequential thinking
|
||
SYSTEM_PROMPT = """
|
||
ROLE:
|
||
Sei un esperto Analista di Cybersecurity specializzato in CTF (Capture The Flag) e analisi di vulnerabilità. Operi in un ambiente Linux sandbox dove la tua unica area di lavoro è la directory /tmp.
|
||
|
||
WORKSPACE CONSTRAINT: IL "SINGLE SOURCE OF TRUTH"
|
||
- Obbligo Assoluto: Tutte le operazioni di lettura, scrittura, download e analisi devono avvenire esclusivamente all'interno di /tmp.
|
||
- Percorsi: Ogni file deve essere referenziato con il percorso assoluto (es. /tmp/binary.bin). Non usare mai directory come ~/, /home o altre al di fuori di /tmp.
|
||
- Condivisione: Ricorda che /tmp è montata su tutti i container MCP (fetch, filesystem, ecc.). Se scarichi un file con fetch in /tmp, il tool filesystem lo troverà immediatamente lì.
|
||
|
||
TOOLSET & WORKFLOW:
|
||
Utilizza i tuoi tool secondo questa logica:
|
||
1. sequentialthinking (Pianificazione): Usa questo tool PRIMA di ogni azione complessa. Suddividi la sfida in step logici (es. 1. Download, 2. Analisi Header, 3. Estrazione Flag). Ti aiuta a non perdere il filo durante task lunghi.
|
||
2. fetch (Ingestion): Usalo per recuperare binari, exploit o dati remoti. Salva l'output sempre in /tmp.
|
||
3. filesystem (Manipolazione): Usalo per ispezionare i file scaricati, creare script di exploit o leggere file di log e flag direttamente in /tmp.
|
||
4. memory (Stato): Utilizza questo tool per memorizzare scoperte chiave, indirizzi di memoria, offset o password trovate durante la sfida. Ti serve per mantenere il contesto tra diverse fasi del ragionamento.
|
||
|
||
METODOLOGIA DI ANALISI:
|
||
- Ipotesi e Test: Prima di agire, formula un'ipotesi basata sui dati presenti in /tmp.
|
||
- Verifica Intermedia: Dopo ogni comando o modifica ai file, verifica il risultato usando il tool filesystem. Non dare mai per scontato che un'operazione sia riuscita senza controllare.
|
||
- Pulizia Mentale: Se una strategia fallisce, usa sequentialthinking per rivedere il piano e aggiorna il tool memory con il motivo del fallimento per non ripetere lo stesso errore.
|
||
|
||
REGOLE DI COMUNICAZIONE:
|
||
- Sii estremamente tecnico, sintetico e preciso.
|
||
- Se un file non è presente in /tmp, non provare a indovinarne il contenuto; usa fetch per ottenerlo o filesystem per cercarlo.
|
||
- Rispondi con l'output delle tue analisi e l'eventuale flag trovata nel formato richiesto dalla sfida.
|
||
"""
|
||
'''
|
||
|
||
SYSTEM_PROMPT = "ROLE:\nSei un esperto Analista di Cybersecurity specializzato in CTF (Capture The Flag) e analisi di vulnerabilità. Operi in un ambiente Linux sandbox dove la tua unica area di lavoro è la directory /tmp.\n\nWORKSPACE CONSTRAINT: IL \"SINGLE SOURCE OF TRUTH\"\n- Obbligo Assoluto: Tutte le operazioni di lettura, scrittura e analisi devono avvenire esclusivamente all'interno di /tmp.\n- Percorsi: Ogni file deve essere referenziato con il percorso assoluto (es. /tmp/binary.bin). Non usare mai directory esterne a /tmp.\n- Condivisione: /tmp è montata su tutti i container MCP. I file creati o modificati da un tool sono immediatamente visibili agli altri.\n\nSTRETTO DIVIETO DI ALLUCINAZIONE TOOL:\n- USA ESCLUSIVAMENTE I TOOL MCP FORNITI: 'memory', 'filesystem'.\n- NON INVENTARE MAI TOOL INESISTENTI: È severamente vietato tentare di richiamare tool come \"run\", \"fetch\", \"execute_command\", \"shell\" o simili.\n- Se un tool non è in questa lista ('memory', 'filesystem'), NON esiste e non puoi usarlo.\n- Se senti la necessità di scaricare dati o eseguire comandi, ricorda che non hai tool per farlo; puoi solo operare sui file già presenti in /tmp tramite 'filesystem' o ragionare sugli stati tramite 'memory'.\n\nTOOLSET & WORKFLOW:\n1. memory (Pianificazione e Stato): È il tuo unico strumento di ragionamento e log. Usalo per definire il piano d'azione, suddividere la sfida in step e memorizzare scoperte (offset, password, indirizzi). Aggiorna la memoria prima di ogni azione.\n2. filesystem (Manipolazione): È il tuo unico strumento operativo. Usalo per ispezionare file esistenti, leggere contenuti, creare script o archiviare risultati esclusivamente in /tmp.\n\nMETODOLOGIA DI ANALISI:\n- Ragionamento Persistente: Documenta ogni ipotesi, passo logico e test nel tool memory.\n- Verifica Intermedia: Dopo ogni operazione sul filesystem, usa 'filesystem' per confermare che l'azione abbia prodotto il risultato atteso.\n- Gestione Errori: Se non trovi i file necessari in /tmp, segnalalo chiaramente senza provare a inventare tool per scaricarli o generarli.\n\nREGOLE DI COMUNICAZIONE:\n- Sii estremamente tecnico, sintetico e preciso.\n- Non fare mai riferimento a tool che non siano 'memory' o 'filesystem'."
|
||
|
||
class OllamaMcpClient:
|
||
"""MCP client that uses Ollama and LangChain create_agent with optional runtime MCP tools."""
|
||
|
||
def __init__(
|
||
self,
|
||
ollamaClient: OllamaClient,
|
||
mcpTools: Optional[List[StructuredTool]] = None,
|
||
systemPrompt: Optional[str] = None,
|
||
):
|
||
self.ollamaClient = ollamaClient
|
||
self.mcpTools = mcpTools or []
|
||
self.systemPrompt = systemPrompt or SYSTEM_PROMPT
|
||
staticTools: List[Any] = [getTime, countWords]
|
||
staticToolNames = [getTime.name, countWords.name]
|
||
middleware: List[AgentMiddleware] = [LogToolCallsMiddleware()]
|
||
if self.mcpTools:
|
||
middleware.append(McpToolsMiddleware(self.mcpTools, staticToolNames=staticToolNames))
|
||
model = ChatOllama(
|
||
base_url=ollamaClient.baseUrl,
|
||
model=ollamaClient.model,
|
||
temperature=0.1,
|
||
)
|
||
self.agent = create_agent(
|
||
model,
|
||
tools=staticTools,
|
||
middleware=middleware,
|
||
system_prompt=self.systemPrompt,
|
||
)
|
||
|
||
async def processRequest(self, prompt: str, context: Optional[List[str]] = None, recursionLimit: int = 50) -> str:
|
||
"""Process a request using the LangChain agent (ReAct loop with tools)."""
|
||
messages: List[Any] = [HumanMessage(content=prompt)]
|
||
if context:
|
||
messages.insert(0, SystemMessage(content=f"Context:\n{chr(10).join(context)}"))
|
||
config: Dict[str, Any] = {"recursion_limit": recursionLimit}
|
||
toolParseRetryPrompt = (
|
||
"ATTENZIONE: Una chiamata write_file ha prodotto JSON non valido. "
|
||
"Quando scrivi file con codice Python: usa \\n per le newline nel JSON, escapa le virgolette con \\. "
|
||
"Non aggiungere parametri extra (es. overwrite). Usa edit_file per modifiche incrementali se il contenuto è lungo."
|
||
)
|
||
try:
|
||
result = await self.agent.ainvoke({"messages": messages}, config=config)
|
||
except OllamaResponseError as e:
|
||
errStr = str(e)
|
||
if "error parsing tool call" in errStr:
|
||
print(f"[Agent Error]: Tool call parse error, retrying with guidance: {errStr[:200]}...", file=sys.stderr)
|
||
retryMessages: List[Any] = [SystemMessage(content=toolParseRetryPrompt)]
|
||
retryMessages.extend(messages)
|
||
result = await self.agent.ainvoke({"messages": retryMessages}, config=config)
|
||
else:
|
||
print(f"[Agent Error]: {e}", file=sys.stderr)
|
||
raise
|
||
except Exception as e:
|
||
print(f"[Agent Error]: {e}", file=sys.stderr)
|
||
raise
|
||
return _extractFinalResponse(result)
|
||
|
||
def listTools(self) -> List[str]:
|
||
"""List tool names (static + MCP)."""
|
||
names = [getTime.name, countWords.name]
|
||
names.extend(t.name for t in self.mcpTools)
|
||
return names
|
||
|
||
|
||
async def async_main(args, ollamaClient: OllamaClient):
|
||
"""Async main: MCP tools come only from mcp.json (Docker containers exposing SSE). Ollama is used only as LLM."""
|
||
mcpTools: List[StructuredTool] = []
|
||
mcpServers: List[McpServerWrapper] = []
|
||
|
||
# MCP servers from config file (mcp.json) – Docker containers with SSE endpoints
|
||
serverUrls: Dict[str, str] = loadMcpConfig(args.mcp_config)
|
||
if args.mcp_server:
|
||
serverUrls["default"] = args.mcp_server.rstrip("/")
|
||
|
||
# Which servers to use: default = all from mcp.json; or --mcp-tools fetch,filesystem to pick a subset
|
||
wantServers = [s.strip() for s in (args.mcp_tools or "").split(",") if s.strip()]
|
||
if not wantServers and serverUrls:
|
||
wantServers = list(serverUrls.keys())
|
||
print(f"MCP tools from config (all SSE servers): {wantServers}", file=sys.stderr)
|
||
for name in wantServers:
|
||
url = serverUrls.get(name)
|
||
if not url:
|
||
print(f"Warning: MCP server '{name}' not in config (known: {list(serverUrls.keys())})", file=sys.stderr)
|
||
continue
|
||
wrapper = McpServerWrapper(httpUrl=url)
|
||
if await wrapper.connect():
|
||
mcpServers.append(wrapper)
|
||
print(f"Connected to MCP server '{name}' at {url}", file=sys.stderr)
|
||
else:
|
||
print(f"Error: Failed to connect to MCP server '{name}' at {url}", file=sys.stderr)
|
||
|
||
if mcpServers:
|
||
mcpTools = await buildMcpLangChainTools(mcpServers, toolTimeout=getattr(args, "tool_timeout", None))
|
||
#print(f"Loaded {len(mcpTools)} MCP tools: {[t.name for t in mcpTools]}", file=sys.stderr)
|
||
|
||
mcpClient = OllamaMcpClient(ollamaClient, mcpTools=mcpTools)
|
||
print(f"Agent tools: {mcpClient.listTools()}", file=sys.stderr)
|
||
|
||
if args.prompt:
|
||
response = await mcpClient.processRequest(args.prompt, recursionLimit=args.recursion_limit)
|
||
print(response)
|
||
elif args.interactive:
|
||
print("MCP Client with Ollama (LangChain agent) - Interactive Mode")
|
||
print("Type 'quit' or 'exit' to exit\n")
|
||
while True:
|
||
try:
|
||
prompt = input("You: ").strip()
|
||
if prompt.lower() in ["quit", "exit"]:
|
||
break
|
||
if not prompt:
|
||
continue
|
||
response = await mcpClient.processRequest(prompt, recursionLimit=args.recursion_limit)
|
||
print(f"Assistant: {response}\n")
|
||
except KeyboardInterrupt:
|
||
print("\nGoodbye!")
|
||
break
|
||
except Exception as e:
|
||
print(f"Error: {e}", file=sys.stderr)
|
||
|
||
for wrapper in mcpServers:
|
||
await wrapper.disconnect()
|
||
|
||
|
||
def main():
|
||
"""Main function to run the MCP client."""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="MCP client using Ollama")
|
||
parser.add_argument(
|
||
"--base-url",
|
||
default="http://localhost:11434",
|
||
help="Ollama base URL (default: http://localhost:11434)"
|
||
)
|
||
parser.add_argument(
|
||
"--model",
|
||
default="gpt-oss:20b",
|
||
help="Ollama model to use (default: ministral-3)"
|
||
)
|
||
parser.add_argument(
|
||
"--list-models",
|
||
action="store_true",
|
||
help="List available Ollama models and exit"
|
||
)
|
||
parser.add_argument(
|
||
"--prompt",
|
||
help="Prompt to send to the model"
|
||
)
|
||
parser.add_argument(
|
||
"--interactive",
|
||
"-i",
|
||
action="store_true",
|
||
help="Run in interactive mode"
|
||
)
|
||
parser.add_argument(
|
||
"--mcp-config",
|
||
default=None,
|
||
help="Path to mcp.json (default: mcpServer/mcp.json relative to project)"
|
||
)
|
||
parser.add_argument(
|
||
"--mcp-tools",
|
||
default="",
|
||
help="Comma-separated MCP server names from mcp.json (default: all servers in config). E.g. fetch,filesystem"
|
||
)
|
||
parser.add_argument(
|
||
"--mcp-server",
|
||
help="Override: single MCP SSE URL (e.g. http://localhost:3000/sse). Added as server 'default' in addition to mcp.json."
|
||
)
|
||
parser.add_argument(
|
||
"--mcp-headers",
|
||
help="Additional headers for MCP server as JSON string (e.g. '{\"Authorization\": \"Bearer token\"}')"
|
||
)
|
||
parser.add_argument(
|
||
"--recursion-limit",
|
||
type=int,
|
||
default=5000,
|
||
help="Max agent steps (model + tool calls) before stopping (default: 50)"
|
||
)
|
||
parser.add_argument(
|
||
"--tool-timeout",
|
||
type=float,
|
||
default=60,
|
||
help="Timeout in seconds for each MCP tool call. Prevents agent from freezing when a tool hangs (e.g. run with missing executable). Default: 60"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Initialize Ollama client
|
||
ollamaClient = OllamaClient(baseUrl=args.base_url, model=args.model)
|
||
|
||
# Check health
|
||
if not ollamaClient.checkHealth():
|
||
print(f"Error: Cannot connect to Ollama at {args.base_url}", file=sys.stderr)
|
||
print("Make sure Ollama is running and accessible.", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
# List models if requested
|
||
if args.list_models:
|
||
models = ollamaClient.listModels()
|
||
print("Available models:")
|
||
for model in models:
|
||
print(f" - {model}")
|
||
sys.exit(0)
|
||
|
||
# Run async main
|
||
asyncio.run(async_main(args, ollamaClient))
|
||
|
||
if not args.prompt and not args.interactive:
|
||
parser.print_help()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |