Files
dss/apps/api/ai_providers.py
2025-12-11 07:13:06 -03:00

351 lines
11 KiB
Python

"""
AI Provider abstraction for Claude and Gemini.
Handles model-specific API calls and tool execution
"""
import asyncio
import json
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
class AIProvider(ABC):
"""Abstract base class for AI providers."""
@abstractmethod
async def chat(
self,
message: str,
system_prompt: str,
history: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
temperature: float = 0.7,
) -> Dict[str, Any]:
"""
Send a chat message and get response
Returns: {
"success": bool,
"response": str,
"model": str,
"tools_used": List[Dict],
"stop_reason": str
}
"""
pass
class ClaudeProvider(AIProvider):
"""Anthropic Claude provider."""
def __init__(self):
self.api_key = os.getenv("ANTHROPIC_API_KEY")
self.default_model = "claude-sonnet-4-5-20250929"
def is_available(self) -> bool:
"""Check if Claude is available."""
try:
from anthropic import Anthropic
return bool(self.api_key)
except ImportError:
return False
async def chat(
self,
message: str,
system_prompt: str,
history: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
temperature: float = 0.7,
mcp_handler=None,
mcp_context=None,
) -> Dict[str, Any]:
"""Chat with Claude."""
if not self.is_available():
return {
"success": False,
"response": "Claude not available. Install anthropic SDK or set ANTHROPIC_API_KEY.",
"model": "error",
"tools_used": [],
"stop_reason": "error",
}
from anthropic import Anthropic
client = Anthropic(api_key=self.api_key)
# Build messages
messages = []
for msg in history[-6:]:
role = msg.get("role", "user")
content = msg.get("content", "")
if content and role in ["user", "assistant"]:
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": message})
# API params
api_params = {
"model": self.default_model,
"max_tokens": 4096,
"temperature": temperature,
"system": system_prompt,
"messages": messages,
}
if tools:
api_params["tools"] = tools
# Initial call
response = await asyncio.to_thread(client.messages.create, **api_params)
# Handle tool use loop
tools_used = []
max_iterations = 5
iteration = 0
while response.stop_reason == "tool_use" and iteration < max_iterations:
iteration += 1
tool_results = []
for content_block in response.content:
if content_block.type == "tool_use":
tool_name = content_block.name
tool_input = content_block.input
tool_use_id = content_block.id
# Execute tool via MCP handler
result = await mcp_handler.execute_tool(
tool_name=tool_name, arguments=tool_input, context=mcp_context
)
tools_used.append(
{
"tool": tool_name,
"success": result.success,
"duration_ms": result.duration_ms,
}
)
# Format result
if result.success:
tool_result_content = json.dumps(result.result, indent=2)
else:
tool_result_content = json.dumps({"error": result.error})
tool_results.append(
{
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": tool_result_content,
}
)
# Continue conversation with tool results
messages.append({"role": "assistant", "content": response.content})
messages.append({"role": "user", "content": tool_results})
response = await asyncio.to_thread(
client.messages.create, **{**api_params, "messages": messages}
)
# Extract final response
response_text = ""
for content_block in response.content:
if hasattr(content_block, "text"):
response_text += content_block.text
return {
"success": True,
"response": response_text,
"model": response.model,
"tools_used": tools_used,
"stop_reason": response.stop_reason,
}
class GeminiProvider(AIProvider):
"""Google Gemini provider."""
def __init__(self):
self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
self.default_model = "gemini-2.0-flash-exp"
def is_available(self) -> bool:
"""Check if Gemini is available."""
try:
import google.generativeai as genai
return bool(self.api_key)
except ImportError:
return False
def _convert_tools_to_gemini_format(
self, claude_tools: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Convert Claude tool format to Gemini function declarations."""
gemini_tools = []
for tool in claude_tools:
# Convert from Claude's format to Gemini's format
function_declaration = {
"name": tool.get("name"),
"description": tool.get("description", ""),
"parameters": {"type": "object", "properties": {}, "required": []},
}
# Convert input schema
if "input_schema" in tool:
schema = tool["input_schema"]
if "properties" in schema:
function_declaration["parameters"]["properties"] = schema["properties"]
if "required" in schema:
function_declaration["parameters"]["required"] = schema["required"]
gemini_tools.append(function_declaration)
return gemini_tools
async def chat(
self,
message: str,
system_prompt: str,
history: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
temperature: float = 0.7,
mcp_handler=None,
mcp_context=None,
) -> Dict[str, Any]:
"""Chat with Gemini."""
if not self.is_available():
return {
"success": False,
"response": "Gemini not available. Install google-generativeai SDK or set GOOGLE_API_KEY/GEMINI_API_KEY.",
"model": "error",
"tools_used": [],
"stop_reason": "error",
}
import google.generativeai as genai
genai.configure(api_key=self.api_key)
# Build chat history
gemini_history = []
for msg in history[-6:]:
role = msg.get("role", "user")
content = msg.get("content", "")
if content and role in ["user", "assistant"]:
gemini_history.append(
{"role": "user" if role == "user" else "model", "parts": [content]}
)
# Create model with tools if available
model_kwargs = {
"model_name": self.default_model,
"generation_config": {
"temperature": temperature,
"max_output_tokens": 4096,
},
"system_instruction": system_prompt,
}
# Convert and add tools if available
if tools and mcp_handler:
gemini_tools = self._convert_tools_to_gemini_format(tools)
model_kwargs["tools"] = gemini_tools
model = genai.GenerativeModel(**model_kwargs)
# Start chat
chat = model.start_chat(history=gemini_history)
# Send message with tool execution loop
tools_used = []
max_iterations = 5
iteration = 0
current_message = message
while iteration < max_iterations:
iteration += 1
response = await asyncio.to_thread(chat.send_message, current_message)
# Check for function calls
if response.candidates and response.candidates[0].content.parts:
has_function_call = False
for part in response.candidates[0].content.parts:
if hasattr(part, "function_call") and part.function_call:
has_function_call = True
func_call = part.function_call
tool_name = func_call.name
tool_args = dict(func_call.args)
# Execute tool
result = await mcp_handler.execute_tool(
tool_name=tool_name, arguments=tool_args, context=mcp_context
)
tools_used.append(
{
"tool": tool_name,
"success": result.success,
"duration_ms": result.duration_ms,
}
)
# Format result for Gemini
function_response = {
"name": tool_name,
"response": result.result
if result.success
else {"error": result.error},
}
# Send function response back
current_message = genai.protos.Content(
parts=[
genai.protos.Part(
function_response=genai.protos.FunctionResponse(
name=tool_name, response=function_response
)
)
]
)
break
# If no function call, we're done
if not has_function_call:
break
else:
break
# Extract final response text
response_text = ""
if response.candidates and response.candidates[0].content.parts:
for part in response.candidates[0].content.parts:
if hasattr(part, "text"):
response_text += part.text
return {
"success": True,
"response": response_text,
"model": self.default_model,
"tools_used": tools_used,
"stop_reason": "stop" if response.candidates else "error",
}
# Factory function
def get_ai_provider(model_name: str) -> AIProvider:
"""Get AI provider by name."""
if model_name.lower() in ["gemini", "google"]:
return GeminiProvider()
else:
return ClaudeProvider()