structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,187 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import requests
|
||||
import uuid
|
||||
import os
|
||||
import tempfile
|
||||
import hashlib
|
||||
import dataclasses
|
||||
|
||||
import sentencepiece as spm
|
||||
import functools
|
||||
from sentencepiece import sentencepiece_model_pb2
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _TokenizerConfig:
|
||||
model_url: str
|
||||
model_hash: str
|
||||
|
||||
|
||||
_GEMMA_TOKENIZER = "google/gemma"
|
||||
|
||||
# SoT: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
|
||||
_GEMINI_MODEL_NAMES = ["gemini-1.0-pro", "gemini-1.5-pro", "gemini-1.5-flash"]
|
||||
_GEMINI_STABLE_MODEL_NAMES = [
|
||||
"gemini-1.0-pro-001",
|
||||
"gemini-1.0-pro-002",
|
||||
"gemini-1.5-pro-001",
|
||||
"gemini-1.5-flash-001",
|
||||
"gemini-1.5-flash-002",
|
||||
"gemini-1.5-pro-002",
|
||||
]
|
||||
|
||||
_TOKENIZERS = {
|
||||
_GEMMA_TOKENIZER: _TokenizerConfig(
|
||||
model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
|
||||
model_hash="61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _load_file(file_url_path: str) -> bytes:
|
||||
"""Loads file bytes from the given file url path."""
|
||||
resp = requests.get(file_url_path)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
|
||||
def _is_valid_model(*, model_data: bytes, expected_hash: str) -> bool:
|
||||
"""Returns true if the content is valid by checking the hash."""
|
||||
if not expected_hash:
|
||||
raise ValueError("expected_hash is required")
|
||||
return hashlib.sha256(model_data).hexdigest() == expected_hash
|
||||
|
||||
|
||||
def _maybe_remove_file(file_path: str) -> None:
|
||||
"""Removes the file if exists."""
|
||||
if not os.path.exists(file_path):
|
||||
return
|
||||
try:
|
||||
os.remove(file_path)
|
||||
except OSError:
|
||||
# Don't raise if we cannot remove file.
|
||||
pass
|
||||
|
||||
|
||||
def _maybe_load_from_cache(*, file_path: str, expected_hash: str) -> bytes:
|
||||
"""Loads the content from the cache path."""
|
||||
if not os.path.exists(file_path):
|
||||
return
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read()
|
||||
if _is_valid_model(model_data=content, expected_hash=expected_hash):
|
||||
return content
|
||||
|
||||
# Cached file corrupted.
|
||||
_maybe_remove_file(file_path)
|
||||
|
||||
|
||||
def _maybe_save_to_cache(*, cache_dir: str, cache_path: str, content: bytes) -> None:
|
||||
"""Saves the content to the cache path."""
|
||||
try:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
tmp_path = cache_dir + "." + str(uuid.uuid4()) + ".tmp"
|
||||
with open(tmp_path, "wb") as f:
|
||||
f.write(content)
|
||||
os.rename(tmp_path, cache_path)
|
||||
except OSError:
|
||||
# Don't raise if we cannot write file.
|
||||
pass
|
||||
|
||||
|
||||
def _load_from_url(*, file_url: str, expected_hash: str) -> bytes:
|
||||
"""Loads model bytes from the given file url."""
|
||||
content = _load_file(file_url)
|
||||
if not _is_valid_model(model_data=content, expected_hash=expected_hash):
|
||||
actual_hash = hashlib.sha256(content).hexdigest()
|
||||
raise ValueError(
|
||||
f"Downloaded model file is corrupted."
|
||||
f" Expected hash {expected_hash}. Got file hash {actual_hash}."
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def _load(*, file_url: str, expected_hash: str) -> bytes:
|
||||
"""Loads model bytes from the given file url.
|
||||
|
||||
1. If the find local cached file for the given url and the cached file hash
|
||||
matches the expected hash, the cached file is returned.
|
||||
2. If local cached file is not found or the hash does not match, the file is
|
||||
downloaded from the given url. And write to local cache and return the
|
||||
file bytes.
|
||||
3. If the file downloaded from the given url does not match the expected
|
||||
hash, raise ValueError.
|
||||
|
||||
Args:
|
||||
file_url: The url of the file to load.
|
||||
expected_hash: The expected hash of the file.
|
||||
|
||||
Returns:
|
||||
The file bytes.
|
||||
"""
|
||||
model_dir = os.path.join(tempfile.gettempdir(), "vertexai_tokenizer_model")
|
||||
filename = hashlib.sha1(file_url.encode()).hexdigest()
|
||||
model_path = os.path.join(model_dir, filename)
|
||||
|
||||
model_data = _maybe_load_from_cache(
|
||||
file_path=model_path, expected_hash=expected_hash
|
||||
)
|
||||
if not model_data:
|
||||
model_data = _load_from_url(file_url=file_url, expected_hash=expected_hash)
|
||||
|
||||
_maybe_save_to_cache(cache_dir=model_dir, cache_path=model_path, content=model_data)
|
||||
return model_data
|
||||
|
||||
|
||||
def _load_model_proto_bytes(tokenizer_name: str) -> bytes:
|
||||
"""Loads model proto bytes from the given tokenizer name."""
|
||||
if tokenizer_name not in _TOKENIZERS:
|
||||
raise ValueError(
|
||||
f"Tokenizer {tokenizer_name} is not supported."
|
||||
f"Supported tokenizers: {list(_TOKENIZERS.keys())}"
|
||||
)
|
||||
return _load(
|
||||
file_url=_TOKENIZERS[tokenizer_name].model_url,
|
||||
expected_hash=_TOKENIZERS[tokenizer_name].model_hash,
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def load_model_proto(tokenizer_name) -> sentencepiece_model_pb2.ModelProto:
|
||||
"""Loads model proto from the given tokenizer name."""
|
||||
model_proto = sentencepiece_model_pb2.ModelProto()
|
||||
model_proto.ParseFromString(_load_model_proto_bytes(tokenizer_name))
|
||||
return model_proto
|
||||
|
||||
|
||||
def get_tokenizer_name(model_name: str):
|
||||
"""Gets the tokenizer name for the given model name."""
|
||||
if model_name in _GEMINI_MODEL_NAMES:
|
||||
return _GEMMA_TOKENIZER
|
||||
if model_name in _GEMINI_STABLE_MODEL_NAMES:
|
||||
return _GEMMA_TOKENIZER
|
||||
raise ValueError(
|
||||
f"Model {model_name} is not supported. Supported models: {', '.join(_GEMINI_STABLE_MODEL_NAMES)}.\n" # pylint: disable=line-too-long
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_sentencepiece(tokenizer_name: str) -> spm.SentencePieceProcessor:
|
||||
"""Loads sentencepiece tokenizer from the given tokenizer name."""
|
||||
processor = spm.SentencePieceProcessor()
|
||||
processor.LoadFromSerializedProto(_load_model_proto_bytes(tokenizer_name))
|
||||
return processor
|
||||
@@ -0,0 +1,581 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import dataclasses
|
||||
from typing import (
|
||||
Iterable,
|
||||
List,
|
||||
Sequence,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from vertexai.generative_models._generative_models import (
|
||||
ContentsType,
|
||||
Image,
|
||||
Tool,
|
||||
PartsType,
|
||||
_validate_contents_type_as_valid_sequence,
|
||||
_content_types_to_gapic_contents,
|
||||
_to_content,
|
||||
)
|
||||
|
||||
from vertexai.tokenization._tokenizer_loading import (
|
||||
get_sentencepiece,
|
||||
get_tokenizer_name,
|
||||
load_model_proto,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
content as gapic_content_types,
|
||||
tool as gapic_tool_types,
|
||||
openapi,
|
||||
)
|
||||
from sentencepiece import sentencepiece_model_pb2
|
||||
from google.protobuf import struct_pb2
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TokensInfo:
|
||||
token_ids: Sequence[int]
|
||||
tokens: Sequence[bytes]
|
||||
role: str = None
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ComputeTokensResult:
|
||||
"""Represents token string pieces and ids output in compute_tokens function.
|
||||
|
||||
Attributes:
|
||||
tokens_info: Lists of tokens_info from the input.
|
||||
The input `contents: ContentsType` could have
|
||||
multiple string instances and each tokens_info
|
||||
item represents each string instance. Each token
|
||||
info consists tokens list, token_ids list and
|
||||
a role.
|
||||
"""
|
||||
|
||||
tokens_info: Sequence[TokensInfo]
|
||||
|
||||
|
||||
class PreviewComputeTokensResult(ComputeTokensResult):
|
||||
def token_info_list(self) -> Sequence[TokensInfo]:
|
||||
import warnings
|
||||
|
||||
message = "PreviewComputeTokensResult.token_info_list is deprecated. Use ComputeTokensResult.tokens_info instead."
|
||||
warnings.warn(message, DeprecationWarning, stacklevel=2)
|
||||
return self.tokens_info
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CountTokensResult:
|
||||
"""Represents an token numbers output in count_tokens function.
|
||||
|
||||
Attributes:
|
||||
total_tokens: number of total tokens.
|
||||
"""
|
||||
|
||||
total_tokens: int
|
||||
|
||||
|
||||
def _parse_hex_byte(token: str) -> int:
|
||||
"""Parses a hex byte string of the form '<0xXX>' and returns the integer value.
|
||||
|
||||
Raises ValueError if the input is malformed or the byte value is invalid.
|
||||
"""
|
||||
|
||||
if len(token) != 6:
|
||||
raise ValueError(f"Invalid byte length: {token}")
|
||||
if not token.startswith("<0x") or not token.endswith(">"):
|
||||
raise ValueError(f"Invalid byte format: {token}")
|
||||
|
||||
try:
|
||||
val = int(token[3:5], 16) # Parse the hex part directly
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid hex value: {token}")
|
||||
|
||||
if val >= 256:
|
||||
raise ValueError(f"Byte value out of range: {token}")
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def _token_str_to_bytes(
|
||||
token: str, type: sentencepiece_model_pb2.ModelProto.SentencePiece.Type
|
||||
) -> bytes:
|
||||
if type == sentencepiece_model_pb2.ModelProto.SentencePiece.Type.BYTE:
|
||||
return _parse_hex_byte(token).to_bytes(length=1, byteorder="big")
|
||||
else:
|
||||
return token.replace("▁", " ").encode("utf-8")
|
||||
|
||||
|
||||
class _SentencePieceAdaptor:
|
||||
r"""An internal tokenizer that can parse text input into tokens."""
|
||||
|
||||
def __init__(self, tokenizer_name: str):
|
||||
r"""Initializes the tokenizer.
|
||||
|
||||
Args:
|
||||
name: The name of the tokenizer.
|
||||
"""
|
||||
self._model_proto = load_model_proto(tokenizer_name)
|
||||
self._tokenizer = get_sentencepiece(tokenizer_name)
|
||||
|
||||
def count_tokens(self, contents: Iterable[str]) -> CountTokensResult:
|
||||
r"""Counts the number of tokens in the input."""
|
||||
tokens_list = self._tokenizer.encode(list(contents))
|
||||
|
||||
return CountTokensResult(
|
||||
total_tokens=sum(len(tokens) for tokens in tokens_list)
|
||||
)
|
||||
|
||||
def compute_tokens(
|
||||
self, *, contents: Iterable[str], roles: Iterable[str]
|
||||
) -> ComputeTokensResult:
|
||||
"""Computes the tokens ids and string pieces in the input."""
|
||||
content_list = list(contents)
|
||||
tokens_protos = self._tokenizer.EncodeAsImmutableProto(content_list)
|
||||
roles = list(roles)
|
||||
|
||||
token_infos = []
|
||||
for tokens_proto, role in zip(tokens_protos, roles):
|
||||
token_infos.append(
|
||||
TokensInfo(
|
||||
token_ids=[piece.id for piece in tokens_proto.pieces],
|
||||
tokens=[
|
||||
_token_str_to_bytes(
|
||||
piece.piece, self._model_proto.pieces[piece.id].type
|
||||
)
|
||||
for piece in tokens_proto.pieces
|
||||
],
|
||||
role=role,
|
||||
)
|
||||
)
|
||||
return ComputeTokensResult(tokens_info=token_infos)
|
||||
|
||||
|
||||
def _to_gapic_contents(
|
||||
contents: ContentsType,
|
||||
) -> List[gapic_content_types.Content]:
|
||||
"""Converts a GenerativeModel compatible contents type to a gapic content."""
|
||||
_validate_contents_type_as_valid_sequence(contents)
|
||||
_assert_no_image_contents_type(contents)
|
||||
gapic_contents = _content_types_to_gapic_contents(contents)
|
||||
# _assert_text_only_content_types_sequence(gapic_contents)
|
||||
return gapic_contents
|
||||
|
||||
|
||||
def _content_types_to_role_iterator(contents: ContentsType) -> Iterable[str]:
|
||||
gapic_contents = _to_gapic_contents(contents)
|
||||
# Flattening role by content's multi parts
|
||||
for content in gapic_contents:
|
||||
for part in content.parts:
|
||||
yield content.role
|
||||
|
||||
|
||||
def _assert_no_image_contents_type(contents: ContentsType):
|
||||
"""Asserts that the contents type does not contain any image content."""
|
||||
if isinstance(contents, Image) or (
|
||||
isinstance(contents, Sequence)
|
||||
and any(isinstance(content, Image) for content in contents)
|
||||
):
|
||||
raise ValueError("Tokenizers do not support Image content type.")
|
||||
|
||||
|
||||
def _is_string_inputs(contents: ContentsType) -> bool:
|
||||
return (
|
||||
isinstance(contents, str)
|
||||
or isinstance(contents, Sequence)
|
||||
and all(isinstance(content, str) for content in contents)
|
||||
)
|
||||
|
||||
|
||||
def _to_canonical_roles(contents: ContentsType) -> Iterable[str]:
|
||||
if isinstance(contents, str):
|
||||
yield "user"
|
||||
elif isinstance(contents, Sequence) and all(
|
||||
isinstance(content, str) for content in contents
|
||||
):
|
||||
yield from ["user"] * len(contents)
|
||||
else:
|
||||
yield from _content_types_to_role_iterator(contents)
|
||||
|
||||
|
||||
class _TextsAccumulator:
|
||||
"""Accumulates texts from contents and tools.
|
||||
|
||||
This class is used to accumulate countable texts from contents and tools.
|
||||
When user passes a unsupported fields that are added in the future, the new
|
||||
fields might be only counted in remote tokenizer. In this case, the local
|
||||
tokenizer should know that an unsupported new field exist in the content or
|
||||
tool instances and raise error to avoid returning incorrect result to users.
|
||||
|
||||
The machanism to detect unsupported fields introduced in the future: when
|
||||
local tokenizer traversing the input instances, it is allowlist based text
|
||||
accumulation. When a field is traversed and evaluated to be countable, the
|
||||
value of this
|
||||
field is copied to two places: (1) self._texts for inputs to sentencepiece
|
||||
token count function. (2) a counted instance object in the recursive
|
||||
function's return value. That's to say, after done current recurssion,
|
||||
the instance(of same type as the input) only keeps the counted values.
|
||||
If user sets unsupported future proto fields, they can be detected by
|
||||
comparing the input instances equal to counted instances or not.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._texts = []
|
||||
|
||||
def get_texts(self) -> Iterable[str]:
|
||||
return self._texts
|
||||
|
||||
def add_texts(self, texts: Union[Iterable[str], str]) -> None:
|
||||
if isinstance(texts, str):
|
||||
self._texts.append(texts)
|
||||
else:
|
||||
self._texts.extend(texts)
|
||||
|
||||
def add_content(self, content: gapic_content_types.Content) -> None:
|
||||
counted_content = gapic_content_types.Content()
|
||||
for part in content.parts:
|
||||
counted_part = gapic_content_types.Part()
|
||||
if "file_data" in part or "inline_data" in part:
|
||||
raise ValueError("Tokenizers do not support non-text content types.")
|
||||
if "video_metadata" in part:
|
||||
counted_part.video_metadata = part.video_metadata
|
||||
if "function_call" in part:
|
||||
self.add_function_call(part.function_call)
|
||||
counted_part.function_call = part.function_call
|
||||
if "function_response" in part:
|
||||
self.add_function_response(part.function_response)
|
||||
counted_part.function_response = part.function_response
|
||||
if "text" in part:
|
||||
counted_part.text = part.text
|
||||
self._texts.append(part.text)
|
||||
counted_content.parts.append(counted_part)
|
||||
counted_content.role = content.role
|
||||
if content._pb != counted_content._pb:
|
||||
raise ValueError(
|
||||
f"Content contains unsupported types for token counting. Supported fields {counted_content}. Got {content}."
|
||||
)
|
||||
|
||||
def add_function_call(self, function_call: gapic_tool_types.FunctionCall) -> None:
|
||||
"""Processes a function call and adds relevant text to the accumulator.
|
||||
|
||||
Args:
|
||||
function_call: The function call to process.
|
||||
"""
|
||||
self._texts.append(function_call.name)
|
||||
counted_function_call = gapic_tool_types.FunctionCall(name=function_call.name)
|
||||
counted_struct = self._struct_traverse(function_call._pb.args)
|
||||
counted_function_call.args = counted_struct
|
||||
if counted_function_call._pb != function_call._pb:
|
||||
raise ValueError(
|
||||
f"Function call argument contains unsupported types for token counting. Supported fields {counted_function_call}. Got {function_call}."
|
||||
)
|
||||
|
||||
def add_function_calls(
|
||||
self, function_calls: Iterable[gapic_tool_types.FunctionCall]
|
||||
) -> None:
|
||||
for function_call in function_calls:
|
||||
self.add_function_call(function_call)
|
||||
|
||||
def add_tool(self, tool: gapic_tool_types.Tool) -> gapic_tool_types.Tool:
|
||||
counted_tool = gapic_tool_types.Tool()
|
||||
for function_declaration in tool.function_declarations:
|
||||
counted_function_declaration = self._function_declaration_traverse(
|
||||
function_declaration
|
||||
)
|
||||
counted_tool.function_declarations.append(counted_function_declaration)
|
||||
if counted_tool._pb != tool._pb:
|
||||
raise ValueError(
|
||||
f"Tool argument contains unsupported types for token counting. Supported fields {counted_tool}. Got {tool}."
|
||||
)
|
||||
|
||||
def add_tools(self, tools: Iterable[gapic_tool_types.Tool]) -> None:
|
||||
for tool in tools:
|
||||
self.add_tool(tool)
|
||||
|
||||
def add_function_responses(
|
||||
self, function_responses: Iterable[gapic_tool_types.FunctionResponse]
|
||||
) -> None:
|
||||
for function_response in function_responses:
|
||||
self.add_function_response(function_response)
|
||||
|
||||
def add_function_response(
|
||||
self, function_response: gapic_tool_types.FunctionResponse
|
||||
) -> None:
|
||||
counted_function_response = gapic_tool_types.FunctionResponse()
|
||||
self._texts.append(function_response.name)
|
||||
counted_struct = self._struct_traverse(function_response._pb.response)
|
||||
counted_function_response.name = function_response.name
|
||||
counted_function_response.response = counted_struct
|
||||
if counted_function_response._pb != function_response._pb:
|
||||
raise ValueError(
|
||||
f"Function response argument contains unsupported types for token counting. Supported fields {counted_function_response}. Got {function_response}."
|
||||
)
|
||||
|
||||
def _function_declaration_traverse(
|
||||
self, function_declaration: gapic_tool_types.FunctionDeclaration
|
||||
) -> gapic_tool_types.FunctionDeclaration:
|
||||
counted_function_declaration = gapic_tool_types.FunctionDeclaration()
|
||||
self._texts.append(function_declaration.name)
|
||||
counted_function_declaration.name = function_declaration.name
|
||||
if function_declaration.description:
|
||||
self._texts.append(function_declaration.description)
|
||||
counted_function_declaration.description = function_declaration.description
|
||||
if function_declaration.parameters:
|
||||
counted_parameters = self._schema_traverse(function_declaration.parameters)
|
||||
counted_function_declaration.parameters = counted_parameters
|
||||
if function_declaration.response:
|
||||
counted_response = self._schema_traverse(function_declaration.response)
|
||||
counted_function_declaration.response = counted_response
|
||||
return counted_function_declaration
|
||||
|
||||
def _schema_traverse(self, schema: openapi.Schema) -> openapi.Schema:
|
||||
"""Processes a schema and adds relevant text to the accumulator.
|
||||
|
||||
Args:
|
||||
schema: The schema to process.
|
||||
|
||||
Returns:
|
||||
The new schema object with only countable fields.
|
||||
"""
|
||||
counted_schema = openapi.Schema()
|
||||
if "type_" in schema:
|
||||
counted_schema.type = schema.type
|
||||
if "title" in schema:
|
||||
counted_schema.title = schema.title
|
||||
if "default" in schema:
|
||||
counted_schema.default = schema.default
|
||||
if "format_" in schema:
|
||||
self._texts.append(schema.format_)
|
||||
counted_schema.format_ = schema.format_
|
||||
if "description" in schema:
|
||||
self._texts.append(schema.description)
|
||||
counted_schema.description = schema.description
|
||||
if "enum" in schema:
|
||||
self._texts.extend(schema.enum)
|
||||
counted_schema.enum = schema.enum
|
||||
if "required" in schema:
|
||||
self._texts.extend(schema.required)
|
||||
counted_schema.required = schema.required
|
||||
if "property_ordering" in schema:
|
||||
counted_schema.property_ordering = schema.property_ordering
|
||||
if "items" in schema:
|
||||
counted_schema_items = self._schema_traverse(schema.items)
|
||||
counted_schema.items = counted_schema_items
|
||||
if "properties" in schema:
|
||||
d = {}
|
||||
for key, value in schema.properties.items():
|
||||
self._texts.append(key)
|
||||
counted_value = self._schema_traverse(value)
|
||||
d[key] = counted_value
|
||||
counted_schema.properties.update(d)
|
||||
if "example" in schema:
|
||||
counted_schema_example = self._value_traverse(schema._pb.example)
|
||||
counted_schema.example = counted_schema_example
|
||||
return counted_schema
|
||||
|
||||
def _struct_traverse(self, struct: struct_pb2.Struct) -> struct_pb2.Struct:
|
||||
"""Processes a struct and adds relevant text to the accumulator.
|
||||
|
||||
Args:
|
||||
struct: The struct to process.
|
||||
|
||||
Returns:
|
||||
The new struct object with only countable fields.
|
||||
"""
|
||||
counted_struct = struct_pb2.Struct()
|
||||
self._texts.extend(list(struct.fields.keys()))
|
||||
for key, val in struct.fields.items():
|
||||
counted_struct_fields = self._value_traverse(val)
|
||||
if isinstance(counted_struct_fields, str):
|
||||
counted_struct.fields[key] = counted_struct_fields
|
||||
else:
|
||||
counted_struct.fields[key].MergeFrom(counted_struct_fields)
|
||||
return counted_struct
|
||||
|
||||
def _value_traverse(self, value: struct_pb2.Value) -> struct_pb2.Value:
|
||||
"""Processes a struct field and adds relevant text to the accumulator.
|
||||
|
||||
Args:
|
||||
struct: The struct field to process.
|
||||
|
||||
Returns:
|
||||
The new struct field object with only countable fields.
|
||||
"""
|
||||
kind = value.WhichOneof("kind")
|
||||
counted_value = struct_pb2.Value()
|
||||
if kind == "string_value":
|
||||
self._texts.append(value.string_value)
|
||||
counted_value.string_value = value.string_value
|
||||
elif kind == "struct_value":
|
||||
counted_struct = self._struct_traverse(value.struct_value)
|
||||
counted_value.struct_value.MergeFrom(counted_struct)
|
||||
elif kind == "list_value":
|
||||
counted_list_value = struct_pb2.ListValue()
|
||||
for item in value.list_value.values:
|
||||
counted_value = self._value_traverse(item)
|
||||
counted_list_value.values.append(counted_value)
|
||||
counted_value.list_value.MergeFrom(counted_list_value)
|
||||
return counted_value
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""A tokenizer that can parse text into tokens."""
|
||||
|
||||
def __init__(self, tokenizer_name: str):
|
||||
"""Initializes the tokenizer.
|
||||
|
||||
Do not use this constructor directly. Use get_tokenizer_for_model instead.
|
||||
|
||||
Args:
|
||||
name: The name of the tokenizer.
|
||||
|
||||
"""
|
||||
self._sentencepiece_adapter = _SentencePieceAdaptor(tokenizer_name)
|
||||
|
||||
def count_tokens(
|
||||
self,
|
||||
contents: ContentsType,
|
||||
*,
|
||||
tools: Optional[List["Tool"]] = None,
|
||||
system_instruction: Optional[PartsType] = None,
|
||||
) -> CountTokensResult:
|
||||
r"""Counts the number of tokens in the text-only contents.
|
||||
|
||||
Args:
|
||||
contents: The contents to count tokens for.
|
||||
Supports either a list of Content objects (passing a multi-turn
|
||||
conversation) or a value that can be converted to a single
|
||||
Content object (passing a single message).
|
||||
Supports
|
||||
* str, Part,
|
||||
* List[Union[str, Part]],
|
||||
* List[Content]
|
||||
Throws an error if the contents contain non-text content.
|
||||
tools: A list of tools (functions) that the model can try calling.
|
||||
system_instruction: The provided system instructions for the model.
|
||||
Note: only text should be used in parts and content in each part
|
||||
will be in a separate paragraph.
|
||||
|
||||
Returns:
|
||||
A CountTokensResult object containing the total number of tokens in
|
||||
the contents.
|
||||
"""
|
||||
|
||||
text_accumulator = _TextsAccumulator()
|
||||
if _is_string_inputs(contents):
|
||||
text_accumulator.add_texts(contents)
|
||||
else:
|
||||
gapic_contents = _to_gapic_contents(contents)
|
||||
for content in gapic_contents:
|
||||
text_accumulator.add_content(content)
|
||||
|
||||
if tools:
|
||||
text_accumulator.add_tools((tool._raw_tool for tool in tools))
|
||||
|
||||
if system_instruction:
|
||||
if _is_string_inputs(system_instruction):
|
||||
text_accumulator.add_texts(system_instruction)
|
||||
else:
|
||||
text_accumulator.add_content(_to_content(system_instruction))
|
||||
|
||||
return self._sentencepiece_adapter.count_tokens(text_accumulator.get_texts())
|
||||
|
||||
def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult:
|
||||
r"""Computes the tokens ids and string pieces in the text-only contents.
|
||||
|
||||
Args:
|
||||
contents: The contents to count tokens for.
|
||||
Supports either a list of Content objects (passing a multi-turn
|
||||
conversation) or a value that can be converted to a single
|
||||
Content object (passing a single message).
|
||||
Supports
|
||||
* str, Part,
|
||||
* List[Union[str, Part]],
|
||||
* List[Content]
|
||||
Throws an error if the contents contain non-text content.
|
||||
|
||||
Returns:
|
||||
A ComputeTokensResult object containing the tokens ids and string
|
||||
pieces in the contents.
|
||||
|
||||
Examples:
|
||||
compute_tokens(["hello world", "what's the weather today"])
|
||||
outputs:
|
||||
ComputeTokensResult(tokens_info=[TokensInfo(token_ids=[17534, 2134], tokens=[b'hello', b' world'], role='user'), TokensInfo(token_ids=[84264, 235341], tokens=[b'Goodbye', b'!'], role='user')], token_info_list=...The same as tokens_info)
|
||||
|
||||
"""
|
||||
text_accumulator = _TextsAccumulator()
|
||||
if _is_string_inputs(contents):
|
||||
text_accumulator.add_texts(contents)
|
||||
else:
|
||||
gapic_contents = _to_gapic_contents(contents)
|
||||
for content in gapic_contents:
|
||||
text_accumulator.add_content(content)
|
||||
|
||||
return self._sentencepiece_adapter.compute_tokens(
|
||||
contents=text_accumulator.get_texts(),
|
||||
roles=_to_canonical_roles(contents),
|
||||
)
|
||||
|
||||
|
||||
class PreviewTokenizer(Tokenizer):
|
||||
def compute_tokens(self, contents: ContentsType) -> PreviewComputeTokensResult:
|
||||
return PreviewComputeTokensResult(
|
||||
tokens_info=super().compute_tokens(contents).tokens_info
|
||||
)
|
||||
|
||||
|
||||
def _get_tokenizer_for_model_preview(model_name: str) -> PreviewTokenizer:
|
||||
"""Returns a tokenizer for the given tokenizer name.
|
||||
|
||||
Usage:
|
||||
```
|
||||
tokenizer = get_tokenizer_for_model("gemini-1.5-pro-001")
|
||||
print(tokenizer.count_tokens("Hello world!"))
|
||||
```
|
||||
|
||||
Supported models can be found at
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models.
|
||||
|
||||
Args:
|
||||
model_name: Specify the tokenizer is from which model.
|
||||
"""
|
||||
if not model_name:
|
||||
raise ValueError("model_name must not be empty.")
|
||||
|
||||
return PreviewTokenizer(get_tokenizer_name(model_name))
|
||||
|
||||
|
||||
def get_tokenizer_for_model(model_name: str) -> Tokenizer:
|
||||
"""Returns a tokenizer for the given tokenizer name.
|
||||
|
||||
Usage:
|
||||
```
|
||||
tokenizer = get_tokenizer_for_model("gemini-1.5-pro-001")
|
||||
print(tokenizer.count_tokens("Hello world!"))
|
||||
```
|
||||
|
||||
Supported models can be found at
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models.
|
||||
|
||||
Args:
|
||||
model_name: Specify the tokenizer is from which model.
|
||||
"""
|
||||
if not model_name:
|
||||
raise ValueError("model_name must not be empty.")
|
||||
|
||||
return Tokenizer(get_tokenizer_name(model_name))
|
||||
Reference in New Issue
Block a user