structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1 @@
f2318883e549f69de597009a914603b0f1b10381e265ef5d98af499ad973fb98 /home/runner/work/aiohttp/aiohttp/aiohttp/_cparser.pxd

View File

@@ -0,0 +1 @@
d067f01423cddb3c442933b5fcc039b18ab651fcec1bc91c577693aafc25cf78 /home/runner/work/aiohttp/aiohttp/aiohttp/_find_header.pxd

View File

@@ -0,0 +1 @@
c107400e3e4b8b3c02ffb9c51abf2722593a1a9a1a41e434df9f47d0730a1ae3 /home/runner/work/aiohttp/aiohttp/aiohttp/_http_parser.pyx

View File

@@ -0,0 +1 @@
f7ab1e2628277b82772d59c1dc3033c13495d769df67b1d1d49b1a474a75dd52 /home/runner/work/aiohttp/aiohttp/aiohttp/_http_writer.pyx

View File

@@ -0,0 +1 @@
dab8f933203eeb245d60f856e542a45b888d5a110094620e4811f90f816628d1 /home/runner/work/aiohttp/aiohttp/aiohttp/hdrs.py

View File

@@ -0,0 +1,264 @@
__version__ = "3.11.18"
from typing import TYPE_CHECKING, Tuple
from . import hdrs as hdrs
from .client import (
BaseConnector,
ClientConnectionError,
ClientConnectionResetError,
ClientConnectorCertificateError,
ClientConnectorDNSError,
ClientConnectorError,
ClientConnectorSSLError,
ClientError,
ClientHttpProxyError,
ClientOSError,
ClientPayloadError,
ClientProxyConnectionError,
ClientRequest,
ClientResponse,
ClientResponseError,
ClientSession,
ClientSSLError,
ClientTimeout,
ClientWebSocketResponse,
ClientWSTimeout,
ConnectionTimeoutError,
ContentTypeError,
Fingerprint,
InvalidURL,
InvalidUrlClientError,
InvalidUrlRedirectClientError,
NamedPipeConnector,
NonHttpUrlClientError,
NonHttpUrlRedirectClientError,
RedirectClientError,
RequestInfo,
ServerConnectionError,
ServerDisconnectedError,
ServerFingerprintMismatch,
ServerTimeoutError,
SocketTimeoutError,
TCPConnector,
TooManyRedirects,
UnixConnector,
WSMessageTypeError,
WSServerHandshakeError,
request,
)
from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar
from .formdata import FormData as FormData
from .helpers import BasicAuth, ChainMapProxy, ETag
from .http import (
HttpVersion as HttpVersion,
HttpVersion10 as HttpVersion10,
HttpVersion11 as HttpVersion11,
WebSocketError as WebSocketError,
WSCloseCode as WSCloseCode,
WSMessage as WSMessage,
WSMsgType as WSMsgType,
)
from .multipart import (
BadContentDispositionHeader as BadContentDispositionHeader,
BadContentDispositionParam as BadContentDispositionParam,
BodyPartReader as BodyPartReader,
MultipartReader as MultipartReader,
MultipartWriter as MultipartWriter,
content_disposition_filename as content_disposition_filename,
parse_content_disposition as parse_content_disposition,
)
from .payload import (
PAYLOAD_REGISTRY as PAYLOAD_REGISTRY,
AsyncIterablePayload as AsyncIterablePayload,
BufferedReaderPayload as BufferedReaderPayload,
BytesIOPayload as BytesIOPayload,
BytesPayload as BytesPayload,
IOBasePayload as IOBasePayload,
JsonPayload as JsonPayload,
Payload as Payload,
StringIOPayload as StringIOPayload,
StringPayload as StringPayload,
TextIOPayload as TextIOPayload,
get_payload as get_payload,
payload_type as payload_type,
)
from .payload_streamer import streamer as streamer
from .resolver import (
AsyncResolver as AsyncResolver,
DefaultResolver as DefaultResolver,
ThreadedResolver as ThreadedResolver,
)
from .streams import (
EMPTY_PAYLOAD as EMPTY_PAYLOAD,
DataQueue as DataQueue,
EofStream as EofStream,
FlowControlDataQueue as FlowControlDataQueue,
StreamReader as StreamReader,
)
from .tracing import (
TraceConfig as TraceConfig,
TraceConnectionCreateEndParams as TraceConnectionCreateEndParams,
TraceConnectionCreateStartParams as TraceConnectionCreateStartParams,
TraceConnectionQueuedEndParams as TraceConnectionQueuedEndParams,
TraceConnectionQueuedStartParams as TraceConnectionQueuedStartParams,
TraceConnectionReuseconnParams as TraceConnectionReuseconnParams,
TraceDnsCacheHitParams as TraceDnsCacheHitParams,
TraceDnsCacheMissParams as TraceDnsCacheMissParams,
TraceDnsResolveHostEndParams as TraceDnsResolveHostEndParams,
TraceDnsResolveHostStartParams as TraceDnsResolveHostStartParams,
TraceRequestChunkSentParams as TraceRequestChunkSentParams,
TraceRequestEndParams as TraceRequestEndParams,
TraceRequestExceptionParams as TraceRequestExceptionParams,
TraceRequestHeadersSentParams as TraceRequestHeadersSentParams,
TraceRequestRedirectParams as TraceRequestRedirectParams,
TraceRequestStartParams as TraceRequestStartParams,
TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams,
)
if TYPE_CHECKING:
# At runtime these are lazy-loaded at the bottom of the file.
from .worker import (
GunicornUVLoopWebWorker as GunicornUVLoopWebWorker,
GunicornWebWorker as GunicornWebWorker,
)
__all__: Tuple[str, ...] = (
"hdrs",
# client
"BaseConnector",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorDNSError",
"ClientConnectorError",
"ClientConnectorSSLError",
"ClientError",
"ClientHttpProxyError",
"ClientOSError",
"ClientPayloadError",
"ClientProxyConnectionError",
"ClientResponse",
"ClientRequest",
"ClientResponseError",
"ClientSSLError",
"ClientSession",
"ClientTimeout",
"ClientWebSocketResponse",
"ClientWSTimeout",
"ConnectionTimeoutError",
"ContentTypeError",
"Fingerprint",
"FlowControlDataQueue",
"InvalidURL",
"InvalidUrlClientError",
"InvalidUrlRedirectClientError",
"NonHttpUrlClientError",
"NonHttpUrlRedirectClientError",
"RedirectClientError",
"RequestInfo",
"ServerConnectionError",
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ServerTimeoutError",
"SocketTimeoutError",
"TCPConnector",
"TooManyRedirects",
"UnixConnector",
"NamedPipeConnector",
"WSServerHandshakeError",
"request",
# cookiejar
"CookieJar",
"DummyCookieJar",
# formdata
"FormData",
# helpers
"BasicAuth",
"ChainMapProxy",
"ETag",
# http
"HttpVersion",
"HttpVersion10",
"HttpVersion11",
"WSMsgType",
"WSCloseCode",
"WSMessage",
"WebSocketError",
# multipart
"BadContentDispositionHeader",
"BadContentDispositionParam",
"BodyPartReader",
"MultipartReader",
"MultipartWriter",
"content_disposition_filename",
"parse_content_disposition",
# payload
"AsyncIterablePayload",
"BufferedReaderPayload",
"BytesIOPayload",
"BytesPayload",
"IOBasePayload",
"JsonPayload",
"PAYLOAD_REGISTRY",
"Payload",
"StringIOPayload",
"StringPayload",
"TextIOPayload",
"get_payload",
"payload_type",
# payload_streamer
"streamer",
# resolver
"AsyncResolver",
"DefaultResolver",
"ThreadedResolver",
# streams
"DataQueue",
"EMPTY_PAYLOAD",
"EofStream",
"StreamReader",
# tracing
"TraceConfig",
"TraceConnectionCreateEndParams",
"TraceConnectionCreateStartParams",
"TraceConnectionQueuedEndParams",
"TraceConnectionQueuedStartParams",
"TraceConnectionReuseconnParams",
"TraceDnsCacheHitParams",
"TraceDnsCacheMissParams",
"TraceDnsResolveHostEndParams",
"TraceDnsResolveHostStartParams",
"TraceRequestChunkSentParams",
"TraceRequestEndParams",
"TraceRequestExceptionParams",
"TraceRequestHeadersSentParams",
"TraceRequestRedirectParams",
"TraceRequestStartParams",
"TraceResponseChunkReceivedParams",
# workers (imported lazily with __getattr__)
"GunicornUVLoopWebWorker",
"GunicornWebWorker",
"WSMessageTypeError",
)
def __dir__() -> Tuple[str, ...]:
return __all__ + ("__doc__",)
def __getattr__(name: str) -> object:
global GunicornUVLoopWebWorker, GunicornWebWorker
# Importing gunicorn takes a long time (>100ms), so only import if actually needed.
if name in ("GunicornUVLoopWebWorker", "GunicornWebWorker"):
try:
from .worker import GunicornUVLoopWebWorker as guv, GunicornWebWorker as gw
except ImportError:
return None
GunicornUVLoopWebWorker = guv # type: ignore[misc]
GunicornWebWorker = gw # type: ignore[misc]
return guv if name == "GunicornUVLoopWebWorker" else gw
raise AttributeError(f"module {__name__} has no attribute {name}")

View File

@@ -0,0 +1,158 @@
from libc.stdint cimport int32_t, uint8_t, uint16_t, uint64_t
cdef extern from "../vendor/llhttp/build/llhttp.h":
struct llhttp__internal_s:
int32_t _index
void* _span_pos0
void* _span_cb0
int32_t error
const char* reason
const char* error_pos
void* data
void* _current
uint64_t content_length
uint8_t type
uint8_t method
uint8_t http_major
uint8_t http_minor
uint8_t header_state
uint8_t lenient_flags
uint8_t upgrade
uint8_t finish
uint16_t flags
uint16_t status_code
void* settings
ctypedef llhttp__internal_s llhttp__internal_t
ctypedef llhttp__internal_t llhttp_t
ctypedef int (*llhttp_data_cb)(llhttp_t*, const char *at, size_t length) except -1
ctypedef int (*llhttp_cb)(llhttp_t*) except -1
struct llhttp_settings_s:
llhttp_cb on_message_begin
llhttp_data_cb on_url
llhttp_data_cb on_status
llhttp_data_cb on_header_field
llhttp_data_cb on_header_value
llhttp_cb on_headers_complete
llhttp_data_cb on_body
llhttp_cb on_message_complete
llhttp_cb on_chunk_header
llhttp_cb on_chunk_complete
llhttp_cb on_url_complete
llhttp_cb on_status_complete
llhttp_cb on_header_field_complete
llhttp_cb on_header_value_complete
ctypedef llhttp_settings_s llhttp_settings_t
enum llhttp_errno:
HPE_OK,
HPE_INTERNAL,
HPE_STRICT,
HPE_LF_EXPECTED,
HPE_UNEXPECTED_CONTENT_LENGTH,
HPE_CLOSED_CONNECTION,
HPE_INVALID_METHOD,
HPE_INVALID_URL,
HPE_INVALID_CONSTANT,
HPE_INVALID_VERSION,
HPE_INVALID_HEADER_TOKEN,
HPE_INVALID_CONTENT_LENGTH,
HPE_INVALID_CHUNK_SIZE,
HPE_INVALID_STATUS,
HPE_INVALID_EOF_STATE,
HPE_INVALID_TRANSFER_ENCODING,
HPE_CB_MESSAGE_BEGIN,
HPE_CB_HEADERS_COMPLETE,
HPE_CB_MESSAGE_COMPLETE,
HPE_CB_CHUNK_HEADER,
HPE_CB_CHUNK_COMPLETE,
HPE_PAUSED,
HPE_PAUSED_UPGRADE,
HPE_USER
ctypedef llhttp_errno llhttp_errno_t
enum llhttp_flags:
F_CHUNKED,
F_CONTENT_LENGTH
enum llhttp_type:
HTTP_REQUEST,
HTTP_RESPONSE,
HTTP_BOTH
enum llhttp_method:
HTTP_DELETE,
HTTP_GET,
HTTP_HEAD,
HTTP_POST,
HTTP_PUT,
HTTP_CONNECT,
HTTP_OPTIONS,
HTTP_TRACE,
HTTP_COPY,
HTTP_LOCK,
HTTP_MKCOL,
HTTP_MOVE,
HTTP_PROPFIND,
HTTP_PROPPATCH,
HTTP_SEARCH,
HTTP_UNLOCK,
HTTP_BIND,
HTTP_REBIND,
HTTP_UNBIND,
HTTP_ACL,
HTTP_REPORT,
HTTP_MKACTIVITY,
HTTP_CHECKOUT,
HTTP_MERGE,
HTTP_MSEARCH,
HTTP_NOTIFY,
HTTP_SUBSCRIBE,
HTTP_UNSUBSCRIBE,
HTTP_PATCH,
HTTP_PURGE,
HTTP_MKCALENDAR,
HTTP_LINK,
HTTP_UNLINK,
HTTP_SOURCE,
HTTP_PRI,
HTTP_DESCRIBE,
HTTP_ANNOUNCE,
HTTP_SETUP,
HTTP_PLAY,
HTTP_PAUSE,
HTTP_TEARDOWN,
HTTP_GET_PARAMETER,
HTTP_SET_PARAMETER,
HTTP_REDIRECT,
HTTP_RECORD,
HTTP_FLUSH
ctypedef llhttp_method llhttp_method_t;
void llhttp_settings_init(llhttp_settings_t* settings)
void llhttp_init(llhttp_t* parser, llhttp_type type,
const llhttp_settings_t* settings)
llhttp_errno_t llhttp_execute(llhttp_t* parser, const char* data, size_t len)
int llhttp_should_keep_alive(const llhttp_t* parser)
void llhttp_resume_after_upgrade(llhttp_t* parser)
llhttp_errno_t llhttp_get_errno(const llhttp_t* parser)
const char* llhttp_get_error_reason(const llhttp_t* parser)
const char* llhttp_get_error_pos(const llhttp_t* parser)
const char* llhttp_method_name(llhttp_method_t method)
void llhttp_set_lenient_headers(llhttp_t* parser, int enabled)
void llhttp_set_lenient_optional_cr_before_lf(llhttp_t* parser, int enabled)
void llhttp_set_lenient_spaces_after_chunk_size(llhttp_t* parser, int enabled)

View File

@@ -0,0 +1,2 @@
cdef extern from "_find_header.h":
int find_header(char *, int)

View File

@@ -0,0 +1,83 @@
# The file is autogenerated from aiohttp/hdrs.py
# Run ./tools/gen.py to update it after the origin changing.
from . import hdrs
cdef tuple headers = (
hdrs.ACCEPT,
hdrs.ACCEPT_CHARSET,
hdrs.ACCEPT_ENCODING,
hdrs.ACCEPT_LANGUAGE,
hdrs.ACCEPT_RANGES,
hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
hdrs.ACCESS_CONTROL_ALLOW_METHODS,
hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
hdrs.ACCESS_CONTROL_MAX_AGE,
hdrs.ACCESS_CONTROL_REQUEST_HEADERS,
hdrs.ACCESS_CONTROL_REQUEST_METHOD,
hdrs.AGE,
hdrs.ALLOW,
hdrs.AUTHORIZATION,
hdrs.CACHE_CONTROL,
hdrs.CONNECTION,
hdrs.CONTENT_DISPOSITION,
hdrs.CONTENT_ENCODING,
hdrs.CONTENT_LANGUAGE,
hdrs.CONTENT_LENGTH,
hdrs.CONTENT_LOCATION,
hdrs.CONTENT_MD5,
hdrs.CONTENT_RANGE,
hdrs.CONTENT_TRANSFER_ENCODING,
hdrs.CONTENT_TYPE,
hdrs.COOKIE,
hdrs.DATE,
hdrs.DESTINATION,
hdrs.DIGEST,
hdrs.ETAG,
hdrs.EXPECT,
hdrs.EXPIRES,
hdrs.FORWARDED,
hdrs.FROM,
hdrs.HOST,
hdrs.IF_MATCH,
hdrs.IF_MODIFIED_SINCE,
hdrs.IF_NONE_MATCH,
hdrs.IF_RANGE,
hdrs.IF_UNMODIFIED_SINCE,
hdrs.KEEP_ALIVE,
hdrs.LAST_EVENT_ID,
hdrs.LAST_MODIFIED,
hdrs.LINK,
hdrs.LOCATION,
hdrs.MAX_FORWARDS,
hdrs.ORIGIN,
hdrs.PRAGMA,
hdrs.PROXY_AUTHENTICATE,
hdrs.PROXY_AUTHORIZATION,
hdrs.RANGE,
hdrs.REFERER,
hdrs.RETRY_AFTER,
hdrs.SEC_WEBSOCKET_ACCEPT,
hdrs.SEC_WEBSOCKET_EXTENSIONS,
hdrs.SEC_WEBSOCKET_KEY,
hdrs.SEC_WEBSOCKET_KEY1,
hdrs.SEC_WEBSOCKET_PROTOCOL,
hdrs.SEC_WEBSOCKET_VERSION,
hdrs.SERVER,
hdrs.SET_COOKIE,
hdrs.TE,
hdrs.TRAILER,
hdrs.TRANSFER_ENCODING,
hdrs.URI,
hdrs.UPGRADE,
hdrs.USER_AGENT,
hdrs.VARY,
hdrs.VIA,
hdrs.WWW_AUTHENTICATE,
hdrs.WANT_DIGEST,
hdrs.WARNING,
hdrs.X_FORWARDED_FOR,
hdrs.X_FORWARDED_HOST,
hdrs.X_FORWARDED_PROTO,
)

View File

