structure saas with tools
This commit is contained in:
7
.venv/lib/python3.10/site-packages/google/protobuf/internal/__init__.py
Executable file
7
.venv/lib/python3.10/site-packages/google/protobuf/internal/__init__.py
Executable file
@@ -0,0 +1,7 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,142 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""Determine which implementation of the protobuf API is used in this process.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
_GOOGLE3_PYTHON_UPB_DEFAULT = True
|
||||
|
||||
|
||||
def _ApiVersionToImplementationType(api_version):
|
||||
if api_version == 2:
|
||||
return 'cpp'
|
||||
if api_version == 1:
|
||||
raise ValueError('api_version=1 is no longer supported.')
|
||||
if api_version == 0:
|
||||
return 'python'
|
||||
return None
|
||||
|
||||
|
||||
_implementation_type = None
|
||||
try:
|
||||
# pylint: disable=g-import-not-at-top
|
||||
from google.protobuf.internal import _api_implementation
|
||||
# The compile-time constants in the _api_implementation module can be used to
|
||||
# switch to a certain implementation of the Python API at build time.
|
||||
_implementation_type = _ApiVersionToImplementationType(
|
||||
_api_implementation.api_version)
|
||||
except ImportError:
|
||||
pass # Unspecified by compiler flags.
|
||||
|
||||
|
||||
def _CanImport(mod_name):
|
||||
try:
|
||||
mod = importlib.import_module(mod_name)
|
||||
# Work around a known issue in the classic bootstrap .par import hook.
|
||||
if not mod:
|
||||
raise ImportError(mod_name + ' import succeeded but was None')
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
if _implementation_type is None:
|
||||
if _CanImport('google._upb._message'):
|
||||
_implementation_type = 'upb'
|
||||
elif _CanImport('google.protobuf.pyext._message'):
|
||||
_implementation_type = 'cpp'
|
||||
else:
|
||||
_implementation_type = 'python'
|
||||
|
||||
|
||||
# This environment variable can be used to switch to a certain implementation
|
||||
# of the Python API, overriding the compile-time constants in the
|
||||
# _api_implementation module. Right now only 'python', 'cpp' and 'upb' are
|
||||
# valid values. Any other value will raise error.
|
||||
_implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION',
|
||||
_implementation_type)
|
||||
|
||||
if _implementation_type not in ('python', 'cpp', 'upb'):
|
||||
raise ValueError('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION {0} is not '
|
||||
'supported. Please set to \'python\', \'cpp\' or '
|
||||
'\'upb\'.'.format(_implementation_type))
|
||||
|
||||
if 'PyPy' in sys.version and _implementation_type == 'cpp':
|
||||
warnings.warn('PyPy does not work yet with cpp protocol buffers. '
|
||||
'Falling back to the python implementation.')
|
||||
_implementation_type = 'python'
|
||||
|
||||
_c_module = None
|
||||
|
||||
if _implementation_type == 'cpp':
|
||||
try:
|
||||
# pylint: disable=g-import-not-at-top
|
||||
from google.protobuf.pyext import _message
|
||||
sys.modules['google3.net.proto2.python.internal.cpp._message'] = _message
|
||||
_c_module = _message
|
||||
del _message
|
||||
except ImportError:
|
||||
# TODO: fail back to python
|
||||
warnings.warn(
|
||||
'Selected implementation cpp is not available.')
|
||||
pass
|
||||
|
||||
if _implementation_type == 'upb':
|
||||
try:
|
||||
# pylint: disable=g-import-not-at-top
|
||||
from google._upb import _message
|
||||
_c_module = _message
|
||||
del _message
|
||||
except ImportError:
|
||||
warnings.warn('Selected implementation upb is not available. '
|
||||
'Falling back to the python implementation.')
|
||||
_implementation_type = 'python'
|
||||
pass
|
||||
|
||||
# Detect if serialization should be deterministic by default
|
||||
try:
|
||||
# The presence of this module in a build allows the proto implementation to
|
||||
# be upgraded merely via build deps.
|
||||
#
|
||||
# NOTE: Merely importing this automatically enables deterministic proto
|
||||
# serialization for C++ code, but we still need to export it as a boolean so
|
||||
# that we can do the same for `_implementation_type == 'python'`.
|
||||
#
|
||||
# NOTE2: It is possible for C++ code to enable deterministic serialization by
|
||||
# default _without_ affecting Python code, if the C++ implementation is not in
|
||||
# use by this module. That is intended behavior, so we don't actually expose
|
||||
# this boolean outside of this module.
|
||||
#
|
||||
# pylint: disable=g-import-not-at-top,unused-import
|
||||
from google.protobuf import enable_deterministic_proto_serialization
|
||||
_python_deterministic_proto_serialization = True
|
||||
except ImportError:
|
||||
_python_deterministic_proto_serialization = False
|
||||
|
||||
|
||||
# Usage of this function is discouraged. Clients shouldn't care which
|
||||
# implementation of the API is in use. Note that there is no guarantee
|
||||
# that differences between APIs will be maintained.
|
||||
# Please don't use this function if possible.
|
||||
def Type():
|
||||
return _implementation_type
|
||||
|
||||
|
||||
# See comment on 'Type' above.
|
||||
# TODO: Remove the API, it returns a constant. b/228102101
|
||||
def Version():
|
||||
return 2
|
||||
|
||||
|
||||
# For internal use only
|
||||
def IsPythonDefaultSerializationDeterministic():
|
||||
return _python_deterministic_proto_serialization
|
||||
118
.venv/lib/python3.10/site-packages/google/protobuf/internal/builder.py
Executable file
118
.venv/lib/python3.10/site-packages/google/protobuf/internal/builder.py
Executable file
@@ -0,0 +1,118 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""Builds descriptors, message classes and services for generated _pb2.py.
|
||||
|
||||
This file is only called in python generated _pb2.py files. It builds
|
||||
descriptors, message classes and services that users can directly use
|
||||
in generated code.
|
||||
"""
|
||||
|
||||
__author__ = 'jieluo@google.com (Jie Luo)'
|
||||
|
||||
from google.protobuf.internal import enum_type_wrapper
|
||||
from google.protobuf.internal import python_message
|
||||
from google.protobuf import message as _message
|
||||
from google.protobuf import reflection as _reflection
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
def BuildMessageAndEnumDescriptors(file_des, module):
|
||||
"""Builds message and enum descriptors.
|
||||
|
||||
Args:
|
||||
file_des: FileDescriptor of the .proto file
|
||||
module: Generated _pb2 module
|
||||
"""
|
||||
|
||||
def BuildNestedDescriptors(msg_des, prefix):
|
||||
for (name, nested_msg) in msg_des.nested_types_by_name.items():
|
||||
module_name = prefix + name.upper()
|
||||
module[module_name] = nested_msg
|
||||
BuildNestedDescriptors(nested_msg, module_name + '_')
|
||||
for enum_des in msg_des.enum_types:
|
||||
module[prefix + enum_des.name.upper()] = enum_des
|
||||
|
||||
for (name, msg_des) in file_des.message_types_by_name.items():
|
||||
module_name = '_' + name.upper()
|
||||
module[module_name] = msg_des
|
||||
BuildNestedDescriptors(msg_des, module_name + '_')
|
||||
|
||||
|
||||
def BuildTopDescriptorsAndMessages(file_des, module_name, module):
|
||||
"""Builds top level descriptors and message classes.
|
||||
|
||||
Args:
|
||||
file_des: FileDescriptor of the .proto file
|
||||
module_name: str, the name of generated _pb2 module
|
||||
module: Generated _pb2 module
|
||||
"""
|
||||
|
||||
def BuildMessage(msg_des, prefix):
|
||||
create_dict = {}
|
||||
for (name, nested_msg) in msg_des.nested_types_by_name.items():
|
||||
create_dict[name] = BuildMessage(nested_msg, prefix + msg_des.name + '.')
|
||||
create_dict['DESCRIPTOR'] = msg_des
|
||||
create_dict['__module__'] = module_name
|
||||
create_dict['__qualname__'] = prefix + msg_des.name
|
||||
message_class = _reflection.GeneratedProtocolMessageType(
|
||||
msg_des.name, (_message.Message,), create_dict)
|
||||
_sym_db.RegisterMessage(message_class)
|
||||
return message_class
|
||||
|
||||
# top level enums
|
||||
for (name, enum_des) in file_des.enum_types_by_name.items():
|
||||
module['_' + name.upper()] = enum_des
|
||||
module[name] = enum_type_wrapper.EnumTypeWrapper(enum_des)
|
||||
for enum_value in enum_des.values:
|
||||
module[enum_value.name] = enum_value.number
|
||||
|
||||
# top level extensions
|
||||
for (name, extension_des) in file_des.extensions_by_name.items():
|
||||
module[name.upper() + '_FIELD_NUMBER'] = extension_des.number
|
||||
module[name] = extension_des
|
||||
|
||||
# services
|
||||
for (name, service) in file_des.services_by_name.items():
|
||||
module['_' + name.upper()] = service
|
||||
|
||||
# Build messages.
|
||||
for (name, msg_des) in file_des.message_types_by_name.items():
|
||||
module[name] = BuildMessage(msg_des, '')
|
||||
|
||||
|
||||
def AddHelpersToExtensions(file_des):
|
||||
"""no-op to keep old generated code work with new runtime.
|
||||
|
||||
Args:
|
||||
file_des: FileDescriptor of the .proto file
|
||||
"""
|
||||
# TODO: Remove this on-op
|
||||
return
|
||||
|
||||
|
||||
def BuildServices(file_des, module_name, module):
|
||||
"""Builds services classes and services stub class.
|
||||
|
||||
Args:
|
||||
file_des: FileDescriptor of the .proto file
|
||||
module_name: str, the name of generated _pb2 module
|
||||
module: Generated _pb2 module
|
||||
"""
|
||||
# pylint: disable=g-import-not-at-top
|
||||
from google.protobuf import service_reflection
|
||||
# pylint: enable=g-import-not-at-top
|
||||
for (name, service) in file_des.services_by_name.items():
|
||||
module[name] = service_reflection.GeneratedServiceType(
|
||||
name, (),
|
||||
dict(DESCRIPTOR=service, __module__=module_name))
|
||||
stub_name = name + '_Stub'
|
||||
module[stub_name] = service_reflection.GeneratedServiceStubType(
|
||||
stub_name, (module[name],),
|
||||
dict(DESCRIPTOR=service, __module__=module_name))
|
||||
690
.venv/lib/python3.10/site-packages/google/protobuf/internal/containers.py
Executable file
690
.venv/lib/python3.10/site-packages/google/protobuf/internal/containers.py
Executable file
@@ -0,0 +1,690 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""Contains container classes to represent different protocol buffer types.
|
||||
|
||||
This file defines container classes which represent categories of protocol
|
||||
buffer field types which need extra maintenance. Currently these categories
|
||||
are:
|
||||
|
||||
- Repeated scalar fields - These are all repeated fields which aren't
|
||||
composite (e.g. they are of simple types like int32, string, etc).
|
||||
- Repeated composite fields - Repeated fields which are composite. This
|
||||
includes groups and nested messages.
|
||||
"""
|
||||
|
||||
import collections.abc
|
||||
import copy
|
||||
import pickle
|
||||
from typing import (
|
||||
Any,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
MutableMapping,
|
||||
MutableSequence,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
_K = TypeVar('_K')
|
||||
_V = TypeVar('_V')
|
||||
|
||||
|
||||
class BaseContainer(Sequence[_T]):
|
||||
"""Base container class."""
|
||||
|
||||
# Minimizes memory usage and disallows assignment to other attributes.
|
||||
__slots__ = ['_message_listener', '_values']
|
||||
|
||||
def __init__(self, message_listener: Any) -> None:
|
||||
"""
|
||||
Args:
|
||||
message_listener: A MessageListener implementation.
|
||||
The RepeatedScalarFieldContainer will call this object's
|
||||
Modified() method when it is modified.
|
||||
"""
|
||||
self._message_listener = message_listener
|
||||
self._values = []
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> _T:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: slice) -> List[_T]:
|
||||
...
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Retrieves item by the specified key."""
|
||||
return self._values[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of elements in the container."""
|
||||
return len(self._values)
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
"""Checks if another instance isn't equal to this one."""
|
||||
# The concrete classes should define __eq__.
|
||||
return not self == other
|
||||
|
||||
__hash__ = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self._values)
|
||||
|
||||
def sort(self, *args, **kwargs) -> None:
|
||||
# Continue to support the old sort_function keyword argument.
|
||||
# This is expected to be a rare occurrence, so use LBYL to avoid
|
||||
# the overhead of actually catching KeyError.
|
||||
if 'sort_function' in kwargs:
|
||||
kwargs['cmp'] = kwargs.pop('sort_function')
|
||||
self._values.sort(*args, **kwargs)
|
||||
|
||||
def reverse(self) -> None:
|
||||
self._values.reverse()
|
||||
|
||||
|
||||
# TODO: Remove this. BaseContainer does *not* conform to
|
||||
# MutableSequence, only its subclasses do.
|
||||
collections.abc.MutableSequence.register(BaseContainer)
|
||||
|
||||
|
||||
class RepeatedScalarFieldContainer(BaseContainer[_T], MutableSequence[_T]):
|
||||
"""Simple, type-checked, list-like container for holding repeated scalars."""
|
||||
|
||||
# Disallows assignment to other attributes.
|
||||
__slots__ = ['_type_checker']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_listener: Any,
|
||||
type_checker: Any,
|
||||
) -> None:
|
||||
"""Args:
|
||||
|
||||
message_listener: A MessageListener implementation. The
|
||||
RepeatedScalarFieldContainer will call this object's Modified() method
|
||||
when it is modified.
|
||||
type_checker: A type_checkers.ValueChecker instance to run on elements
|
||||
inserted into this container.
|
||||
"""
|
||||
super().__init__(message_listener)
|
||||
self._type_checker = type_checker
|
||||
|
||||
def append(self, value: _T) -> None:
|
||||
"""Appends an item to the list. Similar to list.append()."""
|
||||
self._values.append(self._type_checker.CheckValue(value))
|
||||
if not self._message_listener.dirty:
|
||||
self._message_listener.Modified()
|
||||
|
||||
def insert(self, key: int, value: _T) -> None:
|
||||
"""Inserts the item at the specified position. Similar to list.insert()."""
|
||||
self._values.insert(key, self._type_checker.CheckValue(value))
|
||||
if not self._message_listener.dirty:
|
||||
self._message_listener.Modified()
|
||||
|
||||
def extend(self, elem_seq: Iterable[_T]) -> None:
|
||||
"""Extends by appending the given iterable. Similar to list.extend()."""
|
||||
elem_seq_iter = iter(elem_seq)
|
||||
new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
|
||||
if new_values:
|
||||
self._values.extend(new_values)
|
||||
self._message_listener.Modified()
|
||||
|
||||
def MergeFrom(
|
||||
self,
|
||||
other: Union['RepeatedScalarFieldContainer[_T]', Iterable[_T]],
|
||||
) -> None:
|
||||
"""Appends the contents of another repeated field of the same type to this
|
||||
one. We do not check the types of the individual fields.
|
||||
"""
|
||||
self._values.extend(other)
|
||||
self._message_listener.Modified()
|
||||
|
||||
def remove(self, elem: _T):
|
||||
"""Removes an item from the list. Similar to list.remove()."""
|
||||
self._values.remove(elem)
|
||||
self._message_listener.Modified()
|
||||
|
||||
def pop(self, key: Optional[int] = -1) -> _T:
|
||||
"""Removes and returns an item at a given index. Similar to list.pop()."""
|
||||
value = self._values[key]
|
||||
self.__delitem__(key)
|
||||
return value
|
||||
|
||||
@overload
|
||||
def __setitem__(self, key: int, value: _T) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
|
||||
...
|
||||
|
||||
def __setitem__(self, key, value) -> None:
|
||||
"""Sets the item on the specified position."""
|
||||
if isinstance(key, slice):
|
||||
if key.step is not None:
|
||||
raise ValueError('Extended slices not supported')
|
||||
self._values[key] = map(self._type_checker.CheckValue, value)
|
||||
self._message_listener.Modified()
|
||||
else:
|
||||
self._values[key] = self._type_checker.CheckValue(value)
|
||||
self._message_listener.Modified()
|
||||
|
||||
def __delitem__(self, key: Union[int, slice]) -> None:
|
||||
"""Deletes the item at the specified position."""
|
||||
del self._values[key]
|
||||
self._message_listener.Modified()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Compares the current instance with another one."""
|
||||
if self is other:
|
||||
return True
|
||||
# Special case for the same type which should be common and fast.
|
||||
if isinstance(other, self.__class__):
|
||||
return other._values == self._values
|
||||
# We are presumably comparing against some other sequence type.
|
||||
return other == self._values
|
||||
|
||||
def __deepcopy__(
|
||||
self,
|
||||
unused_memo: Any = None,
|
||||
) -> 'RepeatedScalarFieldContainer[_T]':
|
||||
clone = RepeatedScalarFieldContainer(
|
||||
copy.deepcopy(self._message_listener), self._type_checker)
|
||||
clone.MergeFrom(self)
|
||||
return clone
|
||||
|
||||
def __reduce__(self, **kwargs) -> NoReturn:
|
||||
raise pickle.PickleError(
|
||||
"Can't pickle repeated scalar fields, convert to list first")
|
||||
|
||||
|
||||
# TODO: Constrain T to be a subtype of Message.
|
||||
class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]):
|
||||
"""Simple, list-like container for holding repeated composite fields."""
|
||||
|
||||
# Disallows assignment to other attributes.
|
||||
__slots__ = ['_message_descriptor']
|
||||
|
||||
def __init__(self, message_listener: Any, message_descriptor: Any) -> None:
|
||||
"""
|
||||
Note that we pass in a descriptor instead of the generated directly,
|
||||
since at the time we construct a _RepeatedCompositeFieldContainer we
|
||||
haven't yet necessarily initialized the type that will be contained in the
|
||||
container.
|
||||
|
||||
Args:
|
||||
message_listener: A MessageListener implementation.
|
||||
The RepeatedCompositeFieldContainer will call this object's
|
||||
Modified() method when it is modified.
|
||||
message_descriptor: A Descriptor instance describing the protocol type
|
||||
that should be present in this container. We'll use the
|
||||
_concrete_class field of this descriptor when the client calls add().
|
||||
"""
|
||||
super().__init__(message_listener)
|
||||
self._message_descriptor = message_descriptor
|
||||
|
||||
def add(self, **kwargs: Any) -> _T:
|
||||
"""Adds a new element at the end of the list and returns it. Keyword
|
||||
arguments may be used to initialize the element.
|
||||
"""
|
||||
new_element = self._message_descriptor._concrete_class(**kwargs)
|
||||
new_element._SetListener(self._message_listener)
|
||||
self._values.append(new_element)
|
||||
if not self._message_listener.dirty:
|
||||
self._message_listener.Modified()
|
||||
return new_element
|
||||
|
||||
def append(self, value: _T) -> None:
|
||||
"""Appends one element by copying the message."""
|
||||
new_element = self._message_descriptor._concrete_class()
|
||||
new_element._SetListener(self._message_listener)
|
||||
new_element.CopyFrom(value)
|
||||
self._values.append(new_element)
|
||||
if not self._message_listener.dirty:
|
||||
self._message_listener.Modified()
|
||||
|
||||
def insert(self, key: int, value: _T) -> None:
|
||||
"""Inserts the item at the specified position by copying."""
|
||||
new_element = self._message_descriptor._concrete_class()
|
||||
new_element._SetListener(self._message_listener)
|
||||
new_element.CopyFrom(value)
|
||||
self._values.insert(key, new_element)
|
||||
if not self._message_listener.dirty:
|
||||
self._message_listener.Modified()
|
||||
|
||||
def extend(self, elem_seq: Iterable[_T]) -> None:
|
||||
"""Extends by appending the given sequence of elements of the same type
|
||||
|
||||
as this one, copying each individual message.
|
||||
"""
|
||||
message_class = self._message_descriptor._concrete_class
|
||||
listener = self._message_listener
|
||||
values = self._values
|
||||
for message in elem_seq:
|
||||
new_element = message_class()
|
||||
new_element._SetListener(listener)
|
||||
new_element.MergeFrom(message)
|
||||
values.append(new_element)
|
||||
listener.Modified()
|
||||
|
||||
def MergeFrom(
|
||||
self,
|
||||
other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]],
|
||||
) -> None:
|
||||
"""Appends the contents of another repeated field of the same type to this
|
||||
one, copying each individual message.
|
||||
"""
|
||||
self.extend(other)
|
||||
|
||||
def remove(self, elem: _T) -> None:
|
||||
"""Removes an item from the list. Similar to list.remove()."""
|
||||
self._values.remove(elem)
|
||||
self._message_listener.Modified()
|
||||
|
||||
def pop(self, key: Optional[int] = -1) -> _T:
|
||||
"""Removes and returns an item at a given index. Similar to list.pop()."""
|
||||
value = self._values[key]
|
||||
self.__delitem__(key)
|
||||
return value
|
||||
|
||||
@overload
|
||||
def __setitem__(self, key: int, value: _T) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
|
||||
...
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# This method is implemented to make RepeatedCompositeFieldContainer
|
||||
# structurally compatible with typing.MutableSequence. It is
|
||||
# otherwise unsupported and will always raise an error.
|
||||
raise TypeError(
|
||||
f'{self.__class__.__name__} object does not support item assignment')
|
||||
|
||||
def __delitem__(self, key: Union[int, slice]) -> None:
|
||||
"""Deletes the item at the specified position."""
|
||||
del self._values[key]
|
||||
self._message_listener.Modified()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Compares the current instance with another one."""
|
||||
if self is other:
|
||||
return True
|
||||
if not isinstance(other, self.__class__):
|
||||
raise TypeError('Can only compare repeated composite fields against '
|
||||
'other repeated composite fields.')
|
||||
return self._values == other._values
|
||||
|
||||
|
||||
class ScalarMap(MutableMapping[_K, _V]):
|
||||
"""Simple, type-checked, dict-like container for holding repeated scalars."""
|
||||
|
||||
# Disallows assignment to other attributes.
|
||||
__slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener',
|
||||
'_entry_descriptor']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_listener: Any,
|
||||
key_checker: Any,
|
||||
value_checker: Any,
|
||||
entry_descriptor: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
message_listener: A MessageListener implementation.
|
||||
The ScalarMap will call this object's Modified() method when it
|
||||
is modified.
|
||||
key_checker: A type_checkers.ValueChecker instance to run on keys
|
||||
inserted into this container.
|
||||
value_checker: A type_checkers.ValueChecker instance to run on values
|
||||
inserted into this container.
|
||||
entry_descriptor: The MessageDescriptor of a map entry: key and value.
|
||||
"""
|
||||
self._message_listener = message_listener
|
||||
self._key_checker = key_checker
|
||||
self._value_checker = value_checker
|
||||
self._entry_descriptor = entry_descriptor
|
||||
self._values = {}
|
||||
|
||||
def __getitem__(self, key: _K) -> _V:
|
||||
try:
|
||||
return self._values[key]
|
||||
except KeyError:
|
||||
key = self._key_checker.CheckValue(key)
|
||||
val = self._value_checker.DefaultValue()
|
||||
self._values[key] = val
|
||||
return val
|
||||
|
||||
def __contains__(self, item: _K) -> bool:
|
||||
# We check the key's type to match the strong-typing flavor of the API.
|
||||
# Also this makes it easier to match the behavior of the C++ implementation.
|
||||
self._key_checker.CheckValue(item)
|
||||
return item in self._values
|
||||
|
||||
@overload
|
||||
def get(self, key: _K) -> Optional[_V]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self, key: _K, default: _T) -> Union[_V, _T]:
|
||||
...
|
||||
|
||||
# We need to override this explicitly, because our defaultdict-like behavior
|
||||
# will make the default implementation (from our base class) always insert
|
||||
# the key.
|
||||
def get(self, key, default=None):
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
return default
|
||||
|
||||
def __setitem__(self, key: _K, value: _V) -> _T:
|
||||
checked_key = self._key_checker.CheckValue(key)
|
||||
checked_value = self._value_checker.CheckValue(value)
|
||||
self._values[checked_key] = checked_value
|
||||
self._message_listener.Modified()
|
||||
|
||||
def __delitem__(self, key: _K) -> None:
|
||||
del self._values[key]
|
||||
self._message_listener.Modified()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._values)
|
||||
|
||||
def __iter__(self) -> Iterator[_K]:
|
||||
return iter(self._values)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self._values)
|
||||
|
||||
def setdefault(self, key: _K, value: Optional[_V] = None) -> _V:
|
||||
if value == None:
|
||||
raise ValueError('The value for scalar map setdefault must be set.')
|
||||
if key not in self._values:
|
||||
self.__setitem__(key, value)
|
||||
return self[key]
|
||||
|
||||
def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None:
|
||||
self._values.update(other._values)
|
||||
self._message_listener.Modified()
|
||||
|
||||
def InvalidateIterators(self) -> None:
|
||||
# It appears that the only way to reliably invalidate iterators to
|
||||
# self._values is to ensure that its size changes.
|
||||
original = self._values
|
||||
self._values = original.copy()
|
||||
original[None] = None
|
||||
|
||||
# This is defined in the abstract base, but we can do it much more cheaply.
|
||||
def clear(self) -> None:
|
||||
self._values.clear()
|
||||
self._message_listener.Modified()
|
||||
|
||||
def GetEntryClass(self) -> Any:
|
||||
return self._entry_descriptor._concrete_class
|
||||
|
||||
|
||||
class MessageMap(MutableMapping[_K, _V]):
|
||||
"""Simple, type-checked, dict-like container for with submessage values."""
|
||||
|
||||
# Disallows assignment to other attributes.
|
||||
__slots__ = ['_key_checker', '_values', '_message_listener',
|
||||
'_message_descriptor', '_entry_descriptor']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_listener: Any,
|
||||
message_descriptor: Any,
|
||||
key_checker: Any,
|
||||
entry_descriptor: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
message_listener: A MessageListener implementation.
|
||||
The ScalarMap will call this object's Modified() method when it
|
||||
is modified.
|
||||
key_checker: A type_checkers.ValueChecker instance to run on keys
|
||||
inserted into this container.
|
||||
value_checker: A type_checkers.ValueChecker instance to run on values
|
||||
inserted into this container.
|
||||
entry_descriptor: The MessageDescriptor of a map entry: key and value.
|
||||
"""
|
||||
self._message_listener = message_listener
|
||||
self._message_descriptor = message_descriptor
|
||||
self._key_checker = key_checker
|
||||
self._entry_descriptor = entry_descriptor
|
||||
self._values = {}
|
||||
|
||||
def __getitem__(self, key: _K) -> _V:
|
||||
key = self._key_checker.CheckValue(key)
|
||||
try:
|
||||
return self._values[key]
|
||||
except KeyError:
|
||||
new_element = self._message_descriptor._concrete_class()
|
||||
new_element._SetListener(self._message_listener)
|
||||
self._values[key] = new_element
|
||||
self._message_listener.Modified()
|
||||
return new_element
|
||||
|
||||
def get_or_create(self, key: _K) -> _V:
|
||||
"""get_or_create() is an alias for getitem (ie. map[key]).
|
||||
|
||||
Args:
|
||||
key: The key to get or create in the map.
|
||||
|
||||
This is useful in cases where you want to be explicit that the call is
|
||||
mutating the map. This can avoid lint errors for statements like this
|
||||
that otherwise would appear to be pointless statements:
|
||||
|
||||
msg.my_map[key]
|
||||
"""
|
||||
return self[key]
|
||||
|
||||
@overload
|
||||
def get(self, key: _K) -> Optional[_V]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self, key: _K, default: _T) -> Union[_V, _T]:
|
||||
...
|
||||
|
||||
# We need to override this explicitly, because our defaultdict-like behavior
|
||||
# will make the default implementation (from our base class) always insert
|
||||
# the key.
|
||||
def get(self, key, default=None):
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
return default
|
||||
|
||||
def __contains__(self, item: _K) -> bool:
|
||||
item = self._key_checker.CheckValue(item)
|
||||
return item in self._values
|
||||
|
||||
def __setitem__(self, key: _K, value: _V) -> NoReturn:
|
||||
raise ValueError('May not set values directly, call my_map[key].foo = 5')
|
||||
|
||||
def __delitem__(self, key: _K) -> None:
|
||||
key = self._key_checker.CheckValue(key)
|
||||
del self._values[key]
|
||||
self._message_listener.Modified()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._values)
|
||||
|
||||
def __iter__(self) -> Iterator[_K]:
|
||||
return iter(self._values)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self._values)
|
||||
|
||||
def setdefault(self, key: _K, value: Optional[_V] = None) -> _V:
|
||||
raise NotImplementedError(
|
||||
'Set message map value directly is not supported, call'
|
||||
' my_map[key].foo = 5'
|
||||
)
|
||||
|
||||
def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None:
|
||||
# pylint: disable=protected-access
|
||||
for key in other._values:
|
||||
# According to documentation: "When parsing from the wire or when merging,
|
||||
# if there are duplicate map keys the last key seen is used".
|
||||
if key in self:
|
||||
del self[key]
|
||||
self[key].CopyFrom(other[key])
|
||||
# self._message_listener.Modified() not required here, because
|
||||
# mutations to submessages already propagate.
|
||||
|
||||
def InvalidateIterators(self) -> None:
|
||||
# It appears that the only way to reliably invalidate iterators to
|
||||
# self._values is to ensure that its size changes.
|
||||
original = self._values
|
||||
self._values = original.copy()
|
||||
original[None] = None
|
||||
|
||||
# This is defined in the abstract base, but we can do it much more cheaply.
|
||||
def clear(self) -> None:
|
||||
self._values.clear()
|
||||
self._message_listener.Modified()
|
||||
|
||||
def GetEntryClass(self) -> Any:
|
||||
return self._entry_descriptor._concrete_class
|
||||
|
||||
|
||||
class _UnknownField:
|
||||
"""A parsed unknown field."""
|
||||
|
||||
# Disallows assignment to other attributes.
|
||||
__slots__ = ['_field_number', '_wire_type', '_data']
|
||||
|
||||
def __init__(self, field_number, wire_type, data):
|
||||
self._field_number = field_number
|
||||
self._wire_type = wire_type
|
||||
self._data = data
|
||||
return
|
||||
|
||||
def __lt__(self, other):
|
||||
# pylint: disable=protected-access
|
||||
return self._field_number < other._field_number
|
||||
|
||||
def __eq__(self, other):
|
||||
if self is other:
|
||||
return True
|
||||
# pylint: disable=protected-access
|
||||
return (self._field_number == other._field_number and
|
||||
self._wire_type == other._wire_type and
|
||||
self._data == other._data)
|
||||
|
||||
|
||||
class UnknownFieldRef: # pylint: disable=missing-class-docstring
|
||||
|
||||
def __init__(self, parent, index):
|
||||
self._parent = parent
|
||||
self._index = index
|
||||
|
||||
def _check_valid(self):
|
||||
if not self._parent:
|
||||
raise ValueError('UnknownField does not exist. '
|
||||
'The parent message might be cleared.')
|
||||
if self._index >= len(self._parent):
|
||||
raise ValueError('UnknownField does not exist. '
|
||||
'The parent message might be cleared.')
|
||||
|
||||
@property
|
||||
def field_number(self):
|
||||
self._check_valid()
|
||||
# pylint: disable=protected-access
|
||||
return self._parent._internal_get(self._index)._field_number
|
||||
|
||||
@property
|
||||
def wire_type(self):
|
||||
self._check_valid()
|
||||
# pylint: disable=protected-access
|
||||
return self._parent._internal_get(self._index)._wire_type
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
self._check_valid()
|
||||
# pylint: disable=protected-access
|
||||
return self._parent._internal_get(self._index)._data
|
||||
|
||||
|
||||
class UnknownFieldSet:
|
||||
"""UnknownField container"""
|
||||
|
||||
# Disallows assignment to other attributes.
|
||||
__slots__ = ['_values']
|
||||
|
||||
def __init__(self):
|
||||
self._values = []
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self._values is None:
|
||||
raise ValueError('UnknownFields does not exist. '
|
||||
'The parent message might be cleared.')
|
||||
size = len(self._values)
|
||||
if index < 0:
|
||||
index += size
|
||||
if index < 0 or index >= size:
|
||||
raise IndexError('index %d out of range'.index)
|
||||
|
||||
return UnknownFieldRef(self, index)
|
||||
|
||||
def _internal_get(self, index):
|
||||
return self._values[index]
|
||||
|
||||
def __len__(self):
|
||||
if self._values is None:
|
||||
raise ValueError('UnknownFields does not exist. '
|
||||
'The parent message might be cleared.')
|
||||
return len(self._values)
|
||||
|
||||
def _add(self, field_number, wire_type, data):
|
||||
unknown_field = _UnknownField(field_number, wire_type, data)
|
||||
self._values.append(unknown_field)
|
||||
return unknown_field
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self)):
|
||||
yield UnknownFieldRef(self, i)
|
||||
|
||||
def _extend(self, other):
|
||||
if other is None:
|
||||
return
|
||||
# pylint: disable=protected-access
|
||||
self._values.extend(other._values)
|
||||
|
||||
def __eq__(self, other):
|
||||
if self is other:
|
||||
return True
|
||||
# Sort unknown fields because their order shouldn't
|
||||
# affect equality test.
|
||||
values = list(self._values)
|
||||
if other is None:
|
||||
return not values
|
||||
values.sort()
|
||||
# pylint: disable=protected-access
|
||||
other_values = sorted(other._values)
|
||||
return values == other_values
|
||||
|
||||
def _clear(self):
|
||||
for value in self._values:
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(value._data, UnknownFieldSet):
|
||||
value._data._clear() # pylint: disable=protected-access
|
||||
self._values = None
|
||||
983
.venv/lib/python3.10/site-packages/google/protobuf/internal/decoder.py
Executable file
983
.venv/lib/python3.10/site-packages/google/protobuf/internal/decoder.py
Executable file
@@ -0,0 +1,983 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""Code for decoding protocol buffer primitives.
|
||||
|
||||
This code is very similar to encoder.py -- read the docs for that module first.
|
||||
|
||||
A "decoder" is a function with the signature:
|
||||
Decode(buffer, pos, end, message, field_dict)
|
||||
The arguments are:
|
||||
buffer: The string containing the encoded message.
|
||||
pos: The current position in the string.
|
||||
end: The position in the string where the current message ends. May be
|
||||
less than len(buffer) if we're reading a sub-message.
|
||||
message: The message object into which we're parsing.
|
||||
field_dict: message._fields (avoids a hashtable lookup).
|
||||
The decoder reads the field and stores it into field_dict, returning the new
|
||||
buffer position. A decoder for a repeated field may proactively decode all of
|
||||
the elements of that field, if they appear consecutively.
|
||||
|
||||
Note that decoders may throw any of the following:
|
||||
IndexError: Indicates a truncated message.
|
||||
struct.error: Unpacking of a fixed-width field failed.
|
||||
message.DecodeError: Other errors.
|
||||
|
||||
Decoders are expected to raise an exception if they are called with pos > end.
|
||||
This allows callers to be lax about bounds checking: it's fineto read past
|
||||
"end" as long as you are sure that someone else will notice and throw an
|
||||
exception later on.
|
||||
|
||||
Something up the call stack is expected to catch IndexError and struct.error
|
||||
and convert them to message.DecodeError.
|
||||
|
||||
Decoders are constructed using decoder constructors with the signature:
|
||||
MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
|
||||
The arguments are:
|
||||
field_number: The field number of the field we want to decode.
|
||||
is_repeated: Is the field a repeated field? (bool)
|
||||
is_packed: Is the field a packed field? (bool)
|
||||
key: The key to use when looking up the field within field_dict.
|
||||
(This is actually the FieldDescriptor but nothing in this
|
||||
file should depend on that.)
|
||||
new_default: A function which takes a message object as a parameter and
|
||||
returns a new instance of the default value for this field.
|
||||
(This is called for repeated fields and sub-messages, when an
|
||||
instance does not already exist.)
|
||||
|
||||
As with encoders, we define a decoder constructor for every type of field.
|
||||
Then, for every field of every message class we construct an actual decoder.
|
||||
That decoder goes into a dict indexed by tag, so when we decode a message
|
||||
we repeatedly read a tag, look up the corresponding decoder, and invoke it.
|
||||
"""
|
||||
|
||||
__author__ = 'kenton@google.com (Kenton Varda)'
|
||||
|
||||
import math
|
||||
import numbers
|
||||
import struct
|
||||
|
||||
from google.protobuf import message
|
||||
from google.protobuf.internal import containers
|
||||
from google.protobuf.internal import encoder
|
||||
from google.protobuf.internal import wire_format
|
||||
|
||||
|
||||
# This is not for optimization, but rather to avoid conflicts with local
|
||||
# variables named "message".
|
||||
_DecodeError = message.DecodeError
|
||||
|
||||
|
||||
def IsDefaultScalarValue(value):
|
||||
"""Returns whether or not a scalar value is the default value of its type.
|
||||
|
||||
Specifically, this should be used to determine presence of implicit-presence
|
||||
fields, where we disallow custom defaults.
|
||||
|
||||
Args:
|
||||
value: A scalar value to check.
|
||||
|
||||
Returns:
|
||||
True if the value is equivalent to a default value, False otherwise.
|
||||
"""
|
||||
if isinstance(value, numbers.Number) and math.copysign(1.0, value) < 0:
|
||||
# Special case for negative zero, where "truthiness" fails to give the right
|
||||
# answer.
|
||||
return False
|
||||
|
||||
# Normally, we can just use Python's boolean conversion.
|
||||
return not value
|
||||
|
||||
|
||||
def _VarintDecoder(mask, result_type):
|
||||
"""Return an encoder for a basic varint value (does not include tag).
|
||||
|
||||
Decoded values will be bitwise-anded with the given mask before being
|
||||
returned, e.g. to limit them to 32 bits. The returned decoder does not
|
||||
take the usual "end" parameter -- the caller is expected to do bounds checking
|
||||
after the fact (often the caller can defer such checking until later). The
|
||||
decoder returns a (value, new_pos) pair.
|
||||
"""
|
||||
|
||||
def DecodeVarint(buffer, pos: int=None):
|
||||
result = 0
|
||||
shift = 0
|
||||
while 1:
|
||||
if pos is None:
|
||||
# Read from BytesIO
|
||||
try:
|
||||
b = buffer.read(1)[0]
|
||||
except IndexError as e:
|
||||
if shift == 0:
|
||||
# End of BytesIO.
|
||||
return None
|
||||
else:
|
||||
raise ValueError('Fail to read varint %s' % str(e))
|
||||
else:
|
||||
b = buffer[pos]
|
||||
pos += 1
|
||||
result |= ((b & 0x7f) << shift)
|
||||
if not (b & 0x80):
|
||||
result &= mask
|
||||
result = result_type(result)
|
||||
return result if pos is None else (result, pos)
|
||||
shift += 7
|
||||
if shift >= 64:
|
||||
raise _DecodeError('Too many bytes when decoding varint.')
|
||||
|
||||
return DecodeVarint
|
||||
|
||||
|
||||
def _SignedVarintDecoder(bits, result_type):
|
||||
"""Like _VarintDecoder() but decodes signed values."""
|
||||
|
||||
signbit = 1 << (bits - 1)
|
||||
mask = (1 << bits) - 1
|
||||
|
||||
def DecodeVarint(buffer, pos):
|
||||
result = 0
|
||||
shift = 0
|
||||
while 1:
|
||||
b = buffer[pos]
|
||||
result |= ((b & 0x7f) << shift)
|
||||
pos += 1
|
||||
if not (b & 0x80):
|
||||
result &= mask
|
||||
result = (result ^ signbit) - signbit
|
||||
result = result_type(result)
|
||||
return (result, pos)
|
||||
shift += 7
|
||||
if shift >= 64:
|
||||
raise _DecodeError('Too many bytes when decoding varint.')
|
||||
return DecodeVarint
|
||||
|
||||
# All 32-bit and 64-bit values are represented as int.
|
||||
_DecodeVarint = _VarintDecoder((1 << 64) - 1, int)
|
||||
_DecodeSignedVarint = _SignedVarintDecoder(64, int)
|
||||
|
||||
# Use these versions for values which must be limited to 32 bits.
|
||||
_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
|
||||
_DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
|
||||
|
||||
|
||||
def ReadTag(buffer, pos):
|
||||
"""Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
|
||||
|
||||
We return the raw bytes of the tag rather than decoding them. The raw
|
||||
bytes can then be used to look up the proper decoder. This effectively allows
|
||||
us to trade some work that would be done in pure-python (decoding a varint)
|
||||
for work that is done in C (searching for a byte string in a hash table).
|
||||
In a low-level language it would be much cheaper to decode the varint and
|
||||
use that, but not in Python.
|
||||
|
||||
Args:
|
||||
buffer: memoryview object of the encoded bytes
|
||||
pos: int of the current position to start from
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, int] of the tag data and new position.
|
||||
"""
|
||||
start = pos
|
||||
while buffer[pos] & 0x80:
|
||||
pos += 1
|
||||
pos += 1
|
||||
|
||||
tag_bytes = buffer[start:pos].tobytes()
|
||||
return tag_bytes, pos
|
||||
|
||||
|
||||
def DecodeTag(tag_bytes):
|
||||
"""Decode a tag from the bytes.
|
||||
|
||||
Args:
|
||||
tag_bytes: the bytes of the tag
|
||||
|
||||
Returns:
|
||||
Tuple[int, int] of the tag field number and wire type.
|
||||
"""
|
||||
(tag, _) = _DecodeVarint(tag_bytes, 0)
|
||||
return wire_format.UnpackTag(tag)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
|
||||
def _SimpleDecoder(wire_type, decode_value):
|
||||
"""Return a constructor for a decoder for fields of a particular type.
|
||||
|
||||
Args:
|
||||
wire_type: The field's wire type.
|
||||
decode_value: A function which decodes an individual value, e.g.
|
||||
_DecodeVarint()
|
||||
"""
|
||||
|
||||
def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
|
||||
clear_if_default=False):
|
||||
if is_packed:
|
||||
local_DecodeVarint = _DecodeVarint
|
||||
def DecodePackedField(buffer, pos, end, message, field_dict):
|
||||
value = field_dict.get(key)
|
||||
if value is None:
|
||||
value = field_dict.setdefault(key, new_default(message))
|
||||
(endpoint, pos) = local_DecodeVarint(buffer, pos)
|
||||
endpoint += pos
|
||||
if endpoint > end:
|
||||
raise _DecodeError('Truncated message.')
|
||||
while pos < endpoint:
|
||||
(element, pos) = decode_value(buffer, pos)
|
||||
value.append(element)
|
||||
if pos > endpoint:
|
||||
del value[-1] # Discard corrupt value.
|
||||
raise _DecodeError('Packed element was truncated.')
|
||||
return pos
|
||||
return DecodePackedField
|
||||
elif is_repeated:
|
||||
tag_bytes = encoder.TagBytes(field_number, wire_type)
|
||||
tag_len = len(tag_bytes)
|
||||
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
|
||||
value = field_dict.get(key)
|
||||
if value is None:
|
||||
value = field_dict.setdefault(key, new_default(message))
|
||||
while 1:
|
||||
(element, new_pos) = decode_value(buffer, pos)
|
||||
value.append(element)
|
||||
# Predict that the next tag is another copy of the same repeated
|
||||
# field.
|
||||
pos = new_pos + tag_len
|
||||
if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
|
||||
# Prediction failed. Return.
|
||||
if new_pos > end:
|
||||
raise _DecodeError('Truncated message.')
|
||||
return new_pos
|
||||
return DecodeRepeatedField
|
||||
else:
|
||||
def DecodeField(buffer, pos, end, message, field_dict):
|
||||
(new_value, pos) = decode_value(buffer, pos)
|
||||
if pos > end:
|
||||
raise _DecodeError('Truncated message.')
|
||||
if clear_if_default and IsDefaultScalarValue(new_value):
|
||||
field_dict.pop(key, None)
|
||||
else:
|
||||
field_dict[key] = new_value
|
||||
return pos
|
||||
return DecodeField
|
||||
|
||||
return SpecificDecoder
|
||||
|
||||
|
||||
def _ModifiedDecoder(wire_type, decode_value, modify_value):
|
||||
"""Like SimpleDecoder but additionally invokes modify_value on every value
|
||||
before storing it. Usually modify_value is ZigZagDecode.
|
||||
"""
|
||||
|
||||
# Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
|
||||
# not enough to make a significant difference.
|
||||
|
||||
def InnerDecode(buffer, pos):
|
||||
(result, new_pos) = decode_value(buffer, pos)
|
||||
return (modify_value(result), new_pos)
|
||||
return _SimpleDecoder(wire_type, InnerDecode)
|
||||
|
||||
|
||||
def _StructPackDecoder(wire_type, format):
|
||||
"""Return a constructor for a decoder for a fixed-width field.
|
||||
|
||||
Args:
|
||||
wire_type: The field's wire type.
|
||||
format: The format string to pass to struct.unpack().
|
||||
"""
|
||||
|
||||
value_size = struct.calcsize(format)
|
||||
local_unpack = struct.unpack
|
||||
|
||||
# Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
|
||||
# not enough to make a significant difference.
|
||||
|
||||
# Note that we expect someone up-stack to catch struct.error and convert
|
||||
# it to _DecodeError -- this way we don't have to set up exception-
|
||||
# handling blocks every time we parse one value.
|
||||
|
||||
def InnerDecode(buffer, pos):
|
||||
new_pos = pos + value_size
|
||||
result = local_unpack(format, buffer[pos:new_pos])[0]
|
||||
return (result, new_pos)
|
||||
return _SimpleDecoder(wire_type, InnerDecode)
|
||||
|
||||
|
||||
def _FloatDecoder():
|
||||
"""Returns a decoder for a float field.
|
||||
|
||||
This code works around a bug in struct.unpack for non-finite 32-bit
|
||||
floating-point values.
|
||||
"""
|
||||
|
||||
local_unpack = struct.unpack
|
||||
|
||||
def InnerDecode(buffer, pos):
|
||||
"""Decode serialized float to a float and new position.
|
||||
|
||||
Args:
|
||||
buffer: memoryview of the serialized bytes
|
||||
pos: int, position in the memory view to start at.
|
||||
|
||||
Returns:
|
||||
Tuple[float, int] of the deserialized float value and new position
|
||||
in the serialized data.
|
||||
"""
|
||||
# We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
|
||||
# bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
|
||||
new_pos = pos + 4
|
||||
float_bytes = buffer[pos:new_pos].tobytes()
|
||||
|
||||
# If this value has all its exponent bits set, then it's non-finite.
|
||||
# In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
|
||||
# To avoid that, we parse it specially.
|
||||
if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
|
||||
# If at least one significand bit is set...
|
||||
if float_bytes[0:3] != b'\x00\x00\x80':
|
||||
return (math.nan, new_pos)
|
||||
# If sign bit is set...
|
||||
if float_bytes[3:4] == b'\xFF':
|
||||
return (-math.inf, new_pos)
|
||||
return (math.inf, new_pos)
|
||||
|
||||
# Note that we expect someone up-stack to catch struct.error and convert
|
||||
# it to _DecodeError -- this way we don't have to set up exception-
|
||||
# handling blocks every time we parse one value.
|
||||
result = local_unpack('<f', float_bytes)[0]
|
||||
return (result, new_pos)
|
||||
return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
|
||||
|
||||
|
||||
def _DoubleDecoder():
|
||||
"""Returns a decoder for a double field.
|
||||
|
||||
This code works around a bug in struct.unpack for not-a-number.
|
||||
"""
|
||||
|
||||
local_unpack = struct.unpack
|
||||
|
||||
def InnerDecode(buffer, pos):
|
||||
"""Decode serialized double to a double and new position.
|
||||
|
||||
Args:
|
||||
buffer: memoryview of the serialized bytes.
|
||||
pos: int, position in the memory view to start at.
|
||||
|
||||
Returns:
|
||||
Tuple[float, int] of the decoded double value and new position
|
||||
in the serialized data.
|
||||
"""
|
||||
# We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
|
||||
# bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
|
||||
new_pos = pos + 8
|
||||
double_bytes = buffer[pos:new_pos].tobytes()
|
||||
|
||||
# If this value has all its exponent bits set and at least one significand
|
||||
# bit set, it's not a number. In Python 2.4, struct.unpack will treat it
|
||||
# as inf or -inf. To avoid that, we treat it specially.
|
||||
if ((double_bytes[7:8] in b'\x7F\xFF')
|
||||
and (double_bytes[6:7] >= b'\xF0')
|
||||
and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
|
||||
return (math.nan, new_pos)
|
||||
|
||||
# Note that we expect someone up-stack to catch struct.error and convert
|
||||
# it to _DecodeError -- this way we don't have to set up exception-
|
||||
# handling blocks every time we parse one value.
|
||||
result = local_unpack('<d', double_bytes)[0]
|
||||
return (result, new_pos)
|
||||
return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
|
||||
|
||||
|
||||
def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
|
||||
clear_if_default=False):
|
||||
"""Returns a decoder for enum field."""
|
||||
enum_type = key.enum_type
|
||||
if is_packed:
|
||||
local_DecodeVarint = _DecodeVarint
|
||||
def DecodePackedField(buffer, pos, end, message, field_dict):
|
||||
"""Decode serialized packed enum to its value and a new position.
|
||||
|
||||
Args:
|
||||
buffer: memoryview of the serialized bytes.
|
||||
pos: int, position in the memory view to start at.
|
||||
end: int, end position of serialized data
|
||||
message: Message object to store unknown fields in
|
||||
field_dict: Map[Descriptor, Any] to store decoded values in.
|
||||
|
||||
Returns:
|
||||
int, new position in serialized data.
|
||||
"""
|
||||
value = field_dict.get(key)
|
||||
if value is None:
|
||||
value = field_dict.setdefault(key, new_default(message))
|
||||
(endpoint, pos) = local_DecodeVarint(buffer, pos)
|
||||
endpoint += pos
|
||||
if endpoint > end:
|
||||
raise _DecodeError('Truncated message.')
|
||||
while pos < endpoint:
|
||||
value_start_pos = pos
|
||||
(element, pos) = _DecodeSignedVarint32(buffer, pos)
|
||||
# pylint: disable=protected-access
|
||||
if element in enum_type.values_by_number:
|
||||
value.append(element)
|
||||
else:
|
||||
if not message._unknown_fields:
|
||||
message._unknown_fields = []
|
||||
tag_bytes = encoder.TagBytes(field_number,
|
||||
wire_format.WIRETYPE_VARINT)
|
||||
|
||||
message._unknown_fields.append(
|
||||
(tag_bytes, buffer[value_start_pos:pos].tobytes()))
|
||||
# pylint: enable=protected-access
|
||||
if pos > endpoint:
|
||||
if element in enum_type.values_by_number:
|
||||
del value[-1] # Discard corrupt value.
|
||||
else:
|
||||
del message._unknown_fields[-1]
|
||||
# pylint: enable=protected-access
|
||||
raise _DecodeError('Packed element was truncated.')
|
||||
return pos
|
||||
return DecodePackedField
|
||||
elif is_repeated:
|
||||
tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
|
||||
tag_len = len(tag_bytes)
|
||||
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
|
||||
"""Decode serialized repeated enum to its value and a new position.
|
||||
|
||||
Args:
|
||||
buffer: memoryview of the serialized bytes.
|
||||
pos: int, position in the memory view to start at.
|
||||
end: int, end position of serialized data
|
||||
message: Message object to store unknown fields in
|
||||
field_dict: Map[Descriptor, Any] to store decoded values in.
|
||||
|
||||
Returns:
|
||||
int, new position in serialized data.
|
||||
"""
|
||||
value = field_dict.get(key)
|
||||
if value is None:
|
||||
value = field_dict.setdefault(key, new_default(message))
|
||||
while 1:
|
||||
(element, new_pos) = _DecodeSignedVarint32(buffer, pos)
|
||||
# pylint: disable=protected-access
|
||||
if element in enum_type.values_by_number:
|
||||
value.append(element)
|
||||
else:
|
||||
if not message._unknown_fields:
|
||||
message._unknown_fields = []
|
||||
message._unknown_fields.append(
|
||||
(tag_bytes, buffer[pos:new_pos].tobytes()))
|
||||
# pylint: enable=protected-access
|
||||
# Predict that the next tag is another copy of the same repeated
|
||||
# field.
|
||||
pos = new_pos + tag_len
|
||||
if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
|
||||
# Prediction failed. Return.
|
||||
if new_pos > end:
|
||||
raise _DecodeError('Truncated message.')
|
||||
return new_pos
|
||||
return DecodeRepeatedField
|
||||
else:
|
||||
def DecodeField(buffer, pos, end, message, field_dict):
|
||||
"""Decode serialized repeated enum to its value and a new position.
|
||||
|
||||
Args:
|
||||
buffer: memoryview of the serialized bytes.
|
||||
pos: int, position in the memory view to start at.
|
||||
end: int, end position of serialized data
|
||||
message: Message object to store unknown fields in
|
||||
field_dict: Map[Descriptor, Any] to store decoded values in.
|
||||
|
||||
Returns:
|
||||
int, new position in serialized data.
|
||||
"""
|
||||
value_start_pos = pos
|
||||
(enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
|
||||
if pos > end:
|
||||
raise _DecodeError('Truncated message.')
|
||||
if clear_if_default and IsDefaultScalarValue(enum_value):
|
||||
field_dict.pop(key, None)
|
||||
return pos
|
||||
# pylint: disable=protected-access
|
||||
if enum_value in enum_type.values_by_number:
|
||||
field_dict[key] = enum_value
|
||||
else:
|
||||
if not message._unknown_fields:
|
||||
message._unknown_fields = []
|
||||
tag_bytes = encoder.TagBytes(field_number,
|
||||
wire_format.WIRETYPE_VARINT)
|
||||
message._unknown_fields.append(
|
||||
(tag_bytes, buffer[value_start_pos:pos].tobytes()))
|
||||
# pylint: enable=protected-access
|
||||
return pos
|
||||
return DecodeField
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
|
||||
Int32Decoder = _SimpleDecoder(
|
||||
wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
|
||||
|
||||
Int64Decoder = _SimpleDecoder(
|
||||
wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
|
||||
|
||||
UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
|
||||
UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
|
||||
|
||||
SInt32Decoder = _ModifiedDecoder(
|
||||
wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
|
||||
SInt64Decoder = _ModifiedDecoder(
|
||||
wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
|
||||
|
||||
# Note that Python conveniently guarantees that when using the '<' prefix on
|
||||
# formats, they will also have the same size across all platforms (as opposed
|
||||
# to without the prefix, where their sizes depend on the C compiler's basic
|
||||
# type sizes).
|
||||
Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
|
||||
Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
|
||||
SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
|
||||
SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
|
||||
FloatDecoder = _FloatDecoder()
|
||||
DoubleDecoder = _DoubleDecoder()
|
||||
|
||||
BoolDecoder = _ModifiedDecoder(
|
||||
wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
|
||||
|
||||
|
||||
def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
|
||||
clear_if_default=False):
|
||||
"""Returns a decoder for a string field."""
|
||||
|
||||
local_DecodeVarint = _DecodeVarint
|
||||
|
||||
def _ConvertToUnicode(memview):
|
||||
"""Convert byte to unicode."""
|
||||
byte_str = memview.tobytes()
|
||||
try:
|
||||
value = str(byte_str, 'utf-8')
|
||||
except UnicodeDecodeError as e:
|
||||
# add more information to the error message and re-raise it.
|
||||
e.reason = '%s in field: %s' % (e, key.full_name)
|
||||
raise
|
||||
|
||||
return value
|
||||
|
||||
assert not is_packed
|
||||
if is_repeated:
|
||||
tag_bytes = encoder.TagBytes(field_number,
|
||||
wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
tag_len = len(tag_bytes)
|
||||
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
|
||||
value = field_dict.get(key)
|
||||
if value is None:
|
||||
value = field_dict.setdefault(key, new_default(message))
|
||||
while 1:
|
||||
(size, pos) = local_DecodeVarint(buffer, pos)
|
||||
new_pos = pos + size
|
||||
if new_pos > end:
|
||||
raise _DecodeError('Truncated string.')
|
||||
value.append(_ConvertToUnicode(buffer[pos:new_pos]))
|
||||
# Predict that the next tag is another copy of the same repeated field.
|
||||
pos = new_pos + tag_len
|
||||
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
|
||||
# Prediction failed. Return.
|
||||
return new_pos
|
||||
return DecodeRepeatedField
|
||||
else:
|
||||
def DecodeField(buffer, pos, end, message, field_dict):
|
||||
(size, pos) = local_DecodeVarint(buffer, pos)
|
||||
new_pos = pos + size
|
||||
if new_pos > end:
|
||||
raise _DecodeError('Truncated string.')
|
||||
if clear_if_default and IsDefaultScalarValue(size):
|
||||
field_dict.pop(key, None)
|
||||
else:
|
||||
field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
|
||||
return new_pos
|
||||
return DecodeField
|
||||
|
||||
|
||||
def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
|
||||
clear_if_default=False):
|
||||
"""Returns a decoder for a bytes field."""
|
||||
|
||||
local_DecodeVarint = _DecodeVarint
|
||||
|
||||
assert not is_packed
|
||||
if is_repeated:
|
||||
tag_bytes = encoder.TagBytes(field_number,
|
||||
wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
tag_len = len(tag_bytes)
|
||||
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
|
||||
value = field_dict.get(key)
|
||||
if value is None:
|
||||
value = field_dict.setdefault(key, new_default(message))
|
||||
while 1:
|
||||
(size, pos) = local_DecodeVarint(buffer, pos)
|
||||
new_pos = pos + size
|
||||
if new_pos > end:
|
||||
raise _DecodeError('Truncated string.')
|
||||
value.append(buffer[pos:new_pos].tobytes())
|
||||
# Predict that the next tag is another copy of the same repeated field.
|
||||
pos = new_pos + tag_len
|
||||
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
|
||||
# Prediction failed. Return.
|
||||
return new_pos
|
||||
return DecodeRepeatedField
|
||||
else:
|
||||
def DecodeField(buffer, pos, end, message, field_dict):
|
||||
(size, pos) = local_DecodeVarint(buffer, pos)
|
||||
new_pos = pos + size
|
||||
if new_pos > end:
|
||||
raise _DecodeError('Truncated string.')
|
||||
if clear_if_default and IsDefaultScalarValue(size):
|
||||
field_dict.pop(key, None)
|
||||
else:
|
||||
field_dict[key] = buffer[pos:new_pos].tobytes()
|
||||
return new_pos
|
||||
return DecodeField
|
||||
|
||||
|
||||
def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
|
||||
"""Returns a decoder for a group field."""
|
||||
|
||||
end_tag_bytes = encoder.TagBytes(field_number,
|
||||
wire_format.WIRETYPE_END_GROUP)
|
||||
end_tag_len = len(end_tag_bytes)
|
||||
|
||||
assert not is_packed
|
||||
if is_repeated:
|
||||
tag_bytes = encoder.TagBytes(field_number,
|
||||
wire_format.WIRETYPE_START_GROUP)
|
||||
tag_len = len(tag_bytes)
|
||||
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
|
||||
value = field_dict.get(key)
|
||||
if value is None:
|
||||
value = field_dict.setdefault(key, new_default(message))
|
||||
while 1:
|
||||
value = field_dict.get(key)
|
||||
if value is None:
|
||||
value = field_dict.setdefault(key, new_default(message))
|
||||
# Read sub-message.
|
||||
pos = value.add()._InternalParse(buffer, pos, end)
|
||||
# Read end tag.
|
||||
new_pos = pos+end_tag_len
|
||||
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
|
||||
raise _DecodeError('Missing group end tag.')
|
||||
# Predict that the next tag is another copy of the same repeated field.
|
||||
pos = new_pos + tag_len
|
||||
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
|
||||
# Prediction failed. Return.
|
||||
return new_pos
|
||||
return DecodeRepeatedField
|
||||
else:
|
||||
def DecodeField(buffer, pos, end, message, field_dict):
|
||||
value = field_dict.get(key)
|
||||
if value is None:
|
||||
value = field_dict.setdefault(key, new_default(message))
|
||||
# Read sub-message.
|
||||
pos = value._InternalParse(buffer, pos, end)
|
||||
# Read end tag.
|
||||
new_pos = pos+end_tag_len
|
||||
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
|
||||
raise _DecodeError('Missing group end tag.')
|
||||
return new_pos
|
||||
return DecodeField
|
||||
|
||||
|
||||
def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
|
||||
"""Returns a decoder for a message field."""
|
||||
|
||||
local_DecodeVarint = _DecodeVarint
|
||||
|
||||
assert not is_packed
|
||||
if is_repeated:
|
||||
tag_bytes = encoder.TagBytes(field_number,
|
||||
wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
tag_len = len(tag_bytes)
|
||||
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
|
||||
value = field_dict.get(key)
|
||||
if value is None:
|
||||
value = field_dict.setdefault(key, new_default(message))
|
||||
while 1:
|
||||
# Read length.
|
||||
(size, pos) = local_DecodeVarint(buffer, pos)
|
||||
new_pos = pos + size
|
||||
if new_pos > end:
|
||||
raise _DecodeError('Truncated message.')
|
||||
# Read sub-message.
|
||||
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
|
||||
# The only reason _InternalParse would return early is if it
|
||||
# encountered an end-group tag.
|
||||
raise _DecodeError('Unexpected end-group tag.')
|
||||
# Predict that the next tag is another copy of the same repeated field.
|
||||
pos = new_pos + tag_len
|
||||
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
|
||||
# Prediction failed. Return.
|
||||
return new_pos
|
||||
return DecodeRepeatedField
|
||||
else:
|
||||
def DecodeField(buffer, pos, end, message, field_dict):
|
||||
value = field_dict.get(key)
|
||||
if value is None:
|
||||
value = field_dict.setdefault(key, new_default(message))
|
||||
# Read length.
|
||||
(size, pos) = local_DecodeVarint(buffer, pos)
|
||||
new_pos = pos + size
|
||||
if new_pos > end:
|
||||
raise _DecodeError('Truncated message.')
|
||||
# Read sub-message.
|
||||
if value._InternalParse(buffer, pos, new_pos) != new_pos:
|
||||
# The only reason _InternalParse would return early is if it encountered
|
||||
# an end-group tag.
|
||||
raise _DecodeError('Unexpected end-group tag.')
|
||||
return new_pos
|
||||
return DecodeField
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
|
||||
|
||||
def MessageSetItemDecoder(descriptor):
|
||||
"""Returns a decoder for a MessageSet item.
|
||||
|
||||
The parameter is the message Descriptor.
|
||||
|
||||
The message set message looks like this:
|
||||
message MessageSet {
|
||||
repeated group Item = 1 {
|
||||
required int32 type_id = 2;
|
||||
required string message = 3;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
|
||||
message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
|
||||
|
||||
local_ReadTag = ReadTag
|
||||
local_DecodeVarint = _DecodeVarint
|
||||
|
||||
def DecodeItem(buffer, pos, end, message, field_dict):
|
||||
"""Decode serialized message set to its value and new position.
|
||||
|
||||
Args:
|
||||
buffer: memoryview of the serialized bytes.
|
||||
pos: int, position in the memory view to start at.
|
||||
end: int, end position of serialized data
|
||||
message: Message object to store unknown fields in
|
||||
field_dict: Map[Descriptor, Any] to store decoded values in.
|
||||
|
||||
Returns:
|
||||
int, new position in serialized data.
|
||||
"""
|
||||
message_set_item_start = pos
|
||||
type_id = -1
|
||||
message_start = -1
|
||||
message_end = -1
|
||||
|
||||
# Technically, type_id and message can appear in any order, so we need
|
||||
# a little loop here.
|
||||
while 1:
|
||||
(tag_bytes, pos) = local_ReadTag(buffer, pos)
|
||||
if tag_bytes == type_id_tag_bytes:
|
||||
(type_id, pos) = local_DecodeVarint(buffer, pos)
|
||||
elif tag_bytes == message_tag_bytes:
|
||||
(size, message_start) = local_DecodeVarint(buffer, pos)
|
||||
pos = message_end = message_start + size
|
||||
elif tag_bytes == item_end_tag_bytes:
|
||||
break
|
||||
else:
|
||||
field_number, wire_type = DecodeTag(tag_bytes)
|
||||
_, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type)
|
||||
if pos == -1:
|
||||
raise _DecodeError('Unexpected end-group tag.')
|
||||
|
||||
if pos > end:
|
||||
raise _DecodeError('Truncated message.')
|
||||
|
||||
if type_id == -1:
|
||||
raise _DecodeError('MessageSet item missing type_id.')
|
||||
if message_start == -1:
|
||||
raise _DecodeError('MessageSet item missing message.')
|
||||
|
||||
extension = message.Extensions._FindExtensionByNumber(type_id)
|
||||
# pylint: disable=protected-access
|
||||
if extension is not None:
|
||||
value = field_dict.get(extension)
|
||||
if value is None:
|
||||
message_type = extension.message_type
|
||||
if not hasattr(message_type, '_concrete_class'):
|
||||
message_factory.GetMessageClass(message_type)
|
||||
value = field_dict.setdefault(
|
||||
extension, message_type._concrete_class())
|
||||
if value._InternalParse(buffer, message_start,message_end) != message_end:
|
||||
# The only reason _InternalParse would return early is if it encountered
|
||||
# an end-group tag.
|
||||
raise _DecodeError('Unexpected end-group tag.')
|
||||
else:
|
||||
if not message._unknown_fields:
|
||||
message._unknown_fields = []
|
||||
message._unknown_fields.append(
|
||||
(MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return pos
|
||||
|
||||
return DecodeItem
|
||||
|
||||
|
||||
def UnknownMessageSetItemDecoder():
|
||||
"""Returns a decoder for a Unknown MessageSet item."""
|
||||
|
||||
type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
|
||||
message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
|
||||
|
||||
def DecodeUnknownItem(buffer):
|
||||
pos = 0
|
||||
end = len(buffer)
|
||||
message_start = -1
|
||||
message_end = -1
|
||||
while 1:
|
||||
(tag_bytes, pos) = ReadTag(buffer, pos)
|
||||
if tag_bytes == type_id_tag_bytes:
|
||||
(type_id, pos) = _DecodeVarint(buffer, pos)
|
||||
elif tag_bytes == message_tag_bytes:
|
||||
(size, message_start) = _DecodeVarint(buffer, pos)
|
||||
pos = message_end = message_start + size
|
||||
elif tag_bytes == item_end_tag_bytes:
|
||||
break
|
||||
else:
|
||||
field_number, wire_type = DecodeTag(tag_bytes)
|
||||
_, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type)
|
||||
if pos == -1:
|
||||
raise _DecodeError('Unexpected end-group tag.')
|
||||
|
||||
if pos > end:
|
||||
raise _DecodeError('Truncated message.')
|
||||
|
||||
if type_id == -1:
|
||||
raise _DecodeError('MessageSet item missing type_id.')
|
||||
if message_start == -1:
|
||||
raise _DecodeError('MessageSet item missing message.')
|
||||
|
||||
return (type_id, buffer[message_start:message_end].tobytes())
|
||||
|
||||
return DecodeUnknownItem
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
def MapDecoder(field_descriptor, new_default, is_message_map):
|
||||
"""Returns a decoder for a map field."""
|
||||
|
||||
key = field_descriptor
|
||||
tag_bytes = encoder.TagBytes(field_descriptor.number,
|
||||
wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
tag_len = len(tag_bytes)
|
||||
local_DecodeVarint = _DecodeVarint
|
||||
# Can't read _concrete_class yet; might not be initialized.
|
||||
message_type = field_descriptor.message_type
|
||||
|
||||
def DecodeMap(buffer, pos, end, message, field_dict):
|
||||
submsg = message_type._concrete_class()
|
||||
value = field_dict.get(key)
|
||||
if value is None:
|
||||
value = field_dict.setdefault(key, new_default(message))
|
||||
while 1:
|
||||
# Read length.
|
||||
(size, pos) = local_DecodeVarint(buffer, pos)
|
||||
new_pos = pos + size
|
||||
if new_pos > end:
|
||||
raise _DecodeError('Truncated message.')
|
||||
# Read sub-message.
|
||||
submsg.Clear()
|
||||
if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
|
||||
# The only reason _InternalParse would return early is if it
|
||||
# encountered an end-group tag.
|
||||
raise _DecodeError('Unexpected end-group tag.')
|
||||
|
||||
if is_message_map:
|
||||
value[submsg.key].CopyFrom(submsg.value)
|
||||
else:
|
||||
value[submsg.key] = submsg.value
|
||||
|
||||
# Predict that the next tag is another copy of the same repeated field.
|
||||
pos = new_pos + tag_len
|
||||
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
|
||||
# Prediction failed. Return.
|
||||
return new_pos
|
||||
|
||||
return DecodeMap
|
||||
|
||||
|
||||
def _DecodeFixed64(buffer, pos):
|
||||
"""Decode a fixed64."""
|
||||
new_pos = pos + 8
|
||||
return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
|
||||
|
||||
|
||||
def _DecodeFixed32(buffer, pos):
|
||||
"""Decode a fixed32."""
|
||||
|
||||
new_pos = pos + 4
|
||||
return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
|
||||
|
||||
|
||||
def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
|
||||
"""Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
|
||||
|
||||
unknown_field_set = containers.UnknownFieldSet()
|
||||
while end_pos is None or pos < end_pos:
|
||||
(tag_bytes, pos) = ReadTag(buffer, pos)
|
||||
(tag, _) = _DecodeVarint(tag_bytes, 0)
|
||||
field_number, wire_type = wire_format.UnpackTag(tag)
|
||||
if wire_type == wire_format.WIRETYPE_END_GROUP:
|
||||
break
|
||||
(data, pos) = _DecodeUnknownField(
|
||||
buffer, pos, end_pos, field_number, wire_type
|
||||
)
|
||||
# pylint: disable=protected-access
|
||||
unknown_field_set._add(field_number, wire_type, data)
|
||||
|
||||
return (unknown_field_set, pos)
|
||||
|
||||
|
||||
def _DecodeUnknownField(buffer, pos, end_pos, field_number, wire_type):
|
||||
"""Decode a unknown field. Returns the UnknownField and new position."""
|
||||
|
||||
if wire_type == wire_format.WIRETYPE_VARINT:
|
||||
(data, pos) = _DecodeVarint(buffer, pos)
|
||||
elif wire_type == wire_format.WIRETYPE_FIXED64:
|
||||
(data, pos) = _DecodeFixed64(buffer, pos)
|
||||
elif wire_type == wire_format.WIRETYPE_FIXED32:
|
||||
(data, pos) = _DecodeFixed32(buffer, pos)
|
||||
elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
|
||||
(size, pos) = _DecodeVarint(buffer, pos)
|
||||
data = buffer[pos:pos+size].tobytes()
|
||||
pos += size
|
||||
elif wire_type == wire_format.WIRETYPE_START_GROUP:
|
||||
end_tag_bytes = encoder.TagBytes(
|
||||
field_number, wire_format.WIRETYPE_END_GROUP
|
||||
)
|
||||
data, pos = _DecodeUnknownFieldSet(buffer, pos, end_pos)
|
||||
# Check end tag.
|
||||
if buffer[pos - len(end_tag_bytes) : pos] != end_tag_bytes:
|
||||
raise _DecodeError('Missing group end tag.')
|
||||
elif wire_type == wire_format.WIRETYPE_END_GROUP:
|
||||
return (0, -1)
|
||||
else:
|
||||
raise _DecodeError('Wrong wire type in tag.')
|
||||
|
||||
if pos > end_pos:
|
||||
raise _DecodeError('Truncated message.')
|
||||
|
||||
return (data, pos)
|
||||
806
.venv/lib/python3.10/site-packages/google/protobuf/internal/encoder.py
Executable file
806
.venv/lib/python3.10/site-packages/google/protobuf/internal/encoder.py
Executable file
@@ -0,0 +1,806 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""Code for encoding protocol message primitives.
|
||||
|
||||
Contains the logic for encoding every logical protocol field type
|
||||
into one of the 5 physical wire types.
|
||||
|
||||
This code is designed to push the Python interpreter's performance to the
|
||||
limits.
|
||||
|
||||
The basic idea is that at startup time, for every field (i.e. every
|
||||
FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The
|
||||
sizer takes a value of this field's type and computes its byte size. The
|
||||
encoder takes a writer function and a value. It encodes the value into byte
|
||||
strings and invokes the writer function to write those strings. Typically the
|
||||
writer function is the write() method of a BytesIO.
|
||||
|
||||
We try to do as much work as possible when constructing the writer and the
|
||||
sizer rather than when calling them. In particular:
|
||||
* We copy any needed global functions to local variables, so that we do not need
|
||||
to do costly global table lookups at runtime.
|
||||
* Similarly, we try to do any attribute lookups at startup time if possible.
|
||||
* Every field's tag is encoded to bytes at startup, since it can't change at
|
||||
runtime.
|
||||
* Whatever component of the field size we can compute at startup, we do.
|
||||
* We *avoid* sharing code if doing so would make the code slower and not sharing
|
||||
does not burden us too much. For example, encoders for repeated fields do
|
||||
not just call the encoders for singular fields in a loop because this would
|
||||
add an extra function call overhead for every loop iteration; instead, we
|
||||
manually inline the single-value encoder into the loop.
|
||||
* If a Python function lacks a return statement, Python actually generates
|
||||
instructions to pop the result of the last statement off the stack, push
|
||||
None onto the stack, and then return that. If we really don't care what
|
||||
value is returned, then we can save two instructions by returning the
|
||||
result of the last statement. It looks funny but it helps.
|
||||
* We assume that type and bounds checking has happened at a higher level.
|
||||
"""
|
||||
|
||||
__author__ = 'kenton@google.com (Kenton Varda)'
|
||||
|
||||
import struct
|
||||
|
||||
from google.protobuf.internal import wire_format
|
||||
|
||||
|
||||
# This will overflow and thus become IEEE-754 "infinity". We would use
|
||||
# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
|
||||
_POS_INF = 1e10000
|
||||
_NEG_INF = -_POS_INF
|
||||
|
||||
|
||||
def _VarintSize(value):
|
||||
"""Compute the size of a varint value."""
|
||||
if value <= 0x7f: return 1
|
||||
if value <= 0x3fff: return 2
|
||||
if value <= 0x1fffff: return 3
|
||||
if value <= 0xfffffff: return 4
|
||||
if value <= 0x7ffffffff: return 5
|
||||
if value <= 0x3ffffffffff: return 6
|
||||
if value <= 0x1ffffffffffff: return 7
|
||||
if value <= 0xffffffffffffff: return 8
|
||||
if value <= 0x7fffffffffffffff: return 9
|
||||
return 10
|
||||
|
||||
|
||||
def _SignedVarintSize(value):
|
||||
"""Compute the size of a signed varint value."""
|
||||
if value < 0: return 10
|
||||
if value <= 0x7f: return 1
|
||||
if value <= 0x3fff: return 2
|
||||
if value <= 0x1fffff: return 3
|
||||
if value <= 0xfffffff: return 4
|
||||
if value <= 0x7ffffffff: return 5
|
||||
if value <= 0x3ffffffffff: return 6
|
||||
if value <= 0x1ffffffffffff: return 7
|
||||
if value <= 0xffffffffffffff: return 8
|
||||
if value <= 0x7fffffffffffffff: return 9
|
||||
return 10
|
||||
|
||||
|
||||
def _TagSize(field_number):
|
||||
"""Returns the number of bytes required to serialize a tag with this field
|
||||
number."""
|
||||
# Just pass in type 0, since the type won't affect the tag+type size.
|
||||
return _VarintSize(wire_format.PackTag(field_number, 0))
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# In this section we define some generic sizers. Each of these functions
|
||||
# takes parameters specific to a particular field type, e.g. int32 or fixed64.
|
||||
# It returns another function which in turn takes parameters specific to a
|
||||
# particular field, e.g. the field number and whether it is repeated or packed.
|
||||
# Look at the next section to see how these are used.
|
||||
|
||||
|
||||
def _SimpleSizer(compute_value_size):
|
||||
"""A sizer which uses the function compute_value_size to compute the size of
|
||||
each value. Typically compute_value_size is _VarintSize."""
|
||||
|
||||
def SpecificSizer(field_number, is_repeated, is_packed):
|
||||
tag_size = _TagSize(field_number)
|
||||
if is_packed:
|
||||
local_VarintSize = _VarintSize
|
||||
def PackedFieldSize(value):
|
||||
result = 0
|
||||
for element in value:
|
||||
result += compute_value_size(element)
|
||||
return result + local_VarintSize(result) + tag_size
|
||||
return PackedFieldSize
|
||||
elif is_repeated:
|
||||
def RepeatedFieldSize(value):
|
||||
result = tag_size * len(value)
|
||||
for element in value:
|
||||
result += compute_value_size(element)
|
||||
return result
|
||||
return RepeatedFieldSize
|
||||
else:
|
||||
def FieldSize(value):
|
||||
return tag_size + compute_value_size(value)
|
||||
return FieldSize
|
||||
|
||||
return SpecificSizer
|
||||
|
||||
|
||||
def _ModifiedSizer(compute_value_size, modify_value):
|
||||
"""Like SimpleSizer, but modify_value is invoked on each value before it is
|
||||
passed to compute_value_size. modify_value is typically ZigZagEncode."""
|
||||
|
||||
def SpecificSizer(field_number, is_repeated, is_packed):
|
||||
tag_size = _TagSize(field_number)
|
||||
if is_packed:
|
||||
local_VarintSize = _VarintSize
|
||||
def PackedFieldSize(value):
|
||||
result = 0
|
||||
for element in value:
|
||||
result += compute_value_size(modify_value(element))
|
||||
return result + local_VarintSize(result) + tag_size
|
||||
return PackedFieldSize
|
||||
elif is_repeated:
|
||||
def RepeatedFieldSize(value):
|
||||
result = tag_size * len(value)
|
||||
for element in value:
|
||||
result += compute_value_size(modify_value(element))
|
||||
return result
|
||||
return RepeatedFieldSize
|
||||
else:
|
||||
def FieldSize(value):
|
||||
return tag_size + compute_value_size(modify_value(value))
|
||||
return FieldSize
|
||||
|
||||
return SpecificSizer
|
||||
|
||||
|
||||
def _FixedSizer(value_size):
|
||||
"""Like _SimpleSizer except for a fixed-size field. The input is the size
|
||||
of one value."""
|
||||
|
||||
def SpecificSizer(field_number, is_repeated, is_packed):
|
||||
tag_size = _TagSize(field_number)
|
||||
if is_packed:
|
||||
local_VarintSize = _VarintSize
|
||||
def PackedFieldSize(value):
|
||||
result = len(value) * value_size
|
||||
return result + local_VarintSize(result) + tag_size
|
||||
return PackedFieldSize
|
||||
elif is_repeated:
|
||||
element_size = value_size + tag_size
|
||||
def RepeatedFieldSize(value):
|
||||
return len(value) * element_size
|
||||
return RepeatedFieldSize
|
||||
else:
|
||||
field_size = value_size + tag_size
|
||||
def FieldSize(value):
|
||||
return field_size
|
||||
return FieldSize
|
||||
|
||||
return SpecificSizer
|
||||
|
||||
|
||||
# ====================================================================
|
||||
# Here we declare a sizer constructor for each field type. Each "sizer
|
||||
# constructor" is a function that takes (field_number, is_repeated, is_packed)
|
||||
# as parameters and returns a sizer, which in turn takes a field value as
|
||||
# a parameter and returns its encoded size.
|
||||
|
||||
|
||||
Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize)
|
||||
|
||||
UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize)
|
||||
|
||||
SInt32Sizer = SInt64Sizer = _ModifiedSizer(
|
||||
_SignedVarintSize, wire_format.ZigZagEncode)
|
||||
|
||||
Fixed32Sizer = SFixed32Sizer = FloatSizer = _FixedSizer(4)
|
||||
Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8)
|
||||
|
||||
BoolSizer = _FixedSizer(1)
|
||||
|
||||
|
||||
def StringSizer(field_number, is_repeated, is_packed):
|
||||
"""Returns a sizer for a string field."""
|
||||
|
||||
tag_size = _TagSize(field_number)
|
||||
local_VarintSize = _VarintSize
|
||||
local_len = len
|
||||
assert not is_packed
|
||||
if is_repeated:
|
||||
def RepeatedFieldSize(value):
|
||||
result = tag_size * len(value)
|
||||
for element in value:
|
||||
l = local_len(element.encode('utf-8'))
|
||||
result += local_VarintSize(l) + l
|
||||
return result
|
||||
return RepeatedFieldSize
|
||||
else:
|
||||
def FieldSize(value):
|
||||
l = local_len(value.encode('utf-8'))
|
||||
return tag_size + local_VarintSize(l) + l
|
||||
return FieldSize
|
||||
|
||||
|
||||
def BytesSizer(field_number, is_repeated, is_packed):
|
||||
"""Returns a sizer for a bytes field."""
|
||||
|
||||
tag_size = _TagSize(field_number)
|
||||
local_VarintSize = _VarintSize
|
||||
local_len = len
|
||||
assert not is_packed
|
||||
if is_repeated:
|
||||
def RepeatedFieldSize(value):
|
||||
result = tag_size * len(value)
|
||||
for element in value:
|
||||
l = local_len(element)
|
||||
result += local_VarintSize(l) + l
|
||||
return result
|
||||
return RepeatedFieldSize
|
||||
else:
|
||||
def FieldSize(value):
|
||||
l = local_len(value)
|
||||
return tag_size + local_VarintSize(l) + l
|
||||
return FieldSize
|
||||
|
||||
|
||||
def GroupSizer(field_number, is_repeated, is_packed):
|
||||
"""Returns a sizer for a group field."""
|
||||
|
||||
tag_size = _TagSize(field_number) * 2
|
||||
assert not is_packed
|
||||
if is_repeated:
|
||||
def RepeatedFieldSize(value):
|
||||
result = tag_size * len(value)
|
||||
for element in value:
|
||||
result += element.ByteSize()
|
||||
return result
|
||||
return RepeatedFieldSize
|
||||
else:
|
||||
def FieldSize(value):
|
||||
return tag_size + value.ByteSize()
|
||||
return FieldSize
|
||||
|
||||
|
||||
def MessageSizer(field_number, is_repeated, is_packed):
|
||||
"""Returns a sizer for a message field."""
|
||||
|
||||
tag_size = _TagSize(field_number)
|
||||
local_VarintSize = _VarintSize
|
||||
assert not is_packed
|
||||
if is_repeated:
|
||||
def RepeatedFieldSize(value):
|
||||
result = tag_size * len(value)
|
||||
for element in value:
|
||||
l = element.ByteSize()
|
||||
result += local_VarintSize(l) + l
|
||||
return result
|
||||
return RepeatedFieldSize
|
||||
else:
|
||||
def FieldSize(value):
|
||||
l = value.ByteSize()
|
||||
return tag_size + local_VarintSize(l) + l
|
||||
return FieldSize
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# MessageSet is special: it needs custom logic to compute its size properly.
|
||||
|
||||
|
||||
def MessageSetItemSizer(field_number):
|
||||
"""Returns a sizer for extensions of MessageSet.
|
||||
|
||||
The message set message looks like this:
|
||||
message MessageSet {
|
||||
repeated group Item = 1 {
|
||||
required int32 type_id = 2;
|
||||
required string message = 3;
|
||||
}
|
||||
}
|
||||
"""
|
||||
static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) +
|
||||
_TagSize(3))
|
||||
local_VarintSize = _VarintSize
|
||||
|
||||
def FieldSize(value):
|
||||
l = value.ByteSize()
|
||||
return static_size + local_VarintSize(l) + l
|
||||
|
||||
return FieldSize
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Map is special: it needs custom logic to compute its size properly.
|
||||
|
||||
|
||||
def MapSizer(field_descriptor, is_message_map):
|
||||
"""Returns a sizer for a map field."""
|
||||
|
||||
# Can't look at field_descriptor.message_type._concrete_class because it may
|
||||
# not have been initialized yet.
|
||||
message_type = field_descriptor.message_type
|
||||
message_sizer = MessageSizer(field_descriptor.number, False, False)
|
||||
|
||||
def FieldSize(map_value):
|
||||
total = 0
|
||||
for key in map_value:
|
||||
value = map_value[key]
|
||||
# It's wasteful to create the messages and throw them away one second
|
||||
# later since we'll do the same for the actual encode. But there's not an
|
||||
# obvious way to avoid this within the current design without tons of code
|
||||
# duplication. For message map, value.ByteSize() should be called to
|
||||
# update the status.
|
||||
entry_msg = message_type._concrete_class(key=key, value=value)
|
||||
total += message_sizer(entry_msg)
|
||||
if is_message_map:
|
||||
value.ByteSize()
|
||||
return total
|
||||
|
||||
return FieldSize
|
||||
|
||||
# ====================================================================
|
||||
# Encoders!
|
||||
|
||||
|
||||
def _VarintEncoder():
|
||||
"""Return an encoder for a basic varint value (does not include tag)."""
|
||||
|
||||
local_int2byte = struct.Struct('>B').pack
|
||||
|
||||
def EncodeVarint(write, value, unused_deterministic=None):
|
||||
bits = value & 0x7f
|
||||
value >>= 7
|
||||
while value:
|
||||
write(local_int2byte(0x80|bits))
|
||||
bits = value & 0x7f
|
||||
value >>= 7
|
||||
return write(local_int2byte(bits))
|
||||
|
||||
return EncodeVarint
|
||||
|
||||
|
||||
def _SignedVarintEncoder():
|
||||
"""Return an encoder for a basic signed varint value (does not include
|
||||
tag)."""
|
||||
|
||||
local_int2byte = struct.Struct('>B').pack
|
||||
|
||||
def EncodeSignedVarint(write, value, unused_deterministic=None):
|
||||
if value < 0:
|
||||
value += (1 << 64)
|
||||
bits = value & 0x7f
|
||||
value >>= 7
|
||||
while value:
|
||||
write(local_int2byte(0x80|bits))
|
||||
bits = value & 0x7f
|
||||
value >>= 7
|
||||
return write(local_int2byte(bits))
|
||||
|
||||
return EncodeSignedVarint
|
||||
|
||||
|
||||
_EncodeVarint = _VarintEncoder()
|
||||
_EncodeSignedVarint = _SignedVarintEncoder()
|
||||
|
||||
|
||||
def _VarintBytes(value):
|
||||
"""Encode the given integer as a varint and return the bytes. This is only
|
||||
called at startup time so it doesn't need to be fast."""
|
||||
|
||||
pieces = []
|
||||
_EncodeVarint(pieces.append, value, True)
|
||||
return b"".join(pieces)
|
||||
|
||||
|
||||
def TagBytes(field_number, wire_type):
|
||||
"""Encode the given tag and return the bytes. Only called at startup."""
|
||||
|
||||
return bytes(_VarintBytes(wire_format.PackTag(field_number, wire_type)))
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# As with sizers (see above), we have a number of common encoder
|
||||
# implementations.
|
||||
|
||||
|
||||
def _SimpleEncoder(wire_type, encode_value, compute_value_size):
|
||||
"""Return a constructor for an encoder for fields of a particular type.
|
||||
|
||||
Args:
|
||||
wire_type: The field's wire type, for encoding tags.
|
||||
encode_value: A function which encodes an individual value, e.g.
|
||||
_EncodeVarint().
|
||||
compute_value_size: A function which computes the size of an individual
|
||||
value, e.g. _VarintSize().
|
||||
"""
|
||||
|
||||
def SpecificEncoder(field_number, is_repeated, is_packed):
|
||||
if is_packed:
|
||||
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
local_EncodeVarint = _EncodeVarint
|
||||
def EncodePackedField(write, value, deterministic):
|
||||
write(tag_bytes)
|
||||
size = 0
|
||||
for element in value:
|
||||
size += compute_value_size(element)
|
||||
local_EncodeVarint(write, size, deterministic)
|
||||
for element in value:
|
||||
encode_value(write, element, deterministic)
|
||||
return EncodePackedField
|
||||
elif is_repeated:
|
||||
tag_bytes = TagBytes(field_number, wire_type)
|
||||
def EncodeRepeatedField(write, value, deterministic):
|
||||
for element in value:
|
||||
write(tag_bytes)
|
||||
encode_value(write, element, deterministic)
|
||||
return EncodeRepeatedField
|
||||
else:
|
||||
tag_bytes = TagBytes(field_number, wire_type)
|
||||
def EncodeField(write, value, deterministic):
|
||||
write(tag_bytes)
|
||||
return encode_value(write, value, deterministic)
|
||||
return EncodeField
|
||||
|
||||
return SpecificEncoder
|
||||
|
||||
|
||||
def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
|
||||
"""Like SimpleEncoder but additionally invokes modify_value on every value
|
||||
before passing it to encode_value. Usually modify_value is ZigZagEncode."""
|
||||
|
||||
def SpecificEncoder(field_number, is_repeated, is_packed):
|
||||
if is_packed:
|
||||
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
local_EncodeVarint = _EncodeVarint
|
||||
def EncodePackedField(write, value, deterministic):
|
||||
write(tag_bytes)
|
||||
size = 0
|
||||
for element in value:
|
||||
size += compute_value_size(modify_value(element))
|
||||
local_EncodeVarint(write, size, deterministic)
|
||||
for element in value:
|
||||
encode_value(write, modify_value(element), deterministic)
|
||||
return EncodePackedField
|
||||
elif is_repeated:
|
||||
tag_bytes = TagBytes(field_number, wire_type)
|
||||
def EncodeRepeatedField(write, value, deterministic):
|
||||
for element in value:
|
||||
write(tag_bytes)
|
||||
encode_value(write, modify_value(element), deterministic)
|
||||
return EncodeRepeatedField
|
||||
else:
|
||||
tag_bytes = TagBytes(field_number, wire_type)
|
||||
def EncodeField(write, value, deterministic):
|
||||
write(tag_bytes)
|
||||
return encode_value(write, modify_value(value), deterministic)
|
||||
return EncodeField
|
||||
|
||||
return SpecificEncoder
|
||||
|
||||
|
||||
def _StructPackEncoder(wire_type, format):
|
||||
"""Return a constructor for an encoder for a fixed-width field.
|
||||
|
||||
Args:
|
||||
wire_type: The field's wire type, for encoding tags.
|
||||
format: The format string to pass to struct.pack().
|
||||
"""
|
||||
|
||||
value_size = struct.calcsize(format)
|
||||
|
||||
def SpecificEncoder(field_number, is_repeated, is_packed):
|
||||
local_struct_pack = struct.pack
|
||||
if is_packed:
|
||||
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
local_EncodeVarint = _EncodeVarint
|
||||
def EncodePackedField(write, value, deterministic):
|
||||
write(tag_bytes)
|
||||
local_EncodeVarint(write, len(value) * value_size, deterministic)
|
||||
for element in value:
|
||||
write(local_struct_pack(format, element))
|
||||
return EncodePackedField
|
||||
elif is_repeated:
|
||||
tag_bytes = TagBytes(field_number, wire_type)
|
||||
def EncodeRepeatedField(write, value, unused_deterministic=None):
|
||||
for element in value:
|
||||
write(tag_bytes)
|
||||
write(local_struct_pack(format, element))
|
||||
return EncodeRepeatedField
|
||||
else:
|
||||
tag_bytes = TagBytes(field_number, wire_type)
|
||||
def EncodeField(write, value, unused_deterministic=None):
|
||||
write(tag_bytes)
|
||||
return write(local_struct_pack(format, value))
|
||||
return EncodeField
|
||||
|
||||
return SpecificEncoder
|
||||
|
||||
|
||||
def _FloatingPointEncoder(wire_type, format):
|
||||
"""Return a constructor for an encoder for float fields.
|
||||
|
||||
This is like StructPackEncoder, but catches errors that may be due to
|
||||
passing non-finite floating-point values to struct.pack, and makes a
|
||||
second attempt to encode those values.
|
||||
|
||||
Args:
|
||||
wire_type: The field's wire type, for encoding tags.
|
||||
format: The format string to pass to struct.pack().
|
||||
"""
|
||||
|
||||
value_size = struct.calcsize(format)
|
||||
if value_size == 4:
|
||||
def EncodeNonFiniteOrRaise(write, value):
|
||||
# Remember that the serialized form uses little-endian byte order.
|
||||
if value == _POS_INF:
|
||||
write(b'\x00\x00\x80\x7F')
|
||||
elif value == _NEG_INF:
|
||||
write(b'\x00\x00\x80\xFF')
|
||||
elif value != value: # NaN
|
||||
write(b'\x00\x00\xC0\x7F')
|
||||
else:
|
||||
raise
|
||||
elif value_size == 8:
|
||||
def EncodeNonFiniteOrRaise(write, value):
|
||||
if value == _POS_INF:
|
||||
write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
|
||||
elif value == _NEG_INF:
|
||||
write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
|
||||
elif value != value: # NaN
|
||||
write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
raise ValueError('Can\'t encode floating-point values that are '
|
||||
'%d bytes long (only 4 or 8)' % value_size)
|
||||
|
||||
def SpecificEncoder(field_number, is_repeated, is_packed):
|
||||
local_struct_pack = struct.pack
|
||||
if is_packed:
|
||||
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
local_EncodeVarint = _EncodeVarint
|
||||
def EncodePackedField(write, value, deterministic):
|
||||
write(tag_bytes)
|
||||
local_EncodeVarint(write, len(value) * value_size, deterministic)
|
||||
for element in value:
|
||||
# This try/except block is going to be faster than any code that
|
||||
# we could write to check whether element is finite.
|
||||
try:
|
||||
write(local_struct_pack(format, element))
|
||||
except SystemError:
|
||||
EncodeNonFiniteOrRaise(write, element)
|
||||
return EncodePackedField
|
||||
elif is_repeated:
|
||||
tag_bytes = TagBytes(field_number, wire_type)
|
||||
def EncodeRepeatedField(write, value, unused_deterministic=None):
|
||||
for element in value:
|
||||
write(tag_bytes)
|
||||
try:
|
||||
write(local_struct_pack(format, element))
|
||||
except SystemError:
|
||||
EncodeNonFiniteOrRaise(write, element)
|
||||
return EncodeRepeatedField
|
||||
else:
|
||||
tag_bytes = TagBytes(field_number, wire_type)
|
||||
def EncodeField(write, value, unused_deterministic=None):
|
||||
write(tag_bytes)
|
||||
try:
|
||||
write(local_struct_pack(format, value))
|
||||
except SystemError:
|
||||
EncodeNonFiniteOrRaise(write, value)
|
||||
return EncodeField
|
||||
|
||||
return SpecificEncoder
|
||||
|
||||
|
||||
# ====================================================================
|
||||
# Here we declare an encoder constructor for each field type. These work
|
||||
# very similarly to sizer constructors, described earlier.
|
||||
|
||||
|
||||
Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
|
||||
wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
|
||||
|
||||
UInt32Encoder = UInt64Encoder = _SimpleEncoder(
|
||||
wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
|
||||
|
||||
SInt32Encoder = SInt64Encoder = _ModifiedEncoder(
|
||||
wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize,
|
||||
wire_format.ZigZagEncode)
|
||||
|
||||
# Note that Python conveniently guarantees that when using the '<' prefix on
|
||||
# formats, they will also have the same size across all platforms (as opposed
|
||||
# to without the prefix, where their sizes depend on the C compiler's basic
|
||||
# type sizes).
|
||||
Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
|
||||
Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
|
||||
SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
|
||||
SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
|
||||
FloatEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f')
|
||||
DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d')
|
||||
|
||||
|
||||
def BoolEncoder(field_number, is_repeated, is_packed):
|
||||
"""Returns an encoder for a boolean field."""
|
||||
|
||||
false_byte = b'\x00'
|
||||
true_byte = b'\x01'
|
||||
if is_packed:
|
||||
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
local_EncodeVarint = _EncodeVarint
|
||||
def EncodePackedField(write, value, deterministic):
|
||||
write(tag_bytes)
|
||||
local_EncodeVarint(write, len(value), deterministic)
|
||||
for element in value:
|
||||
if element:
|
||||
write(true_byte)
|
||||
else:
|
||||
write(false_byte)
|
||||
return EncodePackedField
|
||||
elif is_repeated:
|
||||
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
|
||||
def EncodeRepeatedField(write, value, unused_deterministic=None):
|
||||
for element in value:
|
||||
write(tag_bytes)
|
||||
if element:
|
||||
write(true_byte)
|
||||
else:
|
||||
write(false_byte)
|
||||
return EncodeRepeatedField
|
||||
else:
|
||||
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
|
||||
def EncodeField(write, value, unused_deterministic=None):
|
||||
write(tag_bytes)
|
||||
if value:
|
||||
return write(true_byte)
|
||||
return write(false_byte)
|
||||
return EncodeField
|
||||
|
||||
|
||||
def StringEncoder(field_number, is_repeated, is_packed):
|
||||
"""Returns an encoder for a string field."""
|
||||
|
||||
tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
local_EncodeVarint = _EncodeVarint
|
||||
local_len = len
|
||||
assert not is_packed
|
||||
if is_repeated:
|
||||
def EncodeRepeatedField(write, value, deterministic):
|
||||
for element in value:
|
||||
encoded = element.encode('utf-8')
|
||||
write(tag)
|
||||
local_EncodeVarint(write, local_len(encoded), deterministic)
|
||||
write(encoded)
|
||||
return EncodeRepeatedField
|
||||
else:
|
||||
def EncodeField(write, value, deterministic):
|
||||
encoded = value.encode('utf-8')
|
||||
write(tag)
|
||||
local_EncodeVarint(write, local_len(encoded), deterministic)
|
||||
return write(encoded)
|
||||
return EncodeField
|
||||
|
||||
|
||||
def BytesEncoder(field_number, is_repeated, is_packed):
|
||||
"""Returns an encoder for a bytes field."""
|
||||
|
||||
tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
local_EncodeVarint = _EncodeVarint
|
||||
local_len = len
|
||||
assert not is_packed
|
||||
if is_repeated:
|
||||
def EncodeRepeatedField(write, value, deterministic):
|
||||
for element in value:
|
||||
write(tag)
|
||||
local_EncodeVarint(write, local_len(element), deterministic)
|
||||
write(element)
|
||||
return EncodeRepeatedField
|
||||
else:
|
||||
def EncodeField(write, value, deterministic):
|
||||
write(tag)
|
||||
local_EncodeVarint(write, local_len(value), deterministic)
|
||||
return write(value)
|
||||
return EncodeField
|
||||
|
||||
|
||||
def GroupEncoder(field_number, is_repeated, is_packed):
|
||||
"""Returns an encoder for a group field."""
|
||||
|
||||
start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
|
||||
end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
|
||||
assert not is_packed
|
||||
if is_repeated:
|
||||
def EncodeRepeatedField(write, value, deterministic):
|
||||
for element in value:
|
||||
write(start_tag)
|
||||
element._InternalSerialize(write, deterministic)
|
||||
write(end_tag)
|
||||
return EncodeRepeatedField
|
||||
else:
|
||||
def EncodeField(write, value, deterministic):
|
||||
write(start_tag)
|
||||
value._InternalSerialize(write, deterministic)
|
||||
return write(end_tag)
|
||||
return EncodeField
|
||||
|
||||
|
||||
def MessageEncoder(field_number, is_repeated, is_packed):
|
||||
"""Returns an encoder for a message field."""
|
||||
|
||||
tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
|
||||
local_EncodeVarint = _EncodeVarint
|
||||
assert not is_packed
|
||||
if is_repeated:
|
||||
def EncodeRepeatedField(write, value, deterministic):
|
||||
for element in value:
|
||||
write(tag)
|
||||
local_EncodeVarint(write, element.ByteSize(), deterministic)
|
||||
element._InternalSerialize(write, deterministic)
|
||||
return EncodeRepeatedField
|
||||
else:
|
||||
def EncodeField(write, value, deterministic):
|
||||
write(tag)
|
||||
local_EncodeVarint(write, value.ByteSize(), deterministic)
|
||||
return value._InternalSerialize(write, deterministic)
|
||||
return EncodeField
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# As before, MessageSet is special.
|
||||
|
||||
|
||||
def MessageSetItemEncoder(field_number):
|
||||
"""Encoder for extensions of MessageSet.
|
||||
|
||||
The message set message looks like this:
|
||||
message MessageSet {
|
||||
repeated group Item = 1 {
|
||||
required int32 type_id = 2;
|
||||
required string message = 3;
|
||||
}
|
||||
}
|
||||
"""
|
||||
start_bytes = b"".join([
|
||||
TagBytes(1, wire_format.WIRETYPE_START_GROUP),
|
||||
TagBytes(2, wire_format.WIRETYPE_VARINT),
|
||||
_VarintBytes(field_number),
|
||||
TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)])
|
||||
end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
|
||||
local_EncodeVarint = _EncodeVarint
|
||||
|
||||
def EncodeField(write, value, deterministic):
|
||||
write(start_bytes)
|
||||
local_EncodeVarint(write, value.ByteSize(), deterministic)
|
||||
value._InternalSerialize(write, deterministic)
|
||||
return write(end_bytes)
|
||||
|
||||
return EncodeField
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# As before, Map is special.
|
||||
|
||||
|
||||
def MapEncoder(field_descriptor):
|
||||
"""Encoder for extensions of MessageSet.
|
||||
|
||||
Maps always have a wire format like this:
|
||||
message MapEntry {
|
||||
key_type key = 1;
|
||||
value_type value = 2;
|
||||
}
|
||||
repeated MapEntry map = N;
|
||||
"""
|
||||
# Can't look at field_descriptor.message_type._concrete_class because it may
|
||||
# not have been initialized yet.
|
||||
message_type = field_descriptor.message_type
|
||||
encode_message = MessageEncoder(field_descriptor.number, False, False)
|
||||
|
||||
def EncodeField(write, value, deterministic):
|
||||
value_keys = sorted(value.keys()) if deterministic else value
|
||||
for key in value_keys:
|
||||
entry_msg = message_type._concrete_class(key=key, value=value[key])
|
||||
encode_message(write, entry_msg, deterministic)
|
||||
|
||||
return EncodeField
|
||||
112
.venv/lib/python3.10/site-packages/google/protobuf/internal/enum_type_wrapper.py
Executable file
112
.venv/lib/python3.10/site-packages/google/protobuf/internal/enum_type_wrapper.py
Executable file
@@ -0,0 +1,112 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""A simple wrapper around enum types to expose utility functions.
|
||||
|
||||
Instances are created as properties with the same name as the enum they wrap
|
||||
on proto classes. For usage, see:
|
||||
reflection_test.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
__author__ = 'rabsatt@google.com (Kevin Rabsatt)'
|
||||
|
||||
|
||||
class EnumTypeWrapper(object):
|
||||
"""A utility for finding the names of enum values."""
|
||||
|
||||
DESCRIPTOR = None
|
||||
|
||||
# This is a type alias, which mypy typing stubs can type as
|
||||
# a genericized parameter constrained to an int, allowing subclasses
|
||||
# to be typed with more constraint in .pyi stubs
|
||||
# Eg.
|
||||
# def MyGeneratedEnum(Message):
|
||||
# ValueType = NewType('ValueType', int)
|
||||
# def Name(self, number: MyGeneratedEnum.ValueType) -> str
|
||||
ValueType = int
|
||||
|
||||
def __init__(self, enum_type):
|
||||
"""Inits EnumTypeWrapper with an EnumDescriptor."""
|
||||
self._enum_type = enum_type
|
||||
self.DESCRIPTOR = enum_type # pylint: disable=invalid-name
|
||||
|
||||
def Name(self, number): # pylint: disable=invalid-name
|
||||
"""Returns a string containing the name of an enum value."""
|
||||
try:
|
||||
return self._enum_type.values_by_number[number].name
|
||||
except KeyError:
|
||||
pass # fall out to break exception chaining
|
||||
|
||||
if not isinstance(number, int):
|
||||
raise TypeError(
|
||||
'Enum value for {} must be an int, but got {} {!r}.'.format(
|
||||
self._enum_type.name, type(number), number))
|
||||
else:
|
||||
# repr here to handle the odd case when you pass in a boolean.
|
||||
raise ValueError('Enum {} has no name defined for value {!r}'.format(
|
||||
self._enum_type.name, number))
|
||||
|
||||
def Value(self, name): # pylint: disable=invalid-name
|
||||
"""Returns the value corresponding to the given enum name."""
|
||||
try:
|
||||
return self._enum_type.values_by_name[name].number
|
||||
except KeyError:
|
||||
pass # fall out to break exception chaining
|
||||
raise ValueError('Enum {} has no value defined for name {!r}'.format(
|
||||
self._enum_type.name, name))
|
||||
|
||||
def keys(self):
|
||||
"""Return a list of the string names in the enum.
|
||||
|
||||
Returns:
|
||||
A list of strs, in the order they were defined in the .proto file.
|
||||
"""
|
||||
|
||||
return [value_descriptor.name
|
||||
for value_descriptor in self._enum_type.values]
|
||||
|
||||
def values(self):
|
||||
"""Return a list of the integer values in the enum.
|
||||
|
||||
Returns:
|
||||
A list of ints, in the order they were defined in the .proto file.
|
||||
"""
|
||||
|
||||
return [value_descriptor.number
|
||||
for value_descriptor in self._enum_type.values]
|
||||
|
||||
def items(self):
|
||||
"""Return a list of the (name, value) pairs of the enum.
|
||||
|
||||
Returns:
|
||||
A list of (str, int) pairs, in the order they were defined
|
||||
in the .proto file.
|
||||
"""
|
||||
return [(value_descriptor.name, value_descriptor.number)
|
||||
for value_descriptor in self._enum_type.values]
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Returns the value corresponding to the given enum name."""
|
||||
try:
|
||||
return super(
|
||||
EnumTypeWrapper,
|
||||
self).__getattribute__('_enum_type').values_by_name[name].number
|
||||
except KeyError:
|
||||
pass # fall out to break exception chaining
|
||||
raise AttributeError('Enum {} has no value defined for name {!r}'.format(
|
||||
self._enum_type.name, name))
|
||||
|
||||
def __or__(self, other):
|
||||
"""Returns the union type of self and other."""
|
||||
if sys.version_info >= (3, 10):
|
||||
return type(self) | other
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'You may not use | on EnumTypes (or classes) below python 3.10'
|
||||
)
|
||||
194
.venv/lib/python3.10/site-packages/google/protobuf/internal/extension_dict.py
Executable file
194
.venv/lib/python3.10/site-packages/google/protobuf/internal/extension_dict.py
Executable file
@@ -0,0 +1,194 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""Contains _ExtensionDict class to represent extensions.
|
||||
"""
|
||||
|
||||
from google.protobuf.internal import type_checkers
|
||||
from google.protobuf.descriptor import FieldDescriptor
|
||||
|
||||
|
||||
def _VerifyExtensionHandle(message, extension_handle):
|
||||
"""Verify that the given extension handle is valid."""
|
||||
|
||||
if not isinstance(extension_handle, FieldDescriptor):
|
||||
raise KeyError('HasExtension() expects an extension handle, got: %s' %
|
||||
extension_handle)
|
||||
|
||||
if not extension_handle.is_extension:
|
||||
raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
|
||||
|
||||
if not extension_handle.containing_type:
|
||||
raise KeyError('"%s" is missing a containing_type.'
|
||||
% extension_handle.full_name)
|
||||
|
||||
if extension_handle.containing_type is not message.DESCRIPTOR:
|
||||
raise KeyError('Extension "%s" extends message type "%s", but this '
|
||||
'message is of type "%s".' %
|
||||
(extension_handle.full_name,
|
||||
extension_handle.containing_type.full_name,
|
||||
message.DESCRIPTOR.full_name))
|
||||
|
||||
|
||||
# TODO: Unify error handling of "unknown extension" crap.
|
||||
# TODO: Support iteritems()-style iteration over all
|
||||
# extensions with the "has" bits turned on?
|
||||
class _ExtensionDict(object):
|
||||
|
||||
"""Dict-like container for Extension fields on proto instances.
|
||||
|
||||
Note that in all cases we expect extension handles to be
|
||||
FieldDescriptors.
|
||||
"""
|
||||
|
||||
def __init__(self, extended_message):
|
||||
"""
|
||||
Args:
|
||||
extended_message: Message instance for which we are the Extensions dict.
|
||||
"""
|
||||
self._extended_message = extended_message
|
||||
|
||||
def __getitem__(self, extension_handle):
|
||||
"""Returns the current value of the given extension handle."""
|
||||
|
||||
_VerifyExtensionHandle(self._extended_message, extension_handle)
|
||||
|
||||
result = self._extended_message._fields.get(extension_handle)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
if extension_handle.label == FieldDescriptor.LABEL_REPEATED:
|
||||
result = extension_handle._default_constructor(self._extended_message)
|
||||
elif extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
|
||||
message_type = extension_handle.message_type
|
||||
if not hasattr(message_type, '_concrete_class'):
|
||||
# pylint: disable=g-import-not-at-top
|
||||
from google.protobuf import message_factory
|
||||
message_factory.GetMessageClass(message_type)
|
||||
if not hasattr(extension_handle.message_type, '_concrete_class'):
|
||||
from google.protobuf import message_factory
|
||||
message_factory.GetMessageClass(extension_handle.message_type)
|
||||
result = extension_handle.message_type._concrete_class()
|
||||
try:
|
||||
result._SetListener(self._extended_message._listener_for_children)
|
||||
except ReferenceError:
|
||||
pass
|
||||
else:
|
||||
# Singular scalar -- just return the default without inserting into the
|
||||
# dict.
|
||||
return extension_handle.default_value
|
||||
|
||||
# Atomically check if another thread has preempted us and, if not, swap
|
||||
# in the new object we just created. If someone has preempted us, we
|
||||
# take that object and discard ours.
|
||||
# WARNING: We are relying on setdefault() being atomic. This is true
|
||||
# in CPython but we haven't investigated others. This warning appears
|
||||
# in several other locations in this file.
|
||||
result = self._extended_message._fields.setdefault(
|
||||
extension_handle, result)
|
||||
|
||||
return result
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
my_fields = self._extended_message.ListFields()
|
||||
other_fields = other._extended_message.ListFields()
|
||||
|
||||
# Get rid of non-extension fields.
|
||||
my_fields = [field for field in my_fields if field.is_extension]
|
||||
other_fields = [field for field in other_fields if field.is_extension]
|
||||
|
||||
return my_fields == other_fields
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
||||
def __len__(self):
|
||||
fields = self._extended_message.ListFields()
|
||||
# Get rid of non-extension fields.
|
||||
extension_fields = [field for field in fields if field[0].is_extension]
|
||||
return len(extension_fields)
|
||||
|
||||
def __hash__(self):
|
||||
raise TypeError('unhashable object')
|
||||
|
||||
# Note that this is only meaningful for non-repeated, scalar extension
|
||||
# fields. Note also that we may have to call _Modified() when we do
|
||||
# successfully set a field this way, to set any necessary "has" bits in the
|
||||
# ancestors of the extended message.
|
||||
def __setitem__(self, extension_handle, value):
|
||||
"""If extension_handle specifies a non-repeated, scalar extension
|
||||
field, sets the value of that field.
|
||||
"""
|
||||
|
||||
_VerifyExtensionHandle(self._extended_message, extension_handle)
|
||||
|
||||
if (extension_handle.label == FieldDescriptor.LABEL_REPEATED or
|
||||
extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE):
|
||||
raise TypeError(
|
||||
'Cannot assign to extension "%s" because it is a repeated or '
|
||||
'composite type.' % extension_handle.full_name)
|
||||
|
||||
# It's slightly wasteful to lookup the type checker each time,
|
||||
# but we expect this to be a vanishingly uncommon case anyway.
|
||||
type_checker = type_checkers.GetTypeChecker(extension_handle)
|
||||
# pylint: disable=protected-access
|
||||
self._extended_message._fields[extension_handle] = (
|
||||
type_checker.CheckValue(value))
|
||||
self._extended_message._Modified()
|
||||
|
||||
def __delitem__(self, extension_handle):
|
||||
self._extended_message.ClearExtension(extension_handle)
|
||||
|
||||
def _FindExtensionByName(self, name):
|
||||
"""Tries to find a known extension with the specified name.
|
||||
|
||||
Args:
|
||||
name: Extension full name.
|
||||
|
||||
Returns:
|
||||
Extension field descriptor.
|
||||
"""
|
||||
descriptor = self._extended_message.DESCRIPTOR
|
||||
extensions = descriptor.file.pool._extensions_by_name[descriptor]
|
||||
return extensions.get(name, None)
|
||||
|
||||
def _FindExtensionByNumber(self, number):
|
||||
"""Tries to find a known extension with the field number.
|
||||
|
||||
Args:
|
||||
number: Extension field number.
|
||||
|
||||
Returns:
|
||||
Extension field descriptor.
|
||||
"""
|
||||
descriptor = self._extended_message.DESCRIPTOR
|
||||
extensions = descriptor.file.pool._extensions_by_number[descriptor]
|
||||
return extensions.get(number, None)
|
||||
|
||||
def __iter__(self):
|
||||
# Return a generator over the populated extension fields
|
||||
return (f[0] for f in self._extended_message.ListFields()
|
||||
if f[0].is_extension)
|
||||
|
||||
def __contains__(self, extension_handle):
|
||||
_VerifyExtensionHandle(self._extended_message, extension_handle)
|
||||
|
||||
if extension_handle not in self._extended_message._fields:
|
||||
return False
|
||||
|
||||
if extension_handle.label == FieldDescriptor.LABEL_REPEATED:
|
||||
return bool(self._extended_message._fields.get(extension_handle))
|
||||
|
||||
if extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
|
||||
value = self._extended_message._fields.get(extension_handle)
|
||||
# pylint: disable=protected-access
|
||||
return value is not None and value._is_present_in_parent
|
||||
|
||||
return True
|
||||
310
.venv/lib/python3.10/site-packages/google/protobuf/internal/field_mask.py
Executable file
310
.venv/lib/python3.10/site-packages/google/protobuf/internal/field_mask.py
Executable file
@@ -0,0 +1,310 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""Contains FieldMask class."""
|
||||
|
||||
from google.protobuf.descriptor import FieldDescriptor
|
||||
|
||||
|
||||
class FieldMask(object):
|
||||
"""Class for FieldMask message type."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def ToJsonString(self):
|
||||
"""Converts FieldMask to string according to proto3 JSON spec."""
|
||||
camelcase_paths = []
|
||||
for path in self.paths:
|
||||
camelcase_paths.append(_SnakeCaseToCamelCase(path))
|
||||
return ','.join(camelcase_paths)
|
||||
|
||||
def FromJsonString(self, value):
|
||||
"""Converts string to FieldMask according to proto3 JSON spec."""
|
||||
if not isinstance(value, str):
|
||||
raise ValueError('FieldMask JSON value not a string: {!r}'.format(value))
|
||||
self.Clear()
|
||||
if value:
|
||||
for path in value.split(','):
|
||||
self.paths.append(_CamelCaseToSnakeCase(path))
|
||||
|
||||
def IsValidForDescriptor(self, message_descriptor):
|
||||
"""Checks whether the FieldMask is valid for Message Descriptor."""
|
||||
for path in self.paths:
|
||||
if not _IsValidPath(message_descriptor, path):
|
||||
return False
|
||||
return True
|
||||
|
||||
def AllFieldsFromDescriptor(self, message_descriptor):
|
||||
"""Gets all direct fields of Message Descriptor to FieldMask."""
|
||||
self.Clear()
|
||||
for field in message_descriptor.fields:
|
||||
self.paths.append(field.name)
|
||||
|
||||
def CanonicalFormFromMask(self, mask):
|
||||
"""Converts a FieldMask to the canonical form.
|
||||
|
||||
Removes paths that are covered by another path. For example,
|
||||
"foo.bar" is covered by "foo" and will be removed if "foo"
|
||||
is also in the FieldMask. Then sorts all paths in alphabetical order.
|
||||
|
||||
Args:
|
||||
mask: The original FieldMask to be converted.
|
||||
"""
|
||||
tree = _FieldMaskTree(mask)
|
||||
tree.ToFieldMask(self)
|
||||
|
||||
def Union(self, mask1, mask2):
|
||||
"""Merges mask1 and mask2 into this FieldMask."""
|
||||
_CheckFieldMaskMessage(mask1)
|
||||
_CheckFieldMaskMessage(mask2)
|
||||
tree = _FieldMaskTree(mask1)
|
||||
tree.MergeFromFieldMask(mask2)
|
||||
tree.ToFieldMask(self)
|
||||
|
||||
def Intersect(self, mask1, mask2):
|
||||
"""Intersects mask1 and mask2 into this FieldMask."""
|
||||
_CheckFieldMaskMessage(mask1)
|
||||
_CheckFieldMaskMessage(mask2)
|
||||
tree = _FieldMaskTree(mask1)
|
||||
intersection = _FieldMaskTree()
|
||||
for path in mask2.paths:
|
||||
tree.IntersectPath(path, intersection)
|
||||
intersection.ToFieldMask(self)
|
||||
|
||||
def MergeMessage(
|
||||
self, source, destination,
|
||||
replace_message_field=False, replace_repeated_field=False):
|
||||
"""Merges fields specified in FieldMask from source to destination.
|
||||
|
||||
Args:
|
||||
source: Source message.
|
||||
destination: The destination message to be merged into.
|
||||
replace_message_field: Replace message field if True. Merge message
|
||||
field if False.
|
||||
replace_repeated_field: Replace repeated field if True. Append
|
||||
elements of repeated field if False.
|
||||
"""
|
||||
tree = _FieldMaskTree(self)
|
||||
tree.MergeMessage(
|
||||
source, destination, replace_message_field, replace_repeated_field)
|
||||
|
||||
|
||||
def _IsValidPath(message_descriptor, path):
|
||||
"""Checks whether the path is valid for Message Descriptor."""
|
||||
parts = path.split('.')
|
||||
last = parts.pop()
|
||||
for name in parts:
|
||||
field = message_descriptor.fields_by_name.get(name)
|
||||
if (field is None or
|
||||
field.label == FieldDescriptor.LABEL_REPEATED or
|
||||
field.type != FieldDescriptor.TYPE_MESSAGE):
|
||||
return False
|
||||
message_descriptor = field.message_type
|
||||
return last in message_descriptor.fields_by_name
|
||||
|
||||
|
||||
def _CheckFieldMaskMessage(message):
|
||||
"""Raises ValueError if message is not a FieldMask."""
|
||||
message_descriptor = message.DESCRIPTOR
|
||||
if (message_descriptor.name != 'FieldMask' or
|
||||
message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
|
||||
raise ValueError('Message {0} is not a FieldMask.'.format(
|
||||
message_descriptor.full_name))
|
||||
|
||||
|
||||
def _SnakeCaseToCamelCase(path_name):
|
||||
"""Converts a path name from snake_case to camelCase."""
|
||||
result = []
|
||||
after_underscore = False
|
||||
for c in path_name:
|
||||
if c.isupper():
|
||||
raise ValueError(
|
||||
'Fail to print FieldMask to Json string: Path name '
|
||||
'{0} must not contain uppercase letters.'.format(path_name))
|
||||
if after_underscore:
|
||||
if c.islower():
|
||||
result.append(c.upper())
|
||||
after_underscore = False
|
||||
else:
|
||||
raise ValueError(
|
||||
'Fail to print FieldMask to Json string: The '
|
||||
'character after a "_" must be a lowercase letter '
|
||||
'in path name {0}.'.format(path_name))
|
||||
elif c == '_':
|
||||
after_underscore = True
|
||||
else:
|
||||
result += c
|
||||
|
||||
if after_underscore:
|
||||
raise ValueError('Fail to print FieldMask to Json string: Trailing "_" '
|
||||
'in path name {0}.'.format(path_name))
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
def _CamelCaseToSnakeCase(path_name):
|
||||
"""Converts a field name from camelCase to snake_case."""
|
||||
result = []
|
||||
for c in path_name:
|
||||
if c == '_':
|
||||
raise ValueError('Fail to parse FieldMask: Path name '
|
||||
'{0} must not contain "_"s.'.format(path_name))
|
||||
if c.isupper():
|
||||
result += '_'
|
||||
result += c.lower()
|
||||
else:
|
||||
result += c
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
class _FieldMaskTree(object):
|
||||
"""Represents a FieldMask in a tree structure.
|
||||
|
||||
For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
|
||||
the FieldMaskTree will be:
|
||||
[_root] -+- foo -+- bar
|
||||
| |
|
||||
| +- baz
|
||||
|
|
||||
+- bar --- baz
|
||||
In the tree, each leaf node represents a field path.
|
||||
"""
|
||||
|
||||
__slots__ = ('_root',)
|
||||
|
||||
def __init__(self, field_mask=None):
|
||||
"""Initializes the tree by FieldMask."""
|
||||
self._root = {}
|
||||
if field_mask:
|
||||
self.MergeFromFieldMask(field_mask)
|
||||
|
||||
def MergeFromFieldMask(self, field_mask):
|
||||
"""Merges a FieldMask to the tree."""
|
||||
for path in field_mask.paths:
|
||||
self.AddPath(path)
|
||||
|
||||
def AddPath(self, path):
|
||||
"""Adds a field path into the tree.
|
||||
|
||||
If the field path to add is a sub-path of an existing field path
|
||||
in the tree (i.e., a leaf node), it means the tree already matches
|
||||
the given path so nothing will be added to the tree. If the path
|
||||
matches an existing non-leaf node in the tree, that non-leaf node
|
||||
will be turned into a leaf node with all its children removed because
|
||||
the path matches all the node's children. Otherwise, a new path will
|
||||
be added.
|
||||
|
||||
Args:
|
||||
path: The field path to add.
|
||||
"""
|
||||
node = self._root
|
||||
for name in path.split('.'):
|
||||
if name not in node:
|
||||
node[name] = {}
|
||||
elif not node[name]:
|
||||
# Pre-existing empty node implies we already have this entire tree.
|
||||
return
|
||||
node = node[name]
|
||||
# Remove any sub-trees we might have had.
|
||||
node.clear()
|
||||
|
||||
def ToFieldMask(self, field_mask):
|
||||
"""Converts the tree to a FieldMask."""
|
||||
field_mask.Clear()
|
||||
_AddFieldPaths(self._root, '', field_mask)
|
||||
|
||||
def IntersectPath(self, path, intersection):
|
||||
"""Calculates the intersection part of a field path with this tree.
|
||||
|
||||
Args:
|
||||
path: The field path to calculates.
|
||||
intersection: The out tree to record the intersection part.
|
||||
"""
|
||||
node = self._root
|
||||
for name in path.split('.'):
|
||||
if name not in node:
|
||||
return
|
||||
elif not node[name]:
|
||||
intersection.AddPath(path)
|
||||
return
|
||||
node = node[name]
|
||||
intersection.AddLeafNodes(path, node)
|
||||
|
||||
def AddLeafNodes(self, prefix, node):
|
||||
"""Adds leaf nodes begin with prefix to this tree."""
|
||||
if not node:
|
||||
self.AddPath(prefix)
|
||||
for name in node:
|
||||
child_path = prefix + '.' + name
|
||||
self.AddLeafNodes(child_path, node[name])
|
||||
|
||||
def MergeMessage(
|
||||
self, source, destination,
|
||||
replace_message, replace_repeated):
|
||||
"""Merge all fields specified by this tree from source to destination."""
|
||||
_MergeMessage(
|
||||
self._root, source, destination, replace_message, replace_repeated)
|
||||
|
||||
|
||||
def _StrConvert(value):
|
||||
"""Converts value to str if it is not."""
|
||||
# This file is imported by c extension and some methods like ClearField
|
||||
# requires string for the field name. py2/py3 has different text
|
||||
# type and may use unicode.
|
||||
if not isinstance(value, str):
|
||||
return value.encode('utf-8')
|
||||
return value
|
||||
|
||||
|
||||
def _MergeMessage(
|
||||
node, source, destination, replace_message, replace_repeated):
|
||||
"""Merge all fields specified by a sub-tree from source to destination."""
|
||||
source_descriptor = source.DESCRIPTOR
|
||||
for name in node:
|
||||
child = node[name]
|
||||
field = source_descriptor.fields_by_name[name]
|
||||
if field is None:
|
||||
raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
|
||||
name, source_descriptor.full_name))
|
||||
if child:
|
||||
# Sub-paths are only allowed for singular message fields.
|
||||
if (field.label == FieldDescriptor.LABEL_REPEATED or
|
||||
field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
|
||||
raise ValueError('Error: Field {0} in message {1} is not a singular '
|
||||
'message field and cannot have sub-fields.'.format(
|
||||
name, source_descriptor.full_name))
|
||||
if source.HasField(name):
|
||||
_MergeMessage(
|
||||
child, getattr(source, name), getattr(destination, name),
|
||||
replace_message, replace_repeated)
|
||||
continue
|
||||
if field.label == FieldDescriptor.LABEL_REPEATED:
|
||||
if replace_repeated:
|
||||
destination.ClearField(_StrConvert(name))
|
||||
repeated_source = getattr(source, name)
|
||||
repeated_destination = getattr(destination, name)
|
||||
repeated_destination.MergeFrom(repeated_source)
|
||||
else:
|
||||
if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
|
||||
if replace_message:
|
||||
destination.ClearField(_StrConvert(name))
|
||||
if source.HasField(name):
|
||||
getattr(destination, name).MergeFrom(getattr(source, name))
|
||||
else:
|
||||
setattr(destination, name, getattr(source, name))
|
||||
|
||||
|
||||
def _AddFieldPaths(node, prefix, field_mask):
|
||||
"""Adds the field paths descended from node to field_mask."""
|
||||
if not node and prefix:
|
||||
field_mask.paths.append(prefix)
|
||||
return
|
||||
for name in sorted(node):
|
||||
if prefix:
|
||||
child_path = prefix + '.' + name
|
||||
else:
|
||||
child_path = name
|
||||
_AddFieldPaths(node[name], child_path, field_mask)
|
||||
@@ -0,0 +1,55 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""Defines a listener interface for observing certain
|
||||
state transitions on Message objects.
|
||||
|
||||
Also defines a null implementation of this interface.
|
||||
"""
|
||||
|
||||
__author__ = 'robinson@google.com (Will Robinson)'
|
||||
|
||||
|
||||
class MessageListener(object):
|
||||
|
||||
"""Listens for modifications made to a message. Meant to be registered via
|
||||
Message._SetListener().
|
||||
|
||||
Attributes:
|
||||
dirty: If True, then calling Modified() would be a no-op. This can be
|
||||
used to avoid these calls entirely in the common case.
|
||||
"""
|
||||
|
||||
def Modified(self):
|
||||
"""Called every time the message is modified in such a way that the parent
|
||||
message may need to be updated. This currently means either:
|
||||
(a) The message was modified for the first time, so the parent message
|
||||
should henceforth mark the message as present.
|
||||
(b) The message's cached byte size became dirty -- i.e. the message was
|
||||
modified for the first time after a previous call to ByteSize().
|
||||
Therefore the parent should also mark its byte size as dirty.
|
||||
Note that (a) implies (b), since new objects start out with a client cached
|
||||
size (zero). However, we document (a) explicitly because it is important.
|
||||
|
||||
Modified() will *only* be called in response to one of these two events --
|
||||
not every time the sub-message is modified.
|
||||
|
||||
Note that if the listener's |dirty| attribute is true, then calling
|
||||
Modified at the moment would be a no-op, so it can be skipped. Performance-
|
||||
sensitive callers should check this attribute directly before calling since
|
||||
it will be true most of the time.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NullMessageListener(object):
|
||||
|
||||
"""No-op MessageListener implementation."""
|
||||
|
||||
def Modified(self):
|
||||
pass
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
This file contains the serialized FeatureSetDefaults object corresponding to
|
||||
the Pure Python runtime. This is used for feature resolution under Editions.
|
||||
"""
|
||||
_PROTOBUF_INTERNAL_PYTHON_EDITION_DEFAULTS = b"\n\025\030\204\007\"\000*\016\010\001\020\002\030\002 \003(\0010\0028\002\n\025\030\347\007\"\000*\016\010\002\020\001\030\001 \002(\0010\0018\002\n\025\030\350\007\"\014\010\001\020\001\030\001 \002(\0010\001*\0028\002 \346\007(\350\007"
|
||||
1580
.venv/lib/python3.10/site-packages/google/protobuf/internal/python_message.py
Executable file
1580
.venv/lib/python3.10/site-packages/google/protobuf/internal/python_message.py
Executable file
File diff suppressed because it is too large
Load Diff
122
.venv/lib/python3.10/site-packages/google/protobuf/internal/testing_refleaks.py
Executable file
122
.venv/lib/python3.10/site-packages/google/protobuf/internal/testing_refleaks.py
Executable file
@@ -0,0 +1,122 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""A subclass of unittest.TestCase which checks for reference leaks.
|
||||
|
||||
To use:
|
||||
- Use testing_refleak.BaseTestCase instead of unittest.TestCase
|
||||
- Configure and compile Python with --with-pydebug
|
||||
|
||||
If sys.gettotalrefcount() is not available (because Python was built without
|
||||
the Py_DEBUG option), then this module is a no-op and tests will run normally.
|
||||
"""
|
||||
|
||||
import copyreg
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
|
||||
class LocalTestResult(unittest.TestResult):
|
||||
"""A TestResult which forwards events to a parent object, except for Skips."""
|
||||
|
||||
def __init__(self, parent_result):
|
||||
unittest.TestResult.__init__(self)
|
||||
self.parent_result = parent_result
|
||||
|
||||
def addError(self, test, error):
|
||||
self.parent_result.addError(test, error)
|
||||
|
||||
def addFailure(self, test, error):
|
||||
self.parent_result.addFailure(test, error)
|
||||
|
||||
def addSkip(self, test, reason):
|
||||
pass
|
||||
|
||||
def addDuration(self, test, duration):
|
||||
pass
|
||||
|
||||
|
||||
class ReferenceLeakCheckerMixin(object):
|
||||
"""A mixin class for TestCase, which checks reference counts."""
|
||||
|
||||
NB_RUNS = 3
|
||||
|
||||
def run(self, result=None):
|
||||
testMethod = getattr(self, self._testMethodName)
|
||||
expecting_failure_method = getattr(testMethod, "__unittest_expecting_failure__", False)
|
||||
expecting_failure_class = getattr(self, "__unittest_expecting_failure__", False)
|
||||
if expecting_failure_class or expecting_failure_method:
|
||||
return
|
||||
|
||||
# python_message.py registers all Message classes to some pickle global
|
||||
# registry, which makes the classes immortal.
|
||||
# We save a copy of this registry, and reset it before we could references.
|
||||
self._saved_pickle_registry = copyreg.dispatch_table.copy()
|
||||
|
||||
# Run the test twice, to warm up the instance attributes.
|
||||
super(ReferenceLeakCheckerMixin, self).run(result=result)
|
||||
super(ReferenceLeakCheckerMixin, self).run(result=result)
|
||||
|
||||
oldrefcount = 0 # pylint: disable=unused-variable but needed for refcounts.
|
||||
local_result = LocalTestResult(result)
|
||||
num_flakes = 0
|
||||
|
||||
refcount_deltas = []
|
||||
while len(refcount_deltas) < self.NB_RUNS:
|
||||
oldrefcount = self._getRefcounts()
|
||||
super(ReferenceLeakCheckerMixin, self).run(result=local_result)
|
||||
newrefcount = self._getRefcounts()
|
||||
# If the GC was able to collect some objects after the call to run() that
|
||||
# it could not collect before the call, then the counts won't match.
|
||||
if newrefcount < oldrefcount and num_flakes < 2:
|
||||
# This result is (probably) a flake -- garbage collectors aren't very
|
||||
# predictable, but a lower ending refcount is the opposite of the
|
||||
# failure we are testing for. If the result is repeatable, then we will
|
||||
# eventually report it, but not after trying to eliminate it.
|
||||
num_flakes += 1
|
||||
continue
|
||||
num_flakes = 0
|
||||
refcount_deltas.append(newrefcount - oldrefcount)
|
||||
print(refcount_deltas, self)
|
||||
|
||||
try:
|
||||
self.assertEqual(refcount_deltas, [0] * self.NB_RUNS)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
result.addError(self, sys.exc_info())
|
||||
|
||||
def _getRefcounts(self):
|
||||
copyreg.dispatch_table.clear()
|
||||
copyreg.dispatch_table.update(self._saved_pickle_registry)
|
||||
# It is sometimes necessary to gc.collect() multiple times, to ensure
|
||||
# that all objects can be collected.
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
return sys.gettotalrefcount()
|
||||
|
||||
|
||||
if hasattr(sys, 'gettotalrefcount'):
|
||||
|
||||
def TestCase(test_class):
|
||||
new_bases = (ReferenceLeakCheckerMixin,) + test_class.__bases__
|
||||
new_class = type(test_class)(
|
||||
test_class.__name__, new_bases, dict(test_class.__dict__))
|
||||
return new_class
|
||||
SkipReferenceLeakChecker = unittest.skip
|
||||
|
||||
else:
|
||||
# When PyDEBUG is not enabled, run the tests normally.
|
||||
|
||||
def TestCase(test_class):
|
||||
return test_class
|
||||
|
||||
def SkipReferenceLeakChecker(reason):
|
||||
del reason # Don't skip, so don't need a reason.
|
||||
def Same(func):
|
||||
return func
|
||||
return Same
|
||||
413
.venv/lib/python3.10/site-packages/google/protobuf/internal/type_checkers.py
Executable file
413
.venv/lib/python3.10/site-packages/google/protobuf/internal/type_checkers.py
Executable file
@@ -0,0 +1,413 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""Provides type checking routines.
|
||||
|
||||
This module defines type checking utilities in the forms of dictionaries:
|
||||
|
||||
VALUE_CHECKERS: A dictionary of field types and a value validation object.
|
||||
TYPE_TO_BYTE_SIZE_FN: A dictionary with field types and a size computing
|
||||
function.
|
||||
TYPE_TO_SERIALIZE_METHOD: A dictionary with field types and serialization
|
||||
function.
|
||||
FIELD_TYPE_TO_WIRE_TYPE: A dictionary with field typed and their
|
||||
corresponding wire types.
|
||||
TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization
|
||||
function.
|
||||
"""
|
||||
|
||||
__author__ = 'robinson@google.com (Will Robinson)'
|
||||
|
||||
import struct
|
||||
import numbers
|
||||
|
||||
from google.protobuf.internal import decoder
|
||||
from google.protobuf.internal import encoder
|
||||
from google.protobuf.internal import wire_format
|
||||
from google.protobuf import descriptor
|
||||
|
||||
_FieldDescriptor = descriptor.FieldDescriptor
|
||||
|
||||
|
||||
def TruncateToFourByteFloat(original):
|
||||
return struct.unpack('<f', struct.pack('<f', original))[0]
|
||||
|
||||
|
||||
def ToShortestFloat(original):
|
||||
"""Returns the shortest float that has same value in wire."""
|
||||
# All 4 byte floats have between 6 and 9 significant digits, so we
|
||||
# start with 6 as the lower bound.
|
||||
# It has to be iterative because use '.9g' directly can not get rid
|
||||
# of the noises for most values. For example if set a float_field=0.9
|
||||
# use '.9g' will print 0.899999976.
|
||||
precision = 6
|
||||
rounded = float('{0:.{1}g}'.format(original, precision))
|
||||
while TruncateToFourByteFloat(rounded) != original:
|
||||
precision += 1
|
||||
rounded = float('{0:.{1}g}'.format(original, precision))
|
||||
return rounded
|
||||
|
||||
|
||||
def GetTypeChecker(field):
|
||||
"""Returns a type checker for a message field of the specified types.
|
||||
|
||||
Args:
|
||||
field: FieldDescriptor object for this field.
|
||||
|
||||
Returns:
|
||||
An instance of TypeChecker which can be used to verify the types
|
||||
of values assigned to a field of the specified type.
|
||||
"""
|
||||
if (field.cpp_type == _FieldDescriptor.CPPTYPE_STRING and
|
||||
field.type == _FieldDescriptor.TYPE_STRING):
|
||||
return UnicodeValueChecker()
|
||||
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
|
||||
if field.enum_type.is_closed:
|
||||
return EnumValueChecker(field.enum_type)
|
||||
else:
|
||||
# When open enums are supported, any int32 can be assigned.
|
||||
return _VALUE_CHECKERS[_FieldDescriptor.CPPTYPE_INT32]
|
||||
return _VALUE_CHECKERS[field.cpp_type]
|
||||
|
||||
|
||||
# None of the typecheckers below make any attempt to guard against people
|
||||
# subclassing builtin types and doing weird things. We're not trying to
|
||||
# protect against malicious clients here, just people accidentally shooting
|
||||
# themselves in the foot in obvious ways.
|
||||
class TypeChecker(object):
|
||||
|
||||
"""Type checker used to catch type errors as early as possible
|
||||
when the client is setting scalar fields in protocol messages.
|
||||
"""
|
||||
|
||||
def __init__(self, *acceptable_types):
|
||||
self._acceptable_types = acceptable_types
|
||||
|
||||
def CheckValue(self, proposed_value):
|
||||
"""Type check the provided value and return it.
|
||||
|
||||
The returned value might have been normalized to another type.
|
||||
"""
|
||||
if not isinstance(proposed_value, self._acceptable_types):
|
||||
message = ('%.1024r has type %s, but expected one of: %s' %
|
||||
(proposed_value, type(proposed_value), self._acceptable_types))
|
||||
raise TypeError(message)
|
||||
return proposed_value
|
||||
|
||||
|
||||
class TypeCheckerWithDefault(TypeChecker):
|
||||
|
||||
def __init__(self, default_value, *acceptable_types):
|
||||
TypeChecker.__init__(self, *acceptable_types)
|
||||
self._default_value = default_value
|
||||
|
||||
def DefaultValue(self):
|
||||
return self._default_value
|
||||
|
||||
|
||||
class BoolValueChecker(object):
|
||||
"""Type checker used for bool fields."""
|
||||
|
||||
def CheckValue(self, proposed_value):
|
||||
if not hasattr(proposed_value, '__index__') or (
|
||||
type(proposed_value).__module__ == 'numpy' and
|
||||
type(proposed_value).__name__ == 'ndarray'):
|
||||
message = ('%.1024r has type %s, but expected one of: %s' %
|
||||
(proposed_value, type(proposed_value), (bool, int)))
|
||||
raise TypeError(message)
|
||||
return bool(proposed_value)
|
||||
|
||||
def DefaultValue(self):
|
||||
return False
|
||||
|
||||
|
||||
# IntValueChecker and its subclasses perform integer type-checks
|
||||
# and bounds-checks.
|
||||
class IntValueChecker(object):
|
||||
|
||||
"""Checker used for integer fields. Performs type-check and range check."""
|
||||
|
||||
def CheckValue(self, proposed_value):
|
||||
if not hasattr(proposed_value, '__index__') or (
|
||||
type(proposed_value).__module__ == 'numpy' and
|
||||
type(proposed_value).__name__ == 'ndarray'):
|
||||
message = ('%.1024r has type %s, but expected one of: %s' %
|
||||
(proposed_value, type(proposed_value), (int,)))
|
||||
raise TypeError(message)
|
||||
|
||||
if not self._MIN <= int(proposed_value) <= self._MAX:
|
||||
raise ValueError('Value out of range: %d' % proposed_value)
|
||||
# We force all values to int to make alternate implementations where the
|
||||
# distinction is more significant (e.g. the C++ implementation) simpler.
|
||||
proposed_value = int(proposed_value)
|
||||
return proposed_value
|
||||
|
||||
def DefaultValue(self):
|
||||
return 0
|
||||
|
||||
|
||||
class EnumValueChecker(object):
|
||||
|
||||
"""Checker used for enum fields. Performs type-check and range check."""
|
||||
|
||||
def __init__(self, enum_type):
|
||||
self._enum_type = enum_type
|
||||
|
||||
def CheckValue(self, proposed_value):
|
||||
if not isinstance(proposed_value, numbers.Integral):
|
||||
message = ('%.1024r has type %s, but expected one of: %s' %
|
||||
(proposed_value, type(proposed_value), (int,)))
|
||||
raise TypeError(message)
|
||||
if int(proposed_value) not in self._enum_type.values_by_number:
|
||||
raise ValueError('Unknown enum value: %d' % proposed_value)
|
||||
return proposed_value
|
||||
|
||||
def DefaultValue(self):
|
||||
return self._enum_type.values[0].number
|
||||
|
||||
|
||||
class UnicodeValueChecker(object):
|
||||
|
||||
"""Checker used for string fields.
|
||||
|
||||
Always returns a unicode value, even if the input is of type str.
|
||||
"""
|
||||
|
||||
def CheckValue(self, proposed_value):
|
||||
if not isinstance(proposed_value, (bytes, str)):
|
||||
message = ('%.1024r has type %s, but expected one of: %s' %
|
||||
(proposed_value, type(proposed_value), (bytes, str)))
|
||||
raise TypeError(message)
|
||||
|
||||
# If the value is of type 'bytes' make sure that it is valid UTF-8 data.
|
||||
if isinstance(proposed_value, bytes):
|
||||
try:
|
||||
proposed_value = proposed_value.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError('%.1024r has type bytes, but isn\'t valid UTF-8 '
|
||||
'encoding. Non-UTF-8 strings must be converted to '
|
||||
'unicode objects before being added.' %
|
||||
(proposed_value))
|
||||
else:
|
||||
try:
|
||||
proposed_value.encode('utf8')
|
||||
except UnicodeEncodeError:
|
||||
raise ValueError('%.1024r isn\'t a valid unicode string and '
|
||||
'can\'t be encoded in UTF-8.'%
|
||||
(proposed_value))
|
||||
|
||||
return proposed_value
|
||||
|
||||
def DefaultValue(self):
|
||||
return u""
|
||||
|
||||
|
||||
class Int32ValueChecker(IntValueChecker):
|
||||
# We're sure to use ints instead of longs here since comparison may be more
|
||||
# efficient.
|
||||
_MIN = -2147483648
|
||||
_MAX = 2147483647
|
||||
|
||||
|
||||
class Uint32ValueChecker(IntValueChecker):
|
||||
_MIN = 0
|
||||
_MAX = (1 << 32) - 1
|
||||
|
||||
|
||||
class Int64ValueChecker(IntValueChecker):
|
||||
_MIN = -(1 << 63)
|
||||
_MAX = (1 << 63) - 1
|
||||
|
||||
|
||||
class Uint64ValueChecker(IntValueChecker):
|
||||
_MIN = 0
|
||||
_MAX = (1 << 64) - 1
|
||||
|
||||
|
||||
# The max 4 bytes float is about 3.4028234663852886e+38
|
||||
_FLOAT_MAX = float.fromhex('0x1.fffffep+127')
|
||||
_FLOAT_MIN = -_FLOAT_MAX
|
||||
_MAX_FLOAT_AS_DOUBLE_ROUNDED = 3.4028235677973366e38
|
||||
_INF = float('inf')
|
||||
_NEG_INF = float('-inf')
|
||||
|
||||
|
||||
class DoubleValueChecker(object):
|
||||
"""Checker used for double fields.
|
||||
|
||||
Performs type-check and range check.
|
||||
"""
|
||||
|
||||
def CheckValue(self, proposed_value):
|
||||
"""Check and convert proposed_value to float."""
|
||||
if (not hasattr(proposed_value, '__float__') and
|
||||
not hasattr(proposed_value, '__index__')) or (
|
||||
type(proposed_value).__module__ == 'numpy' and
|
||||
type(proposed_value).__name__ == 'ndarray'):
|
||||
message = ('%.1024r has type %s, but expected one of: int, float' %
|
||||
(proposed_value, type(proposed_value)))
|
||||
raise TypeError(message)
|
||||
return float(proposed_value)
|
||||
|
||||
def DefaultValue(self):
|
||||
return 0.0
|
||||
|
||||
|
||||
class FloatValueChecker(DoubleValueChecker):
|
||||
"""Checker used for float fields.
|
||||
|
||||
Performs type-check and range check.
|
||||
|
||||
Values exceeding a 32-bit float will be converted to inf/-inf.
|
||||
"""
|
||||
|
||||
def CheckValue(self, proposed_value):
|
||||
"""Check and convert proposed_value to float."""
|
||||
converted_value = super().CheckValue(proposed_value)
|
||||
# This inf rounding matches the C++ proto SafeDoubleToFloat logic.
|
||||
if converted_value > _FLOAT_MAX:
|
||||
if converted_value <= _MAX_FLOAT_AS_DOUBLE_ROUNDED:
|
||||
return _FLOAT_MAX
|
||||
return _INF
|
||||
if converted_value < _FLOAT_MIN:
|
||||
if converted_value >= -_MAX_FLOAT_AS_DOUBLE_ROUNDED:
|
||||
return _FLOAT_MIN
|
||||
return _NEG_INF
|
||||
|
||||
return TruncateToFourByteFloat(converted_value)
|
||||
|
||||
# Type-checkers for all scalar CPPTYPEs.
|
||||
_VALUE_CHECKERS = {
|
||||
_FieldDescriptor.CPPTYPE_INT32: Int32ValueChecker(),
|
||||
_FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(),
|
||||
_FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(),
|
||||
_FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(),
|
||||
_FieldDescriptor.CPPTYPE_DOUBLE: DoubleValueChecker(),
|
||||
_FieldDescriptor.CPPTYPE_FLOAT: FloatValueChecker(),
|
||||
_FieldDescriptor.CPPTYPE_BOOL: BoolValueChecker(),
|
||||
_FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes),
|
||||
}
|
||||
|
||||
|
||||
# Map from field type to a function F, such that F(field_num, value)
|
||||
# gives the total byte size for a value of the given type. This
|
||||
# byte size includes tag information and any other additional space
|
||||
# associated with serializing "value".
|
||||
TYPE_TO_BYTE_SIZE_FN = {
|
||||
_FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize,
|
||||
_FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize,
|
||||
_FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize,
|
||||
_FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize,
|
||||
_FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize,
|
||||
_FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize,
|
||||
_FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize,
|
||||
_FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize,
|
||||
_FieldDescriptor.TYPE_STRING: wire_format.StringByteSize,
|
||||
_FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize,
|
||||
_FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize,
|
||||
_FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize,
|
||||
_FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize,
|
||||
_FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize,
|
||||
_FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize,
|
||||
_FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize,
|
||||
_FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize,
|
||||
_FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize
|
||||
}
|
||||
|
||||
|
||||
# Maps from field types to encoder constructors.
|
||||
TYPE_TO_ENCODER = {
|
||||
_FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder,
|
||||
_FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder,
|
||||
_FieldDescriptor.TYPE_INT64: encoder.Int64Encoder,
|
||||
_FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder,
|
||||
_FieldDescriptor.TYPE_INT32: encoder.Int32Encoder,
|
||||
_FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder,
|
||||
_FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder,
|
||||
_FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder,
|
||||
_FieldDescriptor.TYPE_STRING: encoder.StringEncoder,
|
||||
_FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder,
|
||||
_FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder,
|
||||
_FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder,
|
||||
_FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder,
|
||||
_FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder,
|
||||
_FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder,
|
||||
_FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder,
|
||||
_FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder,
|
||||
_FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder,
|
||||
}
|
||||
|
||||
|
||||
# Maps from field types to sizer constructors.
|
||||
TYPE_TO_SIZER = {
|
||||
_FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer,
|
||||
_FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer,
|
||||
_FieldDescriptor.TYPE_INT64: encoder.Int64Sizer,
|
||||
_FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer,
|
||||
_FieldDescriptor.TYPE_INT32: encoder.Int32Sizer,
|
||||
_FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer,
|
||||
_FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer,
|
||||
_FieldDescriptor.TYPE_BOOL: encoder.BoolSizer,
|
||||
_FieldDescriptor.TYPE_STRING: encoder.StringSizer,
|
||||
_FieldDescriptor.TYPE_GROUP: encoder.GroupSizer,
|
||||
_FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer,
|
||||
_FieldDescriptor.TYPE_BYTES: encoder.BytesSizer,
|
||||
_FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer,
|
||||
_FieldDescriptor.TYPE_ENUM: encoder.EnumSizer,
|
||||
_FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer,
|
||||
_FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer,
|
||||
_FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer,
|
||||
_FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer,
|
||||
}
|
||||
|
||||
|
||||
# Maps from field type to a decoder constructor.
|
||||
TYPE_TO_DECODER = {
|
||||
_FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder,
|
||||
_FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder,
|
||||
_FieldDescriptor.TYPE_INT64: decoder.Int64Decoder,
|
||||
_FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder,
|
||||
_FieldDescriptor.TYPE_INT32: decoder.Int32Decoder,
|
||||
_FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder,
|
||||
_FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder,
|
||||
_FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder,
|
||||
_FieldDescriptor.TYPE_STRING: decoder.StringDecoder,
|
||||
_FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder,
|
||||
_FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder,
|
||||
_FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder,
|
||||
_FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder,
|
||||
_FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder,
|
||||
_FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder,
|
||||
_FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder,
|
||||
_FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder,
|
||||
_FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder,
|
||||
}
|
||||
|
||||
# Maps from field type to expected wiretype.
|
||||
FIELD_TYPE_TO_WIRE_TYPE = {
|
||||
_FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64,
|
||||
_FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32,
|
||||
_FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT,
|
||||
_FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT,
|
||||
_FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT,
|
||||
_FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64,
|
||||
_FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32,
|
||||
_FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT,
|
||||
_FieldDescriptor.TYPE_STRING:
|
||||
wire_format.WIRETYPE_LENGTH_DELIMITED,
|
||||
_FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP,
|
||||
_FieldDescriptor.TYPE_MESSAGE:
|
||||
wire_format.WIRETYPE_LENGTH_DELIMITED,
|
||||
_FieldDescriptor.TYPE_BYTES:
|
||||
wire_format.WIRETYPE_LENGTH_DELIMITED,
|
||||
_FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT,
|
||||
_FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT,
|
||||
_FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32,
|
||||
_FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64,
|
||||
_FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT,
|
||||
_FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT,
|
||||
}
|
||||
678
.venv/lib/python3.10/site-packages/google/protobuf/internal/well_known_types.py
Executable file
678
.venv/lib/python3.10/site-packages/google/protobuf/internal/well_known_types.py
Executable file
@@ -0,0 +1,678 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""Contains well known classes.
|
||||
|
||||
This files defines well known classes which need extra maintenance including:
|
||||
- Any
|
||||
- Duration
|
||||
- FieldMask
|
||||
- Struct
|
||||
- Timestamp
|
||||
"""
|
||||
|
||||
__author__ = 'jieluo@google.com (Jie Luo)'
|
||||
|
||||
import calendar
|
||||
import collections.abc
|
||||
import datetime
|
||||
import warnings
|
||||
from google.protobuf.internal import field_mask
|
||||
from typing import Union
|
||||
|
||||
FieldMask = field_mask.FieldMask
|
||||
|
||||
_TIMESTAMPFORMAT = '%Y-%m-%dT%H:%M:%S'
|
||||
_NANOS_PER_SECOND = 1000000000
|
||||
_NANOS_PER_MILLISECOND = 1000000
|
||||
_NANOS_PER_MICROSECOND = 1000
|
||||
_MILLIS_PER_SECOND = 1000
|
||||
_MICROS_PER_SECOND = 1000000
|
||||
_SECONDS_PER_DAY = 24 * 3600
|
||||
_DURATION_SECONDS_MAX = 315576000000
|
||||
_TIMESTAMP_SECONDS_MIN = -62135596800
|
||||
_TIMESTAMP_SECONDS_MAX = 253402300799
|
||||
|
||||
_EPOCH_DATETIME_NAIVE = datetime.datetime(1970, 1, 1, tzinfo=None)
|
||||
_EPOCH_DATETIME_AWARE = _EPOCH_DATETIME_NAIVE.replace(
|
||||
tzinfo=datetime.timezone.utc
|
||||
)
|
||||
|
||||
|
||||
class Any(object):
|
||||
"""Class for Any Message type."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def Pack(self, msg, type_url_prefix='type.googleapis.com/',
|
||||
deterministic=None):
|
||||
"""Packs the specified message into current Any message."""
|
||||
if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/':
|
||||
self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
|
||||
else:
|
||||
self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
|
||||
self.value = msg.SerializeToString(deterministic=deterministic)
|
||||
|
||||
def Unpack(self, msg):
|
||||
"""Unpacks the current Any message into specified message."""
|
||||
descriptor = msg.DESCRIPTOR
|
||||
if not self.Is(descriptor):
|
||||
return False
|
||||
msg.ParseFromString(self.value)
|
||||
return True
|
||||
|
||||
def TypeName(self):
|
||||
"""Returns the protobuf type name of the inner message."""
|
||||
# Only last part is to be used: b/25630112
|
||||
return self.type_url.rpartition('/')[2]
|
||||
|
||||
def Is(self, descriptor):
|
||||
"""Checks if this Any represents the given protobuf type."""
|
||||
return '/' in self.type_url and self.TypeName() == descriptor.full_name
|
||||
|
||||
|
||||
class Timestamp(object):
|
||||
"""Class for Timestamp message type."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def ToJsonString(self):
|
||||
"""Converts Timestamp to RFC 3339 date string format.
|
||||
|
||||
Returns:
|
||||
A string converted from timestamp. The string is always Z-normalized
|
||||
and uses 3, 6 or 9 fractional digits as required to represent the
|
||||
exact time. Example of the return format: '1972-01-01T10:00:20.021Z'
|
||||
"""
|
||||
_CheckTimestampValid(self.seconds, self.nanos)
|
||||
nanos = self.nanos
|
||||
seconds = self.seconds % _SECONDS_PER_DAY
|
||||
days = (self.seconds - seconds) // _SECONDS_PER_DAY
|
||||
dt = datetime.datetime(1970, 1, 1) + datetime.timedelta(days, seconds)
|
||||
|
||||
result = dt.isoformat()
|
||||
if (nanos % 1e9) == 0:
|
||||
# If there are 0 fractional digits, the fractional
|
||||
# point '.' should be omitted when serializing.
|
||||
return result + 'Z'
|
||||
if (nanos % 1e6) == 0:
|
||||
# Serialize 3 fractional digits.
|
||||
return result + '.%03dZ' % (nanos / 1e6)
|
||||
if (nanos % 1e3) == 0:
|
||||
# Serialize 6 fractional digits.
|
||||
return result + '.%06dZ' % (nanos / 1e3)
|
||||
# Serialize 9 fractional digits.
|
||||
return result + '.%09dZ' % nanos
|
||||
|
||||
def FromJsonString(self, value):
|
||||
"""Parse a RFC 3339 date string format to Timestamp.
|
||||
|
||||
Args:
|
||||
value: A date string. Any fractional digits (or none) and any offset are
|
||||
accepted as long as they fit into nano-seconds precision.
|
||||
Example of accepted format: '1972-01-01T10:00:20.021-05:00'
|
||||
|
||||
Raises:
|
||||
ValueError: On parsing problems.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise ValueError('Timestamp JSON value not a string: {!r}'.format(value))
|
||||
timezone_offset = value.find('Z')
|
||||
if timezone_offset == -1:
|
||||
timezone_offset = value.find('+')
|
||||
if timezone_offset == -1:
|
||||
timezone_offset = value.rfind('-')
|
||||
if timezone_offset == -1:
|
||||
raise ValueError(
|
||||
'Failed to parse timestamp: missing valid timezone offset.')
|
||||
time_value = value[0:timezone_offset]
|
||||
# Parse datetime and nanos.
|
||||
point_position = time_value.find('.')
|
||||
if point_position == -1:
|
||||
second_value = time_value
|
||||
nano_value = ''
|
||||
else:
|
||||
second_value = time_value[:point_position]
|
||||
nano_value = time_value[point_position + 1:]
|
||||
if 't' in second_value:
|
||||
raise ValueError(
|
||||
'time data \'{0}\' does not match format \'%Y-%m-%dT%H:%M:%S\', '
|
||||
'lowercase \'t\' is not accepted'.format(second_value))
|
||||
date_object = datetime.datetime.strptime(second_value, _TIMESTAMPFORMAT)
|
||||
td = date_object - datetime.datetime(1970, 1, 1)
|
||||
seconds = td.seconds + td.days * _SECONDS_PER_DAY
|
||||
if len(nano_value) > 9:
|
||||
raise ValueError(
|
||||
'Failed to parse Timestamp: nanos {0} more than '
|
||||
'9 fractional digits.'.format(nano_value))
|
||||
if nano_value:
|
||||
nanos = round(float('0.' + nano_value) * 1e9)
|
||||
else:
|
||||
nanos = 0
|
||||
# Parse timezone offsets.
|
||||
if value[timezone_offset] == 'Z':
|
||||
if len(value) != timezone_offset + 1:
|
||||
raise ValueError('Failed to parse timestamp: invalid trailing'
|
||||
' data {0}.'.format(value))
|
||||
else:
|
||||
timezone = value[timezone_offset:]
|
||||
pos = timezone.find(':')
|
||||
if pos == -1:
|
||||
raise ValueError(
|
||||
'Invalid timezone offset value: {0}.'.format(timezone))
|
||||
if timezone[0] == '+':
|
||||
seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
|
||||
else:
|
||||
seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
|
||||
# Set seconds and nanos
|
||||
_CheckTimestampValid(seconds, nanos)
|
||||
self.seconds = int(seconds)
|
||||
self.nanos = int(nanos)
|
||||
|
||||
def GetCurrentTime(self):
|
||||
"""Get the current UTC into Timestamp."""
|
||||
self.FromDatetime(datetime.datetime.utcnow())
|
||||
|
||||
def ToNanoseconds(self):
|
||||
"""Converts Timestamp to nanoseconds since epoch."""
|
||||
_CheckTimestampValid(self.seconds, self.nanos)
|
||||
return self.seconds * _NANOS_PER_SECOND + self.nanos
|
||||
|
||||
def ToMicroseconds(self):
|
||||
"""Converts Timestamp to microseconds since epoch."""
|
||||
_CheckTimestampValid(self.seconds, self.nanos)
|
||||
return (self.seconds * _MICROS_PER_SECOND +
|
||||
self.nanos // _NANOS_PER_MICROSECOND)
|
||||
|
||||
def ToMilliseconds(self):
|
||||
"""Converts Timestamp to milliseconds since epoch."""
|
||||
_CheckTimestampValid(self.seconds, self.nanos)
|
||||
return (self.seconds * _MILLIS_PER_SECOND +
|
||||
self.nanos // _NANOS_PER_MILLISECOND)
|
||||
|
||||
def ToSeconds(self):
|
||||
"""Converts Timestamp to seconds since epoch."""
|
||||
_CheckTimestampValid(self.seconds, self.nanos)
|
||||
return self.seconds
|
||||
|
||||
def FromNanoseconds(self, nanos):
|
||||
"""Converts nanoseconds since epoch to Timestamp."""
|
||||
seconds = nanos // _NANOS_PER_SECOND
|
||||
nanos = nanos % _NANOS_PER_SECOND
|
||||
_CheckTimestampValid(seconds, nanos)
|
||||
self.seconds = seconds
|
||||
self.nanos = nanos
|
||||
|
||||
def FromMicroseconds(self, micros):
|
||||
"""Converts microseconds since epoch to Timestamp."""
|
||||
seconds = micros // _MICROS_PER_SECOND
|
||||
nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND
|
||||
_CheckTimestampValid(seconds, nanos)
|
||||
self.seconds = seconds
|
||||
self.nanos = nanos
|
||||
|
||||
def FromMilliseconds(self, millis):
|
||||
"""Converts milliseconds since epoch to Timestamp."""
|
||||
seconds = millis // _MILLIS_PER_SECOND
|
||||
nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND
|
||||
_CheckTimestampValid(seconds, nanos)
|
||||
self.seconds = seconds
|
||||
self.nanos = nanos
|
||||
|
||||
def FromSeconds(self, seconds):
|
||||
"""Converts seconds since epoch to Timestamp."""
|
||||
_CheckTimestampValid(seconds, 0)
|
||||
self.seconds = seconds
|
||||
self.nanos = 0
|
||||
|
||||
def ToDatetime(self, tzinfo=None):
|
||||
"""Converts Timestamp to a datetime.
|
||||
|
||||
Args:
|
||||
tzinfo: A datetime.tzinfo subclass; defaults to None.
|
||||
|
||||
Returns:
|
||||
If tzinfo is None, returns a timezone-naive UTC datetime (with no timezone
|
||||
information, i.e. not aware that it's UTC).
|
||||
|
||||
Otherwise, returns a timezone-aware datetime in the input timezone.
|
||||
"""
|
||||
# Using datetime.fromtimestamp for this would avoid constructing an extra
|
||||
# timedelta object and possibly an extra datetime. Unfortunately, that has
|
||||
# the disadvantage of not handling the full precision (on all platforms, see
|
||||
# https://github.com/python/cpython/issues/109849) or full range (on some
|
||||
# platforms, see https://github.com/python/cpython/issues/110042) of
|
||||
# datetime.
|
||||
_CheckTimestampValid(self.seconds, self.nanos)
|
||||
delta = datetime.timedelta(
|
||||
seconds=self.seconds,
|
||||
microseconds=_RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND),
|
||||
)
|
||||
if tzinfo is None:
|
||||
return _EPOCH_DATETIME_NAIVE + delta
|
||||
else:
|
||||
# Note the tz conversion has to come after the timedelta arithmetic.
|
||||
return (_EPOCH_DATETIME_AWARE + delta).astimezone(tzinfo)
|
||||
|
||||
def FromDatetime(self, dt):
|
||||
"""Converts datetime to Timestamp.
|
||||
|
||||
Args:
|
||||
dt: A datetime. If it's timezone-naive, it's assumed to be in UTC.
|
||||
"""
|
||||
# Using this guide: http://wiki.python.org/moin/WorkingWithTime
|
||||
# And this conversion guide: http://docs.python.org/library/time.html
|
||||
|
||||
# Turn the date parameter into a tuple (struct_time) that can then be
|
||||
# manipulated into a long value of seconds. During the conversion from
|
||||
# struct_time to long, the source date in UTC, and so it follows that the
|
||||
# correct transformation is calendar.timegm()
|
||||
try:
|
||||
seconds = calendar.timegm(dt.utctimetuple())
|
||||
nanos = dt.microsecond * _NANOS_PER_MICROSECOND
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
'Fail to convert to Timestamp. Expected a datetime like '
|
||||
'object got {0} : {1}'.format(type(dt).__name__, e)
|
||||
) from e
|
||||
_CheckTimestampValid(seconds, nanos)
|
||||
self.seconds = seconds
|
||||
self.nanos = nanos
|
||||
|
||||
def _internal_assign(self, dt):
|
||||
self.FromDatetime(dt)
|
||||
|
||||
def __add__(self, value) -> datetime.datetime:
|
||||
if isinstance(value, Duration):
|
||||
return self.ToDatetime() + value.ToTimedelta()
|
||||
return self.ToDatetime() + value
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
def __sub__(self, value) -> Union[datetime.datetime, datetime.timedelta]:
|
||||
if isinstance(value, Timestamp):
|
||||
return self.ToDatetime() - value.ToDatetime()
|
||||
elif isinstance(value, Duration):
|
||||
return self.ToDatetime() - value.ToTimedelta()
|
||||
return self.ToDatetime() - value
|
||||
|
||||
def __rsub__(self, dt) -> datetime.timedelta:
|
||||
return dt - self.ToDatetime()
|
||||
|
||||
|
||||
def _CheckTimestampValid(seconds, nanos):
|
||||
if seconds < _TIMESTAMP_SECONDS_MIN or seconds > _TIMESTAMP_SECONDS_MAX:
|
||||
raise ValueError(
|
||||
'Timestamp is not valid: Seconds {0} must be in range '
|
||||
'[-62135596800, 253402300799].'.format(seconds))
|
||||
if nanos < 0 or nanos >= _NANOS_PER_SECOND:
|
||||
raise ValueError(
|
||||
'Timestamp is not valid: Nanos {} must be in a range '
|
||||
'[0, 999999].'.format(nanos))
|
||||
|
||||
|
||||
class Duration(object):
|
||||
"""Class for Duration message type."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def ToJsonString(self):
|
||||
"""Converts Duration to string format.
|
||||
|
||||
Returns:
|
||||
A string converted from self. The string format will contains
|
||||
3, 6, or 9 fractional digits depending on the precision required to
|
||||
represent the exact Duration value. For example: "1s", "1.010s",
|
||||
"1.000000100s", "-3.100s"
|
||||
"""
|
||||
_CheckDurationValid(self.seconds, self.nanos)
|
||||
if self.seconds < 0 or self.nanos < 0:
|
||||
result = '-'
|
||||
seconds = - self.seconds + int((0 - self.nanos) // 1e9)
|
||||
nanos = (0 - self.nanos) % 1e9
|
||||
else:
|
||||
result = ''
|
||||
seconds = self.seconds + int(self.nanos // 1e9)
|
||||
nanos = self.nanos % 1e9
|
||||
result += '%d' % seconds
|
||||
if (nanos % 1e9) == 0:
|
||||
# If there are 0 fractional digits, the fractional
|
||||
# point '.' should be omitted when serializing.
|
||||
return result + 's'
|
||||
if (nanos % 1e6) == 0:
|
||||
# Serialize 3 fractional digits.
|
||||
return result + '.%03ds' % (nanos / 1e6)
|
||||
if (nanos % 1e3) == 0:
|
||||
# Serialize 6 fractional digits.
|
||||
return result + '.%06ds' % (nanos / 1e3)
|
||||
# Serialize 9 fractional digits.
|
||||
return result + '.%09ds' % nanos
|
||||
|
||||
def FromJsonString(self, value):
|
||||
"""Converts a string to Duration.
|
||||
|
||||
Args:
|
||||
value: A string to be converted. The string must end with 's'. Any
|
||||
fractional digits (or none) are accepted as long as they fit into
|
||||
precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s
|
||||
|
||||
Raises:
|
||||
ValueError: On parsing problems.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise ValueError('Duration JSON value not a string: {!r}'.format(value))
|
||||
if len(value) < 1 or value[-1] != 's':
|
||||
raise ValueError(
|
||||
'Duration must end with letter "s": {0}.'.format(value))
|
||||
try:
|
||||
pos = value.find('.')
|
||||
if pos == -1:
|
||||
seconds = int(value[:-1])
|
||||
nanos = 0
|
||||
else:
|
||||
seconds = int(value[:pos])
|
||||
if value[0] == '-':
|
||||
nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
|
||||
else:
|
||||
nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
|
||||
_CheckDurationValid(seconds, nanos)
|
||||
self.seconds = seconds
|
||||
self.nanos = nanos
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
'Couldn\'t parse duration: {0} : {1}.'.format(value, e))
|
||||
|
||||
def ToNanoseconds(self):
|
||||
"""Converts a Duration to nanoseconds."""
|
||||
return self.seconds * _NANOS_PER_SECOND + self.nanos
|
||||
|
||||
def ToMicroseconds(self):
|
||||
"""Converts a Duration to microseconds."""
|
||||
micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND)
|
||||
return self.seconds * _MICROS_PER_SECOND + micros
|
||||
|
||||
def ToMilliseconds(self):
|
||||
"""Converts a Duration to milliseconds."""
|
||||
millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND)
|
||||
return self.seconds * _MILLIS_PER_SECOND + millis
|
||||
|
||||
def ToSeconds(self):
|
||||
"""Converts a Duration to seconds."""
|
||||
return self.seconds
|
||||
|
||||
def FromNanoseconds(self, nanos):
|
||||
"""Converts nanoseconds to Duration."""
|
||||
self._NormalizeDuration(nanos // _NANOS_PER_SECOND,
|
||||
nanos % _NANOS_PER_SECOND)
|
||||
|
||||
def FromMicroseconds(self, micros):
|
||||
"""Converts microseconds to Duration."""
|
||||
self._NormalizeDuration(
|
||||
micros // _MICROS_PER_SECOND,
|
||||
(micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND)
|
||||
|
||||
def FromMilliseconds(self, millis):
|
||||
"""Converts milliseconds to Duration."""
|
||||
self._NormalizeDuration(
|
||||
millis // _MILLIS_PER_SECOND,
|
||||
(millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND)
|
||||
|
||||
def FromSeconds(self, seconds):
|
||||
"""Converts seconds to Duration."""
|
||||
self.seconds = seconds
|
||||
self.nanos = 0
|
||||
|
||||
def ToTimedelta(self) -> datetime.timedelta:
|
||||
"""Converts Duration to timedelta."""
|
||||
return datetime.timedelta(
|
||||
seconds=self.seconds, microseconds=_RoundTowardZero(
|
||||
self.nanos, _NANOS_PER_MICROSECOND))
|
||||
|
||||
def FromTimedelta(self, td):
|
||||
"""Converts timedelta to Duration."""
|
||||
try:
|
||||
self._NormalizeDuration(
|
||||
td.seconds + td.days * _SECONDS_PER_DAY,
|
||||
td.microseconds * _NANOS_PER_MICROSECOND,
|
||||
)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
'Fail to convert to Duration. Expected a timedelta like '
|
||||
'object got {0}: {1}'.format(type(td).__name__, e)
|
||||
) from e
|
||||
|
||||
def _internal_assign(self, td):
|
||||
self.FromTimedelta(td)
|
||||
|
||||
def _NormalizeDuration(self, seconds, nanos):
|
||||
"""Set Duration by seconds and nanos."""
|
||||
# Force nanos to be negative if the duration is negative.
|
||||
if seconds < 0 and nanos > 0:
|
||||
seconds += 1
|
||||
nanos -= _NANOS_PER_SECOND
|
||||
self.seconds = seconds
|
||||
self.nanos = nanos
|
||||
|
||||
def __add__(self, value) -> Union[datetime.datetime, datetime.timedelta]:
|
||||
if isinstance(value, Timestamp):
|
||||
return self.ToTimedelta() + value.ToDatetime()
|
||||
return self.ToTimedelta() + value
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
def __rsub__(self, dt) -> Union[datetime.datetime, datetime.timedelta]:
|
||||
return dt - self.ToTimedelta()
|
||||
|
||||
|
||||
def _CheckDurationValid(seconds, nanos):
|
||||
if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:
|
||||
raise ValueError(
|
||||
'Duration is not valid: Seconds {0} must be in range '
|
||||
'[-315576000000, 315576000000].'.format(seconds))
|
||||
if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND:
|
||||
raise ValueError(
|
||||
'Duration is not valid: Nanos {0} must be in range '
|
||||
'[-999999999, 999999999].'.format(nanos))
|
||||
if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0):
|
||||
raise ValueError(
|
||||
'Duration is not valid: Sign mismatch.')
|
||||
|
||||
|
||||
def _RoundTowardZero(value, divider):
|
||||
"""Truncates the remainder part after division."""
|
||||
# For some languages, the sign of the remainder is implementation
|
||||
# dependent if any of the operands is negative. Here we enforce
|
||||
# "rounded toward zero" semantics. For example, for (-5) / 2 an
|
||||
# implementation may give -3 as the result with the remainder being
|
||||
# 1. This function ensures we always return -2 (closer to zero).
|
||||
result = value // divider
|
||||
remainder = value % divider
|
||||
if result < 0 and remainder > 0:
|
||||
return result + 1
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def _SetStructValue(struct_value, value):
|
||||
if value is None:
|
||||
struct_value.null_value = 0
|
||||
elif isinstance(value, bool):
|
||||
# Note: this check must come before the number check because in Python
|
||||
# True and False are also considered numbers.
|
||||
struct_value.bool_value = value
|
||||
elif isinstance(value, str):
|
||||
struct_value.string_value = value
|
||||
elif isinstance(value, (int, float)):
|
||||
struct_value.number_value = value
|
||||
elif isinstance(value, (dict, Struct)):
|
||||
struct_value.struct_value.Clear()
|
||||
struct_value.struct_value.update(value)
|
||||
elif isinstance(value, (list, tuple, ListValue)):
|
||||
struct_value.list_value.Clear()
|
||||
struct_value.list_value.extend(value)
|
||||
else:
|
||||
raise ValueError('Unexpected type')
|
||||
|
||||
|
||||
def _GetStructValue(struct_value):
|
||||
which = struct_value.WhichOneof('kind')
|
||||
if which == 'struct_value':
|
||||
return struct_value.struct_value
|
||||
elif which == 'null_value':
|
||||
return None
|
||||
elif which == 'number_value':
|
||||
return struct_value.number_value
|
||||
elif which == 'string_value':
|
||||
return struct_value.string_value
|
||||
elif which == 'bool_value':
|
||||
return struct_value.bool_value
|
||||
elif which == 'list_value':
|
||||
return struct_value.list_value
|
||||
elif which is None:
|
||||
raise ValueError('Value not set')
|
||||
|
||||
|
||||
class Struct(object):
|
||||
"""Class for Struct message type."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return _GetStructValue(self.fields[key])
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
_SetStructValue(self.fields[key], value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.fields[key]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.fields)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.fields)
|
||||
|
||||
def _internal_assign(self, dictionary):
|
||||
self.Clear()
|
||||
self.update(dictionary)
|
||||
|
||||
def _internal_compare(self, other):
|
||||
size = len(self)
|
||||
if size != len(other):
|
||||
return False
|
||||
for key, value in self.items():
|
||||
if key not in other:
|
||||
return False
|
||||
if isinstance(other[key], (dict, list)):
|
||||
if not value._internal_compare(other[key]):
|
||||
return False
|
||||
elif value != other[key]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def keys(self): # pylint: disable=invalid-name
|
||||
return self.fields.keys()
|
||||
|
||||
def values(self): # pylint: disable=invalid-name
|
||||
return [self[key] for key in self]
|
||||
|
||||
def items(self): # pylint: disable=invalid-name
|
||||
return [(key, self[key]) for key in self]
|
||||
|
||||
def get_or_create_list(self, key):
|
||||
"""Returns a list for this key, creating if it didn't exist already."""
|
||||
if not self.fields[key].HasField('list_value'):
|
||||
# Clear will mark list_value modified which will indeed create a list.
|
||||
self.fields[key].list_value.Clear()
|
||||
return self.fields[key].list_value
|
||||
|
||||
def get_or_create_struct(self, key):
|
||||
"""Returns a struct for this key, creating if it didn't exist already."""
|
||||
if not self.fields[key].HasField('struct_value'):
|
||||
# Clear will mark struct_value modified which will indeed create a struct.
|
||||
self.fields[key].struct_value.Clear()
|
||||
return self.fields[key].struct_value
|
||||
|
||||
def update(self, dictionary): # pylint: disable=invalid-name
|
||||
for key, value in dictionary.items():
|
||||
_SetStructValue(self.fields[key], value)
|
||||
|
||||
collections.abc.MutableMapping.register(Struct)
|
||||
|
||||
|
||||
class ListValue(object):
|
||||
"""Class for ListValue message type."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.values)
|
||||
|
||||
def append(self, value):
|
||||
_SetStructValue(self.values.add(), value)
|
||||
|
||||
def extend(self, elem_seq):
|
||||
for value in elem_seq:
|
||||
self.append(value)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Retrieves item by the specified index."""
|
||||
return _GetStructValue(self.values.__getitem__(index))
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
_SetStructValue(self.values.__getitem__(index), value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.values[key]
|
||||
|
||||
def _internal_assign(self, elem_seq):
|
||||
self.Clear()
|
||||
self.extend(elem_seq)
|
||||
|
||||
def _internal_compare(self, other):
|
||||
size = len(self)
|
||||
if size != len(other):
|
||||
return False
|
||||
for i in range(size):
|
||||
if isinstance(other[i], (dict, list)):
|
||||
if not self[i]._internal_compare(other[i]):
|
||||
return False
|
||||
elif self[i] != other[i]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def items(self):
|
||||
for i in range(len(self)):
|
||||
yield self[i]
|
||||
|
||||
def add_struct(self):
|
||||
"""Appends and returns a struct value as the next value in the list."""
|
||||
struct_value = self.values.add().struct_value
|
||||
# Clear will mark struct_value modified which will indeed create a struct.
|
||||
struct_value.Clear()
|
||||
return struct_value
|
||||
|
||||
def add_list(self):
|
||||
"""Appends and returns a list value as the next value in the list."""
|
||||
list_value = self.values.add().list_value
|
||||
# Clear will mark list_value modified which will indeed create a list.
|
||||
list_value.Clear()
|
||||
return list_value
|
||||
|
||||
collections.abc.MutableSequence.register(ListValue)
|
||||
|
||||
|
||||
# LINT.IfChange(wktbases)
|
||||
WKTBASES = {
|
||||
'google.protobuf.Any': Any,
|
||||
'google.protobuf.Duration': Duration,
|
||||
'google.protobuf.FieldMask': FieldMask,
|
||||
'google.protobuf.ListValue': ListValue,
|
||||
'google.protobuf.Struct': Struct,
|
||||
'google.protobuf.Timestamp': Timestamp,
|
||||
}
|
||||
# LINT.ThenChange(//depot/google.protobuf/compiler/python/pyi_generator.cc:wktbases)
|
||||
245
.venv/lib/python3.10/site-packages/google/protobuf/internal/wire_format.py
Executable file
245
.venv/lib/python3.10/site-packages/google/protobuf/internal/wire_format.py
Executable file
@@ -0,0 +1,245 @@
|
||||
# Protocol Buffers - Google's data interchange format
|
||||
# Copyright 2008 Google Inc. All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://developers.google.com/open-source/licenses/bsd
|
||||
|
||||
"""Constants and static functions to support protocol buffer wire format."""
|
||||
|
||||
__author__ = 'robinson@google.com (Will Robinson)'
|
||||
|
||||
import struct
|
||||
from google.protobuf import descriptor
|
||||
from google.protobuf import message
|
||||
|
||||
|
||||
TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag.
|
||||
TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7
|
||||
|
||||
# These numbers identify the wire type of a protocol buffer value.
|
||||
# We use the least-significant TAG_TYPE_BITS bits of the varint-encoded
|
||||
# tag-and-type to store one of these WIRETYPE_* constants.
|
||||
# These values must match WireType enum in //google/protobuf/wire_format.h.
|
||||
WIRETYPE_VARINT = 0
|
||||
WIRETYPE_FIXED64 = 1
|
||||
WIRETYPE_LENGTH_DELIMITED = 2
|
||||
WIRETYPE_START_GROUP = 3
|
||||
WIRETYPE_END_GROUP = 4
|
||||
WIRETYPE_FIXED32 = 5
|
||||
_WIRETYPE_MAX = 5
|
||||
|
||||
|
||||
# Bounds for various integer types.
|
||||
INT32_MAX = int((1 << 31) - 1)
|
||||
INT32_MIN = int(-(1 << 31))
|
||||
UINT32_MAX = (1 << 32) - 1
|
||||
|
||||
INT64_MAX = (1 << 63) - 1
|
||||
INT64_MIN = -(1 << 63)
|
||||
UINT64_MAX = (1 << 64) - 1
|
||||
|
||||
# "struct" format strings that will encode/decode the specified formats.
|
||||
FORMAT_UINT32_LITTLE_ENDIAN = '<I'
|
||||
FORMAT_UINT64_LITTLE_ENDIAN = '<Q'
|
||||
FORMAT_FLOAT_LITTLE_ENDIAN = '<f'
|
||||
FORMAT_DOUBLE_LITTLE_ENDIAN = '<d'
|
||||
|
||||
|
||||
# We'll have to provide alternate implementations of AppendLittleEndian*() on
|
||||
# any architectures where these checks fail.
|
||||
if struct.calcsize(FORMAT_UINT32_LITTLE_ENDIAN) != 4:
|
||||
raise AssertionError('Format "I" is not a 32-bit number.')
|
||||
if struct.calcsize(FORMAT_UINT64_LITTLE_ENDIAN) != 8:
|
||||
raise AssertionError('Format "Q" is not a 64-bit number.')
|
||||
|
||||
|
||||
def PackTag(field_number, wire_type):
|
||||
"""Returns an unsigned 32-bit integer that encodes the field number and
|
||||
wire type information in standard protocol message wire format.
|
||||
|
||||
Args:
|
||||
field_number: Expected to be an integer in the range [1, 1 << 29)
|
||||
wire_type: One of the WIRETYPE_* constants.
|
||||
"""
|
||||
if not 0 <= wire_type <= _WIRETYPE_MAX:
|
||||
raise message.EncodeError('Unknown wire type: %d' % wire_type)
|
||||
return (field_number << TAG_TYPE_BITS) | wire_type
|
||||
|
||||
|
||||
def UnpackTag(tag):
|
||||
"""The inverse of PackTag(). Given an unsigned 32-bit number,
|
||||
returns a (field_number, wire_type) tuple.
|
||||
"""
|
||||
return (tag >> TAG_TYPE_BITS), (tag & TAG_TYPE_MASK)
|
||||
|
||||
|
||||
def ZigZagEncode(value):
|
||||
"""ZigZag Transform: Encodes signed integers so that they can be
|
||||
effectively used with varint encoding. See wire_format.h for
|
||||
more details.
|
||||
"""
|
||||
if value >= 0:
|
||||
return value << 1
|
||||
return (value << 1) ^ (~0)
|
||||
|
||||
|
||||
def ZigZagDecode(value):
|
||||
"""Inverse of ZigZagEncode()."""
|
||||
if not value & 0x1:
|
||||
return value >> 1
|
||||
return (value >> 1) ^ (~0)
|
||||
|
||||
|
||||
|
||||
# The *ByteSize() functions below return the number of bytes required to
|
||||
# serialize "field number + type" information and then serialize the value.
|
||||
|
||||
|
||||
def Int32ByteSize(field_number, int32):
|
||||
return Int64ByteSize(field_number, int32)
|
||||
|
||||
|
||||
def Int32ByteSizeNoTag(int32):
|
||||
return _VarUInt64ByteSizeNoTag(0xffffffffffffffff & int32)
|
||||
|
||||
|
||||
def Int64ByteSize(field_number, int64):
|
||||
# Have to convert to uint before calling UInt64ByteSize().
|
||||
return UInt64ByteSize(field_number, 0xffffffffffffffff & int64)
|
||||
|
||||
|
||||
def UInt32ByteSize(field_number, uint32):
|
||||
return UInt64ByteSize(field_number, uint32)
|
||||
|
||||
|
||||
def UInt64ByteSize(field_number, uint64):
|
||||
return TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64)
|
||||
|
||||
|
||||
def SInt32ByteSize(field_number, int32):
|
||||
return UInt32ByteSize(field_number, ZigZagEncode(int32))
|
||||
|
||||
|
||||
def SInt64ByteSize(field_number, int64):
|
||||
return UInt64ByteSize(field_number, ZigZagEncode(int64))
|
||||
|
||||
|
||||
def Fixed32ByteSize(field_number, fixed32):
|
||||
return TagByteSize(field_number) + 4
|
||||
|
||||
|
||||
def Fixed64ByteSize(field_number, fixed64):
|
||||
return TagByteSize(field_number) + 8
|
||||
|
||||
|
||||
def SFixed32ByteSize(field_number, sfixed32):
|
||||
return TagByteSize(field_number) + 4
|
||||
|
||||
|
||||
def SFixed64ByteSize(field_number, sfixed64):
|
||||
return TagByteSize(field_number) + 8
|
||||
|
||||
|
||||
def FloatByteSize(field_number, flt):
|
||||
return TagByteSize(field_number) + 4
|
||||
|
||||
|
||||
def DoubleByteSize(field_number, double):
|
||||
return TagByteSize(field_number) + 8
|
||||
|
||||
|
||||
def BoolByteSize(field_number, b):
|
||||
return TagByteSize(field_number) + 1
|
||||
|
||||
|
||||
def EnumByteSize(field_number, enum):
|
||||
return UInt32ByteSize(field_number, enum)
|
||||
|
||||
|
||||
def StringByteSize(field_number, string):
|
||||
return BytesByteSize(field_number, string.encode('utf-8'))
|
||||
|
||||
|
||||
def BytesByteSize(field_number, b):
|
||||
return (TagByteSize(field_number)
|
||||
+ _VarUInt64ByteSizeNoTag(len(b))
|
||||
+ len(b))
|
||||
|
||||
|
||||
def GroupByteSize(field_number, message):
|
||||
return (2 * TagByteSize(field_number) # START and END group.
|
||||
+ message.ByteSize())
|
||||
|
||||
|
||||
def MessageByteSize(field_number, message):
|
||||
return (TagByteSize(field_number)
|
||||
+ _VarUInt64ByteSizeNoTag(message.ByteSize())
|
||||
+ message.ByteSize())
|
||||
|
||||
|
||||
def MessageSetItemByteSize(field_number, msg):
|
||||
# First compute the sizes of the tags.
|
||||
# There are 2 tags for the beginning and ending of the repeated group, that
|
||||
# is field number 1, one with field number 2 (type_id) and one with field
|
||||
# number 3 (message).
|
||||
total_size = (2 * TagByteSize(1) + TagByteSize(2) + TagByteSize(3))
|
||||
|
||||
# Add the number of bytes for type_id.
|
||||
total_size += _VarUInt64ByteSizeNoTag(field_number)
|
||||
|
||||
message_size = msg.ByteSize()
|
||||
|
||||
# The number of bytes for encoding the length of the message.
|
||||
total_size += _VarUInt64ByteSizeNoTag(message_size)
|
||||
|
||||
# The size of the message.
|
||||
total_size += message_size
|
||||
return total_size
|
||||
|
||||
|
||||
def TagByteSize(field_number):
|
||||
"""Returns the bytes required to serialize a tag with this field number."""
|
||||
# Just pass in type 0, since the type won't affect the tag+type size.
|
||||
return _VarUInt64ByteSizeNoTag(PackTag(field_number, 0))
|
||||
|
||||
|
||||
# Private helper function for the *ByteSize() functions above.
|
||||
|
||||
def _VarUInt64ByteSizeNoTag(uint64):
|
||||
"""Returns the number of bytes required to serialize a single varint
|
||||
using boundary value comparisons. (unrolled loop optimization -WPierce)
|
||||
uint64 must be unsigned.
|
||||
"""
|
||||
if uint64 <= 0x7f: return 1
|
||||
if uint64 <= 0x3fff: return 2
|
||||
if uint64 <= 0x1fffff: return 3
|
||||
if uint64 <= 0xfffffff: return 4
|
||||
if uint64 <= 0x7ffffffff: return 5
|
||||
if uint64 <= 0x3ffffffffff: return 6
|
||||
if uint64 <= 0x1ffffffffffff: return 7
|
||||
if uint64 <= 0xffffffffffffff: return 8
|
||||
if uint64 <= 0x7fffffffffffffff: return 9
|
||||
if uint64 > UINT64_MAX:
|
||||
raise message.EncodeError('Value out of range: %d' % uint64)
|
||||
return 10
|
||||
|
||||
|
||||
NON_PACKABLE_TYPES = (
|
||||
descriptor.FieldDescriptor.TYPE_STRING,
|
||||
descriptor.FieldDescriptor.TYPE_GROUP,
|
||||
descriptor.FieldDescriptor.TYPE_MESSAGE,
|
||||
descriptor.FieldDescriptor.TYPE_BYTES
|
||||
)
|
||||
|
||||
|
||||
def IsTypePackable(field_type):
|
||||
"""Return true iff packable = true is valid for fields of this type.
|
||||
|
||||
Args:
|
||||
field_type: a FieldDescriptor::Type value.
|
||||
|
||||
Returns:
|
||||
True iff fields of this type are packable.
|
||||
"""
|
||||
return field_type not in NON_PACKABLE_TYPES
|
||||
Reference in New Issue
Block a user