feat: Update for anthropic models

Enable parallel tools for anthropic models, and add agent examples, and also added functional test for anthropic models.

PiperOrigin-RevId: 766703018
This commit is contained in:
Google Team Member 2025-06-03 09:41:44 -07:00 committed by Copybara-Service
parent 44f507895e
commit 16f7d98acf
5 changed files with 312 additions and 9 deletions

View File

@ -0,0 +1,16 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@ -0,0 +1,90 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from google.adk import Agent
from google.adk.models.anthropic_llm import Claude
def roll_die(sides: int) -> int:
"""Roll a die and return the rolled result.
Args:
sides: The integer number of sides the die has.
Returns:
An integer of the result of rolling the die.
"""
return random.randint(1, sides)
async def check_prime(nums: list[int]) -> str:
"""Check if a given list of numbers are prime.
Args:
nums: The list of numbers to check.
Returns:
A str indicating which number is prime.
"""
primes = set()
for number in nums:
number = int(number)
if number <= 1:
continue
is_prime = True
for i in range(2, int(number**0.5) + 1):
if number % i == 0:
is_prime = False
break
if is_prime:
primes.add(number)
return (
"No prime numbers found."
if not primes
else f"{', '.join(str(num) for num in primes)} are prime numbers."
)
root_agent = Agent(
model=Claude(model="claude-3-5-sonnet-v2@20241022"),
name="data_processing_agent",
description=(
"hello world agent that can roll a dice of 8 sides and check prime"
" numbers."
),
instruction="""
You roll dice and answer questions about the outcome of the dice rolls.
You can roll dice of different sizes.
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
It is ok to discuss previous dice roles, and comment on the dice rolls.
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
You should never roll a die on your own.
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
You should not check prime numbers before calling the tool.
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
3. When you respond, you must include the roll_die result from step 1.
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
You should not rely on the previous history on prime results.
""",
tools=[
roll_die,
check_prime,
],
)

View File

@ -0,0 +1,76 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import time
import agent
from dotenv import load_dotenv
from google.adk import Runner
from google.adk.artifacts import InMemoryArtifactService
from google.adk.cli.utils import logs
from google.adk.sessions import InMemorySessionService
from google.adk.sessions import Session
from google.genai import types
load_dotenv(override=True)
logs.log_to_tmp_folder()
async def main():
app_name = 'my_app'
user_id_1 = 'user1'
session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name=app_name,
agent=agent.root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
session_11 = await session_service.create_session(
app_name=app_name, user_id=user_id_1
)
async def run_prompt(session: Session, new_message: str):
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
start_time = time.time()
print('Start time:', start_time)
print('------------------------------------')
await run_prompt(session_11, 'Hi, introduce yourself.')
await run_prompt(
session_11,
'Run the following request 10 times: roll a die with 100 sides and check'
' if it is prime',
)
end_time = time.time()
print('------------------------------------')
print('End time:', end_time)
print('Total time:', end_time - start_time)
if __name__ == '__main__':
asyncio.run(main())

View File

@ -135,6 +135,10 @@ def content_block_to_part(
def message_to_generate_content_response(
message: anthropic_types.Message,
) -> LlmResponse:
logger.info(
"Claude response: %s",
message.model_dump_json(indent=2, exclude_none=True),
)
return LlmResponse(
content=types.Content(
@ -229,14 +233,11 @@ class Claude(BaseLlm):
for tool in llm_request.config.tools[0].function_declarations
]
tool_choice = (
anthropic_types.ToolChoiceAutoParam(
type="auto",
# TODO: allow parallel tool use.
disable_parallel_tool_use=True,
)
anthropic_types.ToolChoiceAutoParam(type="auto")
if llm_request.tools_dict
else NOT_GIVEN
)
# TODO(b/421255973): Enable streaming for anthropic models.
message = self._anthropic_client.messages.create(
model=llm_request.model,
system=llm_request.config.system_instruction,
@ -245,10 +246,6 @@ class Claude(BaseLlm):
tool_choice=tool_choice,
max_tokens=MAX_TOKEN,
)
logger.info(
"Claude response: %s",
message.model_dump_json(indent=2, exclude_none=True),
)
yield message_to_generate_content_response(message)
@cached_property

View File

@ -0,0 +1,124 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from unittest import mock
from anthropic import types as anthropic_types
from google.adk import version as adk_version
from google.adk.models import anthropic_llm
from google.adk.models.anthropic_llm import Claude
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.genai import types
from google.genai import version as genai_version
from google.genai.types import Content
from google.genai.types import Part
import pytest
@pytest.fixture
def generate_content_response():
return anthropic_types.Message(
id="msg_vrtx_testid",
content=[
anthropic_types.TextBlock(
citations=None, text="Hi! How can I help you today?", type="text"
)
],
model="claude-3-5-sonnet-v2-20241022",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
type="message",
usage=anthropic_types.Usage(
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
input_tokens=13,
output_tokens=12,
server_tool_use=None,
service_tier=None,
),
)
@pytest.fixture
def generate_llm_response():
return LlmResponse.create(
types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
role="model",
parts=[Part.from_text(text="Hello, how can I help you?")],
),
finish_reason=types.FinishReason.STOP,
)
]
)
)
@pytest.fixture
def claude_llm():
return Claude(model="claude-3-5-sonnet-v2@20241022")
@pytest.fixture
def llm_request():
return LlmRequest(
model="claude-3-5-sonnet-v2@20241022",
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
config=types.GenerateContentConfig(
temperature=0.1,
response_modalities=[types.Modality.TEXT],
system_instruction="You are a helpful assistant",
),
)
def test_supported_models():
models = Claude.supported_models()
assert len(models) == 2
assert models[0] == r"claude-3-.*"
assert models[1] == r"claude-.*-4.*"
@pytest.mark.asyncio
async def test_generate_content_async(
claude_llm, llm_request, generate_content_response, generate_llm_response
):
with mock.patch.object(claude_llm, "_anthropic_client") as mock_client:
with mock.patch.object(
anthropic_llm,
"message_to_generate_content_response",
return_value=generate_llm_response,
):
# Create a mock coroutine that returns the generate_content_response.
async def mock_coro():
return generate_content_response
# Assign the coroutine to the mocked method
mock_client.messages.create.return_value = mock_coro()
responses = [
resp
async for resp in claude_llm.generate_content_async(
llm_request, stream=False
)
]
assert len(responses) == 1
assert isinstance(responses[0], LlmResponse)
assert responses[0].content.parts[0].text == "Hello, how can I help you?"