""" AI Provider abstraction for Claude and Gemini Handles model-specific API calls and tool execution """ import os import json import asyncio from typing import List, Dict, Any, Optional from abc import ABC, abstractmethod class AIProvider(ABC): """Abstract base class for AI providers""" @abstractmethod async def chat( self, message: str, system_prompt: str, history: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, temperature: float = 0.7 ) -> Dict[str, Any]: """ Send a chat message and get response Returns: { "success": bool, "response": str, "model": str, "tools_used": List[Dict], "stop_reason": str } """ pass class ClaudeProvider(AIProvider): """Anthropic Claude provider""" def __init__(self): self.api_key = os.getenv("ANTHROPIC_API_KEY") self.default_model = "claude-sonnet-4-5-20250929" def is_available(self) -> bool: """Check if Claude is available""" try: from anthropic import Anthropic return bool(self.api_key) except ImportError: return False async def chat( self, message: str, system_prompt: str, history: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, temperature: float = 0.7, mcp_handler=None, mcp_context=None ) -> Dict[str, Any]: """Chat with Claude""" if not self.is_available(): return { "success": False, "response": "Claude not available. Install anthropic SDK or set ANTHROPIC_API_KEY.", "model": "error", "tools_used": [], "stop_reason": "error" } from anthropic import Anthropic client = Anthropic(api_key=self.api_key) # Build messages messages = [] for msg in history[-6:]: role = msg.get("role", "user") content = msg.get("content", "") if content and role in ["user", "assistant"]: messages.append({"role": role, "content": content}) messages.append({"role": "user", "content": message}) # API params api_params = { "model": self.default_model, "max_tokens": 4096, "temperature": temperature, "system": system_prompt, "messages": messages } if tools: api_params["tools"] = tools # Initial call response = await asyncio.to_thread( client.messages.create, **api_params ) # Handle tool use loop tools_used = [] max_iterations = 5 iteration = 0 while response.stop_reason == "tool_use" and iteration < max_iterations: iteration += 1 tool_results = [] for content_block in response.content: if content_block.type == "tool_use": tool_name = content_block.name tool_input = content_block.input tool_use_id = content_block.id # Execute tool via MCP handler result = await mcp_handler.execute_tool( tool_name=tool_name, arguments=tool_input, context=mcp_context ) tools_used.append({ "tool": tool_name, "success": result.success, "duration_ms": result.duration_ms }) # Format result if result.success: tool_result_content = json.dumps(result.result, indent=2) else: tool_result_content = json.dumps({"error": result.error}) tool_results.append({ "type": "tool_result", "tool_use_id": tool_use_id, "content": tool_result_content }) # Continue conversation with tool results messages.append({"role": "assistant", "content": response.content}) messages.append({"role": "user", "content": tool_results}) response = await asyncio.to_thread( client.messages.create, **{**api_params, "messages": messages} ) # Extract final response response_text = "" for content_block in response.content: if hasattr(content_block, "text"): response_text += content_block.text return { "success": True, "response": response_text, "model": response.model, "tools_used": tools_used, "stop_reason": response.stop_reason } class GeminiProvider(AIProvider): """Google Gemini provider""" def __init__(self): self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") self.default_model = "gemini-2.0-flash-exp" def is_available(self) -> bool: """Check if Gemini is available""" try: import google.generativeai as genai return bool(self.api_key) except ImportError: return False def _convert_tools_to_gemini_format(self, claude_tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Convert Claude tool format to Gemini function declarations""" gemini_tools = [] for tool in claude_tools: # Convert from Claude's format to Gemini's format function_declaration = { "name": tool.get("name"), "description": tool.get("description", ""), "parameters": { "type": "object", "properties": {}, "required": [] } } # Convert input schema if "input_schema" in tool: schema = tool["input_schema"] if "properties" in schema: function_declaration["parameters"]["properties"] = schema["properties"] if "required" in schema: function_declaration["parameters"]["required"] = schema["required"] gemini_tools.append(function_declaration) return gemini_tools async def chat( self, message: str, system_prompt: str, history: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, temperature: float = 0.7, mcp_handler=None, mcp_context=None ) -> Dict[str, Any]: """Chat with Gemini""" if not self.is_available(): return { "success": False, "response": "Gemini not available. Install google-generativeai SDK or set GOOGLE_API_KEY/GEMINI_API_KEY.", "model": "error", "tools_used": [], "stop_reason": "error" } import google.generativeai as genai genai.configure(api_key=self.api_key) # Build chat history gemini_history = [] for msg in history[-6:]: role = msg.get("role", "user") content = msg.get("content", "") if content and role in ["user", "assistant"]: gemini_history.append({ "role": "user" if role == "user" else "model", "parts": [content] }) # Create model with tools if available model_kwargs = { "model_name": self.default_model, "generation_config": { "temperature": temperature, "max_output_tokens": 4096, }, "system_instruction": system_prompt } # Convert and add tools if available if tools and mcp_handler: gemini_tools = self._convert_tools_to_gemini_format(tools) model_kwargs["tools"] = gemini_tools model = genai.GenerativeModel(**model_kwargs) # Start chat chat = model.start_chat(history=gemini_history) # Send message with tool execution loop tools_used = [] max_iterations = 5 iteration = 0 current_message = message while iteration < max_iterations: iteration += 1 response = await asyncio.to_thread(chat.send_message, current_message) # Check for function calls if response.candidates and response.candidates[0].content.parts: has_function_call = False for part in response.candidates[0].content.parts: if hasattr(part, 'function_call') and part.function_call: has_function_call = True func_call = part.function_call tool_name = func_call.name tool_args = dict(func_call.args) # Execute tool result = await mcp_handler.execute_tool( tool_name=tool_name, arguments=tool_args, context=mcp_context ) tools_used.append({ "tool": tool_name, "success": result.success, "duration_ms": result.duration_ms }) # Format result for Gemini function_response = { "name": tool_name, "response": result.result if result.success else {"error": result.error} } # Send function response back current_message = genai.protos.Content( parts=[genai.protos.Part( function_response=genai.protos.FunctionResponse( name=tool_name, response=function_response ) )] ) break # If no function call, we're done if not has_function_call: break else: break # Extract final response text response_text = "" if response.candidates and response.candidates[0].content.parts: for part in response.candidates[0].content.parts: if hasattr(part, 'text'): response_text += part.text return { "success": True, "response": response_text, "model": self.default_model, "tools_used": tools_used, "stop_reason": "stop" if response.candidates else "error" } # Factory function def get_ai_provider(model_name: str) -> AIProvider: """Get AI provider by name""" if model_name.lower() in ["gemini", "google"]: return GeminiProvider() else: return ClaudeProvider()