mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
feat: Update for anthropic models
Enable parallel tools for anthropic models, and add agent examples, and also added functional test for anthropic models. PiperOrigin-RevId: 766703018
This commit is contained in:
parent
44f507895e
commit
16f7d98acf
16
contributing/samples/hello_world_anthropic/__init__.py
Normal file
16
contributing/samples/hello_world_anthropic/__init__.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from . import agent
|
90
contributing/samples/hello_world_anthropic/agent.py
Normal file
90
contributing/samples/hello_world_anthropic/agent.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
from google.adk import Agent
|
||||||
|
from google.adk.models.anthropic_llm import Claude
|
||||||
|
|
||||||
|
|
||||||
|
def roll_die(sides: int) -> int:
|
||||||
|
"""Roll a die and return the rolled result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sides: The integer number of sides the die has.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An integer of the result of rolling the die.
|
||||||
|
"""
|
||||||
|
return random.randint(1, sides)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_prime(nums: list[int]) -> str:
|
||||||
|
"""Check if a given list of numbers are prime.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nums: The list of numbers to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A str indicating which number is prime.
|
||||||
|
"""
|
||||||
|
primes = set()
|
||||||
|
for number in nums:
|
||||||
|
number = int(number)
|
||||||
|
if number <= 1:
|
||||||
|
continue
|
||||||
|
is_prime = True
|
||||||
|
for i in range(2, int(number**0.5) + 1):
|
||||||
|
if number % i == 0:
|
||||||
|
is_prime = False
|
||||||
|
break
|
||||||
|
if is_prime:
|
||||||
|
primes.add(number)
|
||||||
|
return (
|
||||||
|
"No prime numbers found."
|
||||||
|
if not primes
|
||||||
|
else f"{', '.join(str(num) for num in primes)} are prime numbers."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
root_agent = Agent(
|
||||||
|
model=Claude(model="claude-3-5-sonnet-v2@20241022"),
|
||||||
|
name="data_processing_agent",
|
||||||
|
description=(
|
||||||
|
"hello world agent that can roll a dice of 8 sides and check prime"
|
||||||
|
" numbers."
|
||||||
|
),
|
||||||
|
instruction="""
|
||||||
|
You roll dice and answer questions about the outcome of the dice rolls.
|
||||||
|
You can roll dice of different sizes.
|
||||||
|
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
|
||||||
|
It is ok to discuss previous dice roles, and comment on the dice rolls.
|
||||||
|
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
|
||||||
|
You should never roll a die on your own.
|
||||||
|
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
|
||||||
|
You should not check prime numbers before calling the tool.
|
||||||
|
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
|
||||||
|
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
|
||||||
|
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
|
||||||
|
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
|
||||||
|
3. When you respond, you must include the roll_die result from step 1.
|
||||||
|
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
|
||||||
|
You should not rely on the previous history on prime results.
|
||||||
|
""",
|
||||||
|
tools=[
|
||||||
|
roll_die,
|
||||||
|
check_prime,
|
||||||
|
],
|
||||||
|
)
|
76
contributing/samples/hello_world_anthropic/main.py
Normal file
76
contributing/samples/hello_world_anthropic/main.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
import agent
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from google.adk import Runner
|
||||||
|
from google.adk.artifacts import InMemoryArtifactService
|
||||||
|
from google.adk.cli.utils import logs
|
||||||
|
from google.adk.sessions import InMemorySessionService
|
||||||
|
from google.adk.sessions import Session
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
logs.log_to_tmp_folder()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
app_name = 'my_app'
|
||||||
|
user_id_1 = 'user1'
|
||||||
|
session_service = InMemorySessionService()
|
||||||
|
artifact_service = InMemoryArtifactService()
|
||||||
|
runner = Runner(
|
||||||
|
app_name=app_name,
|
||||||
|
agent=agent.root_agent,
|
||||||
|
artifact_service=artifact_service,
|
||||||
|
session_service=session_service,
|
||||||
|
)
|
||||||
|
session_11 = await session_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id_1
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_prompt(session: Session, new_message: str):
|
||||||
|
content = types.Content(
|
||||||
|
role='user', parts=[types.Part.from_text(text=new_message)]
|
||||||
|
)
|
||||||
|
print('** User says:', content.model_dump(exclude_none=True))
|
||||||
|
async for event in runner.run_async(
|
||||||
|
user_id=user_id_1,
|
||||||
|
session_id=session.id,
|
||||||
|
new_message=content,
|
||||||
|
):
|
||||||
|
if event.content.parts and event.content.parts[0].text:
|
||||||
|
print(f'** {event.author}: {event.content.parts[0].text}')
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
print('Start time:', start_time)
|
||||||
|
print('------------------------------------')
|
||||||
|
await run_prompt(session_11, 'Hi, introduce yourself.')
|
||||||
|
await run_prompt(
|
||||||
|
session_11,
|
||||||
|
'Run the following request 10 times: roll a die with 100 sides and check'
|
||||||
|
' if it is prime',
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
print('------------------------------------')
|
||||||
|
print('End time:', end_time)
|
||||||
|
print('Total time:', end_time - start_time)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
@ -135,6 +135,10 @@ def content_block_to_part(
|
|||||||
def message_to_generate_content_response(
|
def message_to_generate_content_response(
|
||||||
message: anthropic_types.Message,
|
message: anthropic_types.Message,
|
||||||
) -> LlmResponse:
|
) -> LlmResponse:
|
||||||
|
logger.info(
|
||||||
|
"Claude response: %s",
|
||||||
|
message.model_dump_json(indent=2, exclude_none=True),
|
||||||
|
)
|
||||||
|
|
||||||
return LlmResponse(
|
return LlmResponse(
|
||||||
content=types.Content(
|
content=types.Content(
|
||||||
@ -229,14 +233,11 @@ class Claude(BaseLlm):
|
|||||||
for tool in llm_request.config.tools[0].function_declarations
|
for tool in llm_request.config.tools[0].function_declarations
|
||||||
]
|
]
|
||||||
tool_choice = (
|
tool_choice = (
|
||||||
anthropic_types.ToolChoiceAutoParam(
|
anthropic_types.ToolChoiceAutoParam(type="auto")
|
||||||
type="auto",
|
|
||||||
# TODO: allow parallel tool use.
|
|
||||||
disable_parallel_tool_use=True,
|
|
||||||
)
|
|
||||||
if llm_request.tools_dict
|
if llm_request.tools_dict
|
||||||
else NOT_GIVEN
|
else NOT_GIVEN
|
||||||
)
|
)
|
||||||
|
# TODO(b/421255973): Enable streaming for anthropic models.
|
||||||
message = self._anthropic_client.messages.create(
|
message = self._anthropic_client.messages.create(
|
||||||
model=llm_request.model,
|
model=llm_request.model,
|
||||||
system=llm_request.config.system_instruction,
|
system=llm_request.config.system_instruction,
|
||||||
@ -245,10 +246,6 @@ class Claude(BaseLlm):
|
|||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
max_tokens=MAX_TOKEN,
|
max_tokens=MAX_TOKEN,
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
"Claude response: %s",
|
|
||||||
message.model_dump_json(indent=2, exclude_none=True),
|
|
||||||
)
|
|
||||||
yield message_to_generate_content_response(message)
|
yield message_to_generate_content_response(message)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
124
tests/unittests/models/test_anthropic_llm.py
Normal file
124
tests/unittests/models/test_anthropic_llm.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from anthropic import types as anthropic_types
|
||||||
|
from google.adk import version as adk_version
|
||||||
|
from google.adk.models import anthropic_llm
|
||||||
|
from google.adk.models.anthropic_llm import Claude
|
||||||
|
from google.adk.models.llm_request import LlmRequest
|
||||||
|
from google.adk.models.llm_response import LlmResponse
|
||||||
|
from google.genai import types
|
||||||
|
from google.genai import version as genai_version
|
||||||
|
from google.genai.types import Content
|
||||||
|
from google.genai.types import Part
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def generate_content_response():
|
||||||
|
return anthropic_types.Message(
|
||||||
|
id="msg_vrtx_testid",
|
||||||
|
content=[
|
||||||
|
anthropic_types.TextBlock(
|
||||||
|
citations=None, text="Hi! How can I help you today?", type="text"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="claude-3-5-sonnet-v2-20241022",
|
||||||
|
role="assistant",
|
||||||
|
stop_reason="end_turn",
|
||||||
|
stop_sequence=None,
|
||||||
|
type="message",
|
||||||
|
usage=anthropic_types.Usage(
|
||||||
|
cache_creation_input_tokens=0,
|
||||||
|
cache_read_input_tokens=0,
|
||||||
|
input_tokens=13,
|
||||||
|
output_tokens=12,
|
||||||
|
server_tool_use=None,
|
||||||
|
service_tier=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def generate_llm_response():
|
||||||
|
return LlmResponse.create(
|
||||||
|
types.GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
types.Candidate(
|
||||||
|
content=Content(
|
||||||
|
role="model",
|
||||||
|
parts=[Part.from_text(text="Hello, how can I help you?")],
|
||||||
|
),
|
||||||
|
finish_reason=types.FinishReason.STOP,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def claude_llm():
|
||||||
|
return Claude(model="claude-3-5-sonnet-v2@20241022")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm_request():
|
||||||
|
return LlmRequest(
|
||||||
|
model="claude-3-5-sonnet-v2@20241022",
|
||||||
|
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
|
||||||
|
config=types.GenerateContentConfig(
|
||||||
|
temperature=0.1,
|
||||||
|
response_modalities=[types.Modality.TEXT],
|
||||||
|
system_instruction="You are a helpful assistant",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_supported_models():
|
||||||
|
models = Claude.supported_models()
|
||||||
|
assert len(models) == 2
|
||||||
|
assert models[0] == r"claude-3-.*"
|
||||||
|
assert models[1] == r"claude-.*-4.*"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_content_async(
|
||||||
|
claude_llm, llm_request, generate_content_response, generate_llm_response
|
||||||
|
):
|
||||||
|
with mock.patch.object(claude_llm, "_anthropic_client") as mock_client:
|
||||||
|
with mock.patch.object(
|
||||||
|
anthropic_llm,
|
||||||
|
"message_to_generate_content_response",
|
||||||
|
return_value=generate_llm_response,
|
||||||
|
):
|
||||||
|
# Create a mock coroutine that returns the generate_content_response.
|
||||||
|
async def mock_coro():
|
||||||
|
return generate_content_response
|
||||||
|
|
||||||
|
# Assign the coroutine to the mocked method
|
||||||
|
mock_client.messages.create.return_value = mock_coro()
|
||||||
|
|
||||||
|
responses = [
|
||||||
|
resp
|
||||||
|
async for resp in claude_llm.generate_content_async(
|
||||||
|
llm_request, stream=False
|
||||||
|
)
|
||||||
|
]
|
||||||
|
assert len(responses) == 1
|
||||||
|
assert isinstance(responses[0], LlmResponse)
|
||||||
|
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
|
Loading…
Reference in New Issue
Block a user