@@ -0,0 +1,837 @@
#cython: language_level=3
#
# Based on https://github.com/MagicStack/httptools
#
from cpython cimport (
Py_buffer,
PyBUF_SIMPLE,
PyBuffer_Release,
PyBytes_AsString,
PyBytes_AsStringAndSize,
PyObject_GetBuffer,
)
from cpython.mem cimport PyMem_Free, PyMem_Malloc
from libc.limits cimport ULLONG_MAX
from libc.string cimport memcpy
from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiDictProxy
from yarl import URL as _URL
from aiohttp import hdrs
from aiohttp.helpers import DEBUG, set_exception
from .http_exceptions import (
BadHttpMessage,
BadHttpMethod,
BadStatusLine,
ContentLengthError,
InvalidHeader,
InvalidURLError,
LineTooLong,
PayloadEncodingError,
TransferEncodingError,
)
from .http_parser import DeflateBuffer as _DeflateBuffer
from .http_writer import (
HttpVersion as _HttpVersion,
HttpVersion10 as _HttpVersion10,
HttpVersion11 as _HttpVersion11,
)
from .streams import EMPTY_PAYLOAD as _EMPTY_PAYLOAD, StreamReader as _StreamReader
cimport cython
from aiohttp cimport _cparser as cparser
include "_headers.pxi"
from aiohttp cimport _find_header
ALLOWED_UPGRADES = frozenset({"websocket"})
DEF DEFAULT_FREELIST_SIZE = 250
cdef extern from "Python.h":
int PyByteArray_Resize(object, Py_ssize_t) except -1
Py_ssize_t PyByteArray_Size(object) except -1
char* PyByteArray_AsString(object)
__all__ = ('HttpRequestParser', 'HttpResponseParser',
'RawRequestMessage', 'RawResponseMessage')
cdef object URL = _URL
cdef object URL_build = URL.build
cdef object CIMultiDict = _CIMultiDict
cdef object CIMultiDictProxy = _CIMultiDictProxy
cdef object HttpVersion = _HttpVersion
cdef object HttpVersion10 = _HttpVersion10
cdef object HttpVersion11 = _HttpVersion11
cdef object SEC_WEBSOCKET_KEY1 = hdrs.SEC_WEBSOCKET_KEY1
cdef object CONTENT_ENCODING = hdrs.CONTENT_ENCODING
cdef object EMPTY_PAYLOAD = _EMPTY_PAYLOAD
cdef object StreamReader = _StreamReader
cdef object DeflateBuffer = _DeflateBuffer
cdef bytes EMPTY_BYTES = b""
cdef inline object extend(object buf, const char* at, size_t length):
cdef Py_ssize_t s
cdef char* ptr
s = PyByteArray_Size(buf)
PyByteArray_Resize(buf, s + length)
ptr = PyByteArray_AsString(buf)
memcpy(ptr + s, at, length)
DEF METHODS_COUNT = 46;
cdef list _http_method = []
for i in range(METHODS_COUNT):
_http_method.append(
cparser.llhttp_method_name(<cparser.llhttp_method_t> i).decode('ascii'))
cdef inline str http_method_str(int i):
if i < METHODS_COUNT:
return <str>_http_method[i]
else:
return "<unknown>"
cdef inline object find_header(bytes raw_header):
cdef Py_ssize_t size
cdef char *buf
cdef int idx
PyBytes_AsStringAndSize(raw_header, &buf, &size)
idx = _find_header.find_header(buf, size)
if idx == -1:
return raw_header.decode('utf-8', 'surrogateescape')
return headers[idx]
@cython.freelist(DEFAULT_FREELIST_SIZE)
cdef class RawRequestMessage:
cdef readonly str method
cdef readonly str path
cdef readonly object version # HttpVersion
cdef readonly object headers # CIMultiDict
cdef readonly object raw_headers # tuple
cdef readonly object should_close
cdef readonly object compression
cdef readonly object upgrade
cdef readonly object chunked
cdef readonly object url # yarl.URL
def __init__(self, method, path, version, headers, raw_headers,
should_close, compression, upgrade, chunked, url):
self.method = method
self.path = path
self.version = version
self.headers = headers
self.raw_headers = raw_headers
self.should_close = should_close
self.compression = compression
self.upgrade = upgrade
self.chunked = chunked
self.url = url
def __repr__(self):
info = []
info.append(("method", self.method))
info.append(("path", self.path))
info.append(("version", self.version))
info.append(("headers", self.headers))
info.append(("raw_headers", self.raw_headers))
info.append(("should_close", self.should_close))
info.append(("compression", self.compression))
info.append(("upgrade", self.upgrade))
info.append(("chunked", self.chunked))
info.append(("url", self.url))
sinfo = ', '.join(name + '=' + repr(val) for name, val in info)
return '<RawRequestMessage(' + sinfo + ')>'
def _replace(self, **dct):
cdef RawRequestMessage ret
ret = _new_request_message(self.method,
self.path,
self.version,
self.headers,
self.raw_headers,
self.should_close,
self.compression,
self.upgrade,
self.chunked,
self.url)
if "method" in dct:
ret.method = dct["method"]
if "path" in dct:
ret.path = dct["path"]
if "version" in dct:
ret.version = dct["version"]
if "headers" in dct:
ret.headers = dct["headers"]
if "raw_headers" in dct:
ret.raw_headers = dct["raw_headers"]
if "should_close" in dct:
ret.should_close = dct["should_close"]
if "compression" in dct:
ret.compression = dct["compression"]
if "upgrade" in dct:
ret.upgrade = dct["upgrade"]
if "chunked" in dct:
ret.chunked = dct["chunked"]
if "url" in dct:
ret.url = dct["url"]
return ret
cdef _new_request_message(str method,
str path,
object version,
object headers,
object raw_headers,
bint should_close,
object compression,
bint upgrade,
bint chunked,
object url):
cdef RawRequestMessage ret
ret = RawRequestMessage.__new__(RawRequestMessage)
ret.method = method
ret.path = path
ret.version = version
ret.headers = headers
ret.raw_headers = raw_headers
ret.should_close = should_close
ret.compression = compression
ret.upgrade = upgrade
ret.chunked = chunked
ret.url = url
return ret
@cython.freelist(DEFAULT_FREELIST_SIZE)
cdef class RawResponseMessage:
cdef readonly object version # HttpVersion
cdef readonly int code
cdef readonly str reason
cdef readonly object headers # CIMultiDict
cdef readonly object raw_headers # tuple
cdef readonly object should_close
cdef readonly object compression
cdef readonly object upgrade
cdef readonly object chunked
def __init__(self, version, code, reason, headers, raw_headers,
should_close, compression, upgrade, chunked):
self.version = version
self.code = code
self.reason = reason
self.headers = headers
self.raw_headers = raw_headers
self.should_close = should_close
self.compression = compression
self.upgrade = upgrade
self.chunked = chunked
def __repr__(self):
info = []
info.append(("version", self.version))
info.append(("code", self.code))
info.append(("reason", self.reason))
info.append(("headers", self.headers))
info.append(("raw_headers", self.raw_headers))
info.append(("should_close", self.should_close))
info.append(("compression", self.compression))
info.append(("upgrade", self.upgrade))
info.append(("chunked", self.chunked))
sinfo = ', '.join(name + '=' + repr(val) for name, val in info)
return '<RawResponseMessage(' + sinfo + ')>'
cdef _new_response_message(object version,
int code,
str reason,
object headers,
object raw_headers,
bint should_close,
object compression,
bint upgrade,
bint chunked):
cdef RawResponseMessage ret
ret = RawResponseMessage.__new__(RawResponseMessage)
ret.version = version
ret.code = code
ret.reason = reason
ret.headers = headers
ret.raw_headers = raw_headers
ret.should_close = should_close
ret.compression = compression
ret.upgrade = upgrade
ret.chunked = chunked
return ret
@cython.internal
cdef class HttpParser:
cdef:
cparser.llhttp_t* _cparser
cparser.llhttp_settings_t* _csettings
bytes _raw_name
object _name
bytes _raw_value
bint _has_value
object _protocol
object _loop
object _timer
size_t _max_line_size
size_t _max_field_size
size_t _max_headers
bint _response_with_body
bint _read_until_eof
bint _started
object _url
bytearray _buf
str _path
str _reason
list _headers
list _raw_headers
bint _upgraded
list _messages
object _payload
bint _payload_error
object _payload_exception
object _last_error
bint _auto_decompress
int _limit
str _content_encoding
Py_buffer py_buf
def __cinit__(self):
self._cparser = <cparser.llhttp_t*> \
PyMem_Malloc(sizeof(cparser.llhttp_t))
if self._cparser is NULL:
raise MemoryError()
self._csettings = <cparser.llhttp_settings_t*> \
PyMem_Malloc(sizeof(cparser.llhttp_settings_t))
if self._csettings is NULL:
raise MemoryError()
def __dealloc__(self):
PyMem_Free(self._cparser)
PyMem_Free(self._csettings)
cdef _init(
self, cparser.llhttp_type mode,
object protocol, object loop, int limit,
object timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True,
):
cparser.llhttp_settings_init(self._csettings)
cparser.llhttp_init(self._cparser, mode, self._csettings)
self._cparser.data = <void*>self
self._cparser.content_length = 0
self._protocol = protocol
self._loop = loop
self._timer = timer
self._buf = bytearray()
self._payload = None
self._payload_error = 0
self._payload_exception = payload_exception
self._messages = []
self._raw_name = EMPTY_BYTES
self._raw_value = EMPTY_BYTES
self._has_value = False
self._max_line_size = max_line_size
self._max_headers = max_headers
self._max_field_size = max_field_size
self._response_with_body = response_with_body
self._read_until_eof = read_until_eof
self._upgraded = False
self._auto_decompress = auto_decompress
self._content_encoding = None
self._csettings.on_url = cb_on_url
self._csettings.on_status = cb_on_status
self._csettings.on_header_field = cb_on_header_field
self._csettings.on_header_value = cb_on_header_value
self._csettings.on_headers_complete = cb_on_headers_complete
self._csettings.on_body = cb_on_body
self._csettings.on_message_begin = cb_on_message_begin
self._csettings.on_message_complete = cb_on_message_complete
self._csettings.on_chunk_header = cb_on_chunk_header
self._csettings.on_chunk_complete = cb_on_chunk_complete
self._last_error = None
self._limit = limit
cdef _process_header(self):
cdef str value
if self._raw_name is not EMPTY_BYTES:
name = find_header(self._raw_name)
value = self._raw_value.decode('utf-8', 'surrogateescape')
self._headers.append((name, value))
if name is CONTENT_ENCODING:
self._content_encoding = value
self._has_value = False
self._raw_headers.append((self._raw_name, self._raw_value))
self._raw_name = EMPTY_BYTES
self._raw_value = EMPTY_BYTES
cdef _on_header_field(self, char* at, size_t length):
if self._has_value:
self._process_header()
if self._raw_name is EMPTY_BYTES:
self._raw_name = at[:length]
else:
self._raw_name += at[:length]
cdef _on_header_value(self, char* at, size_t length):
if self._raw_value is EMPTY_BYTES:
self._raw_value = at[:length]
else:
self._raw_value += at[:length]
self._has_value = True
cdef _on_headers_complete(self):
self._process_header()
should_close = not cparser.llhttp_should_keep_alive(self._cparser)
upgrade = self._cparser.upgrade
chunked = self._cparser.flags & cparser.F_CHUNKED
raw_headers = tuple(self._raw_headers)
headers = CIMultiDictProxy(CIMultiDict(self._headers))
if self._cparser.type == cparser.HTTP_REQUEST:
allowed = upgrade and headers.get("upgrade", "").lower() in ALLOWED_UPGRADES
if allowed or self._cparser.method == cparser.HTTP_CONNECT:
self._upgraded = True
else:
if upgrade and self._cparser.status_code == 101:
self._upgraded = True
# do not support old websocket spec
if SEC_WEBSOCKET_KEY1 in headers:
raise InvalidHeader(SEC_WEBSOCKET_KEY1)
encoding = None
enc = self._content_encoding
if enc is not None:
self._content_encoding = None
enc = enc.lower()
if enc in ('gzip', 'deflate', 'br'):
encoding = enc
if self._cparser.type == cparser.HTTP_REQUEST:
method = http_method_str(self._cparser.method)
msg = _new_request_message(
method, self._path,
self.http_version(), headers, raw_headers,
should_close, encoding, upgrade, chunked, self._url)
else:
msg = _new_response_message(
self.http_version(), self._cparser.status_code, self._reason,
headers, raw_headers, should_close, encoding,
upgrade, chunked)
if (
ULLONG_MAX > self._cparser.content_length > 0 or chunked or
self._cparser.method == cparser.HTTP_CONNECT or
(self._cparser.status_code >= 199 and
self._cparser.content_length == 0 and
self._read_until_eof)
):
payload = StreamReader(
self._protocol, timer=self._timer, loop=self._loop,
limit=self._limit)
else:
payload = EMPTY_PAYLOAD
self._payload = payload
if encoding is not None and self._auto_decompress:
self._payload = DeflateBuffer(payload, encoding)
if not self._response_with_body:
payload = EMPTY_PAYLOAD
self._messages.append((msg, payload))
cdef _on_message_complete(self):
self._payload.feed_eof()
self._payload = None
cdef _on_chunk_header(self):
self._payload.begin_http_chunk_receiving()
cdef _on_chunk_complete(self):
self._payload.end_http_chunk_receiving()
cdef object _on_status_complete(self):
pass
cdef inline http_version(self):
cdef cparser.llhttp_t* parser = self._cparser
if parser.http_major == 1:
if parser.http_minor == 0:
return HttpVersion10
elif parser.http_minor == 1:
return HttpVersion11
return HttpVersion(parser.http_major, parser.http_minor)
### Public API ###
def feed_eof(self):
cdef bytes desc
if self._payload is not None:
if self._cparser.flags & cparser.F_CHUNKED:
raise TransferEncodingError(
"Not enough data for satisfy transfer length header.")
elif self._cparser.flags & cparser.F_CONTENT_LENGTH:
raise ContentLengthError(
"Not enough data for satisfy content length header.")
elif cparser.llhttp_get_errno(self._cparser) != cparser.HPE_OK:
desc = cparser.llhttp_get_error_reason(self._cparser)
raise PayloadEncodingError(desc.decode('latin-1'))
else:
self._payload.feed_eof()
elif self._started:
self._on_headers_complete()
if self._messages:
return self._messages[-1][0]
def feed_data(self, data):
cdef:
size_t data_len
size_t nb
cdef cparser.llhttp_errno_t errno
PyObject_GetBuffer(data, &self.py_buf, PyBUF_SIMPLE)
data_len = <size_t>self.py_buf.len
errno = cparser.llhttp_execute(
self._cparser,
<char*>self.py_buf.buf,
data_len)
if errno is cparser.HPE_PAUSED_UPGRADE:
cparser.llhttp_resume_after_upgrade(self._cparser)
nb = cparser.llhttp_get_error_pos(self._cparser) - <char*>self.py_buf.buf
PyBuffer_Release(&self.py_buf)
if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED_UPGRADE):
if self._payload_error == 0:
if self._last_error is not None:
ex = self._last_error
self._last_error = None
else:
after = cparser.llhttp_get_error_pos(self._cparser)
before = data[:after - <char*>self.py_buf.buf]
after_b = after.split(b"\r\n", 1)[0]
before = before.rsplit(b"\r\n", 1)[-1]
data = before + after_b
pointer = " " * (len(repr(before))-1) + "^"
ex = parser_error_from_errno(self._cparser, data, pointer)
self._payload = None
raise ex
if self._messages:
messages = self._messages
self._messages = []
else:
messages = ()
if self._upgraded:
return messages, True, data[nb:]
else:
return messages, False, b""
def set_upgraded(self, val):
self._upgraded = val
cdef class HttpRequestParser(HttpParser):
def __init__(
self, protocol, loop, int limit, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True,
):
self._init(cparser.HTTP_REQUEST, protocol, loop, limit, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof,
auto_decompress)
cdef object _on_status_complete(self):
cdef int idx1, idx2
if not self._buf:
return
self._path = self._buf.decode('utf-8', 'surrogateescape')
try:
idx3 = len(self._path)
if self._cparser.method == cparser.HTTP_CONNECT:
# authority-form,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3
self._url = URL.build(authority=self._path, encoded=True)
elif idx3 > 1 and self._path[0] == '/':
# origin-form,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.1
idx1 = self._path.find("?")
if idx1 == -1:
query = ""
idx2 = self._path.find("#")
if idx2 == -1:
path = self._path
fragment = ""
else:
path = self._path[0: idx2]
fragment = self._path[idx2+1:]
else:
path = self._path[0:idx1]
idx1 += 1
idx2 = self._path.find("#", idx1+1)
if idx2 == -1:
query = self._path[idx1:]
fragment = ""
else:
query = self._path[idx1: idx2]
fragment = self._path[idx2+1:]
self._url = URL.build(
path=path,
query_string=query,
fragment=fragment,
encoded=True,
)
else:
# absolute-form for proxy maybe,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2
self._url = URL(self._path, encoded=True)
finally:
PyByteArray_Resize(self._buf, 0)
cdef class HttpResponseParser(HttpParser):
def __init__(
self, protocol, loop, int limit, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True
):
self._init(cparser.HTTP_RESPONSE, protocol, loop, limit, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof,
auto_decompress)
# Use strict parsing on dev mode, so users are warned about broken servers.
if not DEBUG:
cparser.llhttp_set_lenient_headers(self._cparser, 1)
cparser.llhttp_set_lenient_optional_cr_before_lf(self._cparser, 1)
cparser.llhttp_set_lenient_spaces_after_chunk_size(self._cparser, 1)
cdef object _on_status_complete(self):
if self._buf:
self._reason = self._buf.decode('utf-8', 'surrogateescape')
PyByteArray_Resize(self._buf, 0)
else:
self._reason = self._reason or ''
cdef int cb_on_message_begin(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
pyparser._started = True
pyparser._headers = []
pyparser._raw_headers = []
PyByteArray_Resize(pyparser._buf, 0)
pyparser._path = None
pyparser._reason = None
return 0
cdef int cb_on_url(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
if length > pyparser._max_line_size:
raise LineTooLong(
'Status line is too long', pyparser._max_line_size, length)
extend(pyparser._buf, at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_status(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef str reason
try:
if length > pyparser._max_line_size:
raise LineTooLong(
'Status line is too long', pyparser._max_line_size, length)
extend(pyparser._buf, at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_header_field(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef Py_ssize_t size
try:
pyparser._on_status_complete()
size = len(pyparser._raw_name) + length
if size > pyparser._max_field_size:
raise LineTooLong(
'Header name is too long', pyparser._max_field_size, size)
pyparser._on_header_field(at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_header_value(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef Py_ssize_t size
try:
size = len(pyparser._raw_value) + length
if size > pyparser._max_field_size:
raise LineTooLong(
'Header value is too long', pyparser._max_field_size, size)
pyparser._on_header_value(at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_status_complete()
pyparser._on_headers_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT:
return 2
else:
return 0
cdef int cb_on_body(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef bytes body = at[:length]
try:
pyparser._payload.feed_data(body, length)
except BaseException as underlying_exc:
reraised_exc = underlying_exc
if pyparser._payload_exception is not None:
reraised_exc = pyparser._payload_exception(str(underlying_exc))
set_exception(pyparser._payload, reraised_exc, underlying_exc)
pyparser._payload_error = 1
return -1
else:
return 0
cdef int cb_on_message_complete(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._started = False
pyparser._on_message_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef int cb_on_chunk_header(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_chunk_header()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef int cb_on_chunk_complete(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_chunk_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef parser_error_from_errno(cparser.llhttp_t* parser, data, pointer):
cdef cparser.llhttp_errno_t errno = cparser.llhttp_get_errno(parser)
cdef bytes desc = cparser.llhttp_get_error_reason(parser)
err_msg = "{}:\n\n {!r}\n {}".format(desc.decode("latin-1"), data, pointer)
if errno in {cparser.HPE_CB_MESSAGE_BEGIN,
cparser.HPE_CB_HEADERS_COMPLETE,
cparser.HPE_CB_MESSAGE_COMPLETE,
cparser.HPE_CB_CHUNK_HEADER,
cparser.HPE_CB_CHUNK_COMPLETE,
cparser.HPE_INVALID_CONSTANT,
cparser.HPE_INVALID_HEADER_TOKEN,
cparser.HPE_INVALID_CONTENT_LENGTH,
cparser.HPE_INVALID_CHUNK_SIZE,
cparser.HPE_INVALID_EOF_STATE,
cparser.HPE_INVALID_TRANSFER_ENCODING}:
return BadHttpMessage(err_msg)
elif errno == cparser.HPE_INVALID_METHOD:
return BadHttpMethod(error=err_msg)
elif errno in {cparser.HPE_INVALID_STATUS,
cparser.HPE_INVALID_VERSION}:
return BadStatusLine(error=err_msg)
elif errno == cparser.HPE_INVALID_URL:
return InvalidURLError(err_msg)
return BadHttpMessage(err_msg)

View File

@@ -0,0 +1,160 @@
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.exc cimport PyErr_NoMemory
from cpython.mem cimport PyMem_Free, PyMem_Malloc, PyMem_Realloc
from cpython.object cimport PyObject_Str
from libc.stdint cimport uint8_t, uint64_t
from libc.string cimport memcpy
from multidict import istr
DEF BUF_SIZE = 16 * 1024 # 16KiB
cdef char BUFFER[BUF_SIZE]
cdef object _istr = istr
# ----------------- writer ---------------------------
cdef struct Writer:
char *buf
Py_ssize_t size
Py_ssize_t pos
cdef inline void _init_writer(Writer* writer):
writer.buf = &BUFFER[0]
writer.size = BUF_SIZE
writer.pos = 0
cdef inline void _release_writer(Writer* writer):
if writer.buf != BUFFER:
PyMem_Free(writer.buf)
cdef inline int _write_byte(Writer* writer, uint8_t ch):
cdef char * buf
cdef Py_ssize_t size
if writer.pos == writer.size:
# reallocate
size = writer.size + BUF_SIZE
if writer.buf == BUFFER:
buf = <char*>PyMem_Malloc(size)
if buf == NULL:
PyErr_NoMemory()
return -1
memcpy(buf, writer.buf, writer.size)
else:
buf = <char*>PyMem_Realloc(writer.buf, size)
if buf == NULL:
PyErr_NoMemory()
return -1
writer.buf = buf
writer.size = size
writer.buf[writer.pos] = <char>ch
writer.pos += 1
return 0
cdef inline int _write_utf8(Writer* writer, Py_UCS4 symbol):
cdef uint64_t utf = <uint64_t> symbol
if utf < 0x80:
return _write_byte(writer, <uint8_t>utf)
elif utf < 0x800:
if _write_byte(writer, <uint8_t>(0xc0 | (utf >> 6))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
elif 0xD800 <= utf <= 0xDFFF:
# surogate pair, ignored
return 0
elif utf < 0x10000:
if _write_byte(writer, <uint8_t>(0xe0 | (utf >> 12))) < 0:
return -1
if _write_byte(writer, <uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
elif utf > 0x10FFFF:
# symbol is too large
return 0
else:
if _write_byte(writer, <uint8_t>(0xf0 | (utf >> 18))) < 0:
return -1
if _write_byte(writer,
<uint8_t>(0x80 | ((utf >> 12) & 0x3f))) < 0:
return -1
if _write_byte(writer,
<uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
cdef inline int _write_str(Writer* writer, str s):
cdef Py_UCS4 ch
for ch in s:
if _write_utf8(writer, ch) < 0:
return -1
cdef inline int _write_str_raise_on_nlcr(Writer* writer, object s):
cdef Py_UCS4 ch
cdef str out_str
if type(s) is str:
out_str = <str>s
elif type(s) is _istr:
out_str = PyObject_Str(s)
elif not isinstance(s, str):
raise TypeError("Cannot serialize non-str key {!r}".format(s))
else:
out_str = str(s)
for ch in out_str:
if ch == 0x0D or ch == 0x0A:
raise ValueError(
"Newline or carriage return detected in headers. "
"Potential header injection attack."
)
if _write_utf8(writer, ch) < 0:
return -1
# --------------- _serialize_headers ----------------------
def _serialize_headers(str status_line, headers):
cdef Writer writer
cdef object key
cdef object val
_init_writer(&writer)
try:
if _write_str(&writer, status_line) < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
for key, val in headers.items():
if _write_str_raise_on_nlcr(&writer, key) < 0:
raise
if _write_byte(&writer, b':') < 0:
raise
if _write_byte(&writer, b' ') < 0:
raise
if _write_str_raise_on_nlcr(&writer, val) < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
return PyBytes_FromStringAndSize(writer.buf, writer.pos)
finally:
_release_writer(&writer)

View File

@@ -0,0 +1 @@
b01999d409b29bd916e067bc963d5f2d9ee63cfc9ae0bccb769910131417bf93 /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pxd

View File

@@ -0,0 +1 @@
0478ceb55d0ed30ef1a7da742cd003449bc69a07cf9fdb06789bd2b347cbfffe /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pyx

View File

@@ -0,0 +1 @@
9e5fe78ed0ebce5414d2b8e01868d90c1facc20b84d2d5ff6c23e86e44a155ae /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/reader_c.pxd

View File

@@ -0,0 +1 @@
"""WebSocket protocol versions 13 and 8."""

View File

@@ -0,0 +1,147 @@
"""Helpers for WebSocket protocol versions 13 and 8."""
import functools
import re
from struct import Struct
from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple
from ..helpers import NO_EXTENSIONS
from .models import WSHandshakeError
UNPACK_LEN3 = Struct("!Q").unpack_from
UNPACK_CLOSE_CODE = Struct("!H").unpack
PACK_LEN1 = Struct("!BB").pack
PACK_LEN2 = Struct("!BBH").pack
PACK_LEN3 = Struct("!BBQ").pack
PACK_CLOSE_CODE = Struct("!H").pack
PACK_RANDBITS = Struct("!L").pack
MSG_SIZE: Final[int] = 2**14
MASK_LEN: Final[int] = 4
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
# Used by _websocket_mask_python
@functools.lru_cache
def _xor_table() -> List[bytes]:
return [bytes(a ^ b for a in range(256)) for b in range(256)]
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
"""Websocket masking function.
`mask` is a `bytes` object of length 4; `data` is a `bytearray`
object of any length. The contents of `data` are masked with `mask`,
as specified in section 5.3 of RFC 6455.
Note that this function mutates the `data` argument.
This pure-python implementation may be replaced by an optimized
version when available.
"""
assert isinstance(data, bytearray), data
assert len(mask) == 4, mask
if data:
_XOR_TABLE = _xor_table()
a, b, c, d = (_XOR_TABLE[n] for n in mask)
data[::4] = data[::4].translate(a)
data[1::4] = data[1::4].translate(b)
data[2::4] = data[2::4].translate(c)
data[3::4] = data[3::4].translate(d)
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
websocket_mask = _websocket_mask_python
else:
try:
from .mask import _websocket_mask_cython # type: ignore[import-not-found]
websocket_mask = _websocket_mask_cython
except ImportError: # pragma: no cover
websocket_mask = _websocket_mask_python
_WS_EXT_RE: Final[Pattern[str]] = re.compile(
r"^(?:;\s*(?:"
r"(server_no_context_takeover)|"
r"(client_no_context_takeover)|"
r"(server_max_window_bits(?:=(\d+))?)|"
r"(client_max_window_bits(?:=(\d+))?)))*$"
)
_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
if not extstr:
return 0, False
compress = 0
notakeover = False
for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
defext = ext.group(1)
# Return compress = 15 when get `permessage-deflate`
if not defext:
compress = 15
break
match = _WS_EXT_RE.match(defext)
if match:
compress = 15
if isserver:
# Server never fail to detect compress handshake.
# Server does not need to send max wbit to client
if match.group(4):
compress = int(match.group(4))
# Group3 must match if group4 matches
# Compress wbit 8 does not support in zlib
# If compress level not support,
# CONTINUE to next extension
if compress > 15 or compress < 9:
compress = 0
continue
if match.group(1):
notakeover = True
# Ignore regex group 5 & 6 for client_max_window_bits
break
else:
if match.group(6):
compress = int(match.group(6))
# Group5 must match if group6 matches
# Compress wbit 8 does not support in zlib
# If compress level not support,
# FAIL the parse progress
if compress > 15 or compress < 9:
raise WSHandshakeError("Invalid window size")
if match.group(2):
notakeover = True
# Ignore regex group 5 & 6 for client_max_window_bits
break
# Return Fail if client side and not match
elif not isserver:
raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
return compress, notakeover
def ws_ext_gen(
compress: int = 15, isserver: bool = False, server_notakeover: bool = False
) -> str:
# client_notakeover=False not used for server
# compress wbit 8 does not support in zlib
if compress < 9 or compress > 15:
raise ValueError(
"Compress wbits must between 9 and 15, zlib does not support wbits=8"
)
enabledext = ["permessage-deflate"]
if not isserver:
enabledext.append("client_max_window_bits")
if compress < 15:
enabledext.append("server_max_window_bits=" + str(compress))
if server_notakeover:
enabledext.append("server_no_context_takeover")
# if client_notakeover:
# enabledext.append('client_no_context_takeover')
return "; ".join(enabledext)

View File

@@ -0,0 +1,3 @@
"""Cython declarations for websocket masking."""
cpdef void _websocket_mask_cython(bytes mask, bytearray data)

View File

@@ -0,0 +1,48 @@
from cpython cimport PyBytes_AsString
#from cpython cimport PyByteArray_AsString # cython still not exports that
cdef extern from "Python.h":
char* PyByteArray_AsString(bytearray ba) except NULL
from libc.stdint cimport uint32_t, uint64_t, uintmax_t
cpdef void _websocket_mask_cython(bytes mask, bytearray data):
"""Note, this function mutates its `data` argument
"""
cdef:
Py_ssize_t data_len, i
# bit operations on signed integers are implementation-specific
unsigned char * in_buf
const unsigned char * mask_buf
uint32_t uint32_msk
uint64_t uint64_msk
assert len(mask) == 4
data_len = len(data)
in_buf = <unsigned char*>PyByteArray_AsString(data)
mask_buf = <const unsigned char*>PyBytes_AsString(mask)
uint32_msk = (<uint32_t*>mask_buf)[0]
# TODO: align in_data ptr to achieve even faster speeds
# does it need in python ?! malloc() always aligns to sizeof(long) bytes
if sizeof(size_t) >= 8:
uint64_msk = uint32_msk
uint64_msk = (uint64_msk << 32) | uint32_msk
while data_len >= 8:
(<uint64_t*>in_buf)[0] ^= uint64_msk
in_buf += 8
data_len -= 8
while data_len >= 4:
(<uint32_t*>in_buf)[0] ^= uint32_msk
in_buf += 4
data_len -= 4
for i in range(0, data_len):
in_buf[i] ^= mask_buf[i]

View File

@@ -0,0 +1,84 @@
"""Models for WebSocket protocol versions 13 and 8."""
import json
from enum import IntEnum
from typing import Any, Callable, Final, NamedTuple, Optional, cast
WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
class WSCloseCode(IntEnum):
OK = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
ABNORMAL_CLOSURE = 1006
INVALID_TEXT = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
BAD_GATEWAY = 1014
class WSMsgType(IntEnum):
# websocket spec types
CONTINUATION = 0x0
TEXT = 0x1
BINARY = 0x2
PING = 0x9
PONG = 0xA
CLOSE = 0x8
# aiohttp specific types
CLOSING = 0x100
CLOSED = 0x101
ERROR = 0x102
text = TEXT
binary = BINARY
ping = PING
pong = PONG
close = CLOSE
closing = CLOSING
closed = CLOSED
error = ERROR
class WSMessage(NamedTuple):
type: WSMsgType
# To type correctly, this would need some kind of tagged union for each type.
data: Any
extra: Optional[str]
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
"""Return parsed JSON data.
.. versionadded:: 0.22
"""
return loads(self.data)
# Constructing the tuple directly to avoid the overhead of
# the lambda and arg processing since NamedTuples are constructed
# with a run time built lambda
# https://github.com/python/cpython/blob/d83fcf8371f2f33c7797bc8f5423a8bca8c46e5c/Lib/collections/__init__.py#L441
WS_CLOSED_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSED, None, None))
WS_CLOSING_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSING, None, None))
class WebSocketError(Exception):
"""WebSocket protocol parser error."""
def __init__(self, code: int, message: str) -> None:
self.code = code
super().__init__(code, message)
def __str__(self) -> str:
return cast(str, self.args[1])
class WSHandshakeError(Exception):
"""WebSocket protocol handshake error."""

View File

@@ -0,0 +1,31 @@
"""Reader for WebSocket protocol versions 13 and 8."""
from typing import TYPE_CHECKING
from ..helpers import NO_EXTENSIONS
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
WebSocketReader as WebSocketReaderPython,
)
WebSocketReader = WebSocketReaderPython
WebSocketDataQueue = WebSocketDataQueuePython
else:
try:
from .reader_c import ( # type: ignore[import-not-found]
WebSocketDataQueue as WebSocketDataQueueCython,
WebSocketReader as WebSocketReaderCython,
)
WebSocketReader = WebSocketReaderCython
WebSocketDataQueue = WebSocketDataQueueCython
except ImportError: # pragma: no cover
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
WebSocketReader as WebSocketReaderPython,
)
WebSocketReader = WebSocketReaderPython
WebSocketDataQueue = WebSocketDataQueuePython

View File

@@ -0,0 +1,110 @@
import cython
from .mask cimport _websocket_mask_cython as websocket_mask
cdef unsigned int READ_HEADER
cdef unsigned int READ_PAYLOAD_LENGTH
cdef unsigned int READ_PAYLOAD_MASK
cdef unsigned int READ_PAYLOAD
cdef int OP_CODE_NOT_SET
cdef int OP_CODE_CONTINUATION
cdef int OP_CODE_TEXT
cdef int OP_CODE_BINARY
cdef int OP_CODE_CLOSE
cdef int OP_CODE_PING
cdef int OP_CODE_PONG
cdef int COMPRESSED_NOT_SET
cdef int COMPRESSED_FALSE
cdef int COMPRESSED_TRUE
cdef object UNPACK_LEN3
cdef object UNPACK_CLOSE_CODE
cdef object TUPLE_NEW
cdef object WSMsgType
cdef object WSMessage
cdef object WS_MSG_TYPE_TEXT
cdef object WS_MSG_TYPE_BINARY
cdef set ALLOWED_CLOSE_CODES
cdef set MESSAGE_TYPES_WITH_CONTENT
cdef tuple EMPTY_FRAME
cdef tuple EMPTY_FRAME_ERROR
cdef class WebSocketDataQueue:
cdef unsigned int _size
cdef public object _protocol
cdef unsigned int _limit
cdef object _loop
cdef bint _eof
cdef object _waiter
cdef object _exception
cdef public object _buffer
cdef object _get_buffer
cdef object _put_buffer
cdef void _release_waiter(self)
cpdef void feed_data(self, object data, unsigned int size)
@cython.locals(size="unsigned int")
cdef _read_from_buffer(self)
cdef class WebSocketReader:
cdef WebSocketDataQueue queue
cdef unsigned int _max_msg_size
cdef Exception _exc
cdef bytearray _partial
cdef unsigned int _state
cdef int _opcode
cdef bint _frame_fin
cdef int _frame_opcode
cdef list _payload_fragments
cdef Py_ssize_t _frame_payload_len
cdef bytes _tail
cdef bint _has_mask
cdef bytes _frame_mask
cdef Py_ssize_t _payload_bytes_to_read
cdef unsigned int _payload_len_flag
cdef int _compressed
cdef object _decompressobj
cdef bint _compress
cpdef tuple feed_data(self, object data)
@cython.locals(
is_continuation=bint,
fin=bint,
has_partial=bint,
payload_merged=bytes,
)
cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except *
@cython.locals(
start_pos=Py_ssize_t,
data_len=Py_ssize_t,
length=Py_ssize_t,
chunk_size=Py_ssize_t,
chunk_len=Py_ssize_t,
data_len=Py_ssize_t,
data_cstr="const unsigned char *",
first_byte="unsigned char",
second_byte="unsigned char",
f_start_pos=Py_ssize_t,
f_end_pos=Py_ssize_t,
has_mask=bint,
fin=bint,
had_fragments=Py_ssize_t,
payload_bytearray=bytearray,
)
cpdef void _feed_data(self, bytes data) except *

View File

@@ -0,0 +1,468 @@
"""Reader for WebSocket protocol versions 13 and 8."""
import asyncio
import builtins
from collections import deque
from typing import Deque, Final, Optional, Set, Tuple, Union
from ..base_protocol import BaseProtocol
from ..compression_utils import ZLibDecompressor
from ..helpers import _EXC_SENTINEL, set_exception
from ..streams import EofStream
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
from .models import (
WS_DEFLATE_TRAILING,
WebSocketError,
WSCloseCode,
WSMessage,
WSMsgType,
)
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
# States for the reader, used to parse the WebSocket frame
# integer values are used so they can be cythonized
READ_HEADER = 1
READ_PAYLOAD_LENGTH = 2
READ_PAYLOAD_MASK = 3
READ_PAYLOAD = 4
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
# WSMsgType values unpacked so they can by cythonized to ints
OP_CODE_NOT_SET = -1
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
OP_CODE_TEXT = WSMsgType.TEXT.value
OP_CODE_BINARY = WSMsgType.BINARY.value
OP_CODE_CLOSE = WSMsgType.CLOSE.value
OP_CODE_PING = WSMsgType.PING.value
OP_CODE_PONG = WSMsgType.PONG.value
EMPTY_FRAME_ERROR = (True, b"")
EMPTY_FRAME = (False, b"")
COMPRESSED_NOT_SET = -1
COMPRESSED_FALSE = 0
COMPRESSED_TRUE = 1
TUPLE_NEW = tuple.__new__
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
class WebSocketDataQueue:
"""WebSocketDataQueue resumes and pauses an underlying stream.
It is a destination for WebSocket data.
"""
def __init__(
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
) -> None:
self._size = 0
self._protocol = protocol
self._limit = limit * 2
self._loop = loop
self._eof = False
self._waiter: Optional[asyncio.Future[None]] = None
self._exception: Union[BaseException, None] = None
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
self._get_buffer = self._buffer.popleft
self._put_buffer = self._buffer.append
def is_eof(self) -> bool:
return self._eof
def exception(self) -> Optional[BaseException]:
return self._exception
def set_exception(
self,
exc: "BaseException",
exc_cause: builtins.BaseException = _EXC_SENTINEL,
) -> None:
self._eof = True
self._exception = exc
if (waiter := self._waiter) is not None:
self._waiter = None
set_exception(waiter, exc, exc_cause)
def _release_waiter(self) -> None:
if (waiter := self._waiter) is None:
return
self._waiter = None
if not waiter.done():
waiter.set_result(None)
def feed_eof(self) -> None:
self._eof = True
self._release_waiter()
self._exception = None # Break cyclic references
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
self._size += size
self._put_buffer((data, size))
self._release_waiter()
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()
async def read(self) -> WSMessage:
if not self._buffer and not self._eof:
assert not self._waiter
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
return self._read_from_buffer()
def _read_from_buffer(self) -> WSMessage:
if self._buffer:
data, size = self._get_buffer()
self._size -= size
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
if self._exception is not None:
raise self._exception
raise EofStream
class WebSocketReader:
def __init__(
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._exc: Optional[Exception] = None
self._partial = bytearray()
self._state = READ_HEADER
self._opcode: int = OP_CODE_NOT_SET
self._frame_fin = False
self._frame_opcode: int = OP_CODE_NOT_SET
self._payload_fragments: list[bytes] = []
self._frame_payload_len = 0
self._tail: bytes = b""
self._has_mask = False
self._frame_mask: Optional[bytes] = None
self._payload_bytes_to_read = 0
self._payload_len_flag = 0
self._compressed: int = COMPRESSED_NOT_SET
self._decompressobj: Optional[ZLibDecompressor] = None
self._compress = compress
def feed_eof(self) -> None:
self.queue.feed_eof()
# data can be bytearray on Windows because proactor event loop uses bytearray
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
# coerce data to bytes if it is not
def feed_data(
self, data: Union[bytes, bytearray, memoryview]
) -> Tuple[bool, bytes]:
if type(data) is not bytes:
data = bytes(data)
if self._exc is not None:
return True, data
try:
self._feed_data(data)
except Exception as exc:
self._exc = exc
set_exception(self.queue, exc)
return EMPTY_FRAME_ERROR
return EMPTY_FRAME
def _handle_frame(
self,
fin: bool,
opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
payload: Union[bytes, bytearray],
compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
) -> None:
msg: WSMessage
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
# load text/binary
if not fin:
# got partial frame payload
if opcode != OP_CODE_CONTINUATION:
self._opcode = opcode
self._partial += payload
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {len(self._partial)} "
f"exceeds limit {self._max_msg_size}",
)
return
has_partial = bool(self._partial)
if opcode == OP_CODE_CONTINUATION:
if self._opcode == OP_CODE_NOT_SET:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Continuation frame for non started message",
)
opcode = self._opcode
self._opcode = OP_CODE_NOT_SET
# previous frame was non finished
# we should get continuation opcode
elif has_partial:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
f"to be zero, got {opcode!r}",
)
assembled_payload: Union[bytes, bytearray]
if has_partial:
assembled_payload = self._partial + payload
self._partial.clear()
else:
assembled_payload = payload
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {len(assembled_payload)} "
f"exceeds limit {self._max_msg_size}",
)
# Decompress process must to be done after all packets
# received.
if compressed:
if not self._decompressobj:
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
payload_merged = self._decompressobj.decompress_sync(
assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size
)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Decompressed message size {self._max_msg_size + left}"
f" exceeds limit {self._max_msg_size}",
)
elif type(assembled_payload) is bytes:
payload_merged = assembled_payload
else:
payload_merged = bytes(assembled_payload)
if opcode == OP_CODE_TEXT:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
len(payload_merged),
)
else:
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
len(payload_merged),
)
elif opcode == OP_CODE_CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close code: {close_code}",
)
try:
close_message = payload[2:].decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close frame: {fin} {opcode} {payload!r}",
)
else:
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
self.queue.feed_data(msg, 0)
elif opcode == OP_CODE_PING:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
self.queue.feed_data(msg, len(payload))
elif opcode == OP_CODE_PONG:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
self.queue.feed_data(msg, len(payload))
else:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
)
def _feed_data(self, data: bytes) -> None:
"""Return the next frame from the socket."""
if self._tail:
data, self._tail = self._tail + data, b""
start_pos: int = 0
data_len = len(data)
data_cstr = data
while True:
# read header
if self._state == READ_HEADER:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xF
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received fragmented control frame",
)
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7F
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be larger than 125 bytes",
)
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_len_flag = length
self._state = READ_PAYLOAD_LENGTH
# read payload length
if self._state == READ_PAYLOAD_LENGTH:
len_flag = self._payload_len_flag
if len_flag == 126:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
self._payload_bytes_to_read = first_byte << 8 | second_byte
elif len_flag > 126:
if data_len - start_pos < 8:
break
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
start_pos += 8
else:
self._payload_bytes_to_read = len_flag
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
# read payload mask
if self._state == READ_PAYLOAD_MASK:
if data_len - start_pos < 4:
break
self._frame_mask = data_cstr[start_pos : start_pos + 4]
start_pos += 4
self._state = READ_PAYLOAD
if self._state == READ_PAYLOAD:
chunk_len = data_len - start_pos
if self._payload_bytes_to_read >= chunk_len:
f_end_pos = data_len
self._payload_bytes_to_read -= chunk_len
else:
f_end_pos = start_pos + self._payload_bytes_to_read
self._payload_bytes_to_read = 0
had_fragments = self._frame_payload_len
self._frame_payload_len += f_end_pos - start_pos
f_start_pos = start_pos
start_pos = f_end_pos
if self._payload_bytes_to_read != 0:
# If we don't have a complete frame, we need to save the
# data for the next call to feed_data.
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
break
payload: Union[bytes, bytearray]
if had_fragments:
# We have to join the payload fragments get the payload
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
if self._has_mask:
assert self._frame_mask is not None
payload_bytearray = bytearray(b"".join(self._payload_fragments))
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = b"".join(self._payload_fragments)
self._payload_fragments.clear()
elif self._has_mask:
assert self._frame_mask is not None
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
if type(payload_bytearray) is not bytearray: # pragma: no branch
# Cython will do the conversion for us
# but we need to do it for Python and we
# will always get here in Python
payload_bytearray = bytearray(payload_bytearray)
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = data_cstr[f_start_pos:f_end_pos]
self._handle_frame(
self._frame_fin, self._frame_opcode, payload, self._compressed
)
self._frame_payload_len = 0
self._state = READ_HEADER
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""

