Merge branch 'main' into support-async-tool-callbacks

This commit is contained in:
Alankrit Verma 2025-04-29 22:08:18 -04:00 committed by GitHub
commit 926b0ef1a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 164 additions and 70 deletions

View File

@ -39,12 +39,12 @@ class InputFile(BaseModel):
async def run_input_file( async def run_input_file(
app_name: str, app_name: str,
user_id: str,
root_agent: LlmAgent, root_agent: LlmAgent,
artifact_service: BaseArtifactService, artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService, session_service: BaseSessionService,
input_path: str, input_path: str,
) -> None: ) -> Session:
runner = Runner( runner = Runner(
app_name=app_name, app_name=app_name,
agent=root_agent, agent=root_agent,
@ -55,9 +55,11 @@ async def run_input_file(
input_file = InputFile.model_validate_json(f.read()) input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now() input_file.state['_time'] = datetime.now()
session.state = input_file.state session = session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state
)
for query in input_file.queries: for query in input_file.queries:
click.echo(f'user: {query}') click.echo(f'[user]: {query}')
content = types.Content(role='user', parts=[types.Part(text=query)]) content = types.Content(role='user', parts=[types.Part(text=query)])
async for event in runner.run_async( async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content user_id=session.user_id, session_id=session.id, new_message=content
@ -65,23 +67,23 @@ async def run_input_file(
if event.content and event.content.parts: if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts): if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}') click.echo(f'[{event.author}]: {text}')
return session
async def run_interactively( async def run_interactively(
app_name: str,
root_agent: LlmAgent, root_agent: LlmAgent,
artifact_service: BaseArtifactService, artifact_service: BaseArtifactService,
session: Session, session: Session,
session_service: BaseSessionService, session_service: BaseSessionService,
) -> None: ) -> None:
runner = Runner( runner = Runner(
app_name=app_name, app_name=session.app_name,
agent=root_agent, agent=root_agent,
artifact_service=artifact_service, artifact_service=artifact_service,
session_service=session_service, session_service=session_service,
) )
while True: while True:
query = input('user: ') query = input('[user]: ')
if not query or not query.strip(): if not query or not query.strip():
continue continue
if query == 'exit': if query == 'exit':
@ -100,7 +102,8 @@ async def run_cli(
*, *,
agent_parent_dir: str, agent_parent_dir: str,
agent_folder_name: str, agent_folder_name: str,
json_file_path: Optional[str] = None, input_file: Optional[str] = None,
saved_session_file: Optional[str] = None,
save_session: bool, save_session: bool,
) -> None: ) -> None:
"""Runs an interactive CLI for a certain agent. """Runs an interactive CLI for a certain agent.
@ -109,8 +112,11 @@ async def run_cli(
agent_parent_dir: str, the absolute path of the parent folder of the agent agent_parent_dir: str, the absolute path of the parent folder of the agent
folder. folder.
agent_folder_name: str, the name of the agent folder. agent_folder_name: str, the name of the agent folder.
json_file_path: Optional[str], the absolute path to the json file, either input_file: Optional[str], the absolute path to the json file that contains
*.input.json or *.session.json. the initial session state and user queries, exclusive with
saved_session_file.
saved_session_file: Optional[str], the absolute path to the json file that
contains a previously saved session, exclusive with input_file.
save_session: bool, whether to save the session on exit. save_session: bool, whether to save the session on exit.
""" """
if agent_parent_dir not in sys.path: if agent_parent_dir not in sys.path:
@ -118,46 +124,50 @@ async def run_cli(
artifact_service = InMemoryArtifactService() artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = session_service.create_session(
app_name=agent_folder_name, user_id='test_user'
)
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name) agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name) agent_module = importlib.import_module(agent_folder_name)
user_id = 'test_user'
session = session_service.create_session(
app_name=agent_folder_name, user_id=user_id
)
root_agent = agent_module.agent.root_agent root_agent = agent_module.agent.root_agent
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir) envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
if json_file_path: if input_file:
if json_file_path.endswith('.input.json'): session = await run_input_file(
await run_input_file( app_name=agent_folder_name,
app_name=agent_folder_name, user_id=user_id,
root_agent=root_agent, root_agent=root_agent,
artifact_service=artifact_service, artifact_service=artifact_service,
session=session, session_service=session_service,
session_service=session_service, input_path=input_file,
input_path=json_file_path, )
) elif saved_session_file:
elif json_file_path.endswith('.session.json'):
with open(json_file_path, 'r') as f: loaded_session = None
session = Session.model_validate_json(f.read()) with open(saved_session_file, 'r') as f:
for content in session.get_contents(): loaded_session = Session.model_validate_json(f.read())
if content.role == 'user':
print('user: ', content.parts[0].text) if loaded_session:
for event in loaded_session.events:
session_service.append_event(session, event)
content = event.content
if not content or not content.parts or not content.parts[0].text:
continue
if event.author == 'user':
click.echo(f'[user]: {content.parts[0].text}')
else: else:
print(content.parts[0].text) click.echo(f'[{event.author}]: {content.parts[0].text}')
await run_interactively(
agent_folder_name, await run_interactively(
root_agent, root_agent,
artifact_service, artifact_service,
session, session,
session_service, session_service,
) )
else: else:
print(f'Unsupported file type: {json_file_path}') click.echo(f'Running agent {root_agent.name}, type exit to exit.')
exit(1)
else:
print(f'Running agent {root_agent.name}, type exit to exit.')
await run_interactively( await run_interactively(
agent_folder_name,
root_agent, root_agent,
artifact_service, artifact_service,
session, session,
@ -165,11 +175,8 @@ async def run_cli(
) )
if save_session: if save_session:
if json_file_path: session_id = input('Session ID to save: ')
session_path = json_file_path.replace('.input.json', '.session.json') session_path = f'{agent_module_path}/{session_id}.session.json'
else:
session_id = input('Session ID to save: ')
session_path = f'{agent_module_path}/{session_id}.session.json'
# Fetch the session again to get all the details. # Fetch the session again to get all the details.
session = session_service.get_session( session = session_service.get_session(

View File

@ -96,6 +96,23 @@ def cli_create_cmd(
) )
def validate_exclusive(ctx, param, value):
# Store the validated parameters in the context
if not hasattr(ctx, "exclusive_opts"):
ctx.exclusive_opts = {}
# If this option has a value and we've already seen another exclusive option
if value is not None and any(ctx.exclusive_opts.values()):
exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
raise click.UsageError(
f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
)
# Record this option's value
ctx.exclusive_opts[param.name] = value is not None
return value
@main.command("run") @main.command("run")
@click.option( @click.option(
"--save_session", "--save_session",
@ -105,13 +122,43 @@ def cli_create_cmd(
default=False, default=False,
help="Optional. Whether to save the session to a json file on exit.", help="Optional. Whether to save the session to a json file on exit.",
) )
@click.option(
"--replay",
type=click.Path(
exists=True, dir_okay=False, file_okay=True, resolve_path=True
),
help=(
"The json file that contains the initial state of the session and user"
" queries. A new session will be created using this state. And user"
" queries are run againt the newly created session. Users cannot"
" continue to interact with the agent."
),
callback=validate_exclusive,
)
@click.option(
"--resume",
type=click.Path(
exists=True, dir_okay=False, file_okay=True, resolve_path=True
),
help=(
"The json file that contains a previously saved session (by"
"--save_session option). The previous session will be re-displayed. And"
" user can continue to interact with the agent."
),
callback=validate_exclusive,
)
@click.argument( @click.argument(
"agent", "agent",
type=click.Path( type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True exists=True, dir_okay=True, file_okay=False, resolve_path=True
), ),
) )
def cli_run(agent: str, save_session: bool): def cli_run(
agent: str,
save_session: bool,
replay: Optional[str],
resume: Optional[str],
):
"""Runs an interactive CLI for a certain agent. """Runs an interactive CLI for a certain agent.
AGENT: The path to the agent source code folder. AGENT: The path to the agent source code folder.
@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
run_cli( run_cli(
agent_parent_dir=agent_parent_folder, agent_parent_dir=agent_parent_folder,
agent_folder_name=agent_folder_name, agent_folder_name=agent_folder_name,
input_file=replay,
saved_session_file=resume,
save_session=save_session, save_session=save_session,
) )
) )

