diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index 076668c..4802ea4 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -39,12 +39,12 @@ class InputFile(BaseModel): async def run_input_file( app_name: str, + user_id: str, root_agent: LlmAgent, artifact_service: BaseArtifactService, - session: Session, session_service: BaseSessionService, input_path: str, -) -> None: +) -> Session: runner = Runner( app_name=app_name, agent=root_agent, @@ -55,9 +55,11 @@ async def run_input_file( input_file = InputFile.model_validate_json(f.read()) input_file.state['_time'] = datetime.now() - session.state = input_file.state + session = session_service.create_session( + app_name=app_name, user_id=user_id, state=input_file.state + ) for query in input_file.queries: - click.echo(f'user: {query}') + click.echo(f'[user]: {query}') content = types.Content(role='user', parts=[types.Part(text=query)]) async for event in runner.run_async( user_id=session.user_id, session_id=session.id, new_message=content @@ -65,23 +67,23 @@ async def run_input_file( if event.content and event.content.parts: if text := ''.join(part.text or '' for part in event.content.parts): click.echo(f'[{event.author}]: {text}') + return session async def run_interactively( - app_name: str, root_agent: LlmAgent, artifact_service: BaseArtifactService, session: Session, session_service: BaseSessionService, ) -> None: runner = Runner( - app_name=app_name, + app_name=session.app_name, agent=root_agent, artifact_service=artifact_service, session_service=session_service, ) while True: - query = input('user: ') + query = input('[user]: ') if not query or not query.strip(): continue if query == 'exit': @@ -100,7 +102,8 @@ async def run_cli( *, agent_parent_dir: str, agent_folder_name: str, - json_file_path: Optional[str] = None, + input_file: Optional[str] = None, + saved_session_file: Optional[str] = None, save_session: bool, ) -> None: """Runs an interactive CLI for a certain agent. @@ -109,8 +112,11 @@ async def run_cli( agent_parent_dir: str, the absolute path of the parent folder of the agent folder. agent_folder_name: str, the name of the agent folder. - json_file_path: Optional[str], the absolute path to the json file, either - *.input.json or *.session.json. + input_file: Optional[str], the absolute path to the json file that contains + the initial session state and user queries, exclusive with + saved_session_file. + saved_session_file: Optional[str], the absolute path to the json file that + contains a previously saved session, exclusive with input_file. save_session: bool, whether to save the session on exit. """ if agent_parent_dir not in sys.path: @@ -118,46 +124,50 @@ async def run_cli( artifact_service = InMemoryArtifactService() session_service = InMemorySessionService() - session = session_service.create_session( - app_name=agent_folder_name, user_id='test_user' - ) agent_module_path = os.path.join(agent_parent_dir, agent_folder_name) agent_module = importlib.import_module(agent_folder_name) + user_id = 'test_user' + session = session_service.create_session( + app_name=agent_folder_name, user_id=user_id + ) root_agent = agent_module.agent.root_agent envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir) - if json_file_path: - if json_file_path.endswith('.input.json'): - await run_input_file( - app_name=agent_folder_name, - root_agent=root_agent, - artifact_service=artifact_service, - session=session, - session_service=session_service, - input_path=json_file_path, - ) - elif json_file_path.endswith('.session.json'): - with open(json_file_path, 'r') as f: - session = Session.model_validate_json(f.read()) - for content in session.get_contents(): - if content.role == 'user': - print('user: ', content.parts[0].text) + if input_file: + session = await run_input_file( + app_name=agent_folder_name, + user_id=user_id, + root_agent=root_agent, + artifact_service=artifact_service, + session_service=session_service, + input_path=input_file, + ) + elif saved_session_file: + + loaded_session = None + with open(saved_session_file, 'r') as f: + loaded_session = Session.model_validate_json(f.read()) + + if loaded_session: + for event in loaded_session.events: + session_service.append_event(session, event) + content = event.content + if not content or not content.parts or not content.parts[0].text: + continue + if event.author == 'user': + click.echo(f'[user]: {content.parts[0].text}') else: - print(content.parts[0].text) - await run_interactively( - agent_folder_name, - root_agent, - artifact_service, - session, - session_service, - ) - else: - print(f'Unsupported file type: {json_file_path}') - exit(1) - else: - print(f'Running agent {root_agent.name}, type exit to exit.') + click.echo(f'[{event.author}]: {content.parts[0].text}') + + await run_interactively( + root_agent, + artifact_service, + session, + session_service, + ) + else: + click.echo(f'Running agent {root_agent.name}, type exit to exit.') await run_interactively( - agent_folder_name, root_agent, artifact_service, session, @@ -165,11 +175,8 @@ async def run_cli( ) if save_session: - if json_file_path: - session_path = json_file_path.replace('.input.json', '.session.json') - else: - session_id = input('Session ID to save: ') - session_path = f'{agent_module_path}/{session_id}.session.json' + session_id = input('Session ID to save: ') + session_path = f'{agent_module_path}/{session_id}.session.json' # Fetch the session again to get all the details. session = session_service.get_session( diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 425062b..83a2c3f 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -96,6 +96,23 @@ def cli_create_cmd( ) +def validate_exclusive(ctx, param, value): + # Store the validated parameters in the context + if not hasattr(ctx, "exclusive_opts"): + ctx.exclusive_opts = {} + + # If this option has a value and we've already seen another exclusive option + if value is not None and any(ctx.exclusive_opts.values()): + exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val) + raise click.UsageError( + f"Options '{param.name}' and '{exclusive_opt}' cannot be set together." + ) + + # Record this option's value + ctx.exclusive_opts[param.name] = value is not None + return value + + @main.command("run") @click.option( "--save_session", @@ -105,13 +122,43 @@ def cli_create_cmd( default=False, help="Optional. Whether to save the session to a json file on exit.", ) +@click.option( + "--replay", + type=click.Path( + exists=True, dir_okay=False, file_okay=True, resolve_path=True + ), + help=( + "The json file that contains the initial state of the session and user" + " queries. A new session will be created using this state. And user" + " queries are run againt the newly created session. Users cannot" + " continue to interact with the agent." + ), + callback=validate_exclusive, +) +@click.option( + "--resume", + type=click.Path( + exists=True, dir_okay=False, file_okay=True, resolve_path=True + ), + help=( + "The json file that contains a previously saved session (by" + "--save_session option). The previous session will be re-displayed. And" + " user can continue to interact with the agent." + ), + callback=validate_exclusive, +) @click.argument( "agent", type=click.Path( exists=True, dir_okay=True, file_okay=False, resolve_path=True ), ) -def cli_run(agent: str, save_session: bool): +def cli_run( + agent: str, + save_session: bool, + replay: Optional[str], + resume: Optional[str], +): """Runs an interactive CLI for a certain agent. AGENT: The path to the agent source code folder. @@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool): run_cli( agent_parent_dir=agent_parent_folder, agent_folder_name=agent_folder_name, + input_file=replay, + saved_session_file=resume, save_session=save_session, ) ) diff --git a/src/google/adk/events/event_actions.py b/src/google/adk/events/event_actions.py index 412546e..f4f4078 100644 --- a/src/google/adk/events/event_actions.py +++ b/src/google/adk/events/event_actions.py @@ -48,8 +48,13 @@ class EventActions(BaseModel): """The agent is escalating to a higher level agent.""" requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict) - """Will only be set by a tool response indicating tool request euc. - dict key is the function call id since one function call response (from model) - could correspond to multiple function calls. - dict value is the required auth config. + """Authentication configurations requested by tool responses. + + This field will only be set by a tool response event indicating tool request + auth credential. + - Keys: The function call id. Since one function response event could contain + multiple function responses that correspond to multiple function calls. Each + function call could request different auth configs. This id is used to + identify the function call. + - Values: The requested auth config. """ diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 93e66f7..9bfa3cc 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -58,6 +58,8 @@ from .state import State logger = logging.getLogger(__name__) +DEFAULT_MAX_VARCHAR_LENGTH = 256 + class DynamicJSON(TypeDecorator): """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON @@ -92,17 +94,25 @@ class DynamicJSON(TypeDecorator): class Base(DeclarativeBase): """Base class for database tables.""" + pass class StorageSession(Base): """Represents a session stored in the database.""" + __tablename__ = "sessions" - app_name: Mapped[str] = mapped_column(String, primary_key=True) - user_id: Mapped[str] = mapped_column(String, primary_key=True) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) id: Mapped[str] = mapped_column( - String, primary_key=True, default=lambda: str(uuid.uuid4()) + String(DEFAULT_MAX_VARCHAR_LENGTH), + primary_key=True, + default=lambda: str(uuid.uuid4()), ) state: Mapped[MutableDict[str, Any]] = mapped_column( @@ -125,16 +135,27 @@ class StorageSession(Base): class StorageEvent(Base): """Represents an event stored in the database.""" + __tablename__ = "events" - id: Mapped[str] = mapped_column(String, primary_key=True) - app_name: Mapped[str] = mapped_column(String, primary_key=True) - user_id: Mapped[str] = mapped_column(String, primary_key=True) - session_id: Mapped[str] = mapped_column(String, primary_key=True) + id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) + session_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) - invocation_id: Mapped[str] = mapped_column(String) - author: Mapped[str] = mapped_column(String) - branch: Mapped[str] = mapped_column(String, nullable=True) + invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + branch: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now()) content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType) @@ -147,8 +168,10 @@ class StorageEvent(Base): ) partial: Mapped[bool] = mapped_column(Boolean, nullable=True) turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) - error_code: Mapped[str] = mapped_column(String, nullable=True) - error_message: Mapped[str] = mapped_column(String, nullable=True) + error_code: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + error_message: Mapped[str] = mapped_column(String(1024), nullable=True) interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) storage_session: Mapped[StorageSession] = relationship( @@ -182,9 +205,12 @@ class StorageEvent(Base): class StorageAppState(Base): """Represents an app state stored in the database.""" + __tablename__ = "app_states" - app_name: Mapped[str] = mapped_column(String, primary_key=True) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) state: Mapped[MutableDict[str, Any]] = mapped_column( MutableDict.as_mutable(DynamicJSON), default={} ) @@ -192,13 +218,20 @@ class StorageAppState(Base): DateTime(), default=func.now(), onupdate=func.now() ) - class StorageUserState(Base): """Represents a user state stored in the database.""" + __tablename__ = "user_states" - app_name: Mapped[str] = mapped_column(String, primary_key=True) - user_id: Mapped[str] = mapped_column(String, primary_key=True) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) state: Mapped[MutableDict[str, Any]] = mapped_column( MutableDict.as_mutable(DynamicJSON), default={} ) diff --git a/src/google/adk/tools/mcp_tool/conversion_utils.py b/src/google/adk/tools/mcp_tool/conversion_utils.py index 9884b77..8afa301 100644 --- a/src/google/adk/tools/mcp_tool/conversion_utils.py +++ b/src/google/adk/tools/mcp_tool/conversion_utils.py @@ -22,7 +22,7 @@ def adk_to_mcp_tool_type(tool: BaseTool) -> mcp_types.Tool: """Convert a Tool in ADK into MCP tool type. This function transforms an ADK tool definition into its equivalent - MCP (Model Context Protocol) representation. + representation in the MCP (Model Control Plane) system. Args: tool: The ADK tool to convert. It should be an instance of a class derived