View File

@@ -0,0 +1,468 @@
"""Reader for WebSocket protocol versions 13 and 8."""
import asyncio
import builtins
from collections import deque
from typing import Deque, Final, Optional, Set, Tuple, Union
from ..base_protocol import BaseProtocol
from ..compression_utils import ZLibDecompressor
from ..helpers import _EXC_SENTINEL, set_exception
from ..streams import EofStream
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
from .models import (
WS_DEFLATE_TRAILING,
WebSocketError,
WSCloseCode,
WSMessage,
WSMsgType,
)
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
# States for the reader, used to parse the WebSocket frame
# integer values are used so they can be cythonized
READ_HEADER = 1
READ_PAYLOAD_LENGTH = 2
READ_PAYLOAD_MASK = 3
READ_PAYLOAD = 4
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
# WSMsgType values unpacked so they can by cythonized to ints
OP_CODE_NOT_SET = -1
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
OP_CODE_TEXT = WSMsgType.TEXT.value
OP_CODE_BINARY = WSMsgType.BINARY.value
OP_CODE_CLOSE = WSMsgType.CLOSE.value
OP_CODE_PING = WSMsgType.PING.value
OP_CODE_PONG = WSMsgType.PONG.value
EMPTY_FRAME_ERROR = (True, b"")
EMPTY_FRAME = (False, b"")
COMPRESSED_NOT_SET = -1
COMPRESSED_FALSE = 0
COMPRESSED_TRUE = 1
TUPLE_NEW = tuple.__new__
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
class WebSocketDataQueue:
"""WebSocketDataQueue resumes and pauses an underlying stream.
It is a destination for WebSocket data.
"""
def __init__(
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
) -> None:
self._size = 0
self._protocol = protocol
self._limit = limit * 2
self._loop = loop
self._eof = False
self._waiter: Optional[asyncio.Future[None]] = None
self._exception: Union[BaseException, None] = None
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
self._get_buffer = self._buffer.popleft
self._put_buffer = self._buffer.append
def is_eof(self) -> bool:
return self._eof
def exception(self) -> Optional[BaseException]:
return self._exception
def set_exception(
self,
exc: "BaseException",
exc_cause: builtins.BaseException = _EXC_SENTINEL,
) -> None:
self._eof = True
self._exception = exc
if (waiter := self._waiter) is not None:
self._waiter = None
set_exception(waiter, exc, exc_cause)
def _release_waiter(self) -> None:
if (waiter := self._waiter) is None:
return
self._waiter = None
if not waiter.done():
waiter.set_result(None)
def feed_eof(self) -> None:
self._eof = True
self._release_waiter()
self._exception = None # Break cyclic references
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
self._size += size
self._put_buffer((data, size))
self._release_waiter()
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()
async def read(self) -> WSMessage:
if not self._buffer and not self._eof:
assert not self._waiter
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
return self._read_from_buffer()
def _read_from_buffer(self) -> WSMessage:
if self._buffer:
data, size = self._get_buffer()
self._size -= size
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
if self._exception is not None:
raise self._exception
raise EofStream
class WebSocketReader:
def __init__(
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._exc: Optional[Exception] = None
self._partial = bytearray()
self._state = READ_HEADER
self._opcode: int = OP_CODE_NOT_SET
self._frame_fin = False
self._frame_opcode: int = OP_CODE_NOT_SET
self._payload_fragments: list[bytes] = []
self._frame_payload_len = 0
self._tail: bytes = b""
self._has_mask = False
self._frame_mask: Optional[bytes] = None
self._payload_bytes_to_read = 0
self._payload_len_flag = 0
self._compressed: int = COMPRESSED_NOT_SET
self._decompressobj: Optional[ZLibDecompressor] = None
self._compress = compress
def feed_eof(self) -> None:
self.queue.feed_eof()
# data can be bytearray on Windows because proactor event loop uses bytearray
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
# coerce data to bytes if it is not
def feed_data(
self, data: Union[bytes, bytearray, memoryview]
) -> Tuple[bool, bytes]:
if type(data) is not bytes:
data = bytes(data)
if self._exc is not None:
return True, data
try:
self._feed_data(data)
except Exception as exc:
self._exc = exc
set_exception(self.queue, exc)
return EMPTY_FRAME_ERROR
return EMPTY_FRAME
def _handle_frame(
self,
fin: bool,
opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
payload: Union[bytes, bytearray],
compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
) -> None:
msg: WSMessage
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
# load text/binary
if not fin:
# got partial frame payload
if opcode != OP_CODE_CONTINUATION:
self._opcode = opcode
self._partial += payload
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {len(self._partial)} "
f"exceeds limit {self._max_msg_size}",
)
return
has_partial = bool(self._partial)
if opcode == OP_CODE_CONTINUATION:
if self._opcode == OP_CODE_NOT_SET:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Continuation frame for non started message",
)
opcode = self._opcode
self._opcode = OP_CODE_NOT_SET
# previous frame was non finished
# we should get continuation opcode
elif has_partial:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
f"to be zero, got {opcode!r}",
)
assembled_payload: Union[bytes, bytearray]
if has_partial:
assembled_payload = self._partial + payload
self._partial.clear()
else:
assembled_payload = payload
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {len(assembled_payload)} "
f"exceeds limit {self._max_msg_size}",
)
# Decompress process must to be done after all packets
# received.
if compressed:
if not self._decompressobj:
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
payload_merged = self._decompressobj.decompress_sync(
assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size
)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Decompressed message size {self._max_msg_size + left}"
f" exceeds limit {self._max_msg_size}",
)
elif type(assembled_payload) is bytes:
payload_merged = assembled_payload
else:
payload_merged = bytes(assembled_payload)
if opcode == OP_CODE_TEXT:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
len(payload_merged),
)
else:
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
len(payload_merged),
)
elif opcode == OP_CODE_CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close code: {close_code}",
)
try:
close_message = payload[2:].decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close frame: {fin} {opcode} {payload!r}",
)
else:
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
self.queue.feed_data(msg, 0)
elif opcode == OP_CODE_PING:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
self.queue.feed_data(msg, len(payload))
elif opcode == OP_CODE_PONG:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
self.queue.feed_data(msg, len(payload))
else:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
)
def _feed_data(self, data: bytes) -> None:
"""Return the next frame from the socket."""
if self._tail:
data, self._tail = self._tail + data, b""
start_pos: int = 0
data_len = len(data)
data_cstr = data
while True:
# read header
if self._state == READ_HEADER:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xF
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received fragmented control frame",
)
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7F
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be larger than 125 bytes",
)
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_len_flag = length
self._state = READ_PAYLOAD_LENGTH
# read payload length
if self._state == READ_PAYLOAD_LENGTH:
len_flag = self._payload_len_flag
if len_flag == 126:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
self._payload_bytes_to_read = first_byte << 8 | second_byte
elif len_flag > 126:
if data_len - start_pos < 8:
break
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
start_pos += 8
else:
self._payload_bytes_to_read = len_flag
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
# read payload mask
if self._state == READ_PAYLOAD_MASK:
if data_len - start_pos < 4:
break
self._frame_mask = data_cstr[start_pos : start_pos + 4]
start_pos += 4
self._state = READ_PAYLOAD
if self._state == READ_PAYLOAD:
chunk_len = data_len - start_pos
if self._payload_bytes_to_read >= chunk_len:
f_end_pos = data_len
self._payload_bytes_to_read -= chunk_len
else:
f_end_pos = start_pos + self._payload_bytes_to_read
self._payload_bytes_to_read = 0
had_fragments = self._frame_payload_len
self._frame_payload_len += f_end_pos - start_pos
f_start_pos = start_pos
start_pos = f_end_pos
if self._payload_bytes_to_read != 0:
# If we don't have a complete frame, we need to save the
# data for the next call to feed_data.
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
break
payload: Union[bytes, bytearray]
if had_fragments:
# We have to join the payload fragments get the payload
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
if self._has_mask:
assert self._frame_mask is not None
payload_bytearray = bytearray(b"".join(self._payload_fragments))
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = b"".join(self._payload_fragments)
self._payload_fragments.clear()
elif self._has_mask:
assert self._frame_mask is not None
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
if type(payload_bytearray) is not bytearray: # pragma: no branch
# Cython will do the conversion for us
# but we need to do it for Python and we
# will always get here in Python
payload_bytearray = bytearray(payload_bytearray)
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = data_cstr[f_start_pos:f_end_pos]
self._handle_frame(
self._frame_fin, self._frame_opcode, payload, self._compressed
)
self._frame_payload_len = 0
self._state = READ_HEADER
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""

