mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
Add two adk run
options:
--replay : a json file that contains the initial session state and user queries, adk run will create a new session based on the state and run the user queries against the session. Users cannot continue to interact with agent. --resume : a json file that contains the previously saved session (by --save_session option), adk run will replay this session and then user can continue to interact with the agent. PiperOrigin-RevId: 752923403
This commit is contained in:
parent
2a9ddec7e3
commit
dbbeb190b0
@ -39,12 +39,12 @@ class InputFile(BaseModel):
|
||||
|
||||
async def run_input_file(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
root_agent: LlmAgent,
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
input_path: str,
|
||||
) -> None:
|
||||
) -> Session:
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
@ -55,9 +55,11 @@ async def run_input_file(
|
||||
input_file = InputFile.model_validate_json(f.read())
|
||||
input_file.state['_time'] = datetime.now()
|
||||
|
||||
session.state = input_file.state
|
||||
session = session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=input_file.state
|
||||
)
|
||||
for query in input_file.queries:
|
||||
click.echo(f'user: {query}')
|
||||
click.echo(f'[user]: {query}')
|
||||
content = types.Content(role='user', parts=[types.Part(text=query)])
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id, session_id=session.id, new_message=content
|
||||
@ -65,23 +67,23 @@ async def run_input_file(
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
return session
|
||||
|
||||
|
||||
async def run_interactively(
|
||||
app_name: str,
|
||||
root_agent: LlmAgent,
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
) -> None:
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
app_name=session.app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
while True:
|
||||
query = input('user: ')
|
||||
query = input('[user]: ')
|
||||
if not query or not query.strip():
|
||||
continue
|
||||
if query == 'exit':
|
||||
@ -100,7 +102,8 @@ async def run_cli(
|
||||
*,
|
||||
agent_parent_dir: str,
|
||||
agent_folder_name: str,
|
||||
json_file_path: Optional[str] = None,
|
||||
input_file: Optional[str] = None,
|
||||
saved_session_file: Optional[str] = None,
|
||||
save_session: bool,
|
||||
) -> None:
|
||||
"""Runs an interactive CLI for a certain agent.
|
||||
@ -109,8 +112,11 @@ async def run_cli(
|
||||
agent_parent_dir: str, the absolute path of the parent folder of the agent
|
||||
folder.
|
||||
agent_folder_name: str, the name of the agent folder.
|
||||
json_file_path: Optional[str], the absolute path to the json file, either
|
||||
*.input.json or *.session.json.
|
||||
input_file: Optional[str], the absolute path to the json file that contains
|
||||
the initial session state and user queries, exclusive with
|
||||
saved_session_file.
|
||||
saved_session_file: Optional[str], the absolute path to the json file that
|
||||
contains a previously saved session, exclusive with input_file.
|
||||
save_session: bool, whether to save the session on exit.
|
||||
"""
|
||||
if agent_parent_dir not in sys.path:
|
||||
@ -118,46 +124,50 @@ async def run_cli(
|
||||
|
||||
artifact_service = InMemoryArtifactService()
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name=agent_folder_name, user_id='test_user'
|
||||
)
|
||||
|
||||
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
|
||||
agent_module = importlib.import_module(agent_folder_name)
|
||||
user_id = 'test_user'
|
||||
session = session_service.create_session(
|
||||
app_name=agent_folder_name, user_id=user_id
|
||||
)
|
||||
root_agent = agent_module.agent.root_agent
|
||||
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
|
||||
if json_file_path:
|
||||
if json_file_path.endswith('.input.json'):
|
||||
await run_input_file(
|
||||
if input_file:
|
||||
session = await run_input_file(
|
||||
app_name=agent_folder_name,
|
||||
user_id=user_id,
|
||||
root_agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
input_path=json_file_path,
|
||||
input_path=input_file,
|
||||
)
|
||||
elif json_file_path.endswith('.session.json'):
|
||||
with open(json_file_path, 'r') as f:
|
||||
session = Session.model_validate_json(f.read())
|
||||
for content in session.get_contents():
|
||||
if content.role == 'user':
|
||||
print('user: ', content.parts[0].text)
|
||||
elif saved_session_file:
|
||||
|
||||
loaded_session = None
|
||||
with open(saved_session_file, 'r') as f:
|
||||
loaded_session = Session.model_validate_json(f.read())
|
||||
|
||||
if loaded_session:
|
||||
for event in loaded_session.events:
|
||||
session_service.append_event(session, event)
|
||||
content = event.content
|
||||
if not content or not content.parts or not content.parts[0].text:
|
||||
continue
|
||||
if event.author == 'user':
|
||||
click.echo(f'[user]: {content.parts[0].text}')
|
||||
else:
|
||||
print(content.parts[0].text)
|
||||
click.echo(f'[{event.author}]: {content.parts[0].text}')
|
||||
|
||||
await run_interactively(
|
||||
agent_folder_name,
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
session_service,
|
||||
)
|
||||
else:
|
||||
print(f'Unsupported file type: {json_file_path}')
|
||||
exit(1)
|
||||
else:
|
||||
print(f'Running agent {root_agent.name}, type exit to exit.')
|
||||
click.echo(f'Running agent {root_agent.name}, type exit to exit.')
|
||||
await run_interactively(
|
||||
agent_folder_name,
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
@ -165,9 +175,6 @@ async def run_cli(
|
||||
)
|
||||
|
||||
if save_session:
|
||||
if json_file_path:
|
||||
session_path = json_file_path.replace('.input.json', '.session.json')
|
||||
else:
|
||||
session_id = input('Session ID to save: ')
|
||||
session_path = f'{agent_module_path}/{session_id}.session.json'
|
||||
|
||||
|
@ -96,6 +96,23 @@ def cli_create_cmd(
|
||||
)
|
||||
|
||||
|
||||
def validate_exclusive(ctx, param, value):
|
||||
# Store the validated parameters in the context
|
||||
if not hasattr(ctx, "exclusive_opts"):
|
||||
ctx.exclusive_opts = {}
|
||||
|
||||
# If this option has a value and we've already seen another exclusive option
|
||||
if value is not None and any(ctx.exclusive_opts.values()):
|
||||
exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
|
||||
raise click.UsageError(
|
||||
f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
|
||||
)
|
||||
|
||||
# Record this option's value
|
||||
ctx.exclusive_opts[param.name] = value is not None
|
||||
return value
|
||||
|
||||
|
||||
@main.command("run")
|
||||
@click.option(
|
||||
"--save_session",
|
||||
@ -105,13 +122,43 @@ def cli_create_cmd(
|
||||
default=False,
|
||||
help="Optional. Whether to save the session to a json file on exit.",
|
||||
)
|
||||
@click.option(
|
||||
"--replay",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=False, file_okay=True, resolve_path=True
|
||||
),
|
||||
help=(
|
||||
"The json file that contains the initial state of the session and user"
|
||||
" queries. A new session will be created using this state. And user"
|
||||
" queries are run againt the newly created session. Users cannot"
|
||||
" continue to interact with the agent."
|
||||
),
|
||||
callback=validate_exclusive,
|
||||
)
|
||||
@click.option(
|
||||
"--resume",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=False, file_okay=True, resolve_path=True
|
||||
),
|
||||
help=(
|
||||
"The json file that contains a previously saved session (by"
|
||||
"--save_session option). The previous session will be re-displayed. And"
|
||||
" user can continue to interact with the agent."
|
||||
),
|
||||
callback=validate_exclusive,
|
||||
)
|
||||
@click.argument(
|
||||
"agent",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
)
|
||||
def cli_run(agent: str, save_session: bool):
|
||||
def cli_run(
|
||||
agent: str,
|
||||
save_session: bool,
|
||||
replay: Optional[str],
|
||||
resume: Optional[str],
|
||||
):
|
||||
"""Runs an interactive CLI for a certain agent.
|
||||
|
||||
AGENT: The path to the agent source code folder.
|
||||
@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
|
||||
run_cli(
|
||||
agent_parent_dir=agent_parent_folder,
|
||||
agent_folder_name=agent_folder_name,
|
||||
input_file=replay,
|
||||
saved_session_file=resume,
|
||||
save_session=save_session,
|
||||
)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user