mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
Merge branch 'main' into support-async-tool-callbacks
This commit is contained in:
commit
926b0ef1a6
@ -39,12 +39,12 @@ class InputFile(BaseModel):
|
||||
|
||||
async def run_input_file(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
root_agent: LlmAgent,
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
input_path: str,
|
||||
) -> None:
|
||||
) -> Session:
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
@ -55,9 +55,11 @@ async def run_input_file(
|
||||
input_file = InputFile.model_validate_json(f.read())
|
||||
input_file.state['_time'] = datetime.now()
|
||||
|
||||
session.state = input_file.state
|
||||
session = session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=input_file.state
|
||||
)
|
||||
for query in input_file.queries:
|
||||
click.echo(f'user: {query}')
|
||||
click.echo(f'[user]: {query}')
|
||||
content = types.Content(role='user', parts=[types.Part(text=query)])
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id, session_id=session.id, new_message=content
|
||||
@ -65,23 +67,23 @@ async def run_input_file(
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
return session
|
||||
|
||||
|
||||
async def run_interactively(
|
||||
app_name: str,
|
||||
root_agent: LlmAgent,
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
) -> None:
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
app_name=session.app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
while True:
|
||||
query = input('user: ')
|
||||
query = input('[user]: ')
|
||||
if not query or not query.strip():
|
||||
continue
|
||||
if query == 'exit':
|
||||
@ -100,7 +102,8 @@ async def run_cli(
|
||||
*,
|
||||
agent_parent_dir: str,
|
||||
agent_folder_name: str,
|
||||
json_file_path: Optional[str] = None,
|
||||
input_file: Optional[str] = None,
|
||||
saved_session_file: Optional[str] = None,
|
||||
save_session: bool,
|
||||
) -> None:
|
||||
"""Runs an interactive CLI for a certain agent.
|
||||
@ -109,8 +112,11 @@ async def run_cli(
|
||||
agent_parent_dir: str, the absolute path of the parent folder of the agent
|
||||
folder.
|
||||
agent_folder_name: str, the name of the agent folder.
|
||||
json_file_path: Optional[str], the absolute path to the json file, either
|
||||
*.input.json or *.session.json.
|
||||
input_file: Optional[str], the absolute path to the json file that contains
|
||||
the initial session state and user queries, exclusive with
|
||||
saved_session_file.
|
||||
saved_session_file: Optional[str], the absolute path to the json file that
|
||||
contains a previously saved session, exclusive with input_file.
|
||||
save_session: bool, whether to save the session on exit.
|
||||
"""
|
||||
if agent_parent_dir not in sys.path:
|
||||
@ -118,46 +124,50 @@ async def run_cli(
|
||||
|
||||
artifact_service = InMemoryArtifactService()
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name=agent_folder_name, user_id='test_user'
|
||||
)
|
||||
|
||||
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
|
||||
agent_module = importlib.import_module(agent_folder_name)
|
||||
user_id = 'test_user'
|
||||
session = session_service.create_session(
|
||||
app_name=agent_folder_name, user_id=user_id
|
||||
)
|
||||
root_agent = agent_module.agent.root_agent
|
||||
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
|
||||
if json_file_path:
|
||||
if json_file_path.endswith('.input.json'):
|
||||
await run_input_file(
|
||||
app_name=agent_folder_name,
|
||||
root_agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
input_path=json_file_path,
|
||||
)
|
||||
elif json_file_path.endswith('.session.json'):
|
||||
with open(json_file_path, 'r') as f:
|
||||
session = Session.model_validate_json(f.read())
|
||||
for content in session.get_contents():
|
||||
if content.role == 'user':
|
||||
print('user: ', content.parts[0].text)
|
||||
if input_file:
|
||||
session = await run_input_file(
|
||||
app_name=agent_folder_name,
|
||||
user_id=user_id,
|
||||
root_agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
input_path=input_file,
|
||||
)
|
||||
elif saved_session_file:
|
||||
|
||||
loaded_session = None
|
||||
with open(saved_session_file, 'r') as f:
|
||||
loaded_session = Session.model_validate_json(f.read())
|
||||
|
||||
if loaded_session:
|
||||
for event in loaded_session.events:
|
||||
session_service.append_event(session, event)
|
||||
content = event.content
|
||||
if not content or not content.parts or not content.parts[0].text:
|
||||
continue
|
||||
if event.author == 'user':
|
||||
click.echo(f'[user]: {content.parts[0].text}')
|
||||
else:
|
||||
print(content.parts[0].text)
|
||||
await run_interactively(
|
||||
agent_folder_name,
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
session_service,
|
||||
)
|
||||
else:
|
||||
print(f'Unsupported file type: {json_file_path}')
|
||||
exit(1)
|
||||
else:
|
||||
print(f'Running agent {root_agent.name}, type exit to exit.')
|
||||
click.echo(f'[{event.author}]: {content.parts[0].text}')
|
||||
|
||||
await run_interactively(
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
session_service,
|
||||
)
|
||||
else:
|
||||
click.echo(f'Running agent {root_agent.name}, type exit to exit.')
|
||||
await run_interactively(
|
||||
agent_folder_name,
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
@ -165,11 +175,8 @@ async def run_cli(
|
||||
)
|
||||
|
||||
if save_session:
|
||||
if json_file_path:
|
||||
session_path = json_file_path.replace('.input.json', '.session.json')
|
||||
else:
|
||||
session_id = input('Session ID to save: ')
|
||||
session_path = f'{agent_module_path}/{session_id}.session.json'
|
||||
session_id = input('Session ID to save: ')
|
||||
session_path = f'{agent_module_path}/{session_id}.session.json'
|
||||
|
||||
# Fetch the session again to get all the details.
|
||||
session = session_service.get_session(
|
||||
|
@ -96,6 +96,23 @@ def cli_create_cmd(
|
||||
)
|
||||
|
||||
|
||||
def validate_exclusive(ctx, param, value):
|
||||
# Store the validated parameters in the context
|
||||
if not hasattr(ctx, "exclusive_opts"):
|
||||
ctx.exclusive_opts = {}
|
||||
|
||||
# If this option has a value and we've already seen another exclusive option
|
||||
if value is not None and any(ctx.exclusive_opts.values()):
|
||||
exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
|
||||
raise click.UsageError(
|
||||
f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
|
||||
)
|
||||
|
||||
# Record this option's value
|
||||
ctx.exclusive_opts[param.name] = value is not None
|
||||
return value
|
||||
|
||||
|
||||
@main.command("run")
|
||||
@click.option(
|
||||
"--save_session",
|
||||
@ -105,13 +122,43 @@ def cli_create_cmd(
|
||||
default=False,
|
||||
help="Optional. Whether to save the session to a json file on exit.",
|
||||
)
|
||||
@click.option(
|
||||
"--replay",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=False, file_okay=True, resolve_path=True
|
||||
),
|
||||
help=(
|
||||
"The json file that contains the initial state of the session and user"
|
||||
" queries. A new session will be created using this state. And user"
|
||||
" queries are run againt the newly created session. Users cannot"
|
||||
" continue to interact with the agent."
|
||||
),
|
||||
callback=validate_exclusive,
|
||||
)
|
||||
@click.option(
|
||||
"--resume",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=False, file_okay=True, resolve_path=True
|
||||
),
|
||||
help=(
|
||||
"The json file that contains a previously saved session (by"
|
||||
"--save_session option). The previous session will be re-displayed. And"
|
||||
" user can continue to interact with the agent."
|
||||
),
|
||||
callback=validate_exclusive,
|
||||
)
|
||||
@click.argument(
|
||||
"agent",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
)
|
||||
def cli_run(agent: str, save_session: bool):
|
||||
def cli_run(
|
||||
agent: str,
|
||||
save_session: bool,
|
||||
replay: Optional[str],
|
||||
resume: Optional[str],
|
||||
):
|
||||
"""Runs an interactive CLI for a certain agent.
|
||||
|
||||
AGENT: The path to the agent source code folder.
|
||||
@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
|
||||
run_cli(
|
||||
agent_parent_dir=agent_parent_folder,
|
||||
agent_folder_name=agent_folder_name,
|
||||
input_file=replay,
|
||||
saved_session_file=resume,
|
||||
save_session=save_session,
|
||||
)
|
||||
)
|
||||
|
@ -48,8 +48,13 @@ class EventActions(BaseModel):
|
||||
"""The agent is escalating to a higher level agent."""
|
||||
|
||||
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
|
||||
"""Will only be set by a tool response indicating tool request euc.
|
||||
dict key is the function call id since one function call response (from model)
|
||||
could correspond to multiple function calls.
|
||||
dict value is the required auth config.
|
||||
"""Authentication configurations requested by tool responses.
|
||||
|
||||
This field will only be set by a tool response event indicating tool request
|
||||
auth credential.
|
||||
- Keys: The function call id. Since one function response event could contain
|
||||
multiple function responses that correspond to multiple function calls. Each
|
||||
function call could request different auth configs. This id is used to
|
||||
identify the function call.
|
||||
- Values: The requested auth config.
|
||||
"""
|
||||
|
@ -58,6 +58,8 @@ from .state import State
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MAX_VARCHAR_LENGTH = 256
|
||||
|
||||
|
||||
class DynamicJSON(TypeDecorator):
|
||||
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
|
||||
@ -92,17 +94,25 @@ class DynamicJSON(TypeDecorator):
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for database tables."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StorageSession(Base):
|
||||
"""Represents a session stored in the database."""
|
||||
|
||||
__tablename__ = "sessions"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
)
|
||||
id: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
@ -125,16 +135,27 @@ class StorageSession(Base):
|
||||
|
||||
class StorageEvent(Base):
|
||||
"""Represents an event stored in the database."""
|
||||
|
||||
__tablename__ = "events"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
session_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
)
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
)
|
||||
session_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
)
|
||||
|
||||
invocation_id: Mapped[str] = mapped_column(String)
|
||||
author: Mapped[str] = mapped_column(String)
|
||||
branch: Mapped[str] = mapped_column(String, nullable=True)
|
||||
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
branch: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||
)
|
||||
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
||||
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
||||
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
||||
@ -147,8 +168,10 @@ class StorageEvent(Base):
|
||||
)
|
||||
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
error_code: Mapped[str] = mapped_column(String, nullable=True)
|
||||
error_message: Mapped[str] = mapped_column(String, nullable=True)
|
||||
error_code: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||
)
|
||||
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
|
||||
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
storage_session: Mapped[StorageSession] = relationship(
|
||||
@ -182,9 +205,12 @@ class StorageEvent(Base):
|
||||
|
||||
class StorageAppState(Base):
|
||||
"""Represents an app state stored in the database."""
|
||||
|
||||
__tablename__ = "app_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
@ -192,13 +218,20 @@ class StorageAppState(Base):
|
||||
DateTime(), default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class StorageUserState(Base):
|
||||
"""Represents a user state stored in the database."""
|
||||
|
||||
__tablename__ = "user_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
|
@ -22,7 +22,7 @@ def adk_to_mcp_tool_type(tool: BaseTool) -> mcp_types.Tool:
|
||||
"""Convert a Tool in ADK into MCP tool type.
|
||||
|
||||
This function transforms an ADK tool definition into its equivalent
|
||||
MCP (Model Context Protocol) representation.
|
||||
representation in the MCP (Model Control Plane) system.
|
||||
|
||||
Args:
|
||||
tool: The ADK tool to convert. It should be an instance of a class derived
|
||||
|
Loading…
Reference in New Issue
Block a user