feat(agent): add current time retrieval functionality and improve agent execution timeout handling
This commit is contained in:
parent
b8a95e047f
commit
5b7b690b20
@ -1,4 +1,5 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from zoneinfo import ZoneInfo
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents import SequentialAgent, ParallelAgent, LoopAgent, BaseAgent
|
||||
from google.adk.models.lite_llm import LiteLlm
|
||||
@ -23,6 +24,63 @@ class AgentBuilder:
|
||||
self.custom_tool_builder = CustomToolBuilder()
|
||||
self.mcp_service = MCPService()
|
||||
|
||||
def _get_current_time(self, city: str = "new york") -> dict:
|
||||
"""Get the current time in a city."""
|
||||
city_timezones = {
|
||||
"new york": "America/New_York",
|
||||
"los angeles": "America/Los_Angeles",
|
||||
"chicago": "America/Chicago",
|
||||
"toronto": "America/Toronto",
|
||||
"mexico city": "America/Mexico_City",
|
||||
"sao paulo": "America/Sao_Paulo",
|
||||
"rio de janeiro": "America/Sao_Paulo",
|
||||
"buenos aires": "America/Argentina/Buenos_Aires",
|
||||
"london": "Europe/London",
|
||||
"paris": "Europe/Paris",
|
||||
"berlin": "Europe/Berlin",
|
||||
"rome": "Europe/Rome",
|
||||
"madrid": "Europe/Madrid",
|
||||
"moscow": "Europe/Moscow",
|
||||
"dubai": "Asia/Dubai",
|
||||
"mumbai": "Asia/Kolkata",
|
||||
"delhi": "Asia/Kolkata",
|
||||
"singapore": "Asia/Singapore",
|
||||
"hong kong": "Asia/Hong_Kong",
|
||||
"beijing": "Asia/Shanghai",
|
||||
"shanghai": "Asia/Shanghai",
|
||||
"tokyo": "Asia/Tokyo",
|
||||
"seoul": "Asia/Seoul",
|
||||
"sydney": "Australia/Sydney",
|
||||
"melbourne": "Australia/Melbourne",
|
||||
"auckland": "Pacific/Auckland",
|
||||
"johannesburg": "Africa/Johannesburg",
|
||||
"cairo": "Africa/Cairo",
|
||||
"lagos": "Africa/Lagos",
|
||||
}
|
||||
|
||||
try:
|
||||
if city.lower() in city_timezones:
|
||||
try:
|
||||
tz = ZoneInfo(city_timezones[city.lower()])
|
||||
now = datetime.now(tz)
|
||||
return {
|
||||
"status": "success",
|
||||
"report": f"The current time in {city} is {now.strftime('%Y-%m-%d %H:%M:%S %Z')}",
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"status": "error",
|
||||
"error_message": f"Time information for '{city}' unavailable.",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current time: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error_message": f"Error getting current time: {e}",
|
||||
}
|
||||
|
||||
async def _create_llm_agent(
|
||||
self, agent
|
||||
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
||||
@ -41,7 +99,7 @@ class AgentBuilder:
|
||||
)
|
||||
|
||||
# Combine all tools
|
||||
all_tools = custom_tools + mcp_tools
|
||||
all_tools = custom_tools + mcp_tools + [self._get_current_time]
|
||||
|
||||
now = datetime.now()
|
||||
current_datetime = now.strftime("%d/%m/%Y %H:%M")
|
||||
@ -57,6 +115,9 @@ class AgentBuilder:
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
# Add get_current_time instructions to prompt
|
||||
formatted_prompt += "\n\n<get_current_time_instructions>Use the get_current_time tool to get the current time in a city. The tool is available in the tools section of the configuration. Use 'new york' by default if no city is provided.</get_current_time_instructions>\n\n"
|
||||
|
||||
# Check if load_memory is enabled
|
||||
# before_model_callback_func = None
|
||||
if agent.config.get("load_memory"):
|
||||
|
@ -23,7 +23,9 @@ async def run_agent(
|
||||
memory_service: InMemoryMemoryService,
|
||||
db: Session,
|
||||
session_id: Optional[str] = None,
|
||||
timeout: float = 60.0,
|
||||
):
|
||||
exit_stack = None
|
||||
try:
|
||||
logger.info(f"Starting execution of agent {agent_id} for contact {contact_id}")
|
||||
logger.info(f"Received message: {message}")
|
||||
@ -72,15 +74,61 @@ async def run_agent(
|
||||
|
||||
final_response_text = None
|
||||
try:
|
||||
for event in agent_runner.run(
|
||||
user_id=contact_id,
|
||||
session_id=adk_session_id,
|
||||
new_message=content,
|
||||
):
|
||||
if event.is_final_response() and event.content and event.content.parts:
|
||||
final_response_text = event.content.parts[0].text
|
||||
logger.info(f"Final response received: {final_response_text}")
|
||||
response_queue = asyncio.Queue()
|
||||
execution_completed = asyncio.Event()
|
||||
|
||||
async def process_events():
|
||||
try:
|
||||
events_async = agent_runner.run_async(
|
||||
user_id=contact_id,
|
||||
session_id=adk_session_id,
|
||||
new_message=content,
|
||||
)
|
||||
|
||||
async for event in events_async:
|
||||
if event.is_final_response():
|
||||
if event.content and event.content.parts:
|
||||
# Assuming text response in the first part
|
||||
await response_queue.put(event.content.parts[0].text)
|
||||
elif event.actions and event.actions.escalate:
|
||||
await response_queue.put(
|
||||
f"Agent escalated: {event.error_message or 'No specific message.'}"
|
||||
)
|
||||
|
||||
execution_completed.set()
|
||||
break
|
||||
|
||||
if not execution_completed.is_set():
|
||||
await response_queue.put("Finished without specific response")
|
||||
execution_completed.set()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_events: {str(e)}")
|
||||
await response_queue.put(f"Error: {str(e)}")
|
||||
execution_completed.set()
|
||||
|
||||
task = asyncio.create_task(process_events())
|
||||
|
||||
try:
|
||||
wait_task = asyncio.create_task(execution_completed.wait())
|
||||
done, pending = await asyncio.wait({wait_task}, timeout=timeout)
|
||||
|
||||
for p in pending:
|
||||
p.cancel()
|
||||
|
||||
if not execution_completed.is_set():
|
||||
logger.warning(f"Agent execution timed out after {timeout} seconds")
|
||||
await response_queue.put(
|
||||
"The response took too long and was interrupted."
|
||||
)
|
||||
|
||||
final_response_text = await response_queue.get()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error waiting for response: {str(e)}")
|
||||
final_response_text = f"Error processing response: {str(e)}"
|
||||
|
||||
# Add the session to memory after completion
|
||||
completed_session = session_service.get_session(
|
||||
app_name=agent_id,
|
||||
user_id=contact_id,
|
||||
@ -89,10 +137,19 @@ async def run_agent(
|
||||
|
||||
memory_service.add_session_to_memory(completed_session)
|
||||
|
||||
finally:
|
||||
# Ensure the exit_stack is closed correctly
|
||||
if exit_stack:
|
||||
await exit_stack.aclose()
|
||||
# Cancel the processing task if it is still running
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling task: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing request: {str(e)}")
|
||||
raise e
|
||||
|
||||
logger.info("Agent execution completed successfully")
|
||||
return final_response_text
|
||||
@ -102,6 +159,15 @@ async def run_agent(
|
||||
except Exception as e:
|
||||
logger.error(f"Internal error processing request: {str(e)}", exc_info=True)
|
||||
raise InternalServerError(str(e))
|
||||
finally:
|
||||
# Clean up MCP connection - MUST be executed in the same task
|
||||
if exit_stack:
|
||||
logger.info("Closing MCP server connection...")
|
||||
try:
|
||||
await exit_stack.aclose()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing MCP connection: {e}")
|
||||
# Do not raise the exception to not obscure the original error
|
||||
|
||||
|
||||
async def run_agent_stream(
|
||||
@ -163,11 +229,13 @@ async def run_agent_stream(
|
||||
logger.info("Starting agent streaming execution")
|
||||
|
||||
try:
|
||||
for event in agent_runner.run(
|
||||
events_async = agent_runner.run_async(
|
||||
user_id=contact_id,
|
||||
session_id=adk_session_id,
|
||||
new_message=content,
|
||||
):
|
||||
)
|
||||
|
||||
async for event in events_async:
|
||||
if event.content and event.content.parts:
|
||||
text = event.content.parts[0].text
|
||||
if text:
|
||||
@ -181,14 +249,19 @@ async def run_agent_stream(
|
||||
)
|
||||
|
||||
memory_service.add_session_to_memory(completed_session)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing request: {str(e)}")
|
||||
raise e
|
||||
finally:
|
||||
# Ensure the exit_stack is closed correctly
|
||||
# Clean up MCP connection
|
||||
if exit_stack:
|
||||
await exit_stack.aclose()
|
||||
logger.info("Closing MCP server connection...")
|
||||
try:
|
||||
await exit_stack.aclose()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing MCP connection: {e}")
|
||||
|
||||
logger.info("Agent streaming execution completed successfully")
|
||||
|
||||
except AgentNotFoundError as e:
|
||||
logger.error(f"Error processing request: {str(e)}")
|
||||
raise e
|
||||
|
@ -12,6 +12,14 @@ import httpx
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Helper function to generate API keys
|
||||
def generate_api_key() -> str:
|
||||
"""Generate a secure API key"""
|
||||
# Format: sk-proj-{random 64 chars}
|
||||
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _convert_uuid_to_str(obj):
|
||||
"""
|
||||
Recursively convert all UUID objects to strings in a dictionary, list or scalar value.
|
||||
@ -203,6 +211,11 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent:
|
||||
for tool in config["tools"]
|
||||
]
|
||||
|
||||
# Generate automatic API key if not provided or empty
|
||||
if not config.get("api_key") or config.get("api_key") == "":
|
||||
logger.info(f"Generating automatic API key for new agent")
|
||||
config["api_key"] = generate_api_key()
|
||||
|
||||
agent.config = config
|
||||
|
||||
# Ensure all config objects are serializable (convert UUIDs to strings)
|
||||
@ -391,6 +404,19 @@ async def update_agent(
|
||||
if "config" in agent_data and agent_data["config"] is not None:
|
||||
agent_data["config"] = _convert_uuid_to_str(agent_data["config"])
|
||||
|
||||
# Check if the agent has API key and generate one if not
|
||||
agent_config = agent.config or {}
|
||||
if "config" not in agent_data:
|
||||
agent_data["config"] = agent_config
|
||||
|
||||
if not agent_config.get("api_key") and (
|
||||
"config" not in agent_data or not agent_data["config"].get("api_key")
|
||||
):
|
||||
logger.info(f"Generating missing API key for existing agent: {agent_id}")
|
||||
if "config" not in agent_data:
|
||||
agent_data["config"] = {}
|
||||
agent_data["config"]["api_key"] = generate_api_key()
|
||||
|
||||
for key, value in agent_data.items():
|
||||
setattr(agent, key, value)
|
||||
|
||||
|
@ -92,61 +92,69 @@ class MCPService:
|
||||
self.tools = []
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
# Process each MCP server in the configuration
|
||||
for server in mcp_config.get("mcp_servers", []):
|
||||
try:
|
||||
# Search for the MCP server in the database
|
||||
mcp_server = get_mcp_server(db, server["id"])
|
||||
if not mcp_server:
|
||||
logger.warning(f"Servidor MCP não encontrado: {server['id']}")
|
||||
try:
|
||||
# Process each MCP server in the configuration
|
||||
for server in mcp_config.get("mcp_servers", []):
|
||||
try:
|
||||
# Search for the MCP server in the database
|
||||
mcp_server = get_mcp_server(db, server["id"])
|
||||
if not mcp_server:
|
||||
logger.warning(f"MCP Server not found: {server['id']}")
|
||||
continue
|
||||
|
||||
# Prepares the server configuration
|
||||
server_config = mcp_server.config_json.copy()
|
||||
|
||||
# Replaces the environment variables in the config_json
|
||||
if "env" in server_config:
|
||||
for key, value in server_config["env"].items():
|
||||
if value.startswith("env@@"):
|
||||
env_key = value.replace("env@@", "")
|
||||
if env_key in server.get("envs", {}):
|
||||
server_config["env"][key] = server["envs"][env_key]
|
||||
else:
|
||||
logger.warning(
|
||||
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Connecting to MCP server: {mcp_server.name}")
|
||||
tools, exit_stack = await self._connect_to_mcp_server(server_config)
|
||||
|
||||
if tools and exit_stack:
|
||||
# Filters incompatible tools
|
||||
filtered_tools = self._filter_incompatible_tools(tools)
|
||||
|
||||
# Filters tools compatible with the agent
|
||||
agent_tools = server.get("tools", [])
|
||||
filtered_tools = self._filter_tools_by_agent(
|
||||
filtered_tools, agent_tools
|
||||
)
|
||||
self.tools.extend(filtered_tools)
|
||||
|
||||
# Registers the exit_stack with the AsyncExitStack
|
||||
await self.exit_stack.enter_async_context(exit_stack)
|
||||
logger.info(
|
||||
f"Connected successfully. Added {len(filtered_tools)} tools."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to connect or no tools available for {mcp_server.name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to MCP server {server['id']}: {e}")
|
||||
continue
|
||||
|
||||
# Prepares the server configuration
|
||||
server_config = mcp_server.config_json.copy()
|
||||
logger.info(
|
||||
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
|
||||
)
|
||||
|
||||
# Replaces the environment variables in the config_json
|
||||
if "env" in server_config:
|
||||
for key, value in server_config["env"].items():
|
||||
if value.startswith("env@@"):
|
||||
env_key = value.replace("env@@", "")
|
||||
if env_key in server.get("envs", {}):
|
||||
server_config["env"][key] = server["envs"][env_key]
|
||||
else:
|
||||
logger.warning(
|
||||
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Connecting to MCP server: {mcp_server.name}")
|
||||
tools, exit_stack = await self._connect_to_mcp_server(server_config)
|
||||
|
||||
if tools and exit_stack:
|
||||
# Filters incompatible tools
|
||||
filtered_tools = self._filter_incompatible_tools(tools)
|
||||
|
||||
# Filters tools compatible with the agent
|
||||
agent_tools = server.get("tools", [])
|
||||
filtered_tools = self._filter_tools_by_agent(
|
||||
filtered_tools, agent_tools
|
||||
)
|
||||
self.tools.extend(filtered_tools)
|
||||
|
||||
# Registers the exit_stack with the AsyncExitStack
|
||||
await self.exit_stack.enter_async_context(exit_stack)
|
||||
logger.info(
|
||||
f"Connected successfully. Added {len(filtered_tools)} tools."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to connect or no tools available for {mcp_server.name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to MCP server {server['id']}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
|
||||
)
|
||||
except Exception as e:
|
||||
# Ensure cleanup
|
||||
await self.exit_stack.aclose()
|
||||
logger.error(f"Fatal error connecting to MCP servers: {e}")
|
||||
# Recreate an empty exit_stack
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
return self.tools, self.exit_stack
|
||||
|
Loading…
Reference in New Issue
Block a user