Merge branch 'main' into support-async-tool-callbacks

This commit is contained in:
Alankrit Verma 2025-04-29 22:08:18 -04:00 committed by GitHub
commit 926b0ef1a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 164 additions and 70 deletions

View File

@ -39,12 +39,12 @@ class InputFile(BaseModel):
async def run_input_file(
app_name: str,
user_id: str,
root_agent: LlmAgent,
artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService,
input_path: str,
) -> None:
) -> Session:
runner = Runner(
app_name=app_name,
agent=root_agent,
@ -55,9 +55,11 @@ async def run_input_file(
input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now()
session.state = input_file.state
session = session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state
)
for query in input_file.queries:
click.echo(f'user: {query}')
click.echo(f'[user]: {query}')
content = types.Content(role='user', parts=[types.Part(text=query)])
async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
@ -65,23 +67,23 @@ async def run_input_file(
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')
return session
async def run_interactively(
app_name: str,
root_agent: LlmAgent,
artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService,
) -> None:
runner = Runner(
app_name=app_name,
app_name=session.app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
while True:
query = input('user: ')
query = input('[user]: ')
if not query or not query.strip():
continue
if query == 'exit':
@ -100,7 +102,8 @@ async def run_cli(
*,
agent_parent_dir: str,
agent_folder_name: str,
json_file_path: Optional[str] = None,
input_file: Optional[str] = None,
saved_session_file: Optional[str] = None,
save_session: bool,
) -> None:
"""Runs an interactive CLI for a certain agent.
@ -109,8 +112,11 @@ async def run_cli(
agent_parent_dir: str, the absolute path of the parent folder of the agent
folder.
agent_folder_name: str, the name of the agent folder.
json_file_path: Optional[str], the absolute path to the json file, either
*.input.json or *.session.json.
input_file: Optional[str], the absolute path to the json file that contains
the initial session state and user queries, exclusive with
saved_session_file.
saved_session_file: Optional[str], the absolute path to the json file that
contains a previously saved session, exclusive with input_file.
save_session: bool, whether to save the session on exit.
"""
if agent_parent_dir not in sys.path:
@ -118,46 +124,50 @@ async def run_cli(
artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService()
session = session_service.create_session(
app_name=agent_folder_name, user_id='test_user'
)
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name)
user_id = 'test_user'
session = session_service.create_session(
app_name=agent_folder_name, user_id=user_id
)
root_agent = agent_module.agent.root_agent
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
if json_file_path:
if json_file_path.endswith('.input.json'):
await run_input_file(
app_name=agent_folder_name,
root_agent=root_agent,
artifact_service=artifact_service,
session=session,
session_service=session_service,
input_path=json_file_path,
)
elif json_file_path.endswith('.session.json'):
with open(json_file_path, 'r') as f:
session = Session.model_validate_json(f.read())
for content in session.get_contents():
if content.role == 'user':
print('user: ', content.parts[0].text)
if input_file:
session = await run_input_file(
app_name=agent_folder_name,
user_id=user_id,
root_agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
input_path=input_file,
)
elif saved_session_file:
loaded_session = None
with open(saved_session_file, 'r') as f:
loaded_session = Session.model_validate_json(f.read())
if loaded_session:
for event in loaded_session.events:
session_service.append_event(session, event)
content = event.content
if not content or not content.parts or not content.parts[0].text:
continue
if event.author == 'user':
click.echo(f'[user]: {content.parts[0].text}')
else:
print(content.parts[0].text)
await run_interactively(
agent_folder_name,
root_agent,
artifact_service,
session,
session_service,
)
else:
print(f'Unsupported file type: {json_file_path}')
exit(1)
else:
print(f'Running agent {root_agent.name}, type exit to exit.')
click.echo(f'[{event.author}]: {content.parts[0].text}')
await run_interactively(
root_agent,
artifact_service,
session,
session_service,
)
else:
click.echo(f'Running agent {root_agent.name}, type exit to exit.')
await run_interactively(
agent_folder_name,
root_agent,
artifact_service,
session,
@ -165,11 +175,8 @@ async def run_cli(
)
if save_session:
if json_file_path:
session_path = json_file_path.replace('.input.json', '.session.json')
else:
session_id = input('Session ID to save: ')
session_path = f'{agent_module_path}/{session_id}.session.json'
session_id = input('Session ID to save: ')
session_path = f'{agent_module_path}/{session_id}.session.json'
# Fetch the session again to get all the details.
session = session_service.get_session(

View File

@ -96,6 +96,23 @@ def cli_create_cmd(
)
def validate_exclusive(ctx, param, value):
# Store the validated parameters in the context
if not hasattr(ctx, "exclusive_opts"):
ctx.exclusive_opts = {}
# If this option has a value and we've already seen another exclusive option
if value is not None and any(ctx.exclusive_opts.values()):
exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
raise click.UsageError(
f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
)
# Record this option's value
ctx.exclusive_opts[param.name] = value is not None
return value
@main.command("run")
@click.option(
"--save_session",
@ -105,13 +122,43 @@ def cli_create_cmd(
default=False,
help="Optional. Whether to save the session to a json file on exit.",
)
@click.option(
"--replay",
type=click.Path(
exists=True, dir_okay=False, file_okay=True, resolve_path=True
),
help=(
"The json file that contains the initial state of the session and user"
" queries. A new session will be created using this state. And user"
" queries are run againt the newly created session. Users cannot"
" continue to interact with the agent."
),
callback=validate_exclusive,
)
@click.option(
"--resume",
type=click.Path(
exists=True, dir_okay=False, file_okay=True, resolve_path=True
),
help=(
"The json file that contains a previously saved session (by"
"--save_session option). The previous session will be re-displayed. And"
" user can continue to interact with the agent."
),
callback=validate_exclusive,
)
@click.argument(
"agent",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
)
def cli_run(agent: str, save_session: bool):
def cli_run(
agent: str,
save_session: bool,
replay: Optional[str],
resume: Optional[str],
):
"""Runs an interactive CLI for a certain agent.
AGENT: The path to the agent source code folder.
@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
run_cli(
agent_parent_dir=agent_parent_folder,
agent_folder_name=agent_folder_name,
input_file=replay,
saved_session_file=resume,
save_session=save_session,
)
)

View File

@ -48,8 +48,13 @@ class EventActions(BaseModel):
"""The agent is escalating to a higher level agent."""
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
"""Will only be set by a tool response indicating tool request euc.
dict key is the function call id since one function call response (from model)
could correspond to multiple function calls.
dict value is the required auth config.
"""Authentication configurations requested by tool responses.
This field will only be set by a tool response event indicating tool request
auth credential.
- Keys: The function call id. Since one function response event could contain
multiple function responses that correspond to multiple function calls. Each
function call could request different auth configs. This id is used to
identify the function call.
- Values: The requested auth config.
"""

View File

@ -58,6 +58,8 @@ from .state import State
logger = logging.getLogger(__name__)
DEFAULT_MAX_VARCHAR_LENGTH = 256
class DynamicJSON(TypeDecorator):
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
@ -92,17 +94,25 @@ class DynamicJSON(TypeDecorator):
class Base(DeclarativeBase):
"""Base class for database tables."""
pass
class StorageSession(Base):
"""Represents a session stored in the database."""
__tablename__ = "sessions"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
String(DEFAULT_MAX_VARCHAR_LENGTH),
primary_key=True,
default=lambda: str(uuid.uuid4()),
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
@ -125,16 +135,27 @@ class StorageSession(Base):
class StorageEvent(Base):
"""Represents an event stored in the database."""
__tablename__ = "events"
id: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
session_id: Mapped[str] = mapped_column(String, primary_key=True)
id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
session_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
invocation_id: Mapped[str] = mapped_column(String)
author: Mapped[str] = mapped_column(String)
branch: Mapped[str] = mapped_column(String, nullable=True)
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
branch: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
@ -147,8 +168,10 @@ class StorageEvent(Base):
)
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
error_code: Mapped[str] = mapped_column(String, nullable=True)
error_message: Mapped[str] = mapped_column(String, nullable=True)
error_code: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
storage_session: Mapped[StorageSession] = relationship(
@ -182,9 +205,12 @@ class StorageEvent(Base):
class StorageAppState(Base):
"""Represents an app state stored in the database."""
__tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
@ -192,13 +218,20 @@ class StorageAppState(Base):
DateTime(), default=func.now(), onupdate=func.now()
)
class StorageUserState(Base):
"""Represents a user state stored in the database."""
__tablename__ = "user_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)

View File

@ -22,7 +22,7 @@ def adk_to_mcp_tool_type(tool: BaseTool) -> mcp_types.Tool:
"""Convert a Tool in ADK into MCP tool type.
This function transforms an ADK tool definition into its equivalent
MCP (Model Context Protocol) representation.
representation in the MCP (Model Control Plane) system.
Args:
tool: The ADK tool to convert. It should be an instance of a class derived