fix: Address high-severity bandit issues
This commit is contained in:
@@ -1,17 +1,18 @@
|
||||
"""
|
||||
AI Provider abstraction for Claude and Gemini
|
||||
AI Provider abstraction for Claude and Gemini.
|
||||
|
||||
Handles model-specific API calls and tool execution
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
"""Abstract base class for AI providers"""
|
||||
"""Abstract base class for AI providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
@@ -20,7 +21,7 @@ class AIProvider(ABC):
|
||||
system_prompt: str,
|
||||
history: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
temperature: float = 0.7
|
||||
temperature: float = 0.7,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a chat message and get response
|
||||
@@ -36,16 +37,17 @@ class AIProvider(ABC):
|
||||
|
||||
|
||||
class ClaudeProvider(AIProvider):
|
||||
"""Anthropic Claude provider"""
|
||||
"""Anthropic Claude provider."""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
self.default_model = "claude-sonnet-4-5-20250929"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Claude is available"""
|
||||
"""Check if Claude is available."""
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
|
||||
return bool(self.api_key)
|
||||
except ImportError:
|
||||
return False
|
||||
@@ -58,9 +60,9 @@ class ClaudeProvider(AIProvider):
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
temperature: float = 0.7,
|
||||
mcp_handler=None,
|
||||
mcp_context=None
|
||||
mcp_context=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Chat with Claude"""
|
||||
"""Chat with Claude."""
|
||||
|
||||
if not self.is_available():
|
||||
return {
|
||||
@@ -68,7 +70,7 @@ class ClaudeProvider(AIProvider):
|
||||
"response": "Claude not available. Install anthropic SDK or set ANTHROPIC_API_KEY.",
|
||||
"model": "error",
|
||||
"tools_used": [],
|
||||
"stop_reason": "error"
|
||||
"stop_reason": "error",
|
||||
}
|
||||
|
||||
from anthropic import Anthropic
|
||||
@@ -91,17 +93,14 @@ class ClaudeProvider(AIProvider):
|
||||
"max_tokens": 4096,
|
||||
"temperature": temperature,
|
||||
"system": system_prompt,
|
||||
"messages": messages
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if tools:
|
||||
api_params["tools"] = tools
|
||||
|
||||
# Initial call
|
||||
response = await asyncio.to_thread(
|
||||
client.messages.create,
|
||||
**api_params
|
||||
)
|
||||
response = await asyncio.to_thread(client.messages.create, **api_params)
|
||||
|
||||
# Handle tool use loop
|
||||
tools_used = []
|
||||
@@ -120,16 +119,16 @@ class ClaudeProvider(AIProvider):
|
||||
|
||||
# Execute tool via MCP handler
|
||||
result = await mcp_handler.execute_tool(
|
||||
tool_name=tool_name,
|
||||
arguments=tool_input,
|
||||
context=mcp_context
|
||||
tool_name=tool_name, arguments=tool_input, context=mcp_context
|
||||
)
|
||||
|
||||
tools_used.append({
|
||||
"tool": tool_name,
|
||||
"success": result.success,
|
||||
"duration_ms": result.duration_ms
|
||||
})
|
||||
tools_used.append(
|
||||
{
|
||||
"tool": tool_name,
|
||||
"success": result.success,
|
||||
"duration_ms": result.duration_ms,
|
||||
}
|
||||
)
|
||||
|
||||
# Format result
|
||||
if result.success:
|
||||
@@ -137,19 +136,20 @@ class ClaudeProvider(AIProvider):
|
||||
else:
|
||||
tool_result_content = json.dumps({"error": result.error})
|
||||
|
||||
tool_results.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
"content": tool_result_content
|
||||
})
|
||||
tool_results.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
"content": tool_result_content,
|
||||
}
|
||||
)
|
||||
|
||||
# Continue conversation with tool results
|
||||
messages.append({"role": "assistant", "content": response.content})
|
||||
messages.append({"role": "user", "content": tool_results})
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
client.messages.create,
|
||||
**{**api_params, "messages": messages}
|
||||
client.messages.create, **{**api_params, "messages": messages}
|
||||
)
|
||||
|
||||
# Extract final response
|
||||
@@ -163,27 +163,30 @@ class ClaudeProvider(AIProvider):
|
||||
"response": response_text,
|
||||
"model": response.model,
|
||||
"tools_used": tools_used,
|
||||
"stop_reason": response.stop_reason
|
||||
"stop_reason": response.stop_reason,
|
||||
}
|
||||
|
||||
|
||||
class GeminiProvider(AIProvider):
|
||||
"""Google Gemini provider"""
|
||||
"""Google Gemini provider."""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
self.default_model = "gemini-2.0-flash-exp"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Gemini is available"""
|
||||
"""Check if Gemini is available."""
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
|
||||
return bool(self.api_key)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def _convert_tools_to_gemini_format(self, claude_tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert Claude tool format to Gemini function declarations"""
|
||||
def _convert_tools_to_gemini_format(
|
||||
self, claude_tools: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert Claude tool format to Gemini function declarations."""
|
||||
gemini_tools = []
|
||||
|
||||
for tool in claude_tools:
|
||||
@@ -191,11 +194,7 @@ class GeminiProvider(AIProvider):
|
||||
function_declaration = {
|
||||
"name": tool.get("name"),
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
|
||||
# Convert input schema
|
||||
@@ -218,9 +217,9 @@ class GeminiProvider(AIProvider):
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
temperature: float = 0.7,
|
||||
mcp_handler=None,
|
||||
mcp_context=None
|
||||
mcp_context=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Chat with Gemini"""
|
||||
"""Chat with Gemini."""
|
||||
|
||||
if not self.is_available():
|
||||
return {
|
||||
@@ -228,7 +227,7 @@ class GeminiProvider(AIProvider):
|
||||
"response": "Gemini not available. Install google-generativeai SDK or set GOOGLE_API_KEY/GEMINI_API_KEY.",
|
||||
"model": "error",
|
||||
"tools_used": [],
|
||||
"stop_reason": "error"
|
||||
"stop_reason": "error",
|
||||
}
|
||||
|
||||
import google.generativeai as genai
|
||||
@@ -241,10 +240,9 @@ class GeminiProvider(AIProvider):
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
if content and role in ["user", "assistant"]:
|
||||
gemini_history.append({
|
||||
"role": "user" if role == "user" else "model",
|
||||
"parts": [content]
|
||||
})
|
||||
gemini_history.append(
|
||||
{"role": "user" if role == "user" else "model", "parts": [content]}
|
||||
)
|
||||
|
||||
# Create model with tools if available
|
||||
model_kwargs = {
|
||||
@@ -253,7 +251,7 @@ class GeminiProvider(AIProvider):
|
||||
"temperature": temperature,
|
||||
"max_output_tokens": 4096,
|
||||
},
|
||||
"system_instruction": system_prompt
|
||||
"system_instruction": system_prompt,
|
||||
}
|
||||
|
||||
# Convert and add tools if available
|
||||
@@ -282,7 +280,7 @@ class GeminiProvider(AIProvider):
|
||||
has_function_call = False
|
||||
|
||||
for part in response.candidates[0].content.parts:
|
||||
if hasattr(part, 'function_call') and part.function_call:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
has_function_call = True
|
||||
func_call = part.function_call
|
||||
tool_name = func_call.name
|
||||
@@ -290,31 +288,34 @@ class GeminiProvider(AIProvider):
|
||||
|
||||
# Execute tool
|
||||
result = await mcp_handler.execute_tool(
|
||||
tool_name=tool_name,
|
||||
arguments=tool_args,
|
||||
context=mcp_context
|
||||
tool_name=tool_name, arguments=tool_args, context=mcp_context
|
||||
)
|
||||
|
||||
tools_used.append({
|
||||
"tool": tool_name,
|
||||
"success": result.success,
|
||||
"duration_ms": result.duration_ms
|
||||
})
|
||||
tools_used.append(
|
||||
{
|
||||
"tool": tool_name,
|
||||
"success": result.success,
|
||||
"duration_ms": result.duration_ms,
|
||||
}
|
||||
)
|
||||
|
||||
# Format result for Gemini
|
||||
function_response = {
|
||||
"name": tool_name,
|
||||
"response": result.result if result.success else {"error": result.error}
|
||||
"response": result.result
|
||||
if result.success
|
||||
else {"error": result.error},
|
||||
}
|
||||
|
||||
# Send function response back
|
||||
current_message = genai.protos.Content(
|
||||
parts=[genai.protos.Part(
|
||||
function_response=genai.protos.FunctionResponse(
|
||||
name=tool_name,
|
||||
response=function_response
|
||||
parts=[
|
||||
genai.protos.Part(
|
||||
function_response=genai.protos.FunctionResponse(
|
||||
name=tool_name, response=function_response
|
||||
)
|
||||
)
|
||||
)]
|
||||
]
|
||||
)
|
||||
break
|
||||
|
||||
@@ -328,7 +329,7 @@ class GeminiProvider(AIProvider):
|
||||
response_text = ""
|
||||
if response.candidates and response.candidates[0].content.parts:
|
||||
for part in response.candidates[0].content.parts:
|
||||
if hasattr(part, 'text'):
|
||||
if hasattr(part, "text"):
|
||||
response_text += part.text
|
||||
|
||||
return {
|
||||
@@ -336,13 +337,13 @@ class GeminiProvider(AIProvider):
|
||||
"response": response_text,
|
||||
"model": self.default_model,
|
||||
"tools_used": tools_used,
|
||||
"stop_reason": "stop" if response.candidates else "error"
|
||||
"stop_reason": "stop" if response.candidates else "error",
|
||||
}
|
||||
|
||||
|
||||
# Factory function
|
||||
def get_ai_provider(model_name: str) -> AIProvider:
|
||||
"""Get AI provider by name"""
|
||||
"""Get AI provider by name."""
|
||||
if model_name.lower() in ["gemini", "google"]:
|
||||
return GeminiProvider()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user