# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import dataclasses from typing import ( Iterable, List, Sequence, Optional, Union, ) from vertexai.generative_models._generative_models import ( ContentsType, Image, Tool, PartsType, _validate_contents_type_as_valid_sequence, _content_types_to_gapic_contents, _to_content, ) from vertexai.tokenization._tokenizer_loading import ( get_sentencepiece, get_tokenizer_name, load_model_proto, ) from google.cloud.aiplatform_v1beta1.types import ( content as gapic_content_types, tool as gapic_tool_types, openapi, ) from sentencepiece import sentencepiece_model_pb2 from google.protobuf import struct_pb2 @dataclasses.dataclass(frozen=True) class TokensInfo: token_ids: Sequence[int] tokens: Sequence[bytes] role: str = None @dataclasses.dataclass(frozen=True) class ComputeTokensResult: """Represents token string pieces and ids output in compute_tokens function. Attributes: tokens_info: Lists of tokens_info from the input. The input `contents: ContentsType` could have multiple string instances and each tokens_info item represents each string instance. Each token info consists tokens list, token_ids list and a role. """ tokens_info: Sequence[TokensInfo] class PreviewComputeTokensResult(ComputeTokensResult): def token_info_list(self) -> Sequence[TokensInfo]: import warnings message = "PreviewComputeTokensResult.token_info_list is deprecated. Use ComputeTokensResult.tokens_info instead." warnings.warn(message, DeprecationWarning, stacklevel=2) return self.tokens_info @dataclasses.dataclass(frozen=True) class CountTokensResult: """Represents an token numbers output in count_tokens function. Attributes: total_tokens: number of total tokens. """ total_tokens: int def _parse_hex_byte(token: str) -> int: """Parses a hex byte string of the form '<0xXX>' and returns the integer value. Raises ValueError if the input is malformed or the byte value is invalid. """ if len(token) != 6: raise ValueError(f"Invalid byte length: {token}") if not token.startswith("<0x") or not token.endswith(">"): raise ValueError(f"Invalid byte format: {token}") try: val = int(token[3:5], 16) # Parse the hex part directly except ValueError: raise ValueError(f"Invalid hex value: {token}") if val >= 256: raise ValueError(f"Byte value out of range: {token}") return val def _token_str_to_bytes( token: str, type: sentencepiece_model_pb2.ModelProto.SentencePiece.Type ) -> bytes: if type == sentencepiece_model_pb2.ModelProto.SentencePiece.Type.BYTE: return _parse_hex_byte(token).to_bytes(length=1, byteorder="big") else: return token.replace("▁", " ").encode("utf-8") class _SentencePieceAdaptor: r"""An internal tokenizer that can parse text input into tokens.""" def __init__(self, tokenizer_name: str): r"""Initializes the tokenizer. Args: name: The name of the tokenizer. """ self._model_proto = load_model_proto(tokenizer_name) self._tokenizer = get_sentencepiece(tokenizer_name) def count_tokens(self, contents: Iterable[str]) -> CountTokensResult: r"""Counts the number of tokens in the input.""" tokens_list = self._tokenizer.encode(list(contents)) return CountTokensResult( total_tokens=sum(len(tokens) for tokens in tokens_list) ) def compute_tokens( self, *, contents: Iterable[str], roles: Iterable[str] ) -> ComputeTokensResult: """Computes the tokens ids and string pieces in the input.""" content_list = list(contents) tokens_protos = self._tokenizer.EncodeAsImmutableProto(content_list) roles = list(roles) token_infos = [] for tokens_proto, role in zip(tokens_protos, roles): token_infos.append( TokensInfo( token_ids=[piece.id for piece in tokens_proto.pieces], tokens=[ _token_str_to_bytes( piece.piece, self._model_proto.pieces[piece.id].type ) for piece in tokens_proto.pieces ], role=role, ) ) return ComputeTokensResult(tokens_info=token_infos) def _to_gapic_contents( contents: ContentsType, ) -> List[gapic_content_types.Content]: """Converts a GenerativeModel compatible contents type to a gapic content.""" _validate_contents_type_as_valid_sequence(contents) _assert_no_image_contents_type(contents) gapic_contents = _content_types_to_gapic_contents(contents) # _assert_text_only_content_types_sequence(gapic_contents) return gapic_contents def _content_types_to_role_iterator(contents: ContentsType) -> Iterable[str]: gapic_contents = _to_gapic_contents(contents) # Flattening role by content's multi parts for content in gapic_contents: for part in content.parts: yield content.role def _assert_no_image_contents_type(contents: ContentsType): """Asserts that the contents type does not contain any image content.""" if isinstance(contents, Image) or ( isinstance(contents, Sequence) and any(isinstance(content, Image) for content in contents) ): raise ValueError("Tokenizers do not support Image content type.") def _is_string_inputs(contents: ContentsType) -> bool: return ( isinstance(contents, str) or isinstance(contents, Sequence) and all(isinstance(content, str) for content in contents) ) def _to_canonical_roles(contents: ContentsType) -> Iterable[str]: if isinstance(contents, str): yield "user" elif isinstance(contents, Sequence) and all( isinstance(content, str) for content in contents ): yield from ["user"] * len(contents) else: yield from _content_types_to_role_iterator(contents) class _TextsAccumulator: """Accumulates texts from contents and tools. This class is used to accumulate countable texts from contents and tools. When user passes a unsupported fields that are added in the future, the new fields might be only counted in remote tokenizer. In this case, the local tokenizer should know that an unsupported new field exist in the content or tool instances and raise error to avoid returning incorrect result to users. The machanism to detect unsupported fields introduced in the future: when local tokenizer traversing the input instances, it is allowlist based text accumulation. When a field is traversed and evaluated to be countable, the value of this field is copied to two places: (1) self._texts for inputs to sentencepiece token count function. (2) a counted instance object in the recursive function's return value. That's to say, after done current recurssion, the instance(of same type as the input) only keeps the counted values. If user sets unsupported future proto fields, they can be detected by comparing the input instances equal to counted instances or not. """ def __init__(self): self._texts = [] def get_texts(self) -> Iterable[str]: return self._texts def add_texts(self, texts: Union[Iterable[str], str]) -> None: if isinstance(texts, str): self._texts.append(texts) else: self._texts.extend(texts) def add_content(self, content: gapic_content_types.Content) -> None: counted_content = gapic_content_types.Content() for part in content.parts: counted_part = gapic_content_types.Part() if "file_data" in part or "inline_data" in part: raise ValueError("Tokenizers do not support non-text content types.") if "video_metadata" in part: counted_part.video_metadata = part.video_metadata if "function_call" in part: self.add_function_call(part.function_call) counted_part.function_call = part.function_call if "function_response" in part: self.add_function_response(part.function_response) counted_part.function_response = part.function_response if "text" in part: counted_part.text = part.text self._texts.append(part.text) counted_content.parts.append(counted_part) counted_content.role = content.role if content._pb != counted_content._pb: raise ValueError( f"Content contains unsupported types for token counting. Supported fields {counted_content}. Got {content}." ) def add_function_call(self, function_call: gapic_tool_types.FunctionCall) -> None: """Processes a function call and adds relevant text to the accumulator. Args: function_call: The function call to process. """ self._texts.append(function_call.name) counted_function_call = gapic_tool_types.FunctionCall(name=function_call.name) counted_struct = self._struct_traverse(function_call._pb.args) counted_function_call.args = counted_struct if counted_function_call._pb != function_call._pb: raise ValueError( f"Function call argument contains unsupported types for token counting. Supported fields {counted_function_call}. Got {function_call}." ) def add_function_calls( self, function_calls: Iterable[gapic_tool_types.FunctionCall] ) -> None: for function_call in function_calls: self.add_function_call(function_call) def add_tool(self, tool: gapic_tool_types.Tool) -> gapic_tool_types.Tool: counted_tool = gapic_tool_types.Tool() for function_declaration in tool.function_declarations: counted_function_declaration = self._function_declaration_traverse( function_declaration ) counted_tool.function_declarations.append(counted_function_declaration) if counted_tool._pb != tool._pb: raise ValueError( f"Tool argument contains unsupported types for token counting. Supported fields {counted_tool}. Got {tool}." ) def add_tools(self, tools: Iterable[gapic_tool_types.Tool]) -> None: for tool in tools: self.add_tool(tool) def add_function_responses( self, function_responses: Iterable[gapic_tool_types.FunctionResponse] ) -> None: for function_response in function_responses: self.add_function_response(function_response) def add_function_response( self, function_response: gapic_tool_types.FunctionResponse ) -> None: counted_function_response = gapic_tool_types.FunctionResponse() self._texts.append(function_response.name) counted_struct = self._struct_traverse(function_response._pb.response) counted_function_response.name = function_response.name counted_function_response.response = counted_struct if counted_function_response._pb != function_response._pb: raise ValueError( f"Function response argument contains unsupported types for token counting. Supported fields {counted_function_response}. Got {function_response}." ) def _function_declaration_traverse( self, function_declaration: gapic_tool_types.FunctionDeclaration ) -> gapic_tool_types.FunctionDeclaration: counted_function_declaration = gapic_tool_types.FunctionDeclaration() self._texts.append(function_declaration.name) counted_function_declaration.name = function_declaration.name if function_declaration.description: self._texts.append(function_declaration.description) counted_function_declaration.description = function_declaration.description if function_declaration.parameters: counted_parameters = self._schema_traverse(function_declaration.parameters) counted_function_declaration.parameters = counted_parameters if function_declaration.response: counted_response = self._schema_traverse(function_declaration.response) counted_function_declaration.response = counted_response return counted_function_declaration def _schema_traverse(self, schema: openapi.Schema) -> openapi.Schema: """Processes a schema and adds relevant text to the accumulator. Args: schema: The schema to process. Returns: The new schema object with only countable fields. """ counted_schema = openapi.Schema() if "type_" in schema: counted_schema.type = schema.type if "title" in schema: counted_schema.title = schema.title if "default" in schema: counted_schema.default = schema.default if "format_" in schema: self._texts.append(schema.format_) counted_schema.format_ = schema.format_ if "description" in schema: self._texts.append(schema.description) counted_schema.description = schema.description if "enum" in schema: self._texts.extend(schema.enum) counted_schema.enum = schema.enum if "required" in schema: self._texts.extend(schema.required) counted_schema.required = schema.required if "property_ordering" in schema: counted_schema.property_ordering = schema.property_ordering if "items" in schema: counted_schema_items = self._schema_traverse(schema.items) counted_schema.items = counted_schema_items if "properties" in schema: d = {} for key, value in schema.properties.items(): self._texts.append(key) counted_value = self._schema_traverse(value) d[key] = counted_value counted_schema.properties.update(d) if "example" in schema: counted_schema_example = self._value_traverse(schema._pb.example) counted_schema.example = counted_schema_example return counted_schema def _struct_traverse(self, struct: struct_pb2.Struct) -> struct_pb2.Struct: """Processes a struct and adds relevant text to the accumulator. Args: struct: The struct to process. Returns: The new struct object with only countable fields. """ counted_struct = struct_pb2.Struct() self._texts.extend(list(struct.fields.keys())) for key, val in struct.fields.items(): counted_struct_fields = self._value_traverse(val) if isinstance(counted_struct_fields, str): counted_struct.fields[key] = counted_struct_fields else: counted_struct.fields[key].MergeFrom(counted_struct_fields) return counted_struct def _value_traverse(self, value: struct_pb2.Value) -> struct_pb2.Value: """Processes a struct field and adds relevant text to the accumulator. Args: struct: The struct field to process. Returns: The new struct field object with only countable fields. """ kind = value.WhichOneof("kind") counted_value = struct_pb2.Value() if kind == "string_value": self._texts.append(value.string_value) counted_value.string_value = value.string_value elif kind == "struct_value": counted_struct = self._struct_traverse(value.struct_value) counted_value.struct_value.MergeFrom(counted_struct) elif kind == "list_value": counted_list_value = struct_pb2.ListValue() for item in value.list_value.values: counted_value = self._value_traverse(item) counted_list_value.values.append(counted_value) counted_value.list_value.MergeFrom(counted_list_value) return counted_value class Tokenizer: """A tokenizer that can parse text into tokens.""" def __init__(self, tokenizer_name: str): """Initializes the tokenizer. Do not use this constructor directly. Use get_tokenizer_for_model instead. Args: name: The name of the tokenizer. """ self._sentencepiece_adapter = _SentencePieceAdaptor(tokenizer_name) def count_tokens( self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None, system_instruction: Optional[PartsType] = None, ) -> CountTokensResult: r"""Counts the number of tokens in the text-only contents. Args: contents: The contents to count tokens for. Supports either a list of Content objects (passing a multi-turn conversation) or a value that can be converted to a single Content object (passing a single message). Supports * str, Part, * List[Union[str, Part]], * List[Content] Throws an error if the contents contain non-text content. tools: A list of tools (functions) that the model can try calling. system_instruction: The provided system instructions for the model. Note: only text should be used in parts and content in each part will be in a separate paragraph. Returns: A CountTokensResult object containing the total number of tokens in the contents. """ text_accumulator = _TextsAccumulator() if _is_string_inputs(contents): text_accumulator.add_texts(contents) else: gapic_contents = _to_gapic_contents(contents) for content in gapic_contents: text_accumulator.add_content(content) if tools: text_accumulator.add_tools((tool._raw_tool for tool in tools)) if system_instruction: if _is_string_inputs(system_instruction): text_accumulator.add_texts(system_instruction) else: text_accumulator.add_content(_to_content(system_instruction)) return self._sentencepiece_adapter.count_tokens(text_accumulator.get_texts()) def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult: r"""Computes the tokens ids and string pieces in the text-only contents. Args: contents: The contents to count tokens for. Supports either a list of Content objects (passing a multi-turn conversation) or a value that can be converted to a single Content object (passing a single message). Supports * str, Part, * List[Union[str, Part]], * List[Content] Throws an error if the contents contain non-text content. Returns: A ComputeTokensResult object containing the tokens ids and string pieces in the contents. Examples: compute_tokens(["hello world", "what's the weather today"]) outputs: ComputeTokensResult(tokens_info=[TokensInfo(token_ids=[17534, 2134], tokens=[b'hello', b' world'], role='user'), TokensInfo(token_ids=[84264, 235341], tokens=[b'Goodbye', b'!'], role='user')], token_info_list=...The same as tokens_info) """ text_accumulator = _TextsAccumulator() if _is_string_inputs(contents): text_accumulator.add_texts(contents) else: gapic_contents = _to_gapic_contents(contents) for content in gapic_contents: text_accumulator.add_content(content) return self._sentencepiece_adapter.compute_tokens( contents=text_accumulator.get_texts(), roles=_to_canonical_roles(contents), ) class PreviewTokenizer(Tokenizer): def compute_tokens(self, contents: ContentsType) -> PreviewComputeTokensResult: return PreviewComputeTokensResult( tokens_info=super().compute_tokens(contents).tokens_info ) def _get_tokenizer_for_model_preview(model_name: str) -> PreviewTokenizer: """Returns a tokenizer for the given tokenizer name. Usage: ``` tokenizer = get_tokenizer_for_model("gemini-1.5-pro-001") print(tokenizer.count_tokens("Hello world!")) ``` Supported models can be found at https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. Args: model_name: Specify the tokenizer is from which model. """ if not model_name: raise ValueError("model_name must not be empty.") return PreviewTokenizer(get_tokenizer_name(model_name)) def get_tokenizer_for_model(model_name: str) -> Tokenizer: """Returns a tokenizer for the given tokenizer name. Usage: ``` tokenizer = get_tokenizer_for_model("gemini-1.5-pro-001") print(tokenizer.count_tokens("Hello world!")) ``` Supported models can be found at https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. Args: model_name: Specify the tokenizer is from which model. """ if not model_name: raise ValueError("model_name must not be empty.") return Tokenizer(get_tokenizer_name(model_name))