View File

@@ -0,0 +1,177 @@
"""WebSocket protocol versions 13 and 8."""
import asyncio
import random
import zlib
from functools import partial
from typing import Any, Final, Optional, Union
from ..base_protocol import BaseProtocol
from ..client_exceptions import ClientConnectionResetError
from ..compression_utils import ZLibCompressor
from .helpers import (
MASK_LEN,
MSG_SIZE,
PACK_CLOSE_CODE,
PACK_LEN1,
PACK_LEN2,
PACK_LEN3,
PACK_RANDBITS,
websocket_mask,
)
from .models import WS_DEFLATE_TRAILING, WSMsgType
DEFAULT_LIMIT: Final[int] = 2**16
# For websockets, keeping latency low is extremely important as implementations
# generally expect to be able to send and receive messages quickly. We use a
# larger chunk size than the default to reduce the number of executor calls
# since the executor is a significant source of latency and overhead when
# the chunks are small. A size of 5KiB was chosen because it is also the
# same value python-zlib-ng choose to use as the threshold to release the GIL.
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024
class WebSocketWriter:
"""WebSocket writer.
The writer is responsible for sending messages to the client. It is
created by the protocol when a connection is established. The writer
should avoid implementing any application logic and should only be
concerned with the low-level details of the WebSocket protocol.
"""
def __init__(
self,
protocol: BaseProtocol,
transport: asyncio.Transport,
*,
use_mask: bool = False,
limit: int = DEFAULT_LIMIT,
random: random.Random = random.Random(),
compress: int = 0,
notakeover: bool = False,
) -> None:
"""Initialize a WebSocket writer."""
self.protocol = protocol
self.transport = transport
self.use_mask = use_mask
self.get_random_bits = partial(random.getrandbits, 32)
self.compress = compress
self.notakeover = notakeover
self._closing = False
self._limit = limit
self._output_size = 0
self._compressobj: Any = None # actually compressobj
async def send_frame(
self, message: bytes, opcode: int, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ClientConnectionResetError("Cannot write to closing transport")
# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
rsv = 0
# Only compress larger packets (disabled)
# Does small packet needs to be compressed?
# if self.compress and opcode < 8 and len(message) > 124:
if (compress or self.compress) and opcode < 8:
# RSV1 (rsv = 0x40) is set for compressed frames
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
rsv = 0x40
if compress:
# Do not set self._compress if compressing is for this frame
compressobj = self._make_compress_obj(compress)
else: # self.compress
if not self._compressobj:
self._compressobj = self._make_compress_obj(self.compress)
compressobj = self._compressobj
message = (
await compressobj.compress(message)
+ compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
)
).removesuffix(WS_DEFLATE_TRAILING)
# Its critical that we do not return control to the event
# loop until we have finished sending all the compressed
# data. Otherwise we could end up mixing compressed frames
# if there are multiple coroutines compressing data.
msg_length = len(message)
use_mask = self.use_mask
mask_bit = 0x80 if use_mask else 0
# Depending on the message length, the header is assembled differently.
# The first byte is reserved for the opcode and the RSV bits.
first_byte = 0x80 | rsv | opcode
if msg_length < 126:
header = PACK_LEN1(first_byte, msg_length | mask_bit)
header_len = 2
elif msg_length < 65536:
header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
header_len = 4
else:
header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
header_len = 10
if self.transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
# If we are using a mask, we need to generate it randomly
# and apply it to the message before sending it. A mask is
# a 32-bit value that is applied to the message using a
# bitwise XOR operation. It is used to prevent certain types
# of attacks on the websocket protocol. The mask is only used
# when aiohttp is acting as a client. Servers do not use a mask.
if use_mask:
mask = PACK_RANDBITS(self.get_random_bits())
message = bytearray(message)
websocket_mask(mask, message)
self.transport.write(header + mask + message)
self._output_size += MASK_LEN
elif msg_length > MSG_SIZE:
self.transport.write(header)
self.transport.write(message)
else:
self.transport.write(header + message)
self._output_size += header_len + msg_length
# It is safe to return control to the event loop when using compression
# after this point as we have already sent or buffered all the data.
# Once we have written output_size up to the limit, we call the
# drain helper which waits for the transport to be ready to accept
# more data. This is a flow control mechanism to prevent the buffer
# from growing too large. The drain helper will return right away
# if the writer is not paused.
if self._output_size > self._limit:
self._output_size = 0
if self.protocol._paused:
await self.protocol._drain_helper()
def _make_compress_obj(self, compress: int) -> ZLibCompressor:
return ZLibCompressor(
level=zlib.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode("utf-8")
try:
await self.send_frame(
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
)
finally:
self._closing = True

View File

@@ -0,0 +1,253 @@
import asyncio
import logging
import socket
import zlib
from abc import ABC, abstractmethod
from collections.abc import Sized
from http.cookies import BaseCookie, Morsel
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
TypedDict,
Union,
)
from multidict import CIMultiDict
from yarl import URL
from .typedefs import LooseCookies
if TYPE_CHECKING:
from .web_app import Application
from .web_exceptions import HTTPException
from .web_request import BaseRequest, Request
from .web_response import StreamResponse
else:
BaseRequest = Request = Application = StreamResponse = None
HTTPException = None
class AbstractRouter(ABC):
def __init__(self) -> None:
self._frozen = False
def post_init(self, app: Application) -> None:
"""Post init stage.
Not an abstract method for sake of backward compatibility,
but if the router wants to be aware of the application
it can override this.
"""
@property
def frozen(self) -> bool:
return self._frozen
def freeze(self) -> None:
"""Freeze router."""
self._frozen = True
@abstractmethod
async def resolve(self, request: Request) -> "AbstractMatchInfo":
"""Return MATCH_INFO for given request"""
class AbstractMatchInfo(ABC):
__slots__ = ()
@property # pragma: no branch
@abstractmethod
def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]:
"""Execute matched request handler"""
@property
@abstractmethod
def expect_handler(
self,
) -> Callable[[Request], Awaitable[Optional[StreamResponse]]]:
"""Expect handler for 100-continue processing"""
@property # pragma: no branch
@abstractmethod
def http_exception(self) -> Optional[HTTPException]:
"""HTTPException instance raised on router's resolving, or None"""
@abstractmethod # pragma: no branch
def get_info(self) -> Dict[str, Any]:
"""Return a dict with additional info useful for introspection"""
@property # pragma: no branch
@abstractmethod
def apps(self) -> Tuple[Application, ...]:
"""Stack of nested applications.
Top level application is left-most element.
"""
@abstractmethod
def add_app(self, app: Application) -> None:
"""Add application to the nested apps stack."""
@abstractmethod
def freeze(self) -> None:
"""Freeze the match info.
The method is called after route resolution.
After the call .add_app() is forbidden.
"""
class AbstractView(ABC):
"""Abstract class based view."""
def __init__(self, request: Request) -> None:
self._request = request
@property
def request(self) -> Request:
"""Request instance."""
return self._request
@abstractmethod
def __await__(self) -> Generator[Any, None, StreamResponse]:
"""Execute the view handler."""
class ResolveResult(TypedDict):
"""Resolve result.
This is the result returned from an AbstractResolver's
resolve method.
:param hostname: The hostname that was provided.
:param host: The IP address that was resolved.
:param port: The port that was resolved.
:param family: The address family that was resolved.
:param proto: The protocol that was resolved.
:param flags: The flags that were resolved.
"""
hostname: str
host: str
port: int
family: int
proto: int
flags: int
class AbstractResolver(ABC):
"""Abstract DNS resolver."""
@abstractmethod
async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> List[ResolveResult]:
"""Return IP address for given hostname"""
@abstractmethod
async def close(self) -> None:
"""Release resolver"""
if TYPE_CHECKING:
IterableBase = Iterable[Morsel[str]]
else:
IterableBase = Iterable
ClearCookiePredicate = Callable[["Morsel[str]"], bool]
class AbstractCookieJar(Sized, IterableBase):
"""Abstract Cookie Jar."""
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_running_loop()
@property
@abstractmethod
def quote_cookie(self) -> bool:
"""Return True if cookies should be quoted."""
@abstractmethod
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
"""Clear all cookies if no predicate is passed."""
@abstractmethod
def clear_domain(self, domain: str) -> None:
"""Clear all cookies for domain and all subdomains."""
@abstractmethod
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
"""Update cookies."""
@abstractmethod
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
"""Return the jar's cookies filtered by their attributes."""
class AbstractStreamWriter(ABC):
"""Abstract stream writer."""
buffer_size: int = 0
output_size: int = 0
length: Optional[int] = 0
@abstractmethod
async def write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
"""Write chunk into stream."""
@abstractmethod
async def write_eof(self, chunk: bytes = b"") -> None:
"""Write last chunk."""
@abstractmethod
async def drain(self) -> None:
"""Flush the write buffer."""
@abstractmethod
def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
) -> None:
"""Enable HTTP body compression"""
@abstractmethod
def enable_chunking(self) -> None:
"""Enable HTTP chunked mode"""
@abstractmethod
async def write_headers(
self, status_line: str, headers: "CIMultiDict[str]"
) -> None:
"""Write HTTP headers"""
class AbstractAccessLogger(ABC):
"""Abstract writer to access log."""
__slots__ = ("logger", "log_format")
def __init__(self, logger: logging.Logger, log_format: str) -> None:
self.logger = logger
self.log_format = log_format
@abstractmethod
def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None:
"""Emit log to logger."""
@property
def enabled(self) -> bool:
"""Check if logger is enabled."""
return True

View File

