mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-16 04:02:55 -06:00
Merge branch 'main' into support-async-tool-callbacks
This commit is contained in:
commit
926b0ef1a6
@ -39,12 +39,12 @@ class InputFile(BaseModel):
|
|||||||
|
|
||||||
async def run_input_file(
|
async def run_input_file(
|
||||||
app_name: str,
|
app_name: str,
|
||||||
|
user_id: str,
|
||||||
root_agent: LlmAgent,
|
root_agent: LlmAgent,
|
||||||
artifact_service: BaseArtifactService,
|
artifact_service: BaseArtifactService,
|
||||||
session: Session,
|
|
||||||
session_service: BaseSessionService,
|
session_service: BaseSessionService,
|
||||||
input_path: str,
|
input_path: str,
|
||||||
) -> None:
|
) -> Session:
|
||||||
runner = Runner(
|
runner = Runner(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
agent=root_agent,
|
agent=root_agent,
|
||||||
@ -55,9 +55,11 @@ async def run_input_file(
|
|||||||
input_file = InputFile.model_validate_json(f.read())
|
input_file = InputFile.model_validate_json(f.read())
|
||||||
input_file.state['_time'] = datetime.now()
|
input_file.state['_time'] = datetime.now()
|
||||||
|
|
||||||
session.state = input_file.state
|
session = session_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id, state=input_file.state
|
||||||
|
)
|
||||||
for query in input_file.queries:
|
for query in input_file.queries:
|
||||||
click.echo(f'user: {query}')
|
click.echo(f'[user]: {query}')
|
||||||
content = types.Content(role='user', parts=[types.Part(text=query)])
|
content = types.Content(role='user', parts=[types.Part(text=query)])
|
||||||
async for event in runner.run_async(
|
async for event in runner.run_async(
|
||||||
user_id=session.user_id, session_id=session.id, new_message=content
|
user_id=session.user_id, session_id=session.id, new_message=content
|
||||||
@ -65,23 +67,23 @@ async def run_input_file(
|
|||||||
if event.content and event.content.parts:
|
if event.content and event.content.parts:
|
||||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||||
click.echo(f'[{event.author}]: {text}')
|
click.echo(f'[{event.author}]: {text}')
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def run_interactively(
|
async def run_interactively(
|
||||||
app_name: str,
|
|
||||||
root_agent: LlmAgent,
|
root_agent: LlmAgent,
|
||||||
artifact_service: BaseArtifactService,
|
artifact_service: BaseArtifactService,
|
||||||
session: Session,
|
session: Session,
|
||||||
session_service: BaseSessionService,
|
session_service: BaseSessionService,
|
||||||
) -> None:
|
) -> None:
|
||||||
runner = Runner(
|
runner = Runner(
|
||||||
app_name=app_name,
|
app_name=session.app_name,
|
||||||
agent=root_agent,
|
agent=root_agent,
|
||||||
artifact_service=artifact_service,
|
artifact_service=artifact_service,
|
||||||
session_service=session_service,
|
session_service=session_service,
|
||||||
)
|
)
|
||||||
while True:
|
while True:
|
||||||
query = input('user: ')
|
query = input('[user]: ')
|
||||||
if not query or not query.strip():
|
if not query or not query.strip():
|
||||||
continue
|
continue
|
||||||
if query == 'exit':
|
if query == 'exit':
|
||||||
@ -100,7 +102,8 @@ async def run_cli(
|
|||||||
*,
|
*,
|
||||||
agent_parent_dir: str,
|
agent_parent_dir: str,
|
||||||
agent_folder_name: str,
|
agent_folder_name: str,
|
||||||
json_file_path: Optional[str] = None,
|
input_file: Optional[str] = None,
|
||||||
|
saved_session_file: Optional[str] = None,
|
||||||
save_session: bool,
|
save_session: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Runs an interactive CLI for a certain agent.
|
"""Runs an interactive CLI for a certain agent.
|
||||||
@ -109,8 +112,11 @@ async def run_cli(
|
|||||||
agent_parent_dir: str, the absolute path of the parent folder of the agent
|
agent_parent_dir: str, the absolute path of the parent folder of the agent
|
||||||
folder.
|
folder.
|
||||||
agent_folder_name: str, the name of the agent folder.
|
agent_folder_name: str, the name of the agent folder.
|
||||||
json_file_path: Optional[str], the absolute path to the json file, either
|
input_file: Optional[str], the absolute path to the json file that contains
|
||||||
*.input.json or *.session.json.
|
the initial session state and user queries, exclusive with
|
||||||
|
saved_session_file.
|
||||||
|
saved_session_file: Optional[str], the absolute path to the json file that
|
||||||
|
contains a previously saved session, exclusive with input_file.
|
||||||
save_session: bool, whether to save the session on exit.
|
save_session: bool, whether to save the session on exit.
|
||||||
"""
|
"""
|
||||||
if agent_parent_dir not in sys.path:
|
if agent_parent_dir not in sys.path:
|
||||||
@ -118,46 +124,50 @@ async def run_cli(
|
|||||||
|
|
||||||
artifact_service = InMemoryArtifactService()
|
artifact_service = InMemoryArtifactService()
|
||||||
session_service = InMemorySessionService()
|
session_service = InMemorySessionService()
|
||||||
session = session_service.create_session(
|
|
||||||
app_name=agent_folder_name, user_id='test_user'
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
|
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
|
||||||
agent_module = importlib.import_module(agent_folder_name)
|
agent_module = importlib.import_module(agent_folder_name)
|
||||||
|
user_id = 'test_user'
|
||||||
|
session = session_service.create_session(
|
||||||
|
app_name=agent_folder_name, user_id=user_id
|
||||||
|
)
|
||||||
root_agent = agent_module.agent.root_agent
|
root_agent = agent_module.agent.root_agent
|
||||||
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
|
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
|
||||||
if json_file_path:
|
if input_file:
|
||||||
if json_file_path.endswith('.input.json'):
|
session = await run_input_file(
|
||||||
await run_input_file(
|
app_name=agent_folder_name,
|
||||||
app_name=agent_folder_name,
|
user_id=user_id,
|
||||||
root_agent=root_agent,
|
root_agent=root_agent,
|
||||||
artifact_service=artifact_service,
|
artifact_service=artifact_service,
|
||||||
session=session,
|
session_service=session_service,
|
||||||
session_service=session_service,
|
input_path=input_file,
|
||||||
input_path=json_file_path,
|
)
|
||||||
)
|
elif saved_session_file:
|
||||||
elif json_file_path.endswith('.session.json'):
|
|
||||||
with open(json_file_path, 'r') as f:
|
loaded_session = None
|
||||||
session = Session.model_validate_json(f.read())
|
with open(saved_session_file, 'r') as f:
|
||||||
for content in session.get_contents():
|
loaded_session = Session.model_validate_json(f.read())
|
||||||
if content.role == 'user':
|
|
||||||
print('user: ', content.parts[0].text)
|
if loaded_session:
|
||||||
|
for event in loaded_session.events:
|
||||||
|
session_service.append_event(session, event)
|
||||||
|
content = event.content
|
||||||
|
if not content or not content.parts or not content.parts[0].text:
|
||||||
|
continue
|
||||||
|
if event.author == 'user':
|
||||||
|
click.echo(f'[user]: {content.parts[0].text}')
|
||||||
else:
|
else:
|
||||||
print(content.parts[0].text)
|
click.echo(f'[{event.author}]: {content.parts[0].text}')
|
||||||
await run_interactively(
|
|
||||||
agent_folder_name,
|
await run_interactively(
|
||||||
root_agent,
|
root_agent,
|
||||||
artifact_service,
|
artifact_service,
|
||||||
session,
|
session,
|
||||||
session_service,
|
session_service,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(f'Unsupported file type: {json_file_path}')
|
click.echo(f'Running agent {root_agent.name}, type exit to exit.')
|
||||||
exit(1)
|
|
||||||
else:
|
|
||||||
print(f'Running agent {root_agent.name}, type exit to exit.')
|
|
||||||
await run_interactively(
|
await run_interactively(
|
||||||
agent_folder_name,
|
|
||||||
root_agent,
|
root_agent,
|
||||||
artifact_service,
|
artifact_service,
|
||||||
session,
|
session,
|
||||||
@ -165,11 +175,8 @@ async def run_cli(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if save_session:
|
if save_session:
|
||||||
if json_file_path:
|
session_id = input('Session ID to save: ')
|
||||||
session_path = json_file_path.replace('.input.json', '.session.json')
|
session_path = f'{agent_module_path}/{session_id}.session.json'
|
||||||
else:
|
|
||||||
session_id = input('Session ID to save: ')
|
|
||||||
session_path = f'{agent_module_path}/{session_id}.session.json'
|
|
||||||
|
|
||||||
# Fetch the session again to get all the details.
|
# Fetch the session again to get all the details.
|
||||||
session = session_service.get_session(
|
session = session_service.get_session(
|
||||||
|
@ -96,6 +96,23 @@ def cli_create_cmd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_exclusive(ctx, param, value):
|
||||||
|
# Store the validated parameters in the context
|
||||||
|
if not hasattr(ctx, "exclusive_opts"):
|
||||||
|
ctx.exclusive_opts = {}
|
||||||
|
|
||||||
|
# If this option has a value and we've already seen another exclusive option
|
||||||
|
if value is not None and any(ctx.exclusive_opts.values()):
|
||||||
|
exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
|
||||||
|
raise click.UsageError(
|
||||||
|
f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record this option's value
|
||||||
|
ctx.exclusive_opts[param.name] = value is not None
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
@main.command("run")
|
@main.command("run")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--save_session",
|
"--save_session",
|
||||||
@ -105,13 +122,43 @@ def cli_create_cmd(
|
|||||||
default=False,
|
default=False,
|
||||||
help="Optional. Whether to save the session to a json file on exit.",
|
help="Optional. Whether to save the session to a json file on exit.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--replay",
|
||||||
|
type=click.Path(
|
||||||
|
exists=True, dir_okay=False, file_okay=True, resolve_path=True
|
||||||
|
),
|
||||||
|
help=(
|
||||||
|
"The json file that contains the initial state of the session and user"
|
||||||
|
" queries. A new session will be created using this state. And user"
|
||||||
|
" queries are run againt the newly created session. Users cannot"
|
||||||
|
" continue to interact with the agent."
|
||||||
|
),
|
||||||
|
callback=validate_exclusive,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--resume",
|
||||||
|
type=click.Path(
|
||||||
|
exists=True, dir_okay=False, file_okay=True, resolve_path=True
|
||||||
|
),
|
||||||
|
help=(
|
||||||
|
"The json file that contains a previously saved session (by"
|
||||||
|
"--save_session option). The previous session will be re-displayed. And"
|
||||||
|
" user can continue to interact with the agent."
|
||||||
|
),
|
||||||
|
callback=validate_exclusive,
|
||||||
|
)
|
||||||
@click.argument(
|
@click.argument(
|
||||||
"agent",
|
"agent",
|
||||||
type=click.Path(
|
type=click.Path(
|
||||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def cli_run(agent: str, save_session: bool):
|
def cli_run(
|
||||||
|
agent: str,
|
||||||
|
save_session: bool,
|
||||||
|
replay: Optional[str],
|
||||||
|
resume: Optional[str],
|
||||||
|
):
|
||||||
"""Runs an interactive CLI for a certain agent.
|
"""Runs an interactive CLI for a certain agent.
|
||||||
|
|
||||||
AGENT: The path to the agent source code folder.
|
AGENT: The path to the agent source code folder.
|
||||||
@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
|
|||||||
run_cli(
|
run_cli(
|
||||||
agent_parent_dir=agent_parent_folder,
|
agent_parent_dir=agent_parent_folder,
|
||||||
agent_folder_name=agent_folder_name,
|
agent_folder_name=agent_folder_name,
|
||||||
|
input_file=replay,
|
||||||
|
saved_session_file=resume,
|
||||||
save_session=save_session,
|
save_session=save_session,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -48,8 +48,13 @@ class EventActions(BaseModel):
|
|||||||
"""The agent is escalating to a higher level agent."""
|
"""The agent is escalating to a higher level agent."""
|
||||||
|
|
||||||
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
|
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
|
||||||
"""Will only be set by a tool response indicating tool request euc.
|
"""Authentication configurations requested by tool responses.
|
||||||
dict key is the function call id since one function call response (from model)
|
|
||||||
could correspond to multiple function calls.
|
This field will only be set by a tool response event indicating tool request
|
||||||
dict value is the required auth config.
|
auth credential.
|
||||||
|
- Keys: The function call id. Since one function response event could contain
|
||||||
|
multiple function responses that correspond to multiple function calls. Each
|
||||||
|
function call could request different auth configs. This id is used to
|
||||||
|
identify the function call.
|
||||||
|
- Values: The requested auth config.
|
||||||
"""
|
"""
|
||||||
|
@ -58,6 +58,8 @@ from .state import State
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MAX_VARCHAR_LENGTH = 256
|
||||||
|
|
||||||
|
|
||||||
class DynamicJSON(TypeDecorator):
|
class DynamicJSON(TypeDecorator):
|
||||||
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
|
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
|
||||||
@ -92,17 +94,25 @@ class DynamicJSON(TypeDecorator):
|
|||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
"""Base class for database tables."""
|
"""Base class for database tables."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StorageSession(Base):
|
class StorageSession(Base):
|
||||||
"""Represents a session stored in the database."""
|
"""Represents a session stored in the database."""
|
||||||
|
|
||||||
__tablename__ = "sessions"
|
__tablename__ = "sessions"
|
||||||
|
|
||||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
app_name: Mapped[str] = mapped_column(
|
||||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
id: Mapped[str] = mapped_column(
|
id: Mapped[str] = mapped_column(
|
||||||
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
String(DEFAULT_MAX_VARCHAR_LENGTH),
|
||||||
|
primary_key=True,
|
||||||
|
default=lambda: str(uuid.uuid4()),
|
||||||
)
|
)
|
||||||
|
|
||||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
@ -125,16 +135,27 @@ class StorageSession(Base):
|
|||||||
|
|
||||||
class StorageEvent(Base):
|
class StorageEvent(Base):
|
||||||
"""Represents an event stored in the database."""
|
"""Represents an event stored in the database."""
|
||||||
|
|
||||||
__tablename__ = "events"
|
__tablename__ = "events"
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
id: Mapped[str] = mapped_column(
|
||||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
)
|
||||||
session_id: Mapped[str] = mapped_column(String, primary_key=True)
|
app_name: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
|
session_id: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
|
|
||||||
invocation_id: Mapped[str] = mapped_column(String)
|
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||||
author: Mapped[str] = mapped_column(String)
|
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||||
branch: Mapped[str] = mapped_column(String, nullable=True)
|
branch: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||||
|
)
|
||||||
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
||||||
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
||||||
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
||||||
@ -147,8 +168,10 @@ class StorageEvent(Base):
|
|||||||
)
|
)
|
||||||
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||||
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||||
error_code: Mapped[str] = mapped_column(String, nullable=True)
|
error_code: Mapped[str] = mapped_column(
|
||||||
error_message: Mapped[str] = mapped_column(String, nullable=True)
|
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||||
|
)
|
||||||
|
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
|
||||||
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||||
|
|
||||||
storage_session: Mapped[StorageSession] = relationship(
|
storage_session: Mapped[StorageSession] = relationship(
|
||||||
@ -182,9 +205,12 @@ class StorageEvent(Base):
|
|||||||
|
|
||||||
class StorageAppState(Base):
|
class StorageAppState(Base):
|
||||||
"""Represents an app state stored in the database."""
|
"""Represents an app state stored in the database."""
|
||||||
|
|
||||||
__tablename__ = "app_states"
|
__tablename__ = "app_states"
|
||||||
|
|
||||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
app_name: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
MutableDict.as_mutable(DynamicJSON), default={}
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
)
|
)
|
||||||
@ -192,13 +218,20 @@ class StorageAppState(Base):
|
|||||||
DateTime(), default=func.now(), onupdate=func.now()
|
DateTime(), default=func.now(), onupdate=func.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class StorageUserState(Base):
|
class StorageUserState(Base):
|
||||||
"""Represents a user state stored in the database."""
|
"""Represents a user state stored in the database."""
|
||||||
|
|
||||||
__tablename__ = "user_states"
|
__tablename__ = "user_states"
|
||||||
|
|
||||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
app_name: Mapped[str] = mapped_column(
|
||||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
|
)
|
||||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
MutableDict.as_mutable(DynamicJSON), default={}
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
)
|
)
|
||||||
|
@ -22,7 +22,7 @@ def adk_to_mcp_tool_type(tool: BaseTool) -> mcp_types.Tool:
|
|||||||
"""Convert a Tool in ADK into MCP tool type.
|
"""Convert a Tool in ADK into MCP tool type.
|
||||||
|
|
||||||
This function transforms an ADK tool definition into its equivalent
|
This function transforms an ADK tool definition into its equivalent
|
||||||
MCP (Model Context Protocol) representation.
|
representation in the MCP (Model Control Plane) system.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool: The ADK tool to convert. It should be an instance of a class derived
|
tool: The ADK tool to convert. It should be an instance of a class derived
|
||||||
|
Loading…
Reference in New Issue
Block a user