Files
evo-ai/.venv/lib/python3.10/site-packages/vertexai/tokenization/_tokenizers.py
2025-04-25 15:30:54 -03:00

582 lines
22 KiB
Python

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
from typing import (
Iterable,
List,
Sequence,
Optional,
Union,
)
from vertexai.generative_models._generative_models import (
ContentsType,
Image,
Tool,
PartsType,
_validate_contents_type_as_valid_sequence,
_content_types_to_gapic_contents,
_to_content,
)
from vertexai.tokenization._tokenizer_loading import (
get_sentencepiece,
get_tokenizer_name,
load_model_proto,
)
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types,
tool as gapic_tool_types,
openapi,
)
from sentencepiece import sentencepiece_model_pb2
from google.protobuf import struct_pb2
@dataclasses.dataclass(frozen=True)
class TokensInfo:
token_ids: Sequence[int]
tokens: Sequence[bytes]
role: str = None
@dataclasses.dataclass(frozen=True)
class ComputeTokensResult:
"""Represents token string pieces and ids output in compute_tokens function.
Attributes:
tokens_info: Lists of tokens_info from the input.
The input `contents: ContentsType` could have
multiple string instances and each tokens_info
item represents each string instance. Each token
info consists tokens list, token_ids list and
a role.
"""
tokens_info: Sequence[TokensInfo]
class PreviewComputeTokensResult(ComputeTokensResult):
def token_info_list(self) -> Sequence[TokensInfo]:
import warnings
message = "PreviewComputeTokensResult.token_info_list is deprecated. Use ComputeTokensResult.tokens_info instead."
warnings.warn(message, DeprecationWarning, stacklevel=2)
return self.tokens_info
@dataclasses.dataclass(frozen=True)
class CountTokensResult:
"""Represents an token numbers output in count_tokens function.
Attributes:
total_tokens: number of total tokens.
"""
total_tokens: int
def _parse_hex_byte(token: str) -> int:
"""Parses a hex byte string of the form '<0xXX>' and returns the integer value.
Raises ValueError if the input is malformed or the byte value is invalid.
"""
if len(token) != 6:
raise ValueError(f"Invalid byte length: {token}")
if not token.startswith("<0x") or not token.endswith(">"):
raise ValueError(f"Invalid byte format: {token}")
try:
val = int(token[3:5], 16) # Parse the hex part directly
except ValueError:
raise ValueError(f"Invalid hex value: {token}")
if val >= 256:
raise ValueError(f"Byte value out of range: {token}")
return val
def _token_str_to_bytes(
token: str, type: sentencepiece_model_pb2.ModelProto.SentencePiece.Type
) -> bytes:
if type == sentencepiece_model_pb2.ModelProto.SentencePiece.Type.BYTE:
return _parse_hex_byte(token).to_bytes(length=1, byteorder="big")
else:
return token.replace("", " ").encode("utf-8")
class _SentencePieceAdaptor:
r"""An internal tokenizer that can parse text input into tokens."""
def __init__(self, tokenizer_name: str):
r"""Initializes the tokenizer.
Args:
name: The name of the tokenizer.
"""
self._model_proto = load_model_proto(tokenizer_name)
self._tokenizer = get_sentencepiece(tokenizer_name)
def count_tokens(self, contents: Iterable[str]) -> CountTokensResult:
r"""Counts the number of tokens in the input."""
tokens_list = self._tokenizer.encode(list(contents))
return CountTokensResult(
total_tokens=sum(len(tokens) for tokens in tokens_list)
)
def compute_tokens(
self, *, contents: Iterable[str], roles: Iterable[str]
) -> ComputeTokensResult:
"""Computes the tokens ids and string pieces in the input."""
content_list = list(contents)
tokens_protos = self._tokenizer.EncodeAsImmutableProto(content_list)
roles = list(roles)
token_infos = []
for tokens_proto, role in zip(tokens_protos, roles):
token_infos.append(
TokensInfo(
token_ids=[piece.id for piece in tokens_proto.pieces],
tokens=[
_token_str_to_bytes(
piece.piece, self._model_proto.pieces[piece.id].type
)
for piece in tokens_proto.pieces
],
role=role,
)
)
return ComputeTokensResult(tokens_info=token_infos)
def _to_gapic_contents(
contents: ContentsType,
) -> List[gapic_content_types.Content]:
"""Converts a GenerativeModel compatible contents type to a gapic content."""
_validate_contents_type_as_valid_sequence(contents)
_assert_no_image_contents_type(contents)
gapic_contents = _content_types_to_gapic_contents(contents)
# _assert_text_only_content_types_sequence(gapic_contents)
return gapic_contents
def _content_types_to_role_iterator(contents: ContentsType) -> Iterable[str]:
gapic_contents = _to_gapic_contents(contents)
# Flattening role by content's multi parts
for content in gapic_contents:
for part in content.parts:
yield content.role
def _assert_no_image_contents_type(contents: ContentsType):
"""Asserts that the contents type does not contain any image content."""
if isinstance(contents, Image) or (
isinstance(contents, Sequence)
and any(isinstance(content, Image) for content in contents)
):
raise ValueError("Tokenizers do not support Image content type.")
def _is_string_inputs(contents: ContentsType) -> bool:
return (
isinstance(contents, str)
or isinstance(contents, Sequence)
and all(isinstance(content, str) for content in contents)
)
def _to_canonical_roles(contents: ContentsType) -> Iterable[str]:
if isinstance(contents, str):
yield "user"
elif isinstance(contents, Sequence) and all(
isinstance(content, str) for content in contents
):
yield from ["user"] * len(contents)
else:
yield from _content_types_to_role_iterator(contents)
class _TextsAccumulator:
"""Accumulates texts from contents and tools.
This class is used to accumulate countable texts from contents and tools.
When user passes a unsupported fields that are added in the future, the new
fields might be only counted in remote tokenizer. In this case, the local
tokenizer should know that an unsupported new field exist in the content or
tool instances and raise error to avoid returning incorrect result to users.
The machanism to detect unsupported fields introduced in the future: when
local tokenizer traversing the input instances, it is allowlist based text
accumulation. When a field is traversed and evaluated to be countable, the
value of this
field is copied to two places: (1) self._texts for inputs to sentencepiece
token count function. (2) a counted instance object in the recursive
function's return value. That's to say, after done current recurssion,
the instance(of same type as the input) only keeps the counted values.
If user sets unsupported future proto fields, they can be detected by
comparing the input instances equal to counted instances or not.
"""
def __init__(self):
self._texts = []
def get_texts(self) -> Iterable[str]:
return self._texts
def add_texts(self, texts: Union[Iterable[str], str]) -> None:
if isinstance(texts, str):
self._texts.append(texts)
else:
self._texts.extend(texts)
def add_content(self, content: gapic_content_types.Content) -> None:
counted_content = gapic_content_types.Content()
for part in content.parts:
counted_part = gapic_content_types.Part()
if "file_data" in part or "inline_data" in part:
raise ValueError("Tokenizers do not support non-text content types.")
if "video_metadata" in part:
counted_part.video_metadata = part.video_metadata
if "function_call" in part:
self.add_function_call(part.function_call)
counted_part.function_call = part.function_call
if "function_response" in part:
self.add_function_response(part.function_response)
counted_part.function_response = part.function_response
if "text" in part:
counted_part.text = part.text
self._texts.append(part.text)
counted_content.parts.append(counted_part)
counted_content.role = content.role
if content._pb != counted_content._pb:
raise ValueError(
f"Content contains unsupported types for token counting. Supported fields {counted_content}. Got {content}."
)
def add_function_call(self, function_call: gapic_tool_types.FunctionCall) -> None:
"""Processes a function call and adds relevant text to the accumulator.
Args:
function_call: The function call to process.
"""
self._texts.append(function_call.name)
counted_function_call = gapic_tool_types.FunctionCall(name=function_call.name)
counted_struct = self._struct_traverse(function_call._pb.args)
counted_function_call.args = counted_struct
if counted_function_call._pb != function_call._pb:
raise ValueError(
f"Function call argument contains unsupported types for token counting. Supported fields {counted_function_call}. Got {function_call}."
)
def add_function_calls(
self, function_calls: Iterable[gapic_tool_types.FunctionCall]
) -> None:
for function_call in function_calls:
self.add_function_call(function_call)
def add_tool(self, tool: gapic_tool_types.Tool) -> gapic_tool_types.Tool:
counted_tool = gapic_tool_types.Tool()
for function_declaration in tool.function_declarations:
counted_function_declaration = self._function_declaration_traverse(
function_declaration
)
counted_tool.function_declarations.append(counted_function_declaration)
if counted_tool._pb != tool._pb:
raise ValueError(
f"Tool argument contains unsupported types for token counting. Supported fields {counted_tool}. Got {tool}."
)
def add_tools(self, tools: Iterable[gapic_tool_types.Tool]) -> None:
for tool in tools:
self.add_tool(tool)
def add_function_responses(
self, function_responses: Iterable[gapic_tool_types.FunctionResponse]
) -> None:
for function_response in function_responses:
self.add_function_response(function_response)
def add_function_response(
self, function_response: gapic_tool_types.FunctionResponse
) -> None:
counted_function_response = gapic_tool_types.FunctionResponse()
self._texts.append(function_response.name)
counted_struct = self._struct_traverse(function_response._pb.response)
counted_function_response.name = function_response.name
counted_function_response.response = counted_struct
if counted_function_response._pb != function_response._pb:
raise ValueError(
f"Function response argument contains unsupported types for token counting. Supported fields {counted_function_response}. Got {function_response}."
)
def _function_declaration_traverse(
self, function_declaration: gapic_tool_types.FunctionDeclaration
) -> gapic_tool_types.FunctionDeclaration:
counted_function_declaration = gapic_tool_types.FunctionDeclaration()
self._texts.append(function_declaration.name)
counted_function_declaration.name = function_declaration.name
if function_declaration.description:
self._texts.append(function_declaration.description)
counted_function_declaration.description = function_declaration.description
if function_declaration.parameters:
counted_parameters = self._schema_traverse(function_declaration.parameters)
counted_function_declaration.parameters = counted_parameters
if function_declaration.response:
counted_response = self._schema_traverse(function_declaration.response)
counted_function_declaration.response = counted_response
return counted_function_declaration
def _schema_traverse(self, schema: openapi.Schema) -> openapi.Schema:
"""Processes a schema and adds relevant text to the accumulator.
Args:
schema: The schema to process.
Returns:
The new schema object with only countable fields.
"""
counted_schema = openapi.Schema()
if "type_" in schema:
counted_schema.type = schema.type
if "title" in schema:
counted_schema.title = schema.title
if "default" in schema:
counted_schema.default = schema.default
if "format_" in schema:
self._texts.append(schema.format_)
counted_schema.format_ = schema.format_
if "description" in schema:
self._texts.append(schema.description)
counted_schema.description = schema.description
if "enum" in schema:
self._texts.extend(schema.enum)
counted_schema.enum = schema.enum
if "required" in schema:
self._texts.extend(schema.required)
counted_schema.required = schema.required
if "property_ordering" in schema:
counted_schema.property_ordering = schema.property_ordering
if "items" in schema:
counted_schema_items = self._schema_traverse(schema.items)
counted_schema.items = counted_schema_items
if "properties" in schema:
d = {}
for key, value in schema.properties.items():
self._texts.append(key)
counted_value = self._schema_traverse(value)
d[key] = counted_value
counted_schema.properties.update(d)
if "example" in schema:
counted_schema_example = self._value_traverse(schema._pb.example)
counted_schema.example = counted_schema_example
return counted_schema
def _struct_traverse(self, struct: struct_pb2.Struct) -> struct_pb2.Struct:
"""Processes a struct and adds relevant text to the accumulator.
Args:
struct: The struct to process.
Returns:
The new struct object with only countable fields.
"""
counted_struct = struct_pb2.Struct()
self._texts.extend(list(struct.fields.keys()))
for key, val in struct.fields.items():
counted_struct_fields = self._value_traverse(val)
if isinstance(counted_struct_fields, str):
counted_struct.fields[key] = counted_struct_fields
else:
counted_struct.fields[key].MergeFrom(counted_struct_fields)
return counted_struct
def _value_traverse(self, value: struct_pb2.Value) -> struct_pb2.Value:
"""Processes a struct field and adds relevant text to the accumulator.
Args:
struct: The struct field to process.
Returns:
The new struct field object with only countable fields.
"""
kind = value.WhichOneof("kind")
counted_value = struct_pb2.Value()
if kind == "string_value":
self._texts.append(value.string_value)
counted_value.string_value = value.string_value
elif kind == "struct_value":
counted_struct = self._struct_traverse(value.struct_value)
counted_value.struct_value.MergeFrom(counted_struct)
elif kind == "list_value":
counted_list_value = struct_pb2.ListValue()
for item in value.list_value.values:
counted_value = self._value_traverse(item)
counted_list_value.values.append(counted_value)
counted_value.list_value.MergeFrom(counted_list_value)
return counted_value
class Tokenizer:
"""A tokenizer that can parse text into tokens."""
def __init__(self, tokenizer_name: str):
"""Initializes the tokenizer.
Do not use this constructor directly. Use get_tokenizer_for_model instead.
Args:
name: The name of the tokenizer.
"""
self._sentencepiece_adapter = _SentencePieceAdaptor(tokenizer_name)
def count_tokens(
self,
contents: ContentsType,
*,
tools: Optional[List["Tool"]] = None,
system_instruction: Optional[PartsType] = None,
) -> CountTokensResult:
r"""Counts the number of tokens in the text-only contents.
Args:
contents: The contents to count tokens for.
Supports either a list of Content objects (passing a multi-turn
conversation) or a value that can be converted to a single
Content object (passing a single message).
Supports
* str, Part,
* List[Union[str, Part]],
* List[Content]
Throws an error if the contents contain non-text content.
tools: A list of tools (functions) that the model can try calling.
system_instruction: The provided system instructions for the model.
Note: only text should be used in parts and content in each part
will be in a separate paragraph.
Returns:
A CountTokensResult object containing the total number of tokens in
the contents.
"""
text_accumulator = _TextsAccumulator()
if _is_string_inputs(contents):
text_accumulator.add_texts(contents)
else:
gapic_contents = _to_gapic_contents(contents)
for content in gapic_contents:
text_accumulator.add_content(content)
if tools:
text_accumulator.add_tools((tool._raw_tool for tool in tools))
if system_instruction:
if _is_string_inputs(system_instruction):
text_accumulator.add_texts(system_instruction)
else:
text_accumulator.add_content(_to_content(system_instruction))
return self._sentencepiece_adapter.count_tokens(text_accumulator.get_texts())
def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult:
r"""Computes the tokens ids and string pieces in the text-only contents.
Args:
contents: The contents to count tokens for.
Supports either a list of Content objects (passing a multi-turn
conversation) or a value that can be converted to a single
Content object (passing a single message).
Supports
* str, Part,
* List[Union[str, Part]],
* List[Content]
Throws an error if the contents contain non-text content.
Returns:
A ComputeTokensResult object containing the tokens ids and string
pieces in the contents.
Examples:
compute_tokens(["hello world", "what's the weather today"])
outputs:
ComputeTokensResult(tokens_info=[TokensInfo(token_ids=[17534, 2134], tokens=[b'hello', b' world'], role='user'), TokensInfo(token_ids=[84264, 235341], tokens=[b'Goodbye', b'!'], role='user')], token_info_list=...The same as tokens_info)
"""
text_accumulator = _TextsAccumulator()
if _is_string_inputs(contents):
text_accumulator.add_texts(contents)
else:
gapic_contents = _to_gapic_contents(contents)
for content in gapic_contents:
text_accumulator.add_content(content)
return self._sentencepiece_adapter.compute_tokens(
contents=text_accumulator.get_texts(),
roles=_to_canonical_roles(contents),
)
class PreviewTokenizer(Tokenizer):
def compute_tokens(self, contents: ContentsType) -> PreviewComputeTokensResult:
return PreviewComputeTokensResult(
tokens_info=super().compute_tokens(contents).tokens_info
)
def _get_tokenizer_for_model_preview(model_name: str) -> PreviewTokenizer:
"""Returns a tokenizer for the given tokenizer name.
Usage:
```
tokenizer = get_tokenizer_for_model("gemini-1.5-pro-001")
print(tokenizer.count_tokens("Hello world!"))
```
Supported models can be found at
https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models.
Args:
model_name: Specify the tokenizer is from which model.
"""
if not model_name:
raise ValueError("model_name must not be empty.")
return PreviewTokenizer(get_tokenizer_name(model_name))
def get_tokenizer_for_model(model_name: str) -> Tokenizer:
"""Returns a tokenizer for the given tokenizer name.
Usage:
```
tokenizer = get_tokenizer_for_model("gemini-1.5-pro-001")
print(tokenizer.count_tokens("Hello world!"))
```
Supported models can be found at
https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models.
Args:
model_name: Specify the tokenizer is from which model.
"""
if not model_name:
raise ValueError("model_name must not be empty.")
return Tokenizer(get_tokenizer_name(model_name))