@@ -0,0 +1,100 @@
import asyncio
from typing import Optional, cast
from .client_exceptions import ClientConnectionResetError
from .helpers import set_exception
from .tcp_helpers import tcp_nodelay
class BaseProtocol(asyncio.Protocol):
__slots__ = (
"_loop",
"_paused",
"_drain_waiter",
"_connection_lost",
"_reading_paused",
"transport",
)
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop: asyncio.AbstractEventLoop = loop
self._paused = False
self._drain_waiter: Optional[asyncio.Future[None]] = None
self._reading_paused = False
self.transport: Optional[asyncio.Transport] = None
@property
def connected(self) -> bool:
"""Return True if the connection is open."""
return self.transport is not None
@property
def writing_paused(self) -> bool:
return self._paused
def pause_writing(self) -> None:
assert not self._paused
self._paused = True
def resume_writing(self) -> None:
assert self._paused
self._paused = False
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
waiter.set_result(None)
def pause_reading(self) -> None:
if not self._reading_paused and self.transport is not None:
try:
self.transport.pause_reading()
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = True
def resume_reading(self) -> None:
if self._reading_paused and self.transport is not None:
try:
self.transport.resume_reading()
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = False
def connection_made(self, transport: asyncio.BaseTransport) -> None:
tr = cast(asyncio.Transport, transport)
tcp_nodelay(tr, True)
self.transport = tr
def connection_lost(self, exc: Optional[BaseException]) -> None:
# Wake up the writer if currently paused.
self.transport = None
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
set_exception(
waiter,
ConnectionError("Connection lost"),
exc,
)
async def _drain_helper(self) -> None:
if self.transport is None:
raise ClientConnectionResetError("Connection lost")
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
waiter = self._loop.create_future()
self._drain_waiter = waiter
await asyncio.shield(waiter)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,421 @@
"""HTTP related errors."""
import asyncio
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
from multidict import MultiMapping
from .typedefs import StrOrURL
if TYPE_CHECKING:
import ssl
SSLContext = ssl.SSLContext
else:
try:
import ssl
SSLContext = ssl.SSLContext
except ImportError: # pragma: no cover
ssl = SSLContext = None # type: ignore[assignment]
if TYPE_CHECKING:
from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo
from .http_parser import RawResponseMessage
else:
RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None
__all__ = (
"ClientError",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientOSError",
"ClientConnectorError",
"ClientProxyConnectionError",
"ClientSSLError",
"ClientConnectorDNSError",
"ClientConnectorSSLError",
"ClientConnectorCertificateError",
"ConnectionTimeoutError",
"SocketTimeoutError",
"ServerConnectionError",
"ServerTimeoutError",
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ClientResponseError",
"ClientHttpProxyError",
"WSServerHandshakeError",
"ContentTypeError",
"ClientPayloadError",
"InvalidURL",
"InvalidUrlClientError",
"RedirectClientError",
"NonHttpUrlClientError",
"InvalidUrlRedirectClientError",
"NonHttpUrlRedirectClientError",
"WSMessageTypeError",
)
class ClientError(Exception):
"""Base class for client connection errors."""
class ClientResponseError(ClientError):
"""Base class for exceptions that occur after getting a response.
request_info: An instance of RequestInfo.
history: A sequence of responses, if redirects occurred.
status: HTTP status code.
message: Error message.
headers: Response headers.
"""
def __init__(
self,
request_info: RequestInfo,
history: Tuple[ClientResponse, ...],
*,
code: Optional[int] = None,
status: Optional[int] = None,
message: str = "",
headers: Optional[MultiMapping[str]] = None,
) -> None:
self.request_info = request_info
if code is not None:
if status is not None:
raise ValueError(
"Both code and status arguments are provided; "
"code is deprecated, use status instead"
)
warnings.warn(
"code argument is deprecated, use status instead",
DeprecationWarning,
stacklevel=2,
)
if status is not None:
self.status = status
elif code is not None:
self.status = code
else:
self.status = 0
self.message = message
self.headers = headers
self.history = history
self.args = (request_info, history)
def __str__(self) -> str:
return "{}, message={!r}, url={!r}".format(
self.status,
self.message,
str(self.request_info.real_url),
)
def __repr__(self) -> str:
args = f"{self.request_info!r}, {self.history!r}"
if self.status != 0:
args += f", status={self.status!r}"
if self.message != "":
args += f", message={self.message!r}"
if self.headers is not None:
args += f", headers={self.headers!r}"
return f"{type(self).__name__}({args})"
@property
def code(self) -> int:
warnings.warn(
"code property is deprecated, use status instead",
DeprecationWarning,
stacklevel=2,
)
return self.status
@code.setter
def code(self, value: int) -> None:
warnings.warn(
"code property is deprecated, use status instead",
DeprecationWarning,
stacklevel=2,
)
self.status = value
class ContentTypeError(ClientResponseError):
"""ContentType found is not valid."""
class WSServerHandshakeError(ClientResponseError):
"""websocket server handshake error."""
class ClientHttpProxyError(ClientResponseError):
"""HTTP proxy error.
Raised in :class:`aiohttp.connector.TCPConnector` if
proxy responds with status other than ``200 OK``
on ``CONNECT`` request.
"""
class TooManyRedirects(ClientResponseError):
"""Client was redirected too many times."""
class ClientConnectionError(ClientError):
"""Base class for client socket errors."""
class ClientConnectionResetError(ClientConnectionError, ConnectionResetError):
"""ConnectionResetError"""
class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""
class ClientConnectorError(ClientOSError):
"""Client connector error.
Raised in :class:`aiohttp.connector.TCPConnector` if
a connection can not be established.
"""
def __init__(self, connection_key: ConnectionKey, os_error: OSError) -> None:
self._conn_key = connection_key
self._os_error = os_error
super().__init__(os_error.errno, os_error.strerror)
self.args = (connection_key, os_error)
@property
def os_error(self) -> OSError:
return self._os_error
@property
def host(self) -> str:
return self._conn_key.host
@property
def port(self) -> Optional[int]:
return self._conn_key.port
@property
def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]:
return self._conn_key.ssl
def __str__(self) -> str:
return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format(
self, "default" if self.ssl is True else self.ssl, self.strerror
)
# OSError.__reduce__ does too much black magick
__reduce__ = BaseException.__reduce__
class ClientConnectorDNSError(ClientConnectorError):
"""DNS resolution failed during client connection.
Raised in :class:`aiohttp.connector.TCPConnector` if
DNS resolution fails.
"""
class ClientProxyConnectionError(ClientConnectorError):
"""Proxy connection error.
Raised in :class:`aiohttp.connector.TCPConnector` if
connection to proxy can not be established.
"""
class UnixClientConnectorError(ClientConnectorError):
"""Unix connector error.
Raised in :py:class:`aiohttp.connector.UnixConnector`
if connection to unix socket can not be established.
"""
def __init__(
self, path: str, connection_key: ConnectionKey, os_error: OSError
) -> None:
self._path = path
super().__init__(connection_key, os_error)
@property
def path(self) -> str:
return self._path
def __str__(self) -> str:
return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format(
self, "default" if self.ssl is True else self.ssl, self.strerror
)
class ServerConnectionError(ClientConnectionError):
"""Server connection errors."""
class ServerDisconnectedError(ServerConnectionError):
"""Server disconnected."""
def __init__(self, message: Union[RawResponseMessage, str, None] = None) -> None:
if message is None:
message = "Server disconnected"
self.args = (message,)
self.message = message
class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError):
"""Server timeout error."""
class ConnectionTimeoutError(ServerTimeoutError):
"""Connection timeout error."""
class SocketTimeoutError(ServerTimeoutError):
"""Socket timeout error."""
class ServerFingerprintMismatch(ServerConnectionError):
"""SSL certificate does not match expected fingerprint."""
def __init__(self, expected: bytes, got: bytes, host: str, port: int) -> None:
self.expected = expected
self.got = got
self.host = host
self.port = port
self.args = (expected, got, host, port)
def __repr__(self) -> str:
return "<{} expected={!r} got={!r} host={!r} port={!r}>".format(
self.__class__.__name__, self.expected, self.got, self.host, self.port
)
class ClientPayloadError(ClientError):
"""Response payload error."""
class InvalidURL(ClientError, ValueError):
"""Invalid URL.
URL used for fetching is malformed, e.g. it doesn't contains host
part.
"""
# Derive from ValueError for backward compatibility
def __init__(self, url: StrOrURL, description: Union[str, None] = None) -> None:
# The type of url is not yarl.URL because the exception can be raised
# on URL(url) call
self._url = url
self._description = description
if description:
super().__init__(url, description)
else:
super().__init__(url)
@property
def url(self) -> StrOrURL:
return self._url
@property
def description(self) -> "str | None":
return self._description
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self}>"
def __str__(self) -> str:
if self._description:
return f"{self._url} - {self._description}"
return str(self._url)
class InvalidUrlClientError(InvalidURL):
"""Invalid URL client error."""
class RedirectClientError(ClientError):
"""Client redirect error."""
class NonHttpUrlClientError(ClientError):
"""Non http URL client error."""
class InvalidUrlRedirectClientError(InvalidUrlClientError, RedirectClientError):
"""Invalid URL redirect client error."""
class NonHttpUrlRedirectClientError(NonHttpUrlClientError, RedirectClientError):
"""Non http URL redirect client error."""
class ClientSSLError(ClientConnectorError):
"""Base error for ssl.*Errors."""
if ssl is not None:
cert_errors = (ssl.CertificateError,)
cert_errors_bases = (
ClientSSLError,
ssl.CertificateError,
)
ssl_errors = (ssl.SSLError,)
ssl_error_bases = (ClientSSLError, ssl.SSLError)
else: # pragma: no cover
cert_errors = tuple()
cert_errors_bases = (
ClientSSLError,
ValueError,
)
ssl_errors = tuple()
ssl_error_bases = (ClientSSLError,)
class ClientConnectorSSLError(*ssl_error_bases): # type: ignore[misc]
"""Response ssl error."""
class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore[misc]
"""Response certificate error."""
def __init__(
self, connection_key: ConnectionKey, certificate_error: Exception
) -> None:
self._conn_key = connection_key
self._certificate_error = certificate_error
self.args = (connection_key, certificate_error)
@property
def certificate_error(self) -> Exception:
return self._certificate_error
@property
def host(self) -> str:
return self._conn_key.host
@property
def port(self) -> Optional[int]:
return self._conn_key.port
@property
def ssl(self) -> bool:
return self._conn_key.is_ssl
def __str__(self) -> str:
return (
"Cannot connect to host {0.host}:{0.port} ssl:{0.ssl} "
"[{0.certificate_error.__class__.__name__}: "
"{0.certificate_error.args}]".format(self)
)
class WSMessageTypeError(TypeError):
"""WebSocket message type is not valid."""

View File

@@ -0,0 +1,308 @@
import asyncio
from contextlib import suppress
from typing import Any, Optional, Tuple
from .base_protocol import BaseProtocol
from .client_exceptions import (
ClientOSError,
ClientPayloadError,
ServerDisconnectedError,
SocketTimeoutError,
)
from .helpers import (
_EXC_SENTINEL,
EMPTY_BODY_STATUS_CODES,
BaseTimerContext,
set_exception,
)
from .http import HttpResponseParser, RawResponseMessage
from .http_exceptions import HttpProcessingError
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader
class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamReader]]):
"""Helper class to adapt between Protocol and StreamReader."""
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
BaseProtocol.__init__(self, loop=loop)
DataQueue.__init__(self, loop)
self._should_close = False
self._payload: Optional[StreamReader] = None
self._skip_payload = False
self._payload_parser = None
self._timer = None
self._tail = b""
self._upgraded = False
self._parser: Optional[HttpResponseParser] = None
self._read_timeout: Optional[float] = None
self._read_timeout_handle: Optional[asyncio.TimerHandle] = None
self._timeout_ceil_threshold: Optional[float] = 5
@property
def upgraded(self) -> bool:
return self._upgraded
@property
def should_close(self) -> bool:
return bool(
self._should_close
or (self._payload is not None and not self._payload.is_eof())
or self._upgraded
or self._exception is not None
or self._payload_parser is not None
or self._buffer
or self._tail
)
def force_close(self) -> None:
self._should_close = True
def close(self) -> None:
self._exception = None # Break cyclic references
transport = self.transport
if transport is not None:
transport.close()
self.transport = None
self._payload = None
self._drop_timeout()
def is_connected(self) -> bool:
return self.transport is not None and not self.transport.is_closing()
def connection_lost(self, exc: Optional[BaseException]) -> None:
self._drop_timeout()
original_connection_error = exc
reraised_exc = original_connection_error
connection_closed_cleanly = original_connection_error is None
if self._payload_parser is not None:
with suppress(Exception): # FIXME: log this somehow?
self._payload_parser.feed_eof()
uncompleted = None
if self._parser is not None:
try:
uncompleted = self._parser.feed_eof()
except Exception as underlying_exc:
if self._payload is not None:
client_payload_exc_msg = (
f"Response payload is not completed: {underlying_exc !r}"
)
if not connection_closed_cleanly:
client_payload_exc_msg = (
f"{client_payload_exc_msg !s}. "
f"{original_connection_error !r}"
)
set_exception(
self._payload,
ClientPayloadError(client_payload_exc_msg),
underlying_exc,
)
if not self.is_eof():
if isinstance(original_connection_error, OSError):
reraised_exc = ClientOSError(*original_connection_error.args)
if connection_closed_cleanly:
reraised_exc = ServerDisconnectedError(uncompleted)
# assigns self._should_close to True as side effect,
# we do it anyway below
underlying_non_eof_exc = (
_EXC_SENTINEL
if connection_closed_cleanly
else original_connection_error
)
assert underlying_non_eof_exc is not None
assert reraised_exc is not None
self.set_exception(reraised_exc, underlying_non_eof_exc)
self._should_close = True
self._parser = None
self._payload = None
self._payload_parser = None
self._reading_paused = False
super().connection_lost(reraised_exc)
def eof_received(self) -> None:
# should call parser.feed_eof() most likely
self._drop_timeout()
def pause_reading(self) -> None:
super().pause_reading()
self._drop_timeout()
def resume_reading(self) -> None:
super().resume_reading()
self._reschedule_timeout()
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
self._should_close = True
self._drop_timeout()
super().set_exception(exc, exc_cause)
def set_parser(self, parser: Any, payload: Any) -> None:
# TODO: actual types are:
# parser: WebSocketReader
# payload: WebSocketDataQueue
# but they are not generi enough
# Need an ABC for both types
self._payload = payload
self._payload_parser = parser
self._drop_timeout()
if self._tail:
data, self._tail = self._tail, b""
self.data_received(data)
def set_response_params(
self,
*,
timer: Optional[BaseTimerContext] = None,
skip_payload: bool = False,
read_until_eof: bool = False,
auto_decompress: bool = True,
read_timeout: Optional[float] = None,
read_bufsize: int = 2**16,
timeout_ceil_threshold: float = 5,
max_line_size: int = 8190,
max_field_size: int = 8190,
) -> None:
self._skip_payload = skip_payload
self._read_timeout = read_timeout
self._timeout_ceil_threshold = timeout_ceil_threshold
self._parser = HttpResponseParser(
self,
self._loop,
read_bufsize,
timer=timer,
payload_exception=ClientPayloadError,
response_with_body=not skip_payload,
read_until_eof=read_until_eof,
auto_decompress=auto_decompress,
max_line_size=max_line_size,
max_field_size=max_field_size,
)
if self._tail:
data, self._tail = self._tail, b""
self.data_received(data)
def _drop_timeout(self) -> None:
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
self._read_timeout_handle = None
def _reschedule_timeout(self) -> None:
timeout = self._read_timeout
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
if timeout:
self._read_timeout_handle = self._loop.call_later(
timeout, self._on_read_timeout
)
else:
self._read_timeout_handle = None
def start_timeout(self) -> None:
self._reschedule_timeout()
@property
def read_timeout(self) -> Optional[float]:
return self._read_timeout
@read_timeout.setter
def read_timeout(self, read_timeout: Optional[float]) -> None:
self._read_timeout = read_timeout
def _on_read_timeout(self) -> None:
exc = SocketTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
if self._payload is not None:
set_exception(self._payload, exc)
def data_received(self, data: bytes) -> None:
self._reschedule_timeout()
if not data:
return
# custom payload parser - currently always WebSocketReader
if self._payload_parser is not None:
eof, tail = self._payload_parser.feed_data(data)
if eof:
self._payload = None
self._payload_parser = None
if tail:
self.data_received(tail)
return
if self._upgraded or self._parser is None:
# i.e. websocket connection, websocket parser is not set yet
self._tail += data
return
# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as underlying_exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
if isinstance(underlying_exc, HttpProcessingError):
exc = HttpProcessingError(
code=underlying_exc.code,
message=underlying_exc.message,
headers=underlying_exc.headers,
)
else:
exc = HttpProcessingError()
self.set_exception(exc, underlying_exc)
return
self._upgraded = upgraded
payload: Optional[StreamReader] = None
for message, payload in messages:
if message.should_close:
self._should_close = True
self._payload = payload
if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES:
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)
if payload is not None:
# new message(s) was processed
# register timeout handler unsubscribing
# either on end-of-stream or immediately for
# EMPTY_PAYLOAD
if payload is not EMPTY_PAYLOAD:
payload.on_eof(self._drop_timeout)
else:
self._drop_timeout()
if upgraded and tail:
self.data_received(tail)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,428 @@
"""WebSocket client for asyncio."""
import asyncio
import sys
from types import TracebackType
from typing import Any, Optional, Type, cast
import attr
from ._websocket.reader import WebSocketDataQueue
from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
from .client_reqrep import ClientResponse
from .helpers import calculate_timeout_when, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WebSocketError,
WSCloseCode,
WSMessage,
WSMsgType,
)
from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter
from .streams import EofStream
from .typedefs import (
DEFAULT_JSON_DECODER,
DEFAULT_JSON_ENCODER,
JSONDecoder,
JSONEncoder,
)
if sys.version_info >= (3, 11):
import asyncio as async_timeout
else:
import async_timeout
@attr.s(frozen=True, slots=True)
class ClientWSTimeout:
ws_receive = attr.ib(type=Optional[float], default=None)
ws_close = attr.ib(type=Optional[float], default=None)
DEFAULT_WS_CLIENT_TIMEOUT = ClientWSTimeout(ws_receive=None, ws_close=10.0)
class ClientWebSocketResponse:
def __init__(
self,
reader: WebSocketDataQueue,
writer: WebSocketWriter,
protocol: Optional[str],
response: ClientResponse,
timeout: ClientWSTimeout,
autoclose: bool,
autoping: bool,
loop: asyncio.AbstractEventLoop,
*,
heartbeat: Optional[float] = None,
compress: int = 0,
client_notakeover: bool = False,
) -> None:
self._response = response
self._conn = response.connection
self._writer = writer
self._reader = reader
self._protocol = protocol
self._closed = False
self._closing = False
self._close_code: Optional[int] = None
self._timeout = timeout
self._autoclose = autoclose
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
self._heartbeat_when: float = 0.0
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._loop = loop
self._waiting: bool = False
self._close_wait: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._compress = compress
self._client_notakeover = client_notakeover
self._ping_task: Optional[asyncio.Task[None]] = None
self._reset_heartbeat()
def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None
def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None
def _reset_heartbeat(self) -> None:
if self._heartbeat is None:
return
self._cancel_pong_response_cb()
loop = self._loop
assert loop is not None
conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
)
now = loop.time()
when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
self._heartbeat_when = when
if self._heartbeat_cb is None:
# We do not cancel the previous heartbeat_cb here because
# it generates a significant amount of TimerHandle churn
# which causes asyncio to rebuild the heap frequently.
# Instead _send_heartbeat() will reschedule the next
# heartbeat if it fires too early.
self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
def _send_heartbeat(self) -> None:
self._heartbeat_cb = None
loop = self._loop
now = loop.time()
if now < self._heartbeat_when:
# Heartbeat fired too early, reschedule
self._heartbeat_cb = loop.call_at(
self._heartbeat_when, self._send_heartbeat
)
return
conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
)
when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)
coro = self._writer.send_frame(b"", WSMsgType.PING)
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
ping_task = loop.create_task(coro)
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)
else:
self._ping_task_done(ping_task)
def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
if not task.cancelled() and (exc := task.exception()):
self._handle_ping_pong_exception(exc)
self._ping_task = None
def _pong_not_received(self) -> None:
self._handle_ping_pong_exception(
ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds")
)
def _handle_ping_pong_exception(self, exc: BaseException) -> None:
"""Handle exceptions raised during ping/pong processing."""
if self._closed:
return
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
if self._waiting and not self._closing:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0)
def _set_closed(self) -> None:
"""Set the connection to closed.
Cancel any heartbeat timers and set the closed flag.
"""
self._closed = True
self._cancel_heartbeat()
def _set_closing(self) -> None:
"""Set the connection to closing.
Cancel any heartbeat timers and set the closing flag.
"""
self._closing = True
self._cancel_heartbeat()
@property
def closed(self) -> bool:
return self._closed
@property
def close_code(self) -> Optional[int]:
return self._close_code
@property
def protocol(self) -> Optional[str]:
return self._protocol
@property
def compress(self) -> int:
return self._compress
@property
def client_notakeover(self) -> bool:
return self._client_notakeover
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""extra info from connection transport"""
conn = self._response.connection
if conn is None:
return default
transport = conn.transport
if transport is None:
return default
return transport.get_extra_info(name, default)
def exception(self) -> Optional[BaseException]:
return self._exception
async def ping(self, message: bytes = b"") -> None:
await self._writer.send_frame(message, WSMsgType.PING)
async def pong(self, message: bytes = b"") -> None:
await self._writer.send_frame(message, WSMsgType.PONG)
async def send_frame(
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket."""
await self._writer.send_frame(message, opcode, compress)
async def send_str(self, data: str, compress: Optional[int] = None) -> None:
if not isinstance(data, str):
raise TypeError("data argument must be str (%r)" % type(data))
await self._writer.send_frame(
data.encode("utf-8"), WSMsgType.TEXT, compress=compress
)
async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError("data argument must be byte-ish (%r)" % type(data))
await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)
async def send_json(
self,
data: Any,
compress: Optional[int] = None,
*,
dumps: JSONEncoder = DEFAULT_JSON_ENCODER,
) -> None:
await self.send_str(dumps(data), compress=compress)
async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting and not self._closing:
assert self._loop is not None
self._close_wait = self._loop.create_future()
self._set_closing()
self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._close_wait
if self._closed:
return False
self._set_closed()
try:
await self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._response.close()
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
return True
if self._close_code:
self._response.close()
return True
while True:
try:
async with async_timeout.timeout(self._timeout.ws_close):
msg = await self._reader.read()
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._response.close()
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
return True
if msg.type is WSMsgType.CLOSE:
self._close_code = msg.data
self._response.close()
return True
async def receive(self, timeout: Optional[float] = None) -> WSMessage:
receive_timeout = timeout or self._timeout.ws_receive
while True:
if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")
if self._closed:
return WS_CLOSED_MESSAGE
elif self._closing:
await self.close()
return WS_CLOSED_MESSAGE
try:
self._waiting = True
try:
if receive_timeout:
# Entering the context manager and creating
# Timeout() object can take almost 50% of the
# run time in this loop so we avoid it if
# there is no read timeout.
async with async_timeout.timeout(receive_timeout):
msg = await self._reader.read()
else:
msg = await self._reader.read()
self._reset_heartbeat()
finally:
self._waiting = False
if self._close_wait:
set_result(self._close_wait, None)
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
except EofStream:
self._close_code = WSCloseCode.OK
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
except ClientError:
# Likely ServerDisconnectedError when connection is lost
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
return WS_CLOSED_MESSAGE
except WebSocketError as exc:
self._close_code = exc.code
await self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._set_closing()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type not in _INTERNAL_RECEIVE_TYPES:
# If its not a close/closing/ping/pong message
# we can return it immediately
return msg
if msg.type is WSMsgType.CLOSE:
self._set_closing()
self._close_code = msg.data
if not self._closed and self._autoclose:
await self.close()
elif msg.type is WSMsgType.CLOSING:
self._set_closing()
elif msg.type is WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
elif msg.type is WSMsgType.PONG and self._autoping:
continue
return msg
async def receive_str(self, *, timeout: Optional[float] = None) -> str:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.TEXT:
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
)
return cast(str, msg.data)
async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.BINARY:
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
)
return cast(bytes, msg.data)
async def receive_json(
self,
*,
loads: JSONDecoder = DEFAULT_JSON_DECODER,
timeout: Optional[float] = None,
) -> Any:
data = await self.receive_str(timeout=timeout)
return loads(data)
def __aiter__(self) -> "ClientWebSocketResponse":
return self
async def __anext__(self) -> WSMessage:
msg = await self.receive()
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
raise StopAsyncIteration
return msg
async def __aenter__(self) -> "ClientWebSocketResponse":
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.close()

