feat(agent): add current time retrieval functionality and improve agent execution timeout handling
This commit is contained in:
parent
b8a95e047f
commit
5b7b690b20
@ -1,4 +1,5 @@
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
from google.adk.agents.llm_agent import LlmAgent
|
from google.adk.agents.llm_agent import LlmAgent
|
||||||
from google.adk.agents import SequentialAgent, ParallelAgent, LoopAgent, BaseAgent
|
from google.adk.agents import SequentialAgent, ParallelAgent, LoopAgent, BaseAgent
|
||||||
from google.adk.models.lite_llm import LiteLlm
|
from google.adk.models.lite_llm import LiteLlm
|
||||||
@ -23,6 +24,63 @@ class AgentBuilder:
|
|||||||
self.custom_tool_builder = CustomToolBuilder()
|
self.custom_tool_builder = CustomToolBuilder()
|
||||||
self.mcp_service = MCPService()
|
self.mcp_service = MCPService()
|
||||||
|
|
||||||
|
def _get_current_time(self, city: str = "new york") -> dict:
|
||||||
|
"""Get the current time in a city."""
|
||||||
|
city_timezones = {
|
||||||
|
"new york": "America/New_York",
|
||||||
|
"los angeles": "America/Los_Angeles",
|
||||||
|
"chicago": "America/Chicago",
|
||||||
|
"toronto": "America/Toronto",
|
||||||
|
"mexico city": "America/Mexico_City",
|
||||||
|
"sao paulo": "America/Sao_Paulo",
|
||||||
|
"rio de janeiro": "America/Sao_Paulo",
|
||||||
|
"buenos aires": "America/Argentina/Buenos_Aires",
|
||||||
|
"london": "Europe/London",
|
||||||
|
"paris": "Europe/Paris",
|
||||||
|
"berlin": "Europe/Berlin",
|
||||||
|
"rome": "Europe/Rome",
|
||||||
|
"madrid": "Europe/Madrid",
|
||||||
|
"moscow": "Europe/Moscow",
|
||||||
|
"dubai": "Asia/Dubai",
|
||||||
|
"mumbai": "Asia/Kolkata",
|
||||||
|
"delhi": "Asia/Kolkata",
|
||||||
|
"singapore": "Asia/Singapore",
|
||||||
|
"hong kong": "Asia/Hong_Kong",
|
||||||
|
"beijing": "Asia/Shanghai",
|
||||||
|
"shanghai": "Asia/Shanghai",
|
||||||
|
"tokyo": "Asia/Tokyo",
|
||||||
|
"seoul": "Asia/Seoul",
|
||||||
|
"sydney": "Australia/Sydney",
|
||||||
|
"melbourne": "Australia/Melbourne",
|
||||||
|
"auckland": "Pacific/Auckland",
|
||||||
|
"johannesburg": "Africa/Johannesburg",
|
||||||
|
"cairo": "Africa/Cairo",
|
||||||
|
"lagos": "Africa/Lagos",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if city.lower() in city_timezones:
|
||||||
|
try:
|
||||||
|
tz = ZoneInfo(city_timezones[city.lower()])
|
||||||
|
now = datetime.now(tz)
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"report": f"The current time in {city} is {now.strftime('%Y-%m-%d %H:%M:%S %Z')}",
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error_message": f"Time information for '{city}' unavailable.",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting current time: {e}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error_message": f"Error getting current time: {e}",
|
||||||
|
}
|
||||||
|
|
||||||
async def _create_llm_agent(
|
async def _create_llm_agent(
|
||||||
self, agent
|
self, agent
|
||||||
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
||||||
@ -41,7 +99,7 @@ class AgentBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Combine all tools
|
# Combine all tools
|
||||||
all_tools = custom_tools + mcp_tools
|
all_tools = custom_tools + mcp_tools + [self._get_current_time]
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
current_datetime = now.strftime("%d/%m/%Y %H:%M")
|
current_datetime = now.strftime("%d/%m/%Y %H:%M")
|
||||||
@ -57,6 +115,9 @@ class AgentBuilder:
|
|||||||
current_time=current_time,
|
current_time=current_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add get_current_time instructions to prompt
|
||||||
|
formatted_prompt += "\n\n<get_current_time_instructions>Use the get_current_time tool to get the current time in a city. The tool is available in the tools section of the configuration. Use 'new york' by default if no city is provided.</get_current_time_instructions>\n\n"
|
||||||
|
|
||||||
# Check if load_memory is enabled
|
# Check if load_memory is enabled
|
||||||
# before_model_callback_func = None
|
# before_model_callback_func = None
|
||||||
if agent.config.get("load_memory"):
|
if agent.config.get("load_memory"):
|
||||||
|
@ -23,7 +23,9 @@ async def run_agent(
|
|||||||
memory_service: InMemoryMemoryService,
|
memory_service: InMemoryMemoryService,
|
||||||
db: Session,
|
db: Session,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
|
timeout: float = 60.0,
|
||||||
):
|
):
|
||||||
|
exit_stack = None
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting execution of agent {agent_id} for contact {contact_id}")
|
logger.info(f"Starting execution of agent {agent_id} for contact {contact_id}")
|
||||||
logger.info(f"Received message: {message}")
|
logger.info(f"Received message: {message}")
|
||||||
@ -72,15 +74,61 @@ async def run_agent(
|
|||||||
|
|
||||||
final_response_text = None
|
final_response_text = None
|
||||||
try:
|
try:
|
||||||
for event in agent_runner.run(
|
response_queue = asyncio.Queue()
|
||||||
|
execution_completed = asyncio.Event()
|
||||||
|
|
||||||
|
async def process_events():
|
||||||
|
try:
|
||||||
|
events_async = agent_runner.run_async(
|
||||||
user_id=contact_id,
|
user_id=contact_id,
|
||||||
session_id=adk_session_id,
|
session_id=adk_session_id,
|
||||||
new_message=content,
|
new_message=content,
|
||||||
):
|
)
|
||||||
if event.is_final_response() and event.content and event.content.parts:
|
|
||||||
final_response_text = event.content.parts[0].text
|
|
||||||
logger.info(f"Final response received: {final_response_text}")
|
|
||||||
|
|
||||||
|
async for event in events_async:
|
||||||
|
if event.is_final_response():
|
||||||
|
if event.content and event.content.parts:
|
||||||
|
# Assuming text response in the first part
|
||||||
|
await response_queue.put(event.content.parts[0].text)
|
||||||
|
elif event.actions and event.actions.escalate:
|
||||||
|
await response_queue.put(
|
||||||
|
f"Agent escalated: {event.error_message or 'No specific message.'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_completed.set()
|
||||||
|
break
|
||||||
|
|
||||||
|
if not execution_completed.is_set():
|
||||||
|
await response_queue.put("Finished without specific response")
|
||||||
|
execution_completed.set()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in process_events: {str(e)}")
|
||||||
|
await response_queue.put(f"Error: {str(e)}")
|
||||||
|
execution_completed.set()
|
||||||
|
|
||||||
|
task = asyncio.create_task(process_events())
|
||||||
|
|
||||||
|
try:
|
||||||
|
wait_task = asyncio.create_task(execution_completed.wait())
|
||||||
|
done, pending = await asyncio.wait({wait_task}, timeout=timeout)
|
||||||
|
|
||||||
|
for p in pending:
|
||||||
|
p.cancel()
|
||||||
|
|
||||||
|
if not execution_completed.is_set():
|
||||||
|
logger.warning(f"Agent execution timed out after {timeout} seconds")
|
||||||
|
await response_queue.put(
|
||||||
|
"The response took too long and was interrupted."
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response_text = await response_queue.get()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error waiting for response: {str(e)}")
|
||||||
|
final_response_text = f"Error processing response: {str(e)}"
|
||||||
|
|
||||||
|
# Add the session to memory after completion
|
||||||
completed_session = session_service.get_session(
|
completed_session = session_service.get_session(
|
||||||
app_name=agent_id,
|
app_name=agent_id,
|
||||||
user_id=contact_id,
|
user_id=contact_id,
|
||||||
@ -89,10 +137,19 @@ async def run_agent(
|
|||||||
|
|
||||||
memory_service.add_session_to_memory(completed_session)
|
memory_service.add_session_to_memory(completed_session)
|
||||||
|
|
||||||
finally:
|
# Cancel the processing task if it is still running
|
||||||
# Ensure the exit_stack is closed correctly
|
if not task.done():
|
||||||
if exit_stack:
|
task.cancel()
|
||||||
await exit_stack.aclose()
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Task cancelled successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cancelling task: {str(e)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing request: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
logger.info("Agent execution completed successfully")
|
logger.info("Agent execution completed successfully")
|
||||||
return final_response_text
|
return final_response_text
|
||||||
@ -102,6 +159,15 @@ async def run_agent(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Internal error processing request: {str(e)}", exc_info=True)
|
logger.error(f"Internal error processing request: {str(e)}", exc_info=True)
|
||||||
raise InternalServerError(str(e))
|
raise InternalServerError(str(e))
|
||||||
|
finally:
|
||||||
|
# Clean up MCP connection - MUST be executed in the same task
|
||||||
|
if exit_stack:
|
||||||
|
logger.info("Closing MCP server connection...")
|
||||||
|
try:
|
||||||
|
await exit_stack.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing MCP connection: {e}")
|
||||||
|
# Do not raise the exception to not obscure the original error
|
||||||
|
|
||||||
|
|
||||||
async def run_agent_stream(
|
async def run_agent_stream(
|
||||||
@ -163,11 +229,13 @@ async def run_agent_stream(
|
|||||||
logger.info("Starting agent streaming execution")
|
logger.info("Starting agent streaming execution")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for event in agent_runner.run(
|
events_async = agent_runner.run_async(
|
||||||
user_id=contact_id,
|
user_id=contact_id,
|
||||||
session_id=adk_session_id,
|
session_id=adk_session_id,
|
||||||
new_message=content,
|
new_message=content,
|
||||||
):
|
)
|
||||||
|
|
||||||
|
async for event in events_async:
|
||||||
if event.content and event.content.parts:
|
if event.content and event.content.parts:
|
||||||
text = event.content.parts[0].text
|
text = event.content.parts[0].text
|
||||||
if text:
|
if text:
|
||||||
@ -181,14 +249,19 @@ async def run_agent_stream(
|
|||||||
)
|
)
|
||||||
|
|
||||||
memory_service.add_session_to_memory(completed_session)
|
memory_service.add_session_to_memory(completed_session)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing request: {str(e)}")
|
||||||
|
raise e
|
||||||
finally:
|
finally:
|
||||||
# Ensure the exit_stack is closed correctly
|
# Clean up MCP connection
|
||||||
if exit_stack:
|
if exit_stack:
|
||||||
|
logger.info("Closing MCP server connection...")
|
||||||
|
try:
|
||||||
await exit_stack.aclose()
|
await exit_stack.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing MCP connection: {e}")
|
||||||
|
|
||||||
logger.info("Agent streaming execution completed successfully")
|
logger.info("Agent streaming execution completed successfully")
|
||||||
|
|
||||||
except AgentNotFoundError as e:
|
except AgentNotFoundError as e:
|
||||||
logger.error(f"Error processing request: {str(e)}")
|
logger.error(f"Error processing request: {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
@ -12,6 +12,14 @@ import httpx
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to generate API keys
|
||||||
|
def generate_api_key() -> str:
|
||||||
|
"""Generate a secure API key"""
|
||||||
|
# Format: sk-proj-{random 64 chars}
|
||||||
|
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
def _convert_uuid_to_str(obj):
|
def _convert_uuid_to_str(obj):
|
||||||
"""
|
"""
|
||||||
Recursively convert all UUID objects to strings in a dictionary, list or scalar value.
|
Recursively convert all UUID objects to strings in a dictionary, list or scalar value.
|
||||||
@ -203,6 +211,11 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent:
|
|||||||
for tool in config["tools"]
|
for tool in config["tools"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Generate automatic API key if not provided or empty
|
||||||
|
if not config.get("api_key") or config.get("api_key") == "":
|
||||||
|
logger.info(f"Generating automatic API key for new agent")
|
||||||
|
config["api_key"] = generate_api_key()
|
||||||
|
|
||||||
agent.config = config
|
agent.config = config
|
||||||
|
|
||||||
# Ensure all config objects are serializable (convert UUIDs to strings)
|
# Ensure all config objects are serializable (convert UUIDs to strings)
|
||||||
@ -391,6 +404,19 @@ async def update_agent(
|
|||||||
if "config" in agent_data and agent_data["config"] is not None:
|
if "config" in agent_data and agent_data["config"] is not None:
|
||||||
agent_data["config"] = _convert_uuid_to_str(agent_data["config"])
|
agent_data["config"] = _convert_uuid_to_str(agent_data["config"])
|
||||||
|
|
||||||
|
# Check if the agent has API key and generate one if not
|
||||||
|
agent_config = agent.config or {}
|
||||||
|
if "config" not in agent_data:
|
||||||
|
agent_data["config"] = agent_config
|
||||||
|
|
||||||
|
if not agent_config.get("api_key") and (
|
||||||
|
"config" not in agent_data or not agent_data["config"].get("api_key")
|
||||||
|
):
|
||||||
|
logger.info(f"Generating missing API key for existing agent: {agent_id}")
|
||||||
|
if "config" not in agent_data:
|
||||||
|
agent_data["config"] = {}
|
||||||
|
agent_data["config"]["api_key"] = generate_api_key()
|
||||||
|
|
||||||
for key, value in agent_data.items():
|
for key, value in agent_data.items():
|
||||||
setattr(agent, key, value)
|
setattr(agent, key, value)
|
||||||
|
|
||||||
|
@ -92,13 +92,14 @@ class MCPService:
|
|||||||
self.tools = []
|
self.tools = []
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
|
try:
|
||||||
# Process each MCP server in the configuration
|
# Process each MCP server in the configuration
|
||||||
for server in mcp_config.get("mcp_servers", []):
|
for server in mcp_config.get("mcp_servers", []):
|
||||||
try:
|
try:
|
||||||
# Search for the MCP server in the database
|
# Search for the MCP server in the database
|
||||||
mcp_server = get_mcp_server(db, server["id"])
|
mcp_server = get_mcp_server(db, server["id"])
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
logger.warning(f"Servidor MCP não encontrado: {server['id']}")
|
logger.warning(f"MCP Server not found: {server['id']}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Prepares the server configuration
|
# Prepares the server configuration
|
||||||
@ -149,4 +150,11 @@ class MCPService:
|
|||||||
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
|
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Ensure cleanup
|
||||||
|
await self.exit_stack.aclose()
|
||||||
|
logger.error(f"Fatal error connecting to MCP servers: {e}")
|
||||||
|
# Recreate an empty exit_stack
|
||||||
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
return self.tools, self.exit_stack
|
return self.tools, self.exit_stack
|
||||||
|
Loading…
Reference in New Issue
Block a user