351 lines
11 KiB
Python
351 lines
11 KiB
Python
"""
|
|
AI Provider abstraction for Claude and Gemini.
|
|
|
|
Handles model-specific API calls and tool execution
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
class AIProvider(ABC):
|
|
"""Abstract base class for AI providers."""
|
|
|
|
@abstractmethod
|
|
async def chat(
|
|
self,
|
|
message: str,
|
|
system_prompt: str,
|
|
history: List[Dict[str, Any]],
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
temperature: float = 0.7,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Send a chat message and get response
|
|
Returns: {
|
|
"success": bool,
|
|
"response": str,
|
|
"model": str,
|
|
"tools_used": List[Dict],
|
|
"stop_reason": str
|
|
}
|
|
"""
|
|
pass
|
|
|
|
|
|
class ClaudeProvider(AIProvider):
|
|
"""Anthropic Claude provider."""
|
|
|
|
def __init__(self):
|
|
self.api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
self.default_model = "claude-sonnet-4-5-20250929"
|
|
|
|
def is_available(self) -> bool:
|
|
"""Check if Claude is available."""
|
|
try:
|
|
from anthropic import Anthropic
|
|
|
|
return bool(self.api_key)
|
|
except ImportError:
|
|
return False
|
|
|
|
async def chat(
|
|
self,
|
|
message: str,
|
|
system_prompt: str,
|
|
history: List[Dict[str, Any]],
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
temperature: float = 0.7,
|
|
mcp_handler=None,
|
|
mcp_context=None,
|
|
) -> Dict[str, Any]:
|
|
"""Chat with Claude."""
|
|
|
|
if not self.is_available():
|
|
return {
|
|
"success": False,
|
|
"response": "Claude not available. Install anthropic SDK or set ANTHROPIC_API_KEY.",
|
|
"model": "error",
|
|
"tools_used": [],
|
|
"stop_reason": "error",
|
|
}
|
|
|
|
from anthropic import Anthropic
|
|
|
|
client = Anthropic(api_key=self.api_key)
|
|
|
|
# Build messages
|
|
messages = []
|
|
for msg in history[-6:]:
|
|
role = msg.get("role", "user")
|
|
content = msg.get("content", "")
|
|
if content and role in ["user", "assistant"]:
|
|
messages.append({"role": role, "content": content})
|
|
|
|
messages.append({"role": "user", "content": message})
|
|
|
|
# API params
|
|
api_params = {
|
|
"model": self.default_model,
|
|
"max_tokens": 4096,
|
|
"temperature": temperature,
|
|
"system": system_prompt,
|
|
"messages": messages,
|
|
}
|
|
|
|
if tools:
|
|
api_params["tools"] = tools
|
|
|
|
# Initial call
|
|
response = await asyncio.to_thread(client.messages.create, **api_params)
|
|
|
|
# Handle tool use loop
|
|
tools_used = []
|
|
max_iterations = 5
|
|
iteration = 0
|
|
|
|
while response.stop_reason == "tool_use" and iteration < max_iterations:
|
|
iteration += 1
|
|
|
|
tool_results = []
|
|
for content_block in response.content:
|
|
if content_block.type == "tool_use":
|
|
tool_name = content_block.name
|
|
tool_input = content_block.input
|
|
tool_use_id = content_block.id
|
|
|
|
# Execute tool via MCP handler
|
|
result = await mcp_handler.execute_tool(
|
|
tool_name=tool_name, arguments=tool_input, context=mcp_context
|
|
)
|
|
|
|
tools_used.append(
|
|
{
|
|
"tool": tool_name,
|
|
"success": result.success,
|
|
"duration_ms": result.duration_ms,
|
|
}
|
|
)
|
|
|
|
# Format result
|
|
if result.success:
|
|
tool_result_content = json.dumps(result.result, indent=2)
|
|
else:
|
|
tool_result_content = json.dumps({"error": result.error})
|
|
|
|
tool_results.append(
|
|
{
|
|
"type": "tool_result",
|
|
"tool_use_id": tool_use_id,
|
|
"content": tool_result_content,
|
|
}
|
|
)
|
|
|
|
# Continue conversation with tool results
|
|
messages.append({"role": "assistant", "content": response.content})
|
|
messages.append({"role": "user", "content": tool_results})
|
|
|
|
response = await asyncio.to_thread(
|
|
client.messages.create, **{**api_params, "messages": messages}
|
|
)
|
|
|
|
# Extract final response
|
|
response_text = ""
|
|
for content_block in response.content:
|
|
if hasattr(content_block, "text"):
|
|
response_text += content_block.text
|
|
|
|
return {
|
|
"success": True,
|
|
"response": response_text,
|
|
"model": response.model,
|
|
"tools_used": tools_used,
|
|
"stop_reason": response.stop_reason,
|
|
}
|
|
|
|
|
|
class GeminiProvider(AIProvider):
|
|
"""Google Gemini provider."""
|
|
|
|
def __init__(self):
|
|
self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
|
self.default_model = "gemini-2.0-flash-exp"
|
|
|
|
def is_available(self) -> bool:
|
|
"""Check if Gemini is available."""
|
|
try:
|
|
import google.generativeai as genai
|
|
|
|
return bool(self.api_key)
|
|
except ImportError:
|
|
return False
|
|
|
|
def _convert_tools_to_gemini_format(
|
|
self, claude_tools: List[Dict[str, Any]]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Convert Claude tool format to Gemini function declarations."""
|
|
gemini_tools = []
|
|
|
|
for tool in claude_tools:
|
|
# Convert from Claude's format to Gemini's format
|
|
function_declaration = {
|
|
"name": tool.get("name"),
|
|
"description": tool.get("description", ""),
|
|
"parameters": {"type": "object", "properties": {}, "required": []},
|
|
}
|
|
|
|
# Convert input schema
|
|
if "input_schema" in tool:
|
|
schema = tool["input_schema"]
|
|
if "properties" in schema:
|
|
function_declaration["parameters"]["properties"] = schema["properties"]
|
|
if "required" in schema:
|
|
function_declaration["parameters"]["required"] = schema["required"]
|
|
|
|
gemini_tools.append(function_declaration)
|
|
|
|
return gemini_tools
|
|
|
|
async def chat(
|
|
self,
|
|
message: str,
|
|
system_prompt: str,
|
|
history: List[Dict[str, Any]],
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
temperature: float = 0.7,
|
|
mcp_handler=None,
|
|
mcp_context=None,
|
|
) -> Dict[str, Any]:
|
|
"""Chat with Gemini."""
|
|
|
|
if not self.is_available():
|
|
return {
|
|
"success": False,
|
|
"response": "Gemini not available. Install google-generativeai SDK or set GOOGLE_API_KEY/GEMINI_API_KEY.",
|
|
"model": "error",
|
|
"tools_used": [],
|
|
"stop_reason": "error",
|
|
}
|
|
|
|
import google.generativeai as genai
|
|
|
|
genai.configure(api_key=self.api_key)
|
|
|
|
# Build chat history
|
|
gemini_history = []
|
|
for msg in history[-6:]:
|
|
role = msg.get("role", "user")
|
|
content = msg.get("content", "")
|
|
if content and role in ["user", "assistant"]:
|
|
gemini_history.append(
|
|
{"role": "user" if role == "user" else "model", "parts": [content]}
|
|
)
|
|
|
|
# Create model with tools if available
|
|
model_kwargs = {
|
|
"model_name": self.default_model,
|
|
"generation_config": {
|
|
"temperature": temperature,
|
|
"max_output_tokens": 4096,
|
|
},
|
|
"system_instruction": system_prompt,
|
|
}
|
|
|
|
# Convert and add tools if available
|
|
if tools and mcp_handler:
|
|
gemini_tools = self._convert_tools_to_gemini_format(tools)
|
|
model_kwargs["tools"] = gemini_tools
|
|
|
|
model = genai.GenerativeModel(**model_kwargs)
|
|
|
|
# Start chat
|
|
chat = model.start_chat(history=gemini_history)
|
|
|
|
# Send message with tool execution loop
|
|
tools_used = []
|
|
max_iterations = 5
|
|
iteration = 0
|
|
current_message = message
|
|
|
|
while iteration < max_iterations:
|
|
iteration += 1
|
|
|
|
response = await asyncio.to_thread(chat.send_message, current_message)
|
|
|
|
# Check for function calls
|
|
if response.candidates and response.candidates[0].content.parts:
|
|
has_function_call = False
|
|
|
|
for part in response.candidates[0].content.parts:
|
|
if hasattr(part, "function_call") and part.function_call:
|
|
has_function_call = True
|
|
func_call = part.function_call
|
|
tool_name = func_call.name
|
|
tool_args = dict(func_call.args)
|
|
|
|
# Execute tool
|
|
result = await mcp_handler.execute_tool(
|
|
tool_name=tool_name, arguments=tool_args, context=mcp_context
|
|
)
|
|
|
|
tools_used.append(
|
|
{
|
|
"tool": tool_name,
|
|
"success": result.success,
|
|
"duration_ms": result.duration_ms,
|
|
}
|
|
)
|
|
|
|
# Format result for Gemini
|
|
function_response = {
|
|
"name": tool_name,
|
|
"response": result.result
|
|
if result.success
|
|
else {"error": result.error},
|
|
}
|
|
|
|
# Send function response back
|
|
current_message = genai.protos.Content(
|
|
parts=[
|
|
genai.protos.Part(
|
|
function_response=genai.protos.FunctionResponse(
|
|
name=tool_name, response=function_response
|
|
)
|
|
)
|
|
]
|
|
)
|
|
break
|
|
|
|
# If no function call, we're done
|
|
if not has_function_call:
|
|
break
|
|
else:
|
|
break
|
|
|
|
# Extract final response text
|
|
response_text = ""
|
|
if response.candidates and response.candidates[0].content.parts:
|
|
for part in response.candidates[0].content.parts:
|
|
if hasattr(part, "text"):
|
|
response_text += part.text
|
|
|
|
return {
|
|
"success": True,
|
|
"response": response_text,
|
|
"model": self.default_model,
|
|
"tools_used": tools_used,
|
|
"stop_reason": "stop" if response.candidates else "error",
|
|
}
|
|
|
|
|
|
# Factory function
|
|
def get_ai_provider(model_name: str) -> AIProvider:
|
|
"""Get AI provider by name."""
|
|
if model_name.lower() in ["gemini", "google"]:
|
|
return GeminiProvider()
|
|
else:
|
|
return ClaudeProvider()
|