View File

@@ -0,0 +1,173 @@
import asyncio
import zlib
from concurrent.futures import Executor
from typing import Optional, cast
try:
try:
import brotlicffi as brotli
except ImportError:
import brotli
HAS_BROTLI = True
except ImportError: # pragma: no cover
HAS_BROTLI = False
MAX_SYNC_CHUNK_SIZE = 1024
def encoding_to_mode(
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
) -> int:
if encoding == "gzip":
return 16 + zlib.MAX_WBITS
return -zlib.MAX_WBITS if suppress_deflate_header else zlib.MAX_WBITS
class ZlibBaseHandler:
def __init__(
self,
mode: int,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
self._mode = mode
self._executor = executor
self._max_sync_chunk_size = max_sync_chunk_size
class ZLibCompressor(ZlibBaseHandler):
def __init__(
self,
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
level: Optional[int] = None,
wbits: Optional[int] = None,
strategy: int = zlib.Z_DEFAULT_STRATEGY,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(
mode=(
encoding_to_mode(encoding, suppress_deflate_header)
if wbits is None
else wbits
),
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
)
if level is None:
self._compressor = zlib.compressobj(wbits=self._mode, strategy=strategy)
else:
self._compressor = zlib.compressobj(
wbits=self._mode, strategy=strategy, level=level
)
self._compress_lock = asyncio.Lock()
def compress_sync(self, data: bytes) -> bytes:
return self._compressor.compress(data)
async def compress(self, data: bytes) -> bytes:
"""Compress the data and returned the compressed bytes.
Note that flush() must be called after the last call to compress()
If the data size is large than the max_sync_chunk_size, the compression
will be done in the executor. Otherwise, the compression will be done
in the event loop.
"""
async with self._compress_lock:
# To ensure the stream is consistent in the event
# there are multiple writers, we need to lock
# the compressor so that only one writer can
# compress at a time.
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._compressor.compress, data
)
return self.compress_sync(data)
def flush(self, mode: int = zlib.Z_FINISH) -> bytes:
return self._compressor.flush(mode)
class ZLibDecompressor(ZlibBaseHandler):
def __init__(
self,
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(
mode=encoding_to_mode(encoding, suppress_deflate_header),
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
)
self._decompressor = zlib.decompressobj(wbits=self._mode)
def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
return self._decompressor.decompress(data, max_length)
async def decompress(self, data: bytes, max_length: int = 0) -> bytes:
"""Decompress the data and return the decompressed bytes.
If the data size is large than the max_sync_chunk_size, the decompression
will be done in the executor. Otherwise, the decompression will be done
in the event loop.
"""
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._decompressor.decompress, data, max_length
)
return self.decompress_sync(data, max_length)
def flush(self, length: int = 0) -> bytes:
return (
self._decompressor.flush(length)
if length > 0
else self._decompressor.flush()
)
@property
def eof(self) -> bool:
return self._decompressor.eof
@property
def unconsumed_tail(self) -> bytes:
return self._decompressor.unconsumed_tail
@property
def unused_data(self) -> bytes:
return self._decompressor.unused_data
class BrotliDecompressor:
# Supports both 'brotlipy' and 'Brotli' packages
# since they share an import name. The top branches
# are for 'brotlipy' and bottom branches for 'Brotli'
def __init__(self) -> None:
if not HAS_BROTLI:
raise RuntimeError(
"The brotli decompression is not available. "
"Please install `Brotli` module"
)
self._obj = brotli.Decompressor()
def decompress_sync(self, data: bytes) -> bytes:
if hasattr(self._obj, "decompress"):
return cast(bytes, self._obj.decompress(data))
return cast(bytes, self._obj.process(data))
def flush(self) -> bytes:
if hasattr(self._obj, "flush"):
return cast(bytes, self._obj.flush())
return b""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,495 @@
import asyncio
import calendar
import contextlib
import datetime
import heapq
import itertools
import os # noqa
import pathlib
import pickle
import re
import time
import warnings
from collections import defaultdict
from http.cookies import BaseCookie, Morsel, SimpleCookie
from typing import (
DefaultDict,
Dict,
Iterable,
Iterator,
List,
Mapping,
Optional,
Set,
Tuple,
Union,
cast,
)
from yarl import URL
from .abc import AbstractCookieJar, ClearCookiePredicate
from .helpers import is_ip_address
from .typedefs import LooseCookies, PathLike, StrOrURL
__all__ = ("CookieJar", "DummyCookieJar")
CookieItem = Union[str, "Morsel[str]"]
# We cache these string methods here as their use is in performance critical code.
_FORMAT_PATH = "{}/{}".format
_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format
# The minimum number of scheduled cookie expirations before we start cleaning up
# the expiration heap. This is a performance optimization to avoid cleaning up the
# heap too often when there are only a few scheduled expirations.
_MIN_SCHEDULED_COOKIE_EXPIRATION = 100
class CookieJar(AbstractCookieJar):
"""Implements cookie storage adhering to RFC 6265."""
DATE_TOKENS_RE = re.compile(
r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
)
DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
DATE_MONTH_RE = re.compile(
"(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)",
re.I,
)
DATE_YEAR_RE = re.compile(r"(\d{2,4})")
# calendar.timegm() fails for timestamps after datetime.datetime.max
# Minus one as a loss of precision occurs when timestamp() is called.
MAX_TIME = (
int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
)
try:
calendar.timegm(time.gmtime(MAX_TIME))
except (OSError, ValueError):
# Hit the maximum representable time on Windows
# https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
# Throws ValueError on PyPy 3.9, OSError elsewhere
MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
except OverflowError:
# #4515: datetime.max may not be representable on 32-bit platforms
MAX_TIME = 2**31 - 1
# Avoid minuses in the future, 3x faster
SUB_MAX_TIME = MAX_TIME - 1
def __init__(
self,
*,
unsafe: bool = False,
quote_cookie: bool = True,
treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
super().__init__(loop=loop)
self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
SimpleCookie
)
self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = (
defaultdict(dict)
)
self._host_only_cookies: Set[Tuple[str, str]] = set()
self._unsafe = unsafe
self._quote_cookie = quote_cookie
if treat_as_secure_origin is None:
treat_as_secure_origin = []
elif isinstance(treat_as_secure_origin, URL):
treat_as_secure_origin = [treat_as_secure_origin.origin()]
elif isinstance(treat_as_secure_origin, str):
treat_as_secure_origin = [URL(treat_as_secure_origin).origin()]
else:
treat_as_secure_origin = [
URL(url).origin() if isinstance(url, str) else url.origin()
for url in treat_as_secure_origin
]
self._treat_as_secure_origin = treat_as_secure_origin
self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = []
self._expirations: Dict[Tuple[str, str, str], float] = {}
@property
def quote_cookie(self) -> bool:
return self._quote_cookie
def save(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
with file_path.open(mode="wb") as f:
pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
def load(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
with file_path.open(mode="rb") as f:
self._cookies = pickle.load(f)
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
if predicate is None:
self._expire_heap.clear()
self._cookies.clear()
self._morsel_cache.clear()
self._host_only_cookies.clear()
self._expirations.clear()
return
now = time.time()
to_del = [
key
for (domain, path), cookie in self._cookies.items()
for name, morsel in cookie.items()
if (
(key := (domain, path, name)) in self._expirations
and self._expirations[key] <= now
)
or predicate(morsel)
]
if to_del:
self._delete_cookies(to_del)
def clear_domain(self, domain: str) -> None:
self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
def __iter__(self) -> "Iterator[Morsel[str]]":
self._do_expiration()
for val in self._cookies.values():
yield from val.values()
def __len__(self) -> int:
"""Return number of cookies.
This function does not iterate self to avoid unnecessary expiration
checks.
"""
return sum(len(cookie.values()) for cookie in self._cookies.values())
def _do_expiration(self) -> None:
"""Remove expired cookies."""
if not (expire_heap_len := len(self._expire_heap)):
return
# If the expiration heap grows larger than the number expirations
# times two, we clean it up to avoid keeping expired entries in
# the heap and consuming memory. We guard this with a minimum
# threshold to avoid cleaning up the heap too often when there are
# only a few scheduled expirations.
if (
expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION
and expire_heap_len > len(self._expirations) * 2
):
# Remove any expired entries from the expiration heap
# that do not match the expiration time in the expirations
# as it means the cookie has been re-added to the heap
# with a different expiration time.
self._expire_heap = [
entry
for entry in self._expire_heap
if self._expirations.get(entry[1]) == entry[0]
]
heapq.heapify(self._expire_heap)
now = time.time()
to_del: List[Tuple[str, str, str]] = []
# Find any expired cookies and add them to the to-delete list
while self._expire_heap:
when, cookie_key = self._expire_heap[0]
if when > now:
break
heapq.heappop(self._expire_heap)
# Check if the cookie hasn't been re-added to the heap
# with a different expiration time as it will be removed
# later when it reaches the top of the heap and its
# expiration time is met.
if self._expirations.get(cookie_key) == when:
to_del.append(cookie_key)
if to_del:
self._delete_cookies(to_del)
def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None:
for domain, path, name in to_del:
self._host_only_cookies.discard((domain, name))
self._cookies[(domain, path)].pop(name, None)
self._morsel_cache[(domain, path)].pop(name, None)
self._expirations.pop((domain, path, name), None)
def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
cookie_key = (domain, path, name)
if self._expirations.get(cookie_key) == when:
# Avoid adding duplicates to the heap
return
heapq.heappush(self._expire_heap, (when, cookie_key))
self._expirations[cookie_key] = when
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
"""Update cookies."""
hostname = response_url.raw_host
if not self._unsafe and is_ip_address(hostname):
# Don't accept cookies from IPs
return
if isinstance(cookies, Mapping):
cookies = cookies.items()
for name, cookie in cookies:
if not isinstance(cookie, Morsel):
tmp = SimpleCookie()
tmp[name] = cookie # type: ignore[assignment]
cookie = tmp[name]
domain = cookie["domain"]
# ignore domains with trailing dots
if domain and domain[-1] == ".":
domain = ""
del cookie["domain"]
if not domain and hostname is not None:
# Set the cookie's domain to the response hostname
# and set its host-only-flag
self._host_only_cookies.add((hostname, name))
domain = cookie["domain"] = hostname
if domain and domain[0] == ".":
# Remove leading dot
domain = domain[1:]
cookie["domain"] = domain
if hostname and not self._is_domain_match(domain, hostname):
# Setting cookies for different domains is not allowed
continue
path = cookie["path"]
if not path or path[0] != "/":
# Set the cookie's path to the response path
path = response_url.path
if not path.startswith("/"):
path = "/"
else:
# Cut everything from the last slash to the end
path = "/" + path[1 : path.rfind("/")]
cookie["path"] = path
path = path.rstrip("/")
if max_age := cookie["max-age"]:
try:
delta_seconds = int(max_age)
max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
self._expire_cookie(max_age_expiration, domain, path, name)
except ValueError:
cookie["max-age"] = ""
elif expires := cookie["expires"]:
if expire_time := self._parse_date(expires):
self._expire_cookie(expire_time, domain, path, name)
else:
cookie["expires"] = ""
key = (domain, path)
if self._cookies[key].get(name) != cookie:
# Don't blow away the cache if the same
# cookie gets set again
self._cookies[key][name] = cookie
self._morsel_cache[key].pop(name, None)
self._do_expiration()
def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
"""Returns this jar's cookies filtered by their attributes."""
filtered: Union[SimpleCookie, "BaseCookie[str]"] = (
SimpleCookie() if self._quote_cookie else BaseCookie()
)
if not self._cookies:
# Skip do_expiration() if there are no cookies.
return filtered
self._do_expiration()
if not self._cookies:
# Skip rest of function if no non-expired cookies.
return filtered
if type(request_url) is not URL:
warnings.warn(
"filter_cookies expects yarl.URL instances only,"
f"and will stop working in 4.x, got {type(request_url)}",
DeprecationWarning,
stacklevel=2,
)
request_url = URL(request_url)
hostname = request_url.raw_host or ""
is_not_secure = request_url.scheme not in ("https", "wss")
if is_not_secure and self._treat_as_secure_origin:
request_origin = URL()
with contextlib.suppress(ValueError):
request_origin = request_url.origin()
is_not_secure = request_origin not in self._treat_as_secure_origin
# Send shared cookie
for c in self._cookies[("", "")].values():
filtered[c.key] = c.value
if is_ip_address(hostname):
if not self._unsafe:
return filtered
domains: Iterable[str] = (hostname,)
else:
# Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
domains = itertools.accumulate(
reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED
)
# Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH)
# Create every combination of (domain, path) pairs.
pairs = itertools.product(domains, paths)
path_len = len(request_url.path)
# Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
for p in pairs:
for name, cookie in self._cookies[p].items():
domain = cookie["domain"]
if (domain, name) in self._host_only_cookies and domain != hostname:
continue
# Skip edge case when the cookie has a trailing slash but request doesn't.
if len(cookie["path"]) > path_len:
continue
if is_not_secure and cookie["secure"]:
continue
# We already built the Morsel so reuse it here
if name in self._morsel_cache[p]:
filtered[name] = self._morsel_cache[p][name]
continue
# It's critical we use the Morsel so the coded_value
# (based on cookie version) is preserved
mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
self._morsel_cache[p][name] = mrsl_val
filtered[name] = mrsl_val
return filtered
@staticmethod
def _is_domain_match(domain: str, hostname: str) -> bool:
"""Implements domain matching adhering to RFC 6265."""
if hostname == domain:
return True
if not hostname.endswith(domain):
return False
non_matching = hostname[: -len(domain)]
if not non_matching.endswith("."):
return False
return not is_ip_address(hostname)
@classmethod
def _parse_date(cls, date_str: str) -> Optional[int]:
"""Implements date string parsing adhering to RFC 6265."""
if not date_str:
return None
found_time = False
found_day = False
found_month = False
found_year = False
hour = minute = second = 0
day = 0
month = 0
year = 0
for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
token = token_match.group("token")
if not found_time:
time_match = cls.DATE_HMS_TIME_RE.match(token)
if time_match:
found_time = True
hour, minute, second = (int(s) for s in time_match.groups())
continue
if not found_day:
day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
if day_match:
found_day = True
day = int(day_match.group())
continue
if not found_month:
month_match = cls.DATE_MONTH_RE.match(token)
if month_match:
found_month = True
assert month_match.lastindex is not None
month = month_match.lastindex
continue
if not found_year:
year_match = cls.DATE_YEAR_RE.match(token)
if year_match:
found_year = True
year = int(year_match.group())
if 70 <= year <= 99:
year += 1900
elif 0 <= year <= 69:
year += 2000
if False in (found_day, found_month, found_year, found_time):
return None
if not 1 <= day <= 31:
return None
if year < 1601 or hour > 23 or minute > 59 or second > 59:
return None
return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
class DummyCookieJar(AbstractCookieJar):
"""Implements a dummy cookie storage.
It can be used with the ClientSession when no cookie processing is needed.
"""
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(loop=loop)
def __iter__(self) -> "Iterator[Morsel[str]]":
while False:
yield None
def __len__(self) -> int:
return 0
@property
def quote_cookie(self) -> bool:
return True
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
pass
def clear_domain(self, domain: str) -> None:
pass
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
pass
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
return SimpleCookie()

View File

@@ -0,0 +1,182 @@
import io
import warnings
from typing import Any, Iterable, List, Optional
from urllib.parse import urlencode
from multidict import MultiDict, MultiDictProxy
from . import hdrs, multipart, payload
from .helpers import guess_filename
from .payload import Payload
__all__ = ("FormData",)
class FormData:
"""Helper class for form body generation.
Supports multipart/form-data and application/x-www-form-urlencoded.
"""
def __init__(
self,
fields: Iterable[Any] = (),
quote_fields: bool = True,
charset: Optional[str] = None,
*,
default_to_multipart: bool = False,
) -> None:
self._writer = multipart.MultipartWriter("form-data")
self._fields: List[Any] = []
self._is_multipart = default_to_multipart
self._is_processed = False
self._quote_fields = quote_fields
self._charset = charset
if isinstance(fields, dict):
fields = list(fields.items())
elif not isinstance(fields, (list, tuple)):
fields = (fields,)
self.add_fields(*fields)
@property
def is_multipart(self) -> bool:
return self._is_multipart
def add_field(
self,
name: str,
value: Any,
*,
content_type: Optional[str] = None,
filename: Optional[str] = None,
content_transfer_encoding: Optional[str] = None,
) -> None:
if isinstance(value, io.IOBase):
self._is_multipart = True
elif isinstance(value, (bytes, bytearray, memoryview)):
msg = (
"In v4, passing bytes will no longer create a file field. "
"Please explicitly use the filename parameter or pass a BytesIO object."
)
if filename is None and content_transfer_encoding is None:
warnings.warn(msg, DeprecationWarning)
filename = name
type_options: MultiDict[str] = MultiDict({"name": name})
if filename is not None and not isinstance(filename, str):
raise TypeError("filename must be an instance of str. Got: %s" % filename)
if filename is None and isinstance(value, io.IOBase):
filename = guess_filename(value, name)
if filename is not None:
type_options["filename"] = filename
self._is_multipart = True
headers = {}
if content_type is not None:
if not isinstance(content_type, str):
raise TypeError(
"content_type must be an instance of str. Got: %s" % content_type
)
headers[hdrs.CONTENT_TYPE] = content_type
self._is_multipart = True
if content_transfer_encoding is not None:
if not isinstance(content_transfer_encoding, str):
raise TypeError(
"content_transfer_encoding must be an instance"
" of str. Got: %s" % content_transfer_encoding
)
msg = (
"content_transfer_encoding is deprecated. "
"To maintain compatibility with v4 please pass a BytesPayload."
)
warnings.warn(msg, DeprecationWarning)
self._is_multipart = True
self._fields.append((type_options, headers, value))
def add_fields(self, *fields: Any) -> None:
to_add = list(fields)
while to_add:
rec = to_add.pop(0)
if isinstance(rec, io.IOBase):
k = guess_filename(rec, "unknown")
self.add_field(k, rec) # type: ignore[arg-type]
elif isinstance(rec, (MultiDictProxy, MultiDict)):
to_add.extend(rec.items())
elif isinstance(rec, (list, tuple)) and len(rec) == 2:
k, fp = rec
self.add_field(k, fp) # type: ignore[arg-type]
else:
raise TypeError(
"Only io.IOBase, multidict and (name, file) "
"pairs allowed, use .add_field() for passing "
"more complex parameters, got {!r}".format(rec)
)
def _gen_form_urlencoded(self) -> payload.BytesPayload:
# form data (x-www-form-urlencoded)
data = []
for type_options, _, value in self._fields:
data.append((type_options["name"], value))
charset = self._charset if self._charset is not None else "utf-8"
if charset == "utf-8":
content_type = "application/x-www-form-urlencoded"
else:
content_type = "application/x-www-form-urlencoded; charset=%s" % charset
return payload.BytesPayload(
urlencode(data, doseq=True, encoding=charset).encode(),
content_type=content_type,
)
def _gen_form_data(self) -> multipart.MultipartWriter:
"""Encode a list of fields using the multipart/form-data MIME format"""
if self._is_processed:
raise RuntimeError("Form data has been processed already")
for dispparams, headers, value in self._fields:
try:
if hdrs.CONTENT_TYPE in headers:
part = payload.get_payload(
value,
content_type=headers[hdrs.CONTENT_TYPE],
headers=headers,
encoding=self._charset,
)
else:
part = payload.get_payload(
value, headers=headers, encoding=self._charset
)
except Exception as exc:
raise TypeError(
"Can not serialize value type: %r\n "
"headers: %r\n value: %r" % (type(value), headers, value)
) from exc
if dispparams:
part.set_content_disposition(
"form-data", quote_fields=self._quote_fields, **dispparams
)
# FIXME cgi.FieldStorage doesn't likes body parts with
# Content-Length which were sent via chunked transfer encoding
assert part.headers is not None
part.headers.popall(hdrs.CONTENT_LENGTH, None)
self._writer.append_payload(part)
self._is_processed = True
return self._writer
def __call__(self) -> Payload:
if self._is_multipart:
return self._gen_form_data()
else:
return self._gen_form_urlencoded()

View File

@@ -0,0 +1,121 @@
"""HTTP Headers constants."""
# After changing the file content call ./tools/gen.py
# to regenerate the headers parser
import itertools
from typing import Final, Set
from multidict import istr
METH_ANY: Final[str] = "*"
METH_CONNECT: Final[str] = "CONNECT"
METH_HEAD: Final[str] = "HEAD"
METH_GET: Final[str] = "GET"
METH_DELETE: Final[str] = "DELETE"
METH_OPTIONS: Final[str] = "OPTIONS"
METH_PATCH: Final[str] = "PATCH"
METH_POST: Final[str] = "POST"
METH_PUT: Final[str] = "PUT"
METH_TRACE: Final[str] = "TRACE"
METH_ALL: Final[Set[str]] = {
METH_CONNECT,
METH_HEAD,
METH_GET,
METH_DELETE,
METH_OPTIONS,
METH_PATCH,
METH_POST,
METH_PUT,
METH_TRACE,
}
ACCEPT: Final[istr] = istr("Accept")
ACCEPT_CHARSET: Final[istr] = istr("Accept-Charset")
ACCEPT_ENCODING: Final[istr] = istr("Accept-Encoding")
ACCEPT_LANGUAGE: Final[istr] = istr("Accept-Language")
ACCEPT_RANGES: Final[istr] = istr("Accept-Ranges")
ACCESS_CONTROL_MAX_AGE: Final[istr] = istr("Access-Control-Max-Age")
ACCESS_CONTROL_ALLOW_CREDENTIALS: Final[istr] = istr("Access-Control-Allow-Credentials")
ACCESS_CONTROL_ALLOW_HEADERS: Final[istr] = istr("Access-Control-Allow-Headers")
ACCESS_CONTROL_ALLOW_METHODS: Final[istr] = istr("Access-Control-Allow-Methods")
ACCESS_CONTROL_ALLOW_ORIGIN: Final[istr] = istr("Access-Control-Allow-Origin")
ACCESS_CONTROL_EXPOSE_HEADERS: Final[istr] = istr("Access-Control-Expose-Headers")
ACCESS_CONTROL_REQUEST_HEADERS: Final[istr] = istr("Access-Control-Request-Headers")
ACCESS_CONTROL_REQUEST_METHOD: Final[istr] = istr("Access-Control-Request-Method")
AGE: Final[istr] = istr("Age")
ALLOW: Final[istr] = istr("Allow")
AUTHORIZATION: Final[istr] = istr("Authorization")
CACHE_CONTROL: Final[istr] = istr("Cache-Control")
CONNECTION: Final[istr] = istr("Connection")
CONTENT_DISPOSITION: Final[istr] = istr("Content-Disposition")
CONTENT_ENCODING: Final[istr] = istr("Content-Encoding")
CONTENT_LANGUAGE: Final[istr] = istr("Content-Language")
CONTENT_LENGTH: Final[istr] = istr("Content-Length")
CONTENT_LOCATION: Final[istr] = istr("Content-Location")
CONTENT_MD5: Final[istr] = istr("Content-MD5")
CONTENT_RANGE: Final[istr] = istr("Content-Range")
CONTENT_TRANSFER_ENCODING: Final[istr] = istr("Content-Transfer-Encoding")
CONTENT_TYPE: Final[istr] = istr("Content-Type")
COOKIE: Final[istr] = istr("Cookie")
DATE: Final[istr] = istr("Date")
DESTINATION: Final[istr] = istr("Destination")
DIGEST: Final[istr] = istr("Digest")
ETAG: Final[istr] = istr("Etag")
EXPECT: Final[istr] = istr("Expect")
EXPIRES: Final[istr] = istr("Expires")
FORWARDED: Final[istr] = istr("Forwarded")
FROM: Final[istr] = istr("From")
HOST: Final[istr] = istr("Host")
IF_MATCH: Final[istr] = istr("If-Match")
IF_MODIFIED_SINCE: Final[istr] = istr("If-Modified-Since")
IF_NONE_MATCH: Final[istr] = istr("If-None-Match")
IF_RANGE: Final[istr] = istr("If-Range")
IF_UNMODIFIED_SINCE: Final[istr] = istr("If-Unmodified-Since")
KEEP_ALIVE: Final[istr] = istr("Keep-Alive")
LAST_EVENT_ID: Final[istr] = istr("Last-Event-ID")
LAST_MODIFIED: Final[istr] = istr("Last-Modified")
LINK: Final[istr] = istr("Link")
LOCATION: Final[istr] = istr("Location")
MAX_FORWARDS: Final[istr] = istr("Max-Forwards")
ORIGIN: Final[istr] = istr("Origin")
PRAGMA: Final[istr] = istr("Pragma")
PROXY_AUTHENTICATE: Final[istr] = istr("Proxy-Authenticate")
PROXY_AUTHORIZATION: Final[istr] = istr("Proxy-Authorization")
RANGE: Final[istr] = istr("Range")
REFERER: Final[istr] = istr("Referer")
RETRY_AFTER: Final[istr] = istr("Retry-After")
SEC_WEBSOCKET_ACCEPT: Final[istr] = istr("Sec-WebSocket-Accept")
SEC_WEBSOCKET_VERSION: Final[istr] = istr("Sec-WebSocket-Version")
SEC_WEBSOCKET_PROTOCOL: Final[istr] = istr("Sec-WebSocket-Protocol")
SEC_WEBSOCKET_EXTENSIONS: Final[istr] = istr("Sec-WebSocket-Extensions")
SEC_WEBSOCKET_KEY: Final[istr] = istr("Sec-WebSocket-Key")
SEC_WEBSOCKET_KEY1: Final[istr] = istr("Sec-WebSocket-Key1")
SERVER: Final[istr] = istr("Server")
SET_COOKIE: Final[istr] = istr("Set-Cookie")
TE: Final[istr] = istr("TE")
TRAILER: Final[istr] = istr("Trailer")
TRANSFER_ENCODING: Final[istr] = istr("Transfer-Encoding")
UPGRADE: Final[istr] = istr("Upgrade")
URI: Final[istr] = istr("URI")
USER_AGENT: Final[istr] = istr("User-Agent")
VARY: Final[istr] = istr("Vary")
VIA: Final[istr] = istr("Via")
WANT_DIGEST: Final[istr] = istr("Want-Digest")
WARNING: Final[istr] = istr("Warning")
WWW_AUTHENTICATE: Final[istr] = istr("WWW-Authenticate")
X_FORWARDED_FOR: Final[istr] = istr("X-Forwarded-For")
X_FORWARDED_HOST: Final[istr] = istr("X-Forwarded-Host")
X_FORWARDED_PROTO: Final[istr] = istr("X-Forwarded-Proto")
# These are the upper/lower case variants of the headers/methods
# Example: {'hOst', 'host', 'HoST', 'HOSt', 'hOsT', 'HosT', 'hoSt', ...}
METH_HEAD_ALL: Final = frozenset(
map("".join, itertools.product(*zip(METH_HEAD.upper(), METH_HEAD.lower())))
)
METH_CONNECT_ALL: Final = frozenset(
map("".join, itertools.product(*zip(METH_CONNECT.upper(), METH_CONNECT.lower())))
)
HOST_ALL: Final = frozenset(
map("".join, itertools.product(*zip(HOST.upper(), HOST.lower())))
)

View File

@@ -0,0 +1,958 @@
"""Various helper functions"""
import asyncio
import base64
import binascii
import contextlib
import datetime
import enum
import functools
import inspect
import netrc
import os
import platform
import re
import sys
import time
import weakref
from collections import namedtuple
from contextlib import suppress
from email.parser import HeaderParser
from email.utils import parsedate
from math import ceil
from pathlib import Path
from types import MappingProxyType, TracebackType
from typing import (
Any,
Callable,
ContextManager,
Dict,
Generator,
Generic,
Iterable,
Iterator,
List,
Mapping,
Optional,
Protocol,
Tuple,
Type,
TypeVar,
Union,
get_args,
overload,
)
from urllib.parse import quote
from urllib.request import getproxies, proxy_bypass
import attr
from multidict import MultiDict, MultiDictProxy, MultiMapping
from propcache.api import under_cached_property as reify
from yarl import URL
from . import hdrs
from .log import client_logger
if sys.version_info >= (3, 11):
import asyncio as async_timeout
else:
import async_timeout
__all__ = ("BasicAuth", "ChainMapProxy", "ETag", "reify")
IS_MACOS = platform.system() == "Darwin"
IS_WINDOWS = platform.system() == "Windows"
PY_310 = sys.version_info >= (3, 10)
PY_311 = sys.version_info >= (3, 11)
_T = TypeVar("_T")
_S = TypeVar("_S")
_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
sentinel = _SENTINEL.sentinel
NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200)))
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL
DEBUG = sys.flags.dev_mode or (
not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
)
CHAR = {chr(i) for i in range(0, 128)}
CTL = {chr(i) for i in range(0, 32)} | {
chr(127),
}
SEPARATORS = {
"(",
")",
"<",
">",
"@",
",",
";",
":",
"\\",
'"',
"/",
"[",
"]",
"?",
"=",
"{",
"}",
" ",
chr(9),
}
TOKEN = CHAR ^ CTL ^ SEPARATORS
class noop:
def __await__(self) -> Generator[None, None, None]:
yield
class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
"""Http basic authentication helper."""
def __new__(
cls, login: str, password: str = "", encoding: str = "latin1"
) -> "BasicAuth":
if login is None:
raise ValueError("None is not allowed as login value")
if password is None:
raise ValueError("None is not allowed as password value")
if ":" in login:
raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
return super().__new__(cls, login, password, encoding)
@classmethod
def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
"""Create a BasicAuth object from an Authorization HTTP header."""
try:
auth_type, encoded_credentials = auth_header.split(" ", 1)
except ValueError:
raise ValueError("Could not parse authorization header.")
if auth_type.lower() != "basic":
raise ValueError("Unknown authorization method %s" % auth_type)
try:
decoded = base64.b64decode(
encoded_credentials.encode("ascii"), validate=True
).decode(encoding)
except binascii.Error:
raise ValueError("Invalid base64 encoding.")
try:
# RFC 2617 HTTP Authentication
# https://www.ietf.org/rfc/rfc2617.txt
# the colon must be present, but the username and password may be
# otherwise blank.
username, password = decoded.split(":", 1)
except ValueError:
raise ValueError("Invalid credentials.")
return cls(username, password, encoding=encoding)
@classmethod
def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
"""Create BasicAuth from url."""
if not isinstance(url, URL):
raise TypeError("url should be yarl.URL instance")
# Check raw_user and raw_password first as yarl is likely
# to already have these values parsed from the netloc in the cache.
if url.raw_user is None and url.raw_password is None:
return None
return cls(url.user or "", url.password or "", encoding=encoding)
def encode(self) -> str:
"""Encode credentials."""
creds = (f"{self.login}:{self.password}").encode(self.encoding)
return "Basic %s" % base64.b64encode(creds).decode(self.encoding)
def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
"""Remove user and password from URL if present and return BasicAuth object."""
# Check raw_user and raw_password first as yarl is likely
# to already have these values parsed from the netloc in the cache.
if url.raw_user is None and url.raw_password is None:
return url, None
return url.with_user(None), BasicAuth(url.user or "", url.password or "")
def netrc_from_env() -> Optional[netrc.netrc]:
"""Load netrc from file.
Attempt to load it from the path specified by the env-var
NETRC or in the default location in the user's home directory.
Returns None if it couldn't be found or fails to parse.
"""
netrc_env = os.environ.get("NETRC")
if netrc_env is not None:
netrc_path = Path(netrc_env)
else:
try:
home_dir = Path.home()
except RuntimeError as e: # pragma: no cover
# if pathlib can't resolve home, it may raise a RuntimeError
client_logger.debug(
"Could not resolve home directory when "
"trying to look for .netrc file: %s",
e,
)
return None
netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc")
try:
return netrc.netrc(str(netrc_path))
except netrc.NetrcParseError as e:
client_logger.warning("Could not parse .netrc file: %s", e)
except OSError as e:
netrc_exists = False
with contextlib.suppress(OSError):
netrc_exists = netrc_path.is_file()
# we couldn't read the file (doesn't exist, permissions, etc.)
if netrc_env or netrc_exists:
# only warn if the environment wanted us to load it,
# or it appears like the default file does actually exist
client_logger.warning("Could not read .netrc file: %s", e)
return None
@attr.s(auto_attribs=True, frozen=True, slots=True)
class ProxyInfo:
proxy: URL
proxy_auth: Optional[BasicAuth]
def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth:
"""
Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``.
:raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
entry is found for the ``host``.
"""
if netrc_obj is None:
raise LookupError("No .netrc file found")
auth_from_netrc = netrc_obj.authenticators(host)
if auth_from_netrc is None:
raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
login, account, password = auth_from_netrc
# TODO(PY311): username = login or account
# Up to python 3.10, account could be None if not specified,
# and login will be empty string if not specified. From 3.11,
# login and account will be empty string if not specified.
username = login if (login or account is None) else account
# TODO(PY311): Remove this, as password will be empty string
# if not specified
if password is None:
password = ""
return BasicAuth(username, password)
def proxies_from_env() -> Dict[str, ProxyInfo]:
proxy_urls = {
k: URL(v)
for k, v in getproxies().items()
if k in ("http", "https", "ws", "wss")
}
netrc_obj = netrc_from_env()
stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
ret = {}
for proto, val in stripped.items():
proxy, auth = val
if proxy.scheme in ("https", "wss"):
client_logger.warning(
"%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
)
continue
if netrc_obj and auth is None:
if proxy.host is not None:
try:
auth = basicauth_from_netrc(netrc_obj, proxy.host)
except LookupError:
auth = None
ret[proto] = ProxyInfo(proxy, auth)
return ret
def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
"""Get a permitted proxy for the given URL from the env."""
if url.host is not None and proxy_bypass(url.host):
raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
proxies_in_env = proxies_from_env()
try:
proxy_info = proxies_in_env[url.scheme]
except KeyError:
raise LookupError(f"No proxies found for `{url!s}` in the env")
else:
return proxy_info.proxy, proxy_info.proxy_auth
@attr.s(auto_attribs=True, frozen=True, slots=True)
class MimeType:
type: str
subtype: str
suffix: str
parameters: "MultiDictProxy[str]"
@functools.lru_cache(maxsize=56)
def parse_mimetype(mimetype: str) -> MimeType:
"""Parses a MIME type into its components.
mimetype is a MIME type string.
Returns a MimeType object.
Example:
>>> parse_mimetype('text/html; charset=utf-8')
MimeType(type='text', subtype='html', suffix='',
parameters={'charset': 'utf-8'})
"""
if not mimetype:
return MimeType(
type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
)
parts = mimetype.split(";")
params: MultiDict[str] = MultiDict()
for item in parts[1:]:
if not item:
continue
key, _, value = item.partition("=")
params.add(key.lower().strip(), value.strip(' "'))
fulltype = parts[0].strip().lower()
if fulltype == "*":
fulltype = "*/*"
mtype, _, stype = fulltype.partition("/")
stype, _, suffix = stype.partition("+")
return MimeType(
type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
)
@functools.lru_cache(maxsize=56)
def parse_content_type(raw: str) -> Tuple[str, MappingProxyType[str, str]]:
"""Parse Content-Type header.
Returns a tuple of the parsed content type and a
MappingProxyType of parameters.
"""
msg = HeaderParser().parsestr(f"Content-Type: {raw}")
content_type = msg.get_content_type()
params = msg.get_params(())
content_dict = dict(params[1:]) # First element is content type again
return content_type, MappingProxyType(content_dict)
def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]:
name = getattr(obj, "name", None)
if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
return Path(name).name
return default
not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
def quoted_string(content: str) -> str:
"""Return 7-bit content as quoted-string.
Format content into a quoted-string as defined in RFC5322 for
Internet Message Format. Notice that this is not the 8-bit HTTP
format, but the 7-bit email format. Content must be in usascii or
a ValueError is raised.
"""
if not (QCONTENT > set(content)):
raise ValueError(f"bad content for quoted-string {content!r}")
return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
def content_disposition_header(
disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str
) -> str:
"""Sets ``Content-Disposition`` header for MIME.
This is the MIME payload Content-Disposition header from RFC 2183
and RFC 7579 section 4.2, not the HTTP Content-Disposition from
RFC 6266.
disptype is a disposition type: inline, attachment, form-data.
Should be valid extension token (see RFC 2183)
quote_fields performs value quoting to 7-bit MIME headers
according to RFC 7578. Set to quote_fields to False if recipient
can take 8-bit file names and field values.
_charset specifies the charset to use when quote_fields is True.
params is a dict with disposition params.
"""
if not disptype or not (TOKEN > set(disptype)):
raise ValueError(f"bad content disposition type {disptype!r}")
value = disptype
if params:
lparams = []
for key, val in params.items():
if not key or not (TOKEN > set(key)):
raise ValueError(f"bad content disposition parameter {key!r}={val!r}")
if quote_fields:
if key.lower() == "filename":
qval = quote(val, "", encoding=_charset)
lparams.append((key, '"%s"' % qval))
else:
try:
qval = quoted_string(val)
except ValueError:
qval = "".join(
(_charset, "''", quote(val, "", encoding=_charset))
)
lparams.append((key + "*", qval))
else:
lparams.append((key, '"%s"' % qval))
else:
qval = val.replace("\\", "\\\\").replace('"', '\\"')
lparams.append((key, '"%s"' % qval))
sparams = "; ".join("=".join(pair) for pair in lparams)
value = "; ".join((value, sparams))
return value
def is_ip_address(host: Optional[str]) -> bool:
"""Check if host looks like an IP Address.
This check is only meant as a heuristic to ensure that
a host is not a domain name.
"""
if not host:
return False
# For a host to be an ipv4 address, it must be all numeric.
# The host must contain a colon to be an IPv6 address.
return ":" in host or host.replace(".", "").isdigit()
_cached_current_datetime: Optional[int] = None
_cached_formatted_datetime = ""
def rfc822_formatted_time() -> str:
global _cached_current_datetime
global _cached_formatted_datetime
now = int(time.time())
if now != _cached_current_datetime:
# Weekday and month names for HTTP date/time formatting;
# always English!
# Tuples are constants stored in codeobject!
_weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
_monthname = (
"", # Dummy so we can use 1-based month numbers
"Jan",
"Feb",
"Mar",
"Apr",
"May",
"Jun",
"Jul",
"Aug",
"Sep",
"Oct",
"Nov",
"Dec",
)
year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
_cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
_weekdayname[wd],
day,
_monthname[month],
year,
hh,
mm,
ss,
)
_cached_current_datetime = now
return _cached_formatted_datetime
def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None:
ref, name = info
ob = ref()
if ob is not None:
with suppress(Exception):
getattr(ob, name)()
def weakref_handle(
ob: object,
name: str,
timeout: float,
loop: asyncio.AbstractEventLoop,
timeout_ceil_threshold: float = 5,
) -> Optional[asyncio.TimerHandle]:
if timeout is not None and timeout > 0:
when = loop.time() + timeout
if timeout >= timeout_ceil_threshold:
when = ceil(when)
return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
return None
def call_later(
cb: Callable[[], Any],
timeout: float,
loop: asyncio.AbstractEventLoop,
timeout_ceil_threshold: float = 5,
) -> Optional[asyncio.TimerHandle]:
if timeout is None or timeout <= 0:
return None
now = loop.time()
when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
return loop.call_at(when, cb)
def calculate_timeout_when(
loop_time: float,
timeout: float,
timeout_ceiling_threshold: float,
) -> float:
"""Calculate when to execute a timeout."""
when = loop_time + timeout
if timeout > timeout_ceiling_threshold:
return ceil(when)
return when
class TimeoutHandle:
"""Timeout handle"""
__slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks")
def __init__(
self,
loop: asyncio.AbstractEventLoop,
timeout: Optional[float],
ceil_threshold: float = 5,
) -> None:
self._timeout = timeout
self._loop = loop
self._ceil_threshold = ceil_threshold
self._callbacks: List[
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
] = []
def register(
self, callback: Callable[..., None], *args: Any, **kwargs: Any
) -> None:
self._callbacks.append((callback, args, kwargs))
def close(self) -> None:
self._callbacks.clear()
def start(self) -> Optional[asyncio.TimerHandle]:
timeout = self._timeout
if timeout is not None and timeout > 0:
when = self._loop.time() + timeout
if timeout >= self._ceil_threshold:
when = ceil(when)
return self._loop.call_at(when, self.__call__)
else:
return None
def timer(self) -> "BaseTimerContext":
if self._timeout is not None and self._timeout > 0:
timer = TimerContext(self._loop)
self.register(timer.timeout)
return timer
else:
return TimerNoop()
def __call__(self) -> None:
for cb, args, kwargs in self._callbacks:
with suppress(Exception):
cb(*args, **kwargs)
self._callbacks.clear()
class BaseTimerContext(ContextManager["BaseTimerContext"]):
__slots__ = ()
def assert_timeout(self) -> None:
"""Raise TimeoutError if timeout has been exceeded."""
class TimerNoop(BaseTimerContext):
__slots__ = ()
def __enter__(self) -> BaseTimerContext:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
return
class TimerContext(BaseTimerContext):
"""Low resolution timeout context manager"""
__slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling")
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._tasks: List[asyncio.Task[Any]] = []
self._cancelled = False
self._cancelling = 0
def assert_timeout(self) -> None:
"""Raise TimeoutError if timer has already been cancelled."""
if self._cancelled:
raise asyncio.TimeoutError from None
def __enter__(self) -> BaseTimerContext:
task = asyncio.current_task(loop=self._loop)
if task is None:
raise RuntimeError("Timeout context manager should be used inside a task")
if sys.version_info >= (3, 11):
# Remember if the task was already cancelling
# so when we __exit__ we can decide if we should
# raise asyncio.TimeoutError or let the cancellation propagate
self._cancelling = task.cancelling()
if self._cancelled:
raise asyncio.TimeoutError from None
self._tasks.append(task)
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
enter_task: Optional[asyncio.Task[Any]] = None
if self._tasks:
enter_task = self._tasks.pop()
if exc_type is asyncio.CancelledError and self._cancelled:
assert enter_task is not None
# The timeout was hit, and the task was cancelled
# so we need to uncancel the last task that entered the context manager
# since the cancellation should not leak out of the context manager
if sys.version_info >= (3, 11):
# If the task was already cancelling don't raise
# asyncio.TimeoutError and instead return None
# to allow the cancellation to propagate
if enter_task.uncancel() > self._cancelling:
return None
raise asyncio.TimeoutError from exc_val
return None
def timeout(self) -> None:
if not self._cancelled:
for task in set(self._tasks):
task.cancel()
self._cancelled = True
def ceil_timeout(
delay: Optional[float], ceil_threshold: float = 5
) -> async_timeout.Timeout:
if delay is None or delay <= 0:
return async_timeout.timeout(None)
loop = asyncio.get_running_loop()
now = loop.time()
when = now + delay
if delay > ceil_threshold:
when = ceil(when)
return async_timeout.timeout_at(when)
class HeadersMixin:
"""Mixin for handling headers."""
ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"])
_headers: MultiMapping[str]
_content_type: Optional[str] = None
_content_dict: Optional[Dict[str, str]] = None
_stored_content_type: Union[str, None, _SENTINEL] = sentinel
def _parse_content_type(self, raw: Optional[str]) -> None:
self._stored_content_type = raw
if raw is None:
# default value according to RFC 2616
self._content_type = "application/octet-stream"
self._content_dict = {}
else:
content_type, content_mapping_proxy = parse_content_type(raw)
self._content_type = content_type
# _content_dict needs to be mutable so we can update it
self._content_dict = content_mapping_proxy.copy()
@property
def content_type(self) -> str:
"""The value of content part for Content-Type HTTP header."""
raw = self._headers.get(hdrs.CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
assert self._content_type is not None
return self._content_type
@property
def charset(self) -> Optional[str]:
"""The value of charset part for Content-Type HTTP header."""
raw = self._headers.get(hdrs.CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
assert self._content_dict is not None
return self._content_dict.get("charset")
@property
def content_length(self) -> Optional[int]:
"""The value of Content-Length HTTP header."""
content_length = self._headers.get(hdrs.CONTENT_LENGTH)
return None if content_length is None else int(content_length)
def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
if not fut.done():
fut.set_result(result)
_EXC_SENTINEL = BaseException()
class ErrorableProtocol(Protocol):
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = ...,
) -> None: ... # pragma: no cover
def set_exception(
fut: "asyncio.Future[_T] | ErrorableProtocol",
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
"""Set future exception.
If the future is marked as complete, this function is a no-op.
:param exc_cause: An exception that is a direct cause of ``exc``.
Only set if provided.
"""
if asyncio.isfuture(fut) and fut.done():
return
exc_is_sentinel = exc_cause is _EXC_SENTINEL
exc_causes_itself = exc is exc_cause
if not exc_is_sentinel and not exc_causes_itself:
exc.__cause__ = exc_cause
fut.set_exception(exc)
@functools.total_ordering
class AppKey(Generic[_T]):
"""Keys for static typing support in Application."""
__slots__ = ("_name", "_t", "__orig_class__")
# This may be set by Python when instantiating with a generic type. We need to
# support this, in order to support types that are not concrete classes,
# like Iterable, which can't be passed as the second parameter to __init__.
__orig_class__: Type[object]
def __init__(self, name: str, t: Optional[Type[_T]] = None):
# Prefix with module name to help deduplicate key names.
frame = inspect.currentframe()
while frame:
if frame.f_code.co_name == "<module>":
module: str = frame.f_globals["__name__"]
break
frame = frame.f_back
self._name = module + "." + name
self._t = t
def __lt__(self, other: object) -> bool:
if isinstance(other, AppKey):
return self._name < other._name
return True # Order AppKey above other types.
def __repr__(self) -> str:
t = self._t
if t is None:
with suppress(AttributeError):
# Set to type arg.
t = get_args(self.__orig_class__)[0]
if t is None:
t_repr = "<<Unknown>>"
elif isinstance(t, type):
if t.__module__ == "builtins":
t_repr = t.__qualname__
else:
t_repr = f"{t.__module__}.{t.__qualname__}"
else:
t_repr = repr(t)
return f"<AppKey({self._name}, type={t_repr})>"
class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]):
__slots__ = ("_maps",)
def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None:
self._maps = tuple(maps)
def __init_subclass__(cls) -> None:
raise TypeError(
"Inheritance class {} from ChainMapProxy "
"is forbidden".format(cls.__name__)
)
@overload # type: ignore[override]
def __getitem__(self, key: AppKey[_T]) -> _T: ...
@overload
def __getitem__(self, key: str) -> Any: ...
def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
for mapping in self._maps:
try:
return mapping[key]
except KeyError:
pass
raise KeyError(key)
@overload # type: ignore[override]
def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ...
@overload
def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ...
@overload
def get(self, key: str, default: Any = ...) -> Any: ...
def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
try:
return self[key]
except KeyError:
return default
def __len__(self) -> int:
# reuses stored hash values if possible
return len(set().union(*self._maps))
def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
d: Dict[Union[str, AppKey[Any]], Any] = {}
for mapping in reversed(self._maps):
# reuses stored hash values if possible
d.update(mapping)
return iter(d)
def __contains__(self, key: object) -> bool:
return any(key in m for m in self._maps)
def __bool__(self) -> bool:
return any(self._maps)
def __repr__(self) -> str:
content = ", ".join(map(repr, self._maps))
return f"ChainMapProxy({content})"
# https://tools.ietf.org/html/rfc7232#section-2.3
_ETAGC = r"[!\x23-\x7E\x80-\xff]+"
_ETAGC_RE = re.compile(_ETAGC)
_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
ETAG_ANY = "*"
@attr.s(auto_attribs=True, frozen=True, slots=True)
class ETag:
value: str
is_weak: bool = False
def validate_etag_value(value: str) -> None:
if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
raise ValueError(
f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
)
def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]:
"""Process a date string, return a datetime object"""
if date_str is not None:
timetuple = parsedate(date_str)
if timetuple is not None:
with suppress(ValueError):
return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
return None
@functools.lru_cache
def must_be_empty_body(method: str, code: int) -> bool:
"""Check if a request must return an empty body."""
return (
code in EMPTY_BODY_STATUS_CODES
or method in EMPTY_BODY_METHODS
or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL)
)
def should_remove_content_length(method: str, code: int) -> bool:
"""Check if a Content-Length header should be removed.
This should always be a subset of must_be_empty_body
"""
# https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8
# https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4
return code in EMPTY_BODY_STATUS_CODES or (
200 <= code < 300 and method in hdrs.METH_CONNECT_ALL
)

