Update samples for async session changes and remove exp suffix from gemini model

PiperOrigin-RevId: 759364113
This commit is contained in:
Selcuk Gun 2025-05-15 17:11:32 -07:00 committed by Copybara-Service
parent 04820cb0a7
commit 01cf186299
7 changed files with 8 additions and 112 deletions

View File

@ -24,7 +24,7 @@ async def log_query(tool_context: ToolContext, query: str):
root_agent = Agent(
model='gemini-2.0-flash-exp',
model='gemini-2.0-flash',
name='log_agent',
description='Log user query.',
instruction="""Always log the user query and reploy "kk, I've logged."

View File

@ -145,7 +145,7 @@ def after_tool_cb3(tool, args, tool_context, tool_response):
root_agent = Agent(
model='gemini-2.0-flash-exp',
model='gemini-2.0-flash',
name='data_processing_agent',
description=(
'hello world agent that can roll a dice of 8 sides and check prime'

View File

@ -19,7 +19,6 @@ import warnings
import agent
from dotenv import load_dotenv
from google.adk import Runner
from google.adk.agents.run_config import RunConfig
from google.adk.artifacts import InMemoryArtifactService
from google.adk.cli.utils import logs
from google.adk.sessions import InMemorySessionService
@ -42,7 +41,7 @@ async def main():
artifact_service=artifact_service,
session_service=session_service,
)
session_11 = session_service.create_session(
session_11 = await session_service.create_session(
app_name=app_name, user_id=user_id_1
)
@ -59,25 +58,6 @@ async def main():
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
async def run_prompt_bytes(session: Session, new_message: str):
content = types.Content(
role='user',
parts=[
types.Part.from_bytes(
data=str.encode(new_message), mime_type='text/plain'
)
],
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
run_config=RunConfig(save_input_blobs_as_artifacts=True),
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
start_time = time.time()
print('Start time:', start_time)
print('------------------------------------')
@ -85,7 +65,6 @@ async def main():
await run_prompt(session_11, 'Roll a die with 100 sides')
await run_prompt(session_11, 'Roll a die again with 100 sides.')
await run_prompt(session_11, 'What numbers did I got?')
await run_prompt_bytes(session_11, 'Hi bytes')
print(
await artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id_1, session_id=session_11.id
@ -97,49 +76,5 @@ async def main():
print('Total time:', end_time - start_time)
def main_sync():
app_name = 'my_app'
user_id_1 = 'user1'
session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name=app_name,
agent=agent.root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
session_11 = session_service.create_session(
app_name=app_name, user_id=user_id_1
)
def run_prompt(session: Session, new_message: str):
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
for event in runner.run(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
start_time = time.time()
print('Start time:', start_time)
print('------------------------------------')
run_prompt(session_11, 'Hi')
run_prompt(session_11, 'Roll a die with 100 sides.')
run_prompt(session_11, 'Roll a die again with 100 sides.')
run_prompt(session_11, 'What numbers did I got?')
end_time = time.time()
print('------------------------------------')
print('End time:', end_time)
print('Total time:', end_time - start_time)
if __name__ == '__main__':
print('--------------ASYNC--------------------')
asyncio.run(main())
print('--------------SYNC--------------------')
main_sync()

View File

@ -41,7 +41,7 @@ async def main():
artifact_service=artifact_service,
session_service=session_service,
)
session_11 = session_service.create_session(app_name, user_id_1)
session_11 = await session_service.create_session(app_name, user_id_1)
async def run_prompt(session: Session, new_message: str):
content = types.Content(
@ -69,44 +69,5 @@ async def main():
print('Total time:', end_time - start_time)
def main_sync():
app_name = 'my_app'
user_id_1 = 'user1'
session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name=app_name,
agent=agent.root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
session_11 = session_service.create_session(app_name, user_id_1)
def run_prompt(session: Session, new_message: str):
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
for event in runner.run_sync(
session=session,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
start_time = time.time()
print('Start time:', start_time)
print('------------------------------------')
run_prompt(session_11, 'Hi')
run_prompt(session_11, 'Roll a die.')
run_prompt(session_11, 'Roll a die again.')
run_prompt(session_11, 'What numbers did I got?')
end_time = time.time()
print('------------------------------------')
print('End time:', end_time)
print('Total time:', end_time - start_time)
if __name__ == '__main__':
asyncio.run(main())
main_sync()

View File

@ -66,7 +66,7 @@ async def check_prime(nums: list[int]) -> str:
)
root_agent = Agent(
model='gemini-2.0-flash-exp',
model='gemini-2.0-flash',
name='data_processing_agent',
description=(
'hello world agent that can roll a dice of 8 sides and check prime'

View File

@ -28,7 +28,7 @@ def roll_die(sides: int) -> int:
roll_agent = LlmAgent(
name="roll_agent",
description="Handles rolling dice of different sizes.",
model="gemini-2.0-flash-exp",
model="gemini-2.0-flash",
instruction="""
You are responsible for rolling dice based on the user's request.
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
@ -69,7 +69,7 @@ def check_prime(nums: list[int]) -> str:
prime_agent = LlmAgent(
name="prime_agent",
description="Handles checking if numbers are prime.",
model="gemini-2.0-flash-exp",
model="gemini-2.0-flash",
instruction="""
You are responsible for checking whether numbers are prime.
When asked to check primes, you must call the check_prime tool with a list of integers.

View File

@ -42,7 +42,7 @@ async def main():
artifact_service=artifact_service,
session_service=session_service,
)
session_11 = session_service.create_session(
session_11 = await session_service.create_session(
app_name=app_name, user_id=user_id_1
)