Files
dss/apps/api/ai_providers.py
2025-12-11 18:55:57 -03:00

423 lines
14 KiB
Python

"""
AI Provider abstraction for Claude and Gemini.
Handles model-specific API calls and tool execution
"""
import asyncio
import json
import os
import subprocess
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
class AIProvider(ABC):
"""Abstract base class for AI providers."""
@abstractmethod
async def chat(
self,
message: str,
system_prompt: str,
history: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
temperature: float = 0.7,
) -> Dict[str, Any]:
"""
Send a chat message and get response
Returns: {
"success": bool,
"response": str,
"model": str,
"tools_used": List[Dict],
"stop_reason": str
}
"""
pass
class ClaudeProvider(AIProvider):
"""Anthropic Claude provider."""
# SoFi LLM Proxy configuration
PROXY_BASE_URL = "https://internal.sofitest.com/llm-proxy"
API_KEY_HELPER = os.path.expanduser("~/.local/bin/llm-proxy-keys")
def __init__(self):
self.api_key = os.getenv("ANTHROPIC_API_KEY")
self.base_url = os.getenv("ANTHROPIC_BASE_URL", self.PROXY_BASE_URL)
self.default_model = "claude-sonnet-4-5-20250929"
self._proxy_key = None
def _get_proxy_key(self) -> Optional[str]:
"""Get API key from SoFi LLM proxy helper script"""
if self._proxy_key:
return self._proxy_key
try:
if os.path.exists(self.API_KEY_HELPER):
result = subprocess.run(
[self.API_KEY_HELPER],
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
# Extract the key from output (last line with sk- prefix)
for line in result.stdout.strip().split('\n'):
if line.startswith('sk-'):
self._proxy_key = line.strip()
return self._proxy_key
except Exception as e:
print(f"Error getting proxy key: {e}")
return None
def is_available(self) -> bool:
"""Check if Claude is available."""
try:
from anthropic import Anthropic
# Available if SDK is installed (proxy may have keys)
return True
except ImportError:
return False
def _create_client(self):
"""Create Anthropic client configured for SoFi proxy"""
from anthropic import Anthropic
import httpx
# Create httpx client that skips SSL verification (for corporate proxy)
http_client = httpx.Client(verify=False)
# Get API key: prefer env var, then proxy helper
api_key = self.api_key or self._get_proxy_key()
if not api_key:
raise ValueError("No API key available. Set ANTHROPIC_API_KEY or ensure llm-proxy-keys is installed.")
return Anthropic(
api_key=api_key,
base_url=self.base_url,
http_client=http_client
)
async def chat(
self,
message: str,
system_prompt: str,
history: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
temperature: float = 0.7,
mcp_handler=None,
mcp_context=None,
) -> Dict[str, Any]:
"""Chat with Claude."""
if not self.is_available():
return {
"success": False,
"response": "Claude not available. Install anthropic SDK.",
"model": "error",
"tools_used": [],
"stop_reason": "error",
}
from anthropic import Anthropic
# Create client with SoFi proxy settings
try:
client = self._create_client()
except ValueError as e:
return {
"success": False,
"response": str(e),
"model": "error",
"tools_used": [],
"stop_reason": "error"
}
# Build messages
messages = []
for msg in history[-6:]:
role = msg.get("role", "user")
content = msg.get("content", "")
if content and role in ["user", "assistant"]:
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": message})
# API params
api_params = {
"model": self.default_model,
"max_tokens": 4096,
"temperature": temperature,
"system": system_prompt,
"messages": messages,
}
if tools:
api_params["tools"] = tools
# Make API call via SoFi proxy
try:
response = await asyncio.to_thread(
client.messages.create,
**api_params
)
except Exception as e:
return {
"success": False,
"response": f"Claude API error: {str(e)}",
"model": "error",
"tools_used": [],
"stop_reason": "error"
}
# Handle tool use loop
tools_used = []
max_iterations = 5
iteration = 0
while response.stop_reason == "tool_use" and iteration < max_iterations:
iteration += 1
tool_results = []
for content_block in response.content:
if content_block.type == "tool_use":
tool_name = content_block.name
tool_input = content_block.input
tool_use_id = content_block.id
# Execute tool via MCP handler
result = await mcp_handler.execute_tool(
tool_name=tool_name, arguments=tool_input, context=mcp_context
)
tools_used.append(
{
"tool": tool_name,
"success": result.success,
"duration_ms": result.duration_ms,
}
)
# Format result
if result.success:
tool_result_content = json.dumps(result.result, indent=2)
else:
tool_result_content = json.dumps({"error": result.error})
tool_results.append(
{
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": tool_result_content,
}
)
# Continue conversation with tool results
messages.append({"role": "assistant", "content": response.content})
messages.append({"role": "user", "content": tool_results})
response = await asyncio.to_thread(
client.messages.create, **{**api_params, "messages": messages}
)
# Extract final response
response_text = ""
for content_block in response.content:
if hasattr(content_block, "text"):
response_text += content_block.text
return {
"success": True,
"response": response_text,
"model": response.model,
"tools_used": tools_used,
"stop_reason": response.stop_reason,
}
class GeminiProvider(AIProvider):
"""Google Gemini provider."""
def __init__(self):
self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
self.default_model = "gemini-2.0-flash-exp"
def is_available(self) -> bool:
"""Check if Gemini is available."""
try:
import google.generativeai as genai
return bool(self.api_key)
except ImportError:
return False
def _convert_tools_to_gemini_format(
self, claude_tools: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Convert Claude tool format to Gemini function declarations."""
gemini_tools = []
for tool in claude_tools:
# Convert from Claude's format to Gemini's format
function_declaration = {
"name": tool.get("name"),
"description": tool.get("description", ""),
"parameters": {"type": "object", "properties": {}, "required": []},
}
# Convert input schema
if "input_schema" in tool:
schema = tool["input_schema"]
if "properties" in schema:
function_declaration["parameters"]["properties"] = schema["properties"]
if "required" in schema:
function_declaration["parameters"]["required"] = schema["required"]
gemini_tools.append(function_declaration)
return gemini_tools
async def chat(
self,
message: str,
system_prompt: str,
history: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
temperature: float = 0.7,
mcp_handler=None,
mcp_context=None,
) -> Dict[str, Any]:
"""Chat with Gemini."""
if not self.is_available():
return {
"success": False,
"response": "Gemini not available. Install google-generativeai SDK or set GOOGLE_API_KEY/GEMINI_API_KEY.",
"model": "error",
"tools_used": [],
"stop_reason": "error",
}
import google.generativeai as genai
genai.configure(api_key=self.api_key)
# Build chat history
gemini_history = []
for msg in history[-6:]:
role = msg.get("role", "user")
content = msg.get("content", "")
if content and role in ["user", "assistant"]:
gemini_history.append(
{"role": "user" if role == "user" else "model", "parts": [content]}
)
# Create model with tools if available
model_kwargs = {
"model_name": self.default_model,
"generation_config": {
"temperature": temperature,
"max_output_tokens": 4096,
},
"system_instruction": system_prompt,
}
# Convert and add tools if available
if tools and mcp_handler:
gemini_tools = self._convert_tools_to_gemini_format(tools)
model_kwargs["tools"] = gemini_tools
model = genai.GenerativeModel(**model_kwargs)
# Start chat
chat = model.start_chat(history=gemini_history)
# Send message with tool execution loop
tools_used = []
max_iterations = 5
iteration = 0
current_message = message
while iteration < max_iterations:
iteration += 1
response = await asyncio.to_thread(chat.send_message, current_message)
# Check for function calls
if response.candidates and response.candidates[0].content.parts:
has_function_call = False
for part in response.candidates[0].content.parts:
if hasattr(part, "function_call") and part.function_call:
has_function_call = True
func_call = part.function_call
tool_name = func_call.name
tool_args = dict(func_call.args)
# Execute tool
result = await mcp_handler.execute_tool(
tool_name=tool_name, arguments=tool_args, context=mcp_context
)
tools_used.append(
{
"tool": tool_name,
"success": result.success,
"duration_ms": result.duration_ms,
}
)
# Format result for Gemini
function_response = {
"name": tool_name,
"response": result.result
if result.success
else {"error": result.error},
}
# Send function response back
current_message = genai.protos.Content(
parts=[
genai.protos.Part(
function_response=genai.protos.FunctionResponse(
name=tool_name, response=function_response
)
)
]
)
break
# If no function call, we're done
if not has_function_call:
break
else:
break
# Extract final response text
response_text = ""
if response.candidates and response.candidates[0].content.parts:
for part in response.candidates[0].content.parts:
if hasattr(part, "text"):
response_text += part.text
return {
"success": True,
"response": response_text,
"model": self.default_model,
"tools_used": tools_used,
"stop_reason": "stop" if response.candidates else "error",
}
# Factory function
def get_ai_provider(model_name: str) -> AIProvider:
"""Get AI provider by name."""
if model_name.lower() in ["gemini", "google"]:
return GeminiProvider()
else:
return ClaudeProvider()