View File

@@ -0,0 +1,72 @@
import sys
from http import HTTPStatus
from typing import Mapping, Tuple
from . import __version__
from .http_exceptions import HttpProcessingError as HttpProcessingError
from .http_parser import (
HeadersParser as HeadersParser,
HttpParser as HttpParser,
HttpRequestParser as HttpRequestParser,
HttpResponseParser as HttpResponseParser,
RawRequestMessage as RawRequestMessage,
RawResponseMessage as RawResponseMessage,
)
from .http_websocket import (
WS_CLOSED_MESSAGE as WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE as WS_CLOSING_MESSAGE,
WS_KEY as WS_KEY,
WebSocketError as WebSocketError,
WebSocketReader as WebSocketReader,
WebSocketWriter as WebSocketWriter,
WSCloseCode as WSCloseCode,
WSMessage as WSMessage,
WSMsgType as WSMsgType,
ws_ext_gen as ws_ext_gen,
ws_ext_parse as ws_ext_parse,
)
from .http_writer import (
HttpVersion as HttpVersion,
HttpVersion10 as HttpVersion10,
HttpVersion11 as HttpVersion11,
StreamWriter as StreamWriter,
)
__all__ = (
"HttpProcessingError",
"RESPONSES",
"SERVER_SOFTWARE",
# .http_writer
"StreamWriter",
"HttpVersion",
"HttpVersion10",
"HttpVersion11",
# .http_parser
"HeadersParser",
"HttpParser",
"HttpRequestParser",
"HttpResponseParser",
"RawRequestMessage",
"RawResponseMessage",
# .http_websocket
"WS_CLOSED_MESSAGE",
"WS_CLOSING_MESSAGE",
"WS_KEY",
"WebSocketReader",
"WebSocketWriter",
"ws_ext_gen",
"ws_ext_parse",
"WSMessage",
"WebSocketError",
"WSMsgType",
"WSCloseCode",
)
SERVER_SOFTWARE: str = "Python/{0[0]}.{0[1]} aiohttp/{1}".format(
sys.version_info, __version__
)
RESPONSES: Mapping[int, Tuple[str, str]] = {
v: (v.phrase, v.description) for v in HTTPStatus.__members__.values()
}

