From dbbeb190b0cf6ce5dc414c607e6c6b35553642c1 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 29 Apr 2025 16:22:43 -0700 Subject: [PATCH] Add two `adk run` options: --replay : a json file that contains the initial session state and user queries, adk run will create a new session based on the state and run the user queries against the session. Users cannot continue to interact with agent. --resume : a json file that contains the previously saved session (by --save_session option), adk run will replay this session and then user can continue to interact with the agent. PiperOrigin-RevId: 752923403 --- src/google/adk/cli/cli.py | 103 ++++++++++++++------------ src/google/adk/cli/cli_tools_click.py | 51 ++++++++++++- 2 files changed, 105 insertions(+), 49 deletions(-) diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index 076668c..4802ea4 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -39,12 +39,12 @@ class InputFile(BaseModel): async def run_input_file( app_name: str, + user_id: str, root_agent: LlmAgent, artifact_service: BaseArtifactService, - session: Session, session_service: BaseSessionService, input_path: str, -) -> None: +) -> Session: runner = Runner( app_name=app_name, agent=root_agent, @@ -55,9 +55,11 @@ async def run_input_file( input_file = InputFile.model_validate_json(f.read()) input_file.state['_time'] = datetime.now() - session.state = input_file.state + session = session_service.create_session( + app_name=app_name, user_id=user_id, state=input_file.state + ) for query in input_file.queries: - click.echo(f'user: {query}') + click.echo(f'[user]: {query}') content = types.Content(role='user', parts=[types.Part(text=query)]) async for event in runner.run_async( user_id=session.user_id, session_id=session.id, new_message=content @@ -65,23 +67,23 @@ async def run_input_file( if event.content and event.content.parts: if text := ''.join(part.text or '' for part in event.content.parts): click.echo(f'[{event.author}]: {text}') + return session async def run_interactively( - app_name: str, root_agent: LlmAgent, artifact_service: BaseArtifactService, session: Session, session_service: BaseSessionService, ) -> None: runner = Runner( - app_name=app_name, + app_name=session.app_name, agent=root_agent, artifact_service=artifact_service, session_service=session_service, ) while True: - query = input('user: ') + query = input('[user]: ') if not query or not query.strip(): continue if query == 'exit': @@ -100,7 +102,8 @@ async def run_cli( *, agent_parent_dir: str, agent_folder_name: str, - json_file_path: Optional[str] = None, + input_file: Optional[str] = None, + saved_session_file: Optional[str] = None, save_session: bool, ) -> None: """Runs an interactive CLI for a certain agent. @@ -109,8 +112,11 @@ async def run_cli( agent_parent_dir: str, the absolute path of the parent folder of the agent folder. agent_folder_name: str, the name of the agent folder. - json_file_path: Optional[str], the absolute path to the json file, either - *.input.json or *.session.json. + input_file: Optional[str], the absolute path to the json file that contains + the initial session state and user queries, exclusive with + saved_session_file. + saved_session_file: Optional[str], the absolute path to the json file that + contains a previously saved session, exclusive with input_file. save_session: bool, whether to save the session on exit. """ if agent_parent_dir not in sys.path: @@ -118,46 +124,50 @@ async def run_cli( artifact_service = InMemoryArtifactService() session_service = InMemorySessionService() - session = session_service.create_session( - app_name=agent_folder_name, user_id='test_user' - ) agent_module_path = os.path.join(agent_parent_dir, agent_folder_name) agent_module = importlib.import_module(agent_folder_name) + user_id = 'test_user' + session = session_service.create_session( + app_name=agent_folder_name, user_id=user_id + ) root_agent = agent_module.agent.root_agent envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir) - if json_file_path: - if json_file_path.endswith('.input.json'): - await run_input_file( - app_name=agent_folder_name, - root_agent=root_agent, - artifact_service=artifact_service, - session=session, - session_service=session_service, - input_path=json_file_path, - ) - elif json_file_path.endswith('.session.json'): - with open(json_file_path, 'r') as f: - session = Session.model_validate_json(f.read()) - for content in session.get_contents(): - if content.role == 'user': - print('user: ', content.parts[0].text) + if input_file: + session = await run_input_file( + app_name=agent_folder_name, + user_id=user_id, + root_agent=root_agent, + artifact_service=artifact_service, + session_service=session_service, + input_path=input_file, + ) + elif saved_session_file: + + loaded_session = None + with open(saved_session_file, 'r') as f: + loaded_session = Session.model_validate_json(f.read()) + + if loaded_session: + for event in loaded_session.events: + session_service.append_event(session, event) + content = event.content + if not content or not content.parts or not content.parts[0].text: + continue + if event.author == 'user': + click.echo(f'[user]: {content.parts[0].text}') else: - print(content.parts[0].text) - await run_interactively( - agent_folder_name, - root_agent, - artifact_service, - session, - session_service, - ) - else: - print(f'Unsupported file type: {json_file_path}') - exit(1) - else: - print(f'Running agent {root_agent.name}, type exit to exit.') + click.echo(f'[{event.author}]: {content.parts[0].text}') + + await run_interactively( + root_agent, + artifact_service, + session, + session_service, + ) + else: + click.echo(f'Running agent {root_agent.name}, type exit to exit.') await run_interactively( - agent_folder_name, root_agent, artifact_service, session, @@ -165,11 +175,8 @@ async def run_cli( ) if save_session: - if json_file_path: - session_path = json_file_path.replace('.input.json', '.session.json') - else: - session_id = input('Session ID to save: ') - session_path = f'{agent_module_path}/{session_id}.session.json' + session_id = input('Session ID to save: ') + session_path = f'{agent_module_path}/{session_id}.session.json' # Fetch the session again to get all the details. session = session_service.get_session( diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 425062b..83a2c3f 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -96,6 +96,23 @@ def cli_create_cmd( ) +def validate_exclusive(ctx, param, value): + # Store the validated parameters in the context + if not hasattr(ctx, "exclusive_opts"): + ctx.exclusive_opts = {} + + # If this option has a value and we've already seen another exclusive option + if value is not None and any(ctx.exclusive_opts.values()): + exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val) + raise click.UsageError( + f"Options '{param.name}' and '{exclusive_opt}' cannot be set together." + ) + + # Record this option's value + ctx.exclusive_opts[param.name] = value is not None + return value + + @main.command("run") @click.option( "--save_session", @@ -105,13 +122,43 @@ def cli_create_cmd( default=False, help="Optional. Whether to save the session to a json file on exit.", ) +@click.option( + "--replay", + type=click.Path( + exists=True, dir_okay=False, file_okay=True, resolve_path=True + ), + help=( + "The json file that contains the initial state of the session and user" + " queries. A new session will be created using this state. And user" + " queries are run againt the newly created session. Users cannot" + " continue to interact with the agent." + ), + callback=validate_exclusive, +) +@click.option( + "--resume", + type=click.Path( + exists=True, dir_okay=False, file_okay=True, resolve_path=True + ), + help=( + "The json file that contains a previously saved session (by" + "--save_session option). The previous session will be re-displayed. And" + " user can continue to interact with the agent." + ), + callback=validate_exclusive, +) @click.argument( "agent", type=click.Path( exists=True, dir_okay=True, file_okay=False, resolve_path=True ), ) -def cli_run(agent: str, save_session: bool): +def cli_run( + agent: str, + save_session: bool, + replay: Optional[str], + resume: Optional[str], +): """Runs an interactive CLI for a certain agent. AGENT: The path to the agent source code folder. @@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool): run_cli( agent_parent_dir=agent_parent_folder, agent_folder_name=agent_folder_name, + input_file=replay, + saved_session_file=resume, save_session=save_session, ) )