""" AI Provider abstraction for Claude and Gemini. Handles model-specific API calls and tool execution """ import asyncio import json import os import subprocess from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional class AIProvider(ABC): """Abstract base class for AI providers.""" @abstractmethod async def chat( self, message: str, system_prompt: str, history: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, temperature: float = 0.7, ) -> Dict[str, Any]: """ Send a chat message and get response Returns: { "success": bool, "response": str, "model": str, "tools_used": List[Dict], "stop_reason": str } """ pass class ClaudeProvider(AIProvider): """Anthropic Claude provider.""" # SoFi LLM Proxy configuration PROXY_BASE_URL = "https://internal.sofitest.com/llm-proxy" API_KEY_HELPER = os.path.expanduser("~/.local/bin/llm-proxy-keys") def __init__(self): self.api_key = os.getenv("ANTHROPIC_API_KEY") self.base_url = os.getenv("ANTHROPIC_BASE_URL", self.PROXY_BASE_URL) self.default_model = "claude-sonnet-4-5-20250929" self._proxy_key = None def _get_proxy_key(self) -> Optional[str]: """Get API key from SoFi LLM proxy helper script""" if self._proxy_key: return self._proxy_key try: if os.path.exists(self.API_KEY_HELPER): result = subprocess.run( [self.API_KEY_HELPER], capture_output=True, text=True, timeout=10 ) if result.returncode == 0: # Extract the key from output (last line with sk- prefix) for line in result.stdout.strip().split('\n'): if line.startswith('sk-'): self._proxy_key = line.strip() return self._proxy_key except Exception as e: print(f"Error getting proxy key: {e}") return None def is_available(self) -> bool: """Check if Claude is available.""" try: from anthropic import Anthropic # Available if SDK is installed (proxy may have keys) return True except ImportError: return False def _create_client(self): """Create Anthropic client configured for SoFi proxy""" from anthropic import Anthropic import httpx # Create httpx client that skips SSL verification (for corporate proxy) http_client = httpx.Client(verify=False) # Get API key: prefer env var, then proxy helper api_key = self.api_key or self._get_proxy_key() if not api_key: raise ValueError("No API key available. Set ANTHROPIC_API_KEY or ensure llm-proxy-keys is installed.") return Anthropic( api_key=api_key, base_url=self.base_url, http_client=http_client ) async def chat( self, message: str, system_prompt: str, history: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, temperature: float = 0.7, mcp_handler=None, mcp_context=None, ) -> Dict[str, Any]: """Chat with Claude.""" if not self.is_available(): return { "success": False, "response": "Claude not available. Install anthropic SDK.", "model": "error", "tools_used": [], "stop_reason": "error", } from anthropic import Anthropic # Create client with SoFi proxy settings try: client = self._create_client() except ValueError as e: return { "success": False, "response": str(e), "model": "error", "tools_used": [], "stop_reason": "error" } # Build messages messages = [] for msg in history[-6:]: role = msg.get("role", "user") content = msg.get("content", "") if content and role in ["user", "assistant"]: messages.append({"role": role, "content": content}) messages.append({"role": "user", "content": message}) # API params api_params = { "model": self.default_model, "max_tokens": 4096, "temperature": temperature, "system": system_prompt, "messages": messages, } if tools: api_params["tools"] = tools # Make API call via SoFi proxy try: response = await asyncio.to_thread( client.messages.create, **api_params ) except Exception as e: return { "success": False, "response": f"Claude API error: {str(e)}", "model": "error", "tools_used": [], "stop_reason": "error" } # Handle tool use loop tools_used = [] max_iterations = 5 iteration = 0 while response.stop_reason == "tool_use" and iteration < max_iterations: iteration += 1 tool_results = [] for content_block in response.content: if content_block.type == "tool_use": tool_name = content_block.name tool_input = content_block.input tool_use_id = content_block.id # Execute tool via MCP handler result = await mcp_handler.execute_tool( tool_name=tool_name, arguments=tool_input, context=mcp_context ) tools_used.append( { "tool": tool_name, "success": result.success, "duration_ms": result.duration_ms, } ) # Format result if result.success: tool_result_content = json.dumps(result.result, indent=2) else: tool_result_content = json.dumps({"error": result.error}) tool_results.append( { "type": "tool_result", "tool_use_id": tool_use_id, "content": tool_result_content, } ) # Continue conversation with tool results messages.append({"role": "assistant", "content": response.content}) messages.append({"role": "user", "content": tool_results}) response = await asyncio.to_thread( client.messages.create, **{**api_params, "messages": messages} ) # Extract final response response_text = "" for content_block in response.content: if hasattr(content_block, "text"): response_text += content_block.text return { "success": True, "response": response_text, "model": response.model, "tools_used": tools_used, "stop_reason": response.stop_reason, } class GeminiProvider(AIProvider): """Google Gemini provider.""" def __init__(self): self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") self.default_model = "gemini-2.0-flash-exp" def is_available(self) -> bool: """Check if Gemini is available.""" try: import google.generativeai as genai return bool(self.api_key) except ImportError: return False def _convert_tools_to_gemini_format( self, claude_tools: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """Convert Claude tool format to Gemini function declarations.""" gemini_tools = [] for tool in claude_tools: # Convert from Claude's format to Gemini's format function_declaration = { "name": tool.get("name"), "description": tool.get("description", ""), "parameters": {"type": "object", "properties": {}, "required": []}, } # Convert input schema if "input_schema" in tool: schema = tool["input_schema"] if "properties" in schema: function_declaration["parameters"]["properties"] = schema["properties"] if "required" in schema: function_declaration["parameters"]["required"] = schema["required"] gemini_tools.append(function_declaration) return gemini_tools async def chat( self, message: str, system_prompt: str, history: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, temperature: float = 0.7, mcp_handler=None, mcp_context=None, ) -> Dict[str, Any]: """Chat with Gemini.""" if not self.is_available(): return { "success": False, "response": "Gemini not available. Install google-generativeai SDK or set GOOGLE_API_KEY/GEMINI_API_KEY.", "model": "error", "tools_used": [], "stop_reason": "error", } import google.generativeai as genai genai.configure(api_key=self.api_key) # Build chat history gemini_history = [] for msg in history[-6:]: role = msg.get("role", "user") content = msg.get("content", "") if content and role in ["user", "assistant"]: gemini_history.append( {"role": "user" if role == "user" else "model", "parts": [content]} ) # Create model with tools if available model_kwargs = { "model_name": self.default_model, "generation_config": { "temperature": temperature, "max_output_tokens": 4096, }, "system_instruction": system_prompt, } # Convert and add tools if available if tools and mcp_handler: gemini_tools = self._convert_tools_to_gemini_format(tools) model_kwargs["tools"] = gemini_tools model = genai.GenerativeModel(**model_kwargs) # Start chat chat = model.start_chat(history=gemini_history) # Send message with tool execution loop tools_used = [] max_iterations = 5 iteration = 0 current_message = message while iteration < max_iterations: iteration += 1 response = await asyncio.to_thread(chat.send_message, current_message) # Check for function calls if response.candidates and response.candidates[0].content.parts: has_function_call = False for part in response.candidates[0].content.parts: if hasattr(part, "function_call") and part.function_call: has_function_call = True func_call = part.function_call tool_name = func_call.name tool_args = dict(func_call.args) # Execute tool result = await mcp_handler.execute_tool( tool_name=tool_name, arguments=tool_args, context=mcp_context ) tools_used.append( { "tool": tool_name, "success": result.success, "duration_ms": result.duration_ms, } ) # Format result for Gemini function_response = { "name": tool_name, "response": result.result if result.success else {"error": result.error}, } # Send function response back current_message = genai.protos.Content( parts=[ genai.protos.Part( function_response=genai.protos.FunctionResponse( name=tool_name, response=function_response ) ) ] ) break # If no function call, we're done if not has_function_call: break else: break # Extract final response text response_text = "" if response.candidates and response.candidates[0].content.parts: for part in response.candidates[0].content.parts: if hasattr(part, "text"): response_text += part.text return { "success": True, "response": response_text, "model": self.default_model, "tools_used": tools_used, "stop_reason": "stop" if response.candidates else "error", } # Factory function def get_ai_provider(model_name: str) -> AIProvider: """Get AI provider by name.""" if model_name.lower() in ["gemini", "google"]: return GeminiProvider() else: return ClaudeProvider()