View File

@@ -0,0 +1,112 @@
"""Low-level http related exceptions."""
from textwrap import indent
from typing import Optional, Union
from .typedefs import _CIMultiDict
__all__ = ("HttpProcessingError",)
class HttpProcessingError(Exception):
"""HTTP error.
Shortcut for raising HTTP errors with custom code, message and headers.
code: HTTP Error code.
message: (optional) Error message.
headers: (optional) Headers to be sent in response, a list of pairs
"""
code = 0
message = ""
headers = None
def __init__(
self,
*,
code: Optional[int] = None,
message: str = "",
headers: Optional[_CIMultiDict] = None,
) -> None:
if code is not None:
self.code = code
self.headers = headers
self.message = message
def __str__(self) -> str:
msg = indent(self.message, " ")
return f"{self.code}, message:\n{msg}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: {self.code}, message={self.message!r}>"
class BadHttpMessage(HttpProcessingError):
code = 400
message = "Bad Request"
def __init__(self, message: str, *, headers: Optional[_CIMultiDict] = None) -> None:
super().__init__(message=message, headers=headers)
self.args = (message,)
class HttpBadRequest(BadHttpMessage):
code = 400
message = "Bad Request"
class PayloadEncodingError(BadHttpMessage):
"""Base class for payload errors"""
class ContentEncodingError(PayloadEncodingError):
"""Content encoding error."""
class TransferEncodingError(PayloadEncodingError):
"""transfer encoding error."""
class ContentLengthError(PayloadEncodingError):
"""Not enough data for satisfy content length header."""
class LineTooLong(BadHttpMessage):
def __init__(
self, line: str, limit: str = "Unknown", actual_size: str = "Unknown"
) -> None:
super().__init__(
f"Got more than {limit} bytes ({actual_size}) when reading {line}."
)
self.args = (line, limit, actual_size)
class InvalidHeader(BadHttpMessage):
def __init__(self, hdr: Union[bytes, str]) -> None:
hdr_s = hdr.decode(errors="backslashreplace") if isinstance(hdr, bytes) else hdr
super().__init__(f"Invalid HTTP header: {hdr!r}")
self.hdr = hdr_s
self.args = (hdr,)
class BadStatusLine(BadHttpMessage):
def __init__(self, line: str = "", error: Optional[str] = None) -> None:
if not isinstance(line, str):
line = repr(line)
super().__init__(error or f"Bad status line {line!r}")
self.args = (line,)
self.line = line
class BadHttpMethod(BadStatusLine):
"""Invalid HTTP method in status line."""
def __init__(self, line: str = "", error: Optional[str] = None) -> None:
super().__init__(line, error or f"Bad HTTP method in status line {line!r}")
class InvalidURLError(BadHttpMessage):
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,36 @@
"""WebSocket protocol versions 13 and 8."""
from ._websocket.helpers import WS_KEY, ws_ext_gen, ws_ext_parse
from ._websocket.models import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WebSocketError,
WSCloseCode,
WSHandshakeError,
WSMessage,
WSMsgType,
)
from ._websocket.reader import WebSocketReader
from ._websocket.writer import WebSocketWriter
# Messages that the WebSocketResponse.receive needs to handle internally
_INTERNAL_RECEIVE_TYPES = frozenset(
(WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.PING, WSMsgType.PONG)
)
__all__ = (
"WS_CLOSED_MESSAGE",
"WS_CLOSING_MESSAGE",
"WS_KEY",
"WebSocketReader",
"WebSocketWriter",
"WSMessage",
"WebSocketError",
"WSMsgType",
"WSCloseCode",
"ws_ext_gen",
"ws_ext_parse",
"WSHandshakeError",
"WSMessage",
)

View File

@@ -0,0 +1,249 @@
"""Http related parsers and protocol."""
import asyncio
import sys
import zlib
from typing import ( # noqa
Any,
Awaitable,
Callable,
Iterable,
List,
NamedTuple,
Optional,
Union,
)
from multidict import CIMultiDict
from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor
from .helpers import NO_EXTENSIONS
__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
MIN_PAYLOAD_FOR_WRITELINES = 2048
IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
# writelines is not safe for use
# on Python 3.12+ until 3.12.9
# on Python 3.13+ until 3.13.2
# and on older versions it not any faster than write
# CVE-2024-12254: https://github.com/python/cpython/pull/127656
class HttpVersion(NamedTuple):
major: int
minor: int
HttpVersion10 = HttpVersion(1, 0)
HttpVersion11 = HttpVersion(1, 1)
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]
class StreamWriter(AbstractStreamWriter):
length: Optional[int] = None
chunked: bool = False
_eof: bool = False
_compress: Optional[ZLibCompressor] = None
def __init__(
self,
protocol: BaseProtocol,
loop: asyncio.AbstractEventLoop,
on_chunk_sent: _T_OnChunkSent = None,
on_headers_sent: _T_OnHeadersSent = None,
) -> None:
self._protocol = protocol
self.loop = loop
self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
self._on_headers_sent: _T_OnHeadersSent = on_headers_sent
@property
def transport(self) -> Optional[asyncio.Transport]:
return self._protocol.transport
@property
def protocol(self) -> BaseProtocol:
return self._protocol
def enable_chunking(self) -> None:
self.chunked = True
def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
) -> None:
self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
def _write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
size = len(chunk)
self.buffer_size += size
self.output_size += size
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(chunk)
def _writelines(self, chunks: Iterable[bytes]) -> None:
size = 0
for chunk in chunks:
size += len(chunk)
self.buffer_size += size
self.output_size += size
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
transport.write(b"".join(chunks))
else:
transport.writelines(chunks)
async def write(
self,
chunk: Union[bytes, bytearray, memoryview],
*,
drain: bool = True,
LIMIT: int = 0x10000,
) -> None:
"""Writes chunk of data to a stream.
write_eof() indicates end of stream.
writer can't be used after write_eof() method being called.
write() return drain future.
"""
if self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)
if isinstance(chunk, memoryview):
if chunk.nbytes != len(chunk):
# just reshape it
chunk = chunk.cast("c")
if self._compress is not None:
chunk = await self._compress.compress(chunk)
if not chunk:
return
if self.length is not None:
chunk_len = len(chunk)
if self.length >= chunk_len:
self.length = self.length - chunk_len
else:
chunk = chunk[: self.length]
self.length = 0
if not chunk:
return
if chunk:
if self.chunked:
self._writelines(
(f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n")
)
else:
self._write(chunk)
if self.buffer_size > LIMIT and drain:
self.buffer_size = 0
await self.drain()
async def write_headers(
self, status_line: str, headers: "CIMultiDict[str]"
) -> None:
"""Write request/response status and headers."""
if self._on_headers_sent is not None:
await self._on_headers_sent(headers)
# status + headers
buf = _serialize_headers(status_line, headers)
self._write(buf)
def set_eof(self) -> None:
"""Indicate that the message is complete."""
self._eof = True
async def write_eof(self, chunk: bytes = b"") -> None:
if self._eof:
return
if chunk and self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)
if self._compress:
chunks: List[bytes] = []
chunks_len = 0
if chunk and (compressed_chunk := await self._compress.compress(chunk)):
chunks_len = len(compressed_chunk)
chunks.append(compressed_chunk)
flush_chunk = self._compress.flush()
chunks_len += len(flush_chunk)
chunks.append(flush_chunk)
assert chunks_len
if self.chunked:
chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n"))
elif len(chunks) > 1:
self._writelines(chunks)
else:
self._write(chunks[0])
elif self.chunked:
if chunk:
chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
self._writelines((chunk_len_pre, chunk, b"\r\n0\r\n\r\n"))
else:
self._write(b"0\r\n\r\n")
elif chunk:
self._write(chunk)
await self.drain()
self._eof = True
async def drain(self) -> None:
"""Flush the write buffer.
The intended use is to write
await w.write(data)
await w.drain()
"""
protocol = self._protocol
if protocol.transport is not None and protocol._paused:
await protocol._drain_helper()
def _safe_header(string: str) -> str:
if "\r" in string or "\n" in string:
raise ValueError(
"Newline or carriage return detected in headers. "
"Potential header injection attack."
)
return string
def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
return line.encode("utf-8")
_serialize_headers = _py_serialize_headers
try:
import aiohttp._http_writer as _http_writer # type: ignore[import-not-found]
_c_serialize_headers = _http_writer._serialize_headers
if not NO_EXTENSIONS:
_serialize_headers = _c_serialize_headers
except ImportError:
pass

View File

@@ -0,0 +1,8 @@
import logging
access_logger = logging.getLogger("aiohttp.access")
client_logger = logging.getLogger("aiohttp.client")
internal_logger = logging.getLogger("aiohttp.internal")
server_logger = logging.getLogger("aiohttp.server")
web_logger = logging.getLogger("aiohttp.web")
ws_logger = logging.getLogger("aiohttp.websocket")

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More