View File

@ -48,8 +48,13 @@ class EventActions(BaseModel):
"""The agent is escalating to a higher level agent.""" """The agent is escalating to a higher level agent."""
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict) requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
"""Will only be set by a tool response indicating tool request euc. """Authentication configurations requested by tool responses.
dict key is the function call id since one function call response (from model)
could correspond to multiple function calls. This field will only be set by a tool response event indicating tool request
dict value is the required auth config. auth credential.
- Keys: The function call id. Since one function response event could contain
multiple function responses that correspond to multiple function calls. Each
function call could request different auth configs. This id is used to
identify the function call.
- Values: The requested auth config.
""" """

View File

@ -58,6 +58,8 @@ from .state import State
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_MAX_VARCHAR_LENGTH = 256
class DynamicJSON(TypeDecorator): class DynamicJSON(TypeDecorator):
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
@ -92,17 +94,25 @@ class DynamicJSON(TypeDecorator):
class Base(DeclarativeBase): class Base(DeclarativeBase):
"""Base class for database tables.""" """Base class for database tables."""
pass pass
class StorageSession(Base): class StorageSession(Base):
"""Represents a session stored in the database.""" """Represents a session stored in the database."""
__tablename__ = "sessions" __tablename__ = "sessions"
app_name: Mapped[str] = mapped_column(String, primary_key=True) app_name: Mapped[str] = mapped_column(
user_id: Mapped[str] = mapped_column(String, primary_key=True) String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
id: Mapped[str] = mapped_column( id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4()) String(DEFAULT_MAX_VARCHAR_LENGTH),
primary_key=True,
default=lambda: str(uuid.uuid4()),
) )
state: Mapped[MutableDict[str, Any]] = mapped_column( state: Mapped[MutableDict[str, Any]] = mapped_column(
@ -125,16 +135,27 @@ class StorageSession(Base):
class StorageEvent(Base): class StorageEvent(Base):
"""Represents an event stored in the database.""" """Represents an event stored in the database."""
__tablename__ = "events" __tablename__ = "events"
id: Mapped[str] = mapped_column(String, primary_key=True) id: Mapped[str] = mapped_column(
app_name: Mapped[str] = mapped_column(String, primary_key=True) String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
user_id: Mapped[str] = mapped_column(String, primary_key=True) )
session_id: Mapped[str] = mapped_column(String, primary_key=True) app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
session_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
invocation_id: Mapped[str] = mapped_column(String) invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
author: Mapped[str] = mapped_column(String) author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
branch: Mapped[str] = mapped_column(String, nullable=True) branch: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now()) timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType) actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
@ -147,8 +168,10 @@ class StorageEvent(Base):
) )
partial: Mapped[bool] = mapped_column(Boolean, nullable=True) partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
error_code: Mapped[str] = mapped_column(String, nullable=True) error_code: Mapped[str] = mapped_column(
error_message: Mapped[str] = mapped_column(String, nullable=True) String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
storage_session: Mapped[StorageSession] = relationship( storage_session: Mapped[StorageSession] = relationship(
@ -182,9 +205,12 @@ class StorageEvent(Base):
class StorageAppState(Base): class StorageAppState(Base):
"""Represents an app state stored in the database.""" """Represents an app state stored in the database."""
__tablename__ = "app_states" __tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True) app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column( state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={} MutableDict.as_mutable(DynamicJSON), default={}
) )
@ -192,13 +218,20 @@ class StorageAppState(Base):
DateTime(), default=func.now(), onupdate=func.now() DateTime(), default=func.now(), onupdate=func.now()
) )
class StorageUserState(Base): class StorageUserState(Base):
"""Represents a user state stored in the database.""" """Represents a user state stored in the database."""
__tablename__ = "user_states" __tablename__ = "user_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True) app_name: Mapped[str] = mapped_column(
user_id: Mapped[str] = mapped_column(String, primary_key=True) String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
state: Mapped[MutableDict[str, Any]] = mapped_column( state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={} MutableDict.as_mutable(DynamicJSON), default={}
) )

View File

@ -22,7 +22,7 @@ def adk_to_mcp_tool_type(tool: BaseTool) -> mcp_types.Tool:
"""Convert a Tool in ADK into MCP tool type. """Convert a Tool in ADK into MCP tool type.
This function transforms an ADK tool definition into its equivalent This function transforms an ADK tool definition into its equivalent
MCP (Model Context Protocol) representation. representation in the MCP (Model Control Plane) system.
Args: Args:
tool: The ADK tool to convert. It should be an instance of a class derived tool: The ADK tool to convert. It should be an instance of a class derived