feat: WebSocket service for event and connection management

This commit is contained in:
Davidson Gomes 2025-03-06 09:22:02 -03:00
parent 9c22621876
commit ad3d0de564
184 changed files with 29850 additions and 8 deletions

217
README.md
View File

@ -289,4 +289,219 @@ config = HandleLabel(
) )
response = client.label.handle_label(instance_id, config, instance_token) response = client.label.handle_label(instance_id, config, instance_token)
``` ```
## WebSocket
O cliente Evolution API suporta conexão via WebSocket para receber eventos em tempo real. Aqui está um guia de como usar:
### Configuração Básica
```python
from evolutionapi.services.websocket import WebSocketManager
import logging
# Configuração do logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Configuração do WebSocket
websocket = WebSocketManager(
base_url="http://localhost:8081", # URL da sua Evolution API
instance_id="teste", # ID da sua instância
api_token="seu-token-aqui" # Token de autenticação
)
```
### Registrando Handlers de Eventos
Você pode registrar handlers para diferentes tipos de eventos:
```python
def handle_message(data):
print(f"Nova mensagem recebida: {data}")
def handle_qrcode(data):
print(f"QR Code atualizado: {data}")
# Registrando handlers
websocket.on("messages.upsert", handle_message)
websocket.on("qrcode.updated", handle_qrcode)
```
### Eventos Disponíveis
Os eventos disponíveis são:
#### Eventos de Instância
- `application.startup`: Disparado quando a aplicação inicia
- `instance.create`: Disparado quando uma nova instância é criada
- `instance.delete`: Disparado quando uma instância é deletada
- `remove.instance`: Disparado quando uma instância é removida
- `logout.instance`: Disparado quando uma instância faz logout
#### Eventos de Conexão e QR Code
- `qrcode.updated`: Disparado quando o QR Code é atualizado
- `connection.update`: Disparado quando o status da conexão muda
- `status.instance`: Disparado quando o status da instância muda
- `creds.update`: Disparado quando as credenciais são atualizadas
#### Eventos de Mensagens
- `messages.set`: Disparado quando mensagens são definidas
- `messages.upsert`: Disparado quando novas mensagens são recebidas
- `messages.edited`: Disparado quando mensagens são editadas
- `messages.update`: Disparado quando mensagens são atualizadas
- `messages.delete`: Disparado quando mensagens são deletadas
- `send.message`: Disparado quando uma mensagem é enviada
- `messaging-history.set`: Disparado quando o histórico de mensagens é definido
#### Eventos de Contatos
- `contacts.set`: Disparado quando contatos são definidos
- `contacts.upsert`: Disparado quando novos contatos são adicionados
- `contacts.update`: Disparado quando contatos são atualizados
#### Eventos de Chats
- `chats.set`: Disparado quando chats são definidos
- `chats.update`: Disparado quando chats são atualizados
- `chats.upsert`: Disparado quando novos chats são adicionados
- `chats.delete`: Disparado quando chats são deletados
#### Eventos de Grupos
- `groups.upsert`: Disparado quando grupos são criados/atualizados
- `groups.update`: Disparado quando grupos são atualizados
- `group-participants.update`: Disparado quando participantes de um grupo são atualizados
#### Eventos de Presença
- `presence.update`: Disparado quando o status de presença é atualizado
#### Eventos de Chamadas
- `call`: Disparado quando há uma chamada
#### Eventos de Typebot
- `typebot.start`: Disparado quando um typebot inicia
- `typebot.change-status`: Disparado quando o status do typebot muda
#### Eventos de Labels
- `labels.edit`: Disparado quando labels são editados
- `labels.association`: Disparado quando labels são associados/desassociados
### Exemplo de Uso com Eventos Específicos
```python
def handle_messages(data):
logger.info(f"Nova mensagem: {data}")
def handle_contacts(data):
logger.info(f"Contatos atualizados: {data}")
def handle_groups(data):
logger.info(f"Grupos atualizados: {data}")
def handle_presence(data):
logger.info(f"Status de presença: {data}")
# Registrando handlers para diferentes eventos
websocket.on("messages.upsert", handle_messages)
websocket.on("contacts.upsert", handle_contacts)
websocket.on("groups.upsert", handle_groups)
websocket.on("presence.update", handle_presence)
```
### Exemplo Completo
```python
from evolutionapi.services.websocket import WebSocketManager
import logging
import time
# Configuração do logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def handle_message(data):
logger.info(f"Nova mensagem recebida: {data}")
def handle_qrcode(data):
logger.info(f"QR Code atualizado: {data}")
def handle_connection(data):
logger.info(f"Status da conexão: {data}")
def main():
# Inicializa o WebSocket
websocket = WebSocketManager(
base_url="http://localhost:8081",
instance_id="teste",
api_token="seu-token-aqui"
)
# Registra os handlers
websocket.on("messages.upsert", handle_message)
websocket.on("qrcode.updated", handle_qrcode)
websocket.on("connection.update", handle_connection)
try:
# Conecta ao WebSocket
websocket.connect()
logger.info("Conectado ao WebSocket. Aguardando eventos...")
# Mantém o programa rodando
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("Encerrando conexão...")
websocket.disconnect()
except Exception as e:
logger.error(f"Erro: {e}")
websocket.disconnect()
if __name__ == "__main__":
main()
### Recursos Adicionais
#### Reconexão Automática
O WebSocket Manager possui reconexão automática com backoff exponencial:
```python
websocket = WebSocketManager(
base_url="http://localhost:8081",
instance_id="teste",
api_token="seu-token-aqui",
max_retries=5, # Número máximo de tentativas de reconexão
retry_delay=1.0 # Delay inicial entre tentativas em segundos
)
```
#### Logging
O WebSocket Manager utiliza o sistema de logging do Python. Você pode ajustar o nível de log conforme necessário:
```python
# Para ver mais detalhes
logging.getLogger("evolutionapi.services.websocket").setLevel(logging.DEBUG)
```
### Tratamento de Erros
O WebSocket Manager possui tratamento de erros robusto:
- Reconexão automática em caso de desconexão
- Logs detalhados de erros
- Tratamento de eventos inválidos
- Validação de dados recebidos
### Dicas de Uso
1. Sempre use try/except ao conectar ao WebSocket
2. Implemente handlers para todos os eventos que você precisa monitorar
3. Use logging para debug e monitoramento
4. Considere implementar um mecanismo de heartbeat se necessário
5. Mantenha o token de API seguro e não o exponha em logs

8
env/bin/wsdump vendored Executable file
View File

@ -0,0 +1,8 @@
#!/home/davidson/Projects/evolution_client/python/env/bin/python
# -*- coding: utf-8 -*-
import re
import sys
from websocket._wsdump import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

View File

@ -0,0 +1 @@
pip

View File

@ -0,0 +1,376 @@
Mozilla Public License Version 2.0
==================================
Copyright 2009-2024 Joshua Bronson. All rights reserved.
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.

View File

@ -0,0 +1,260 @@
Metadata-Version: 2.1
Name: bidict
Version: 0.23.1
Summary: The bidirectional mapping library for Python.
Author-email: Joshua Bronson <jabronson@gmail.com>
License: MPL 2.0
Project-URL: Changelog, https://bidict.readthedocs.io/changelog.html
Project-URL: Documentation, https://bidict.readthedocs.io
Project-URL: Funding, https://bidict.readthedocs.io/#sponsoring
Project-URL: Repository, https://github.com/jab/bidict
Keywords: bidict,bimap,bidirectional,dict,dictionary,mapping,collections
Classifier: License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: Implementation :: CPython
Classifier: Programming Language :: Python :: Implementation :: PyPy
Classifier: Typing :: Typed
Requires-Python: >=3.8
Description-Content-Type: text/x-rst
License-File: LICENSE
.. role:: doc
.. (Forward declaration for the "doc" role that Sphinx defines for interop with renderers that
are often used to show this doc and that are unaware of Sphinx (GitHub.com, PyPI.org, etc.).
Use :doc: rather than :ref: here for better interop as well.)
bidict
======
*The bidirectional mapping library for Python.*
Status
------
.. image:: https://img.shields.io/pypi/v/bidict.svg
:target: https://pypi.org/project/bidict
:alt: Latest release
.. image:: https://img.shields.io/readthedocs/bidict/main.svg
:target: https://bidict.readthedocs.io/en/main/
:alt: Documentation
.. image:: https://github.com/jab/bidict/actions/workflows/test.yml/badge.svg
:target: https://github.com/jab/bidict/actions/workflows/test.yml?query=branch%3Amain
:alt: GitHub Actions CI status
.. image:: https://img.shields.io/pypi/l/bidict.svg
:target: https://raw.githubusercontent.com/jab/bidict/main/LICENSE
:alt: License
.. image:: https://static.pepy.tech/badge/bidict
:target: https://pepy.tech/project/bidict
:alt: PyPI Downloads
.. image:: https://img.shields.io/badge/GitHub-sponsor-ff69b4
:target: https://github.com/sponsors/jab
:alt: Sponsor
Features
--------
- Mature: Depended on by
Google, Venmo, CERN, Baidu, Tencent,
and teams across the world since 2009
- Familiar, Pythonic APIs
that are carefully designed for
safety, simplicity, flexibility, and ergonomics
- Lightweight, with no runtime dependencies
outside Python's standard library
- Implemented in
concise, well-factored, fully type-hinted Python code
that is optimized for running efficiently
as well as for long-term maintenance and stability
(as well as `joy <#learning-from-bidict>`__)
- Extensively `documented <https://bidict.readthedocs.io>`__
- 100% test coverage
running continuously across all supported Python versions
(including property-based tests and benchmarks)
Installation
------------
``pip install bidict``
Quick Start
-----------
.. code:: python
>>> from bidict import bidict
>>> element_by_symbol = bidict({'H': 'hydrogen'})
>>> element_by_symbol['H']
'hydrogen'
>>> element_by_symbol.inverse['hydrogen']
'H'
For more usage documentation,
head to the :doc:`intro` [#fn-intro]_
and proceed from there.
Enterprise Support
------------------
Enterprise-level support for bidict can be obtained via the
`Tidelift subscription <https://tidelift.com/subscription/pkg/pypi-bidict?utm_source=pypi-bidict&utm_medium=referral&utm_campaign=readme>`__
or by `contacting me directly <mailto:jabronson@gmail.com>`__.
I have a US-based LLC set up for invoicing,
and I have 15+ years of professional experience
delivering software and support to companies successfully.
You can also sponsor my work through several platforms, including GitHub Sponsors.
See the `Sponsoring <#sponsoring>`__ section below for details,
including rationale and examples of companies
supporting the open source projects they depend on.
Voluntary Community Support
---------------------------
Please search through already-asked questions and answers
in `GitHub Discussions <https://github.com/jab/bidict/discussions>`__
and the `issue tracker <https://github.com/jab/bidict/issues?q=is%3Aissue>`__
in case your question has already been addressed.
Otherwise, please feel free to
`start a new discussion <https://github.com/jab/bidict/discussions>`__
or `create a new issue <https://github.com/jab/bidict/issues/new>`__ on GitHub
for voluntary community support.
Notice of Usage
---------------
If you use bidict,
and especially if your usage or your organization is significant in some way,
please let me know in any of the following ways:
- `star bidict on GitHub <https://github.com/jab/bidict>`__
- post in `GitHub Discussions <https://github.com/jab/bidict/discussions>`__
- `email me <mailto:jabronson@gmail.com>`__
Changelog
---------
For bidict release notes, see the :doc:`changelog`. [#fn-changelog]_
Release Notifications
---------------------
.. duplicated in CHANGELOG.rst:
(Would use `.. include::` but GitHub's renderer doesn't support it.)
Watch `bidict releases on GitHub <https://github.com/jab/bidict/releases>`__
to be notified when new versions of bidict are published.
Click the "Watch" dropdown, choose "Custom", and then choose "Releases".
Learning from bidict
--------------------
One of the best things about bidict
is that it touches a surprising number of
interesting Python corners,
especially given its small size and scope.
Check out :doc:`learning-from-bidict` [#fn-learning]_
if you're interested in learning more.
Contributing
------------
I have been bidict's sole maintainer
and `active contributor <https://github.com/jab/bidict/graphs/contributors>`__
since I started the project ~15 years ago.
Your help would be most welcome!
See the :doc:`contributors-guide` [#fn-contributing]_
for more information.
Sponsoring
----------
.. duplicated in CONTRIBUTING.rst
(Would use `.. include::` but GitHub's renderer doesn't support it.)
.. image:: https://img.shields.io/badge/GitHub-sponsor-ff69b4
:target: https://github.com/sponsors/jab
:alt: Sponsor through GitHub
Bidict is the product of thousands of hours of my unpaid work
over the 15+ years that I've been the sole maintainer.
If bidict has helped you or your company accomplish your work,
please sponsor my work through one of the following,
and/or ask your company to do the same:
- `GitHub <https://github.com/sponsors/jab>`__
- `PayPal <https://www.paypal.com/cgi-bin/webscr?cmd=_xclick&business=jabronson%40gmail%2ecom&lc=US&item_name=Sponsor%20bidict>`__
- `Tidelift <https://tidelift.com>`__
- `thanks.dev <https://thanks.dev>`__
- `Gumroad <https://gumroad.com/l/bidict>`__
- `a support engagement with my LLC <#enterprise-support>`__
If you're not sure which to use, GitHub is an easy option,
especially if you already have a GitHub account.
Just choose a monthly or one-time amount, and GitHub handles everything else.
Your bidict sponsorship on GitHub will automatically go
on the same regular bill as any other GitHub charges you pay for.
PayPal is another easy option for one-time contributions.
See the following for rationale and examples of companies
supporting the open source projects they depend on
in this manner:
- `<https://engineering.atspotify.com/2022/04/announcing-the-spotify-foss-fund/>`__
- `<https://blog.sentry.io/2021/10/21/we-just-gave-154-999-dollars-and-89-cents-to-open-source-maintainers>`__
- `<https://engineering.indeedblog.com/blog/2019/07/foss-fund-six-months-in/>`__
.. - `<https://sethmlarson.dev/blog/people-in-your-software-supply-chain>`__
.. - `<https://www.cognitect.com/blog/supporting-open-source-developers>`__
.. - `<https://vorpus.org/blog/the-unreasonable-effectiveness-of-investment-in-open-source-infrastructure/>`__
Finding Documentation
---------------------
If you're viewing this on `<https://bidict.readthedocs.io>`__,
note that multiple versions of the documentation are available,
and you can choose a different version using the popup menu at the bottom-right.
Please make sure you're viewing the version of the documentation
that corresponds to the version of bidict you'd like to use.
If you're viewing this on GitHub, PyPI, or some other place
that can't render and link this documentation properly
and are seeing broken links,
try these alternate links instead:
.. [#fn-intro] `<https://bidict.readthedocs.io/intro.html>`__ | `<docs/intro.rst>`__
.. [#fn-changelog] `<https://bidict.readthedocs.io/changelog.html>`__ | `<CHANGELOG.rst>`__
.. [#fn-learning] `<https://bidict.readthedocs.io/learning-from-bidict.html>`__ | `<docs/learning-from-bidict.rst>`__
.. [#fn-contributing] `<https://bidict.readthedocs.io/contributors-guide.html>`__ | `<CONTRIBUTING.rst>`__

View File

@ -0,0 +1,31 @@
bidict-0.23.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
bidict-0.23.1.dist-info/LICENSE,sha256=8_U63OyqSNc6ZuI4-lupBstBh2eDtF0ooTRrMULuvZo,16784
bidict-0.23.1.dist-info/METADATA,sha256=2ovIRm6Df8gdwAMekGqkeBSF5TWj2mv1jpmh4W4ks7o,8704
bidict-0.23.1.dist-info/RECORD,,
bidict-0.23.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
bidict-0.23.1.dist-info/top_level.txt,sha256=WuQO02jp0ODioS7sJoaHg3JJ5_3h6Sxo9RITvNGPYmc,7
bidict/__init__.py,sha256=pL87KsrDpBsl3AG09LQk1t1TSFt0hIJVYa2POMdErN8,4398
bidict/__pycache__/__init__.cpython-310.pyc,,
bidict/__pycache__/_abc.cpython-310.pyc,,
bidict/__pycache__/_base.cpython-310.pyc,,
bidict/__pycache__/_bidict.cpython-310.pyc,,
bidict/__pycache__/_dup.cpython-310.pyc,,
bidict/__pycache__/_exc.cpython-310.pyc,,
bidict/__pycache__/_frozen.cpython-310.pyc,,
bidict/__pycache__/_iter.cpython-310.pyc,,
bidict/__pycache__/_orderedbase.cpython-310.pyc,,
bidict/__pycache__/_orderedbidict.cpython-310.pyc,,
bidict/__pycache__/_typing.cpython-310.pyc,,
bidict/__pycache__/metadata.cpython-310.pyc,,
bidict/_abc.py,sha256=SMCNdCsmqSWg0OGnMZtnnXY8edjXcyZup5tva4HBm_c,3172
bidict/_base.py,sha256=YiauA0aj52fNB6cfZ4gBt6OV-CRQoZm7WVhuw1nT-Cg,24439
bidict/_bidict.py,sha256=Sr-RoEzWOaxpnDRbDJ7ngaGRIsyGnqZgzvR-NyT4jl4,6923
bidict/_dup.py,sha256=YAn5gWA6lwMBA5A6ebVF19UTZyambGS8WxmbK4TN1Ww,2079
bidict/_exc.py,sha256=HnD_WgteI5PrXa3zBx9RUiGlgnZTO6CF4nIU9p3-njk,1066
bidict/_frozen.py,sha256=p4TaRHKeyTs0KmlpwSnZiTlN_CR4J97kAgBpNdZHQMs,1771
bidict/_iter.py,sha256=zVUx-hJ1M4YuJROoFWRjPKlcaFnyo1AAuRpOaKAFhOQ,1530
bidict/_orderedbase.py,sha256=M7v5rHa7vrym9Z3DxQBFQDxjnrr39Z8p26V0c1PggoE,8942
bidict/_orderedbidict.py,sha256=pPnmC19mIISrj8_yjnb-4r_ti1B74tD5eTd08DETNuI,7080
bidict/_typing.py,sha256=AylMZpBhEFTQegfziPSxfKkKLk7oUsH6o3awDIg2z_k,1289
bidict/metadata.py,sha256=BMIKu6fBY_OKeV_q48EpumE7MdmFw8rFcdaUz8kcIYk,573
bidict/py.typed,sha256=RJao5SVFYIp8IfbxhL_SpZkBQYe3XXzPlobSRdh4B_c,16

View File

@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.42.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@ -0,0 +1 @@
bidict

View File

@ -0,0 +1,103 @@
# Copyright 2009-2024 Joshua Bronson. All rights reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# ============================================================================
# * Welcome to the bidict source code *
# ============================================================================
# Reading through the code? You'll find a "Code review nav" comment like the one
# below at the top and bottom of the key source files. Follow these cues to take
# a path through the code that's optimized for familiarizing yourself with it.
#
# If you're not reading this on https://github.com/jab/bidict already, go there
# to ensure you have the latest version of the code. While there, you can also
# star the project, watch it for updates, fork the code, and submit an issue or
# pull request with any proposed changes. More information can be found linked
# from README.rst, which is also shown on https://github.com/jab/bidict.
# * Code review nav *
# ============================================================================
# Current: __init__.py Next: _abc.py →
# ============================================================================
"""The bidirectional mapping library for Python.
----
bidict by example:
.. code-block:: python
>>> from bidict import bidict
>>> element_by_symbol = bidict({'H': 'hydrogen'})
>>> element_by_symbol['H']
'hydrogen'
>>> element_by_symbol.inverse['hydrogen']
'H'
Please see https://github.com/jab/bidict for the most up-to-date code and
https://bidict.readthedocs.io for the most up-to-date documentation
if you are reading this elsewhere.
----
.. :copyright: (c) 2009-2024 Joshua Bronson.
.. :license: MPLv2. See LICENSE for details.
"""
# Use private aliases to not re-export these publicly (for Sphinx automodule with imported-members).
from __future__ import annotations as _annotations
from contextlib import suppress as _suppress
from ._abc import BidirectionalMapping as BidirectionalMapping
from ._abc import MutableBidirectionalMapping as MutableBidirectionalMapping
from ._base import BidictBase as BidictBase
from ._base import BidictKeysView as BidictKeysView
from ._base import GeneratedBidictInverse as GeneratedBidictInverse
from ._bidict import MutableBidict as MutableBidict
from ._bidict import bidict as bidict
from ._dup import DROP_NEW as DROP_NEW
from ._dup import DROP_OLD as DROP_OLD
from ._dup import ON_DUP_DEFAULT as ON_DUP_DEFAULT
from ._dup import ON_DUP_DROP_OLD as ON_DUP_DROP_OLD
from ._dup import ON_DUP_RAISE as ON_DUP_RAISE
from ._dup import RAISE as RAISE
from ._dup import OnDup as OnDup
from ._dup import OnDupAction as OnDupAction
from ._exc import BidictException as BidictException
from ._exc import DuplicationError as DuplicationError
from ._exc import KeyAndValueDuplicationError as KeyAndValueDuplicationError
from ._exc import KeyDuplicationError as KeyDuplicationError
from ._exc import ValueDuplicationError as ValueDuplicationError
from ._frozen import frozenbidict as frozenbidict
from ._iter import inverted as inverted
from ._orderedbase import OrderedBidictBase as OrderedBidictBase
from ._orderedbidict import OrderedBidict as OrderedBidict
from .metadata import __author__ as __author__
from .metadata import __copyright__ as __copyright__
from .metadata import __description__ as __description__
from .metadata import __license__ as __license__
from .metadata import __url__ as __url__
from .metadata import __version__ as __version__
# Set __module__ of re-exported classes to the 'bidict' top-level module, so that e.g.
# 'bidict.bidict' shows up as 'bidict.bidict` rather than 'bidict._bidict.bidict'.
for _obj in tuple(locals().values()): # pragma: no cover
if not getattr(_obj, '__module__', '').startswith('bidict.'):
continue
with _suppress(AttributeError):
_obj.__module__ = 'bidict'
# * Code review nav *
# ============================================================================
# Current: __init__.py Next: _abc.py →
# ============================================================================

View File

@ -0,0 +1,79 @@
# Copyright 2009-2024 Joshua Bronson. All rights reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# * Code review nav *
# (see comments in __init__.py)
# ============================================================================
# ← Prev: __init__.py Current: _abc.py Next: _base.py →
# ============================================================================
"""Provide the :class:`BidirectionalMapping` abstract base class."""
from __future__ import annotations
import typing as t
from abc import abstractmethod
from ._typing import KT
from ._typing import VT
class BidirectionalMapping(t.Mapping[KT, VT]):
"""Abstract base class for bidirectional mapping types.
Extends :class:`collections.abc.Mapping` primarily by adding the
(abstract) :attr:`inverse` property,
which implementers of :class:`BidirectionalMapping`
should override to return a reference to the inverse
:class:`BidirectionalMapping` instance.
"""
__slots__ = ()
@property
@abstractmethod
def inverse(self) -> BidirectionalMapping[VT, KT]:
"""The inverse of this bidirectional mapping instance.
*See also* :attr:`bidict.BidictBase.inverse`, :attr:`bidict.BidictBase.inv`
:raises NotImplementedError: Meant to be overridden in subclasses.
"""
# The @abstractmethod decorator prevents subclasses from being instantiated unless they
# override this method. But an overriding implementation may merely return super().inverse,
# in which case this implementation is used. Raise NotImplementedError to indicate that
# subclasses must actually provide their own implementation.
raise NotImplementedError
def __inverted__(self) -> t.Iterator[tuple[VT, KT]]:
"""Get an iterator over the items in :attr:`inverse`.
This is functionally equivalent to iterating over the items in the
forward mapping and inverting each one on the fly, but this provides a
more efficient implementation: Assuming the already-inverted items
are stored in :attr:`inverse`, just return an iterator over them directly.
Providing this default implementation enables external functions,
particularly :func:`~bidict.inverted`, to use this optimized
implementation when available, instead of having to invert on the fly.
*See also* :func:`bidict.inverted`
"""
return iter(self.inverse.items())
class MutableBidirectionalMapping(BidirectionalMapping[KT, VT], t.MutableMapping[KT, VT]):
"""Abstract base class for mutable bidirectional mapping types."""
__slots__ = ()
# * Code review nav *
# ============================================================================
# ← Prev: __init__.py Current: _abc.py Next: _base.py →
# ============================================================================

View File

@ -0,0 +1,556 @@
# Copyright 2009-2024 Joshua Bronson. All rights reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# * Code review nav *
# (see comments in __init__.py)
# ============================================================================
# ← Prev: _abc.py Current: _base.py Next: _frozen.py →
# ============================================================================
"""Provide :class:`BidictBase`."""
from __future__ import annotations
import typing as t
import weakref
from itertools import starmap
from operator import eq
from types import MappingProxyType
from ._abc import BidirectionalMapping
from ._dup import DROP_NEW
from ._dup import DROP_OLD
from ._dup import ON_DUP_DEFAULT
from ._dup import RAISE
from ._dup import OnDup
from ._exc import DuplicationError
from ._exc import KeyAndValueDuplicationError
from ._exc import KeyDuplicationError
from ._exc import ValueDuplicationError
from ._iter import inverted
from ._iter import iteritems
from ._typing import KT
from ._typing import MISSING
from ._typing import OKT
from ._typing import OVT
from ._typing import VT
from ._typing import Maplike
from ._typing import MapOrItems
OldKV = t.Tuple[OKT[KT], OVT[VT]]
DedupResult = t.Optional[OldKV[KT, VT]]
Unwrites = t.List[t.Tuple[t.Any, ...]]
BT = t.TypeVar('BT', bound='BidictBase[t.Any, t.Any]')
class BidictKeysView(t.KeysView[KT], t.ValuesView[KT]):
"""Since the keys of a bidict are the values of its inverse (and vice versa),
the :class:`~collections.abc.ValuesView` result of calling *bi.values()*
is also a :class:`~collections.abc.KeysView` of *bi.inverse*.
"""
class BidictBase(BidirectionalMapping[KT, VT]):
"""Base class implementing :class:`BidirectionalMapping`."""
#: The default :class:`~bidict.OnDup`
#: that governs behavior when a provided item
#: duplicates the key or value of other item(s).
#:
#: *See also*
#: :ref:`basic-usage:Values Must Be Unique` (https://bidict.rtfd.io/basic-usage.html#values-must-be-unique),
#: :doc:`extending` (https://bidict.rtfd.io/extending.html)
on_dup = ON_DUP_DEFAULT
_fwdm: t.MutableMapping[KT, VT] #: the backing forward mapping (*key* → *val*)
_invm: t.MutableMapping[VT, KT] #: the backing inverse mapping (*val* → *key*)
# Use Any rather than KT/VT in the following to avoid "ClassVar cannot contain type variables" errors:
_fwdm_cls: t.ClassVar[type[t.MutableMapping[t.Any, t.Any]]] = dict #: class of the backing forward mapping
_invm_cls: t.ClassVar[type[t.MutableMapping[t.Any, t.Any]]] = dict #: class of the backing inverse mapping
#: The class of the inverse bidict instance.
_inv_cls: t.ClassVar[type[BidictBase[t.Any, t.Any]]]
def __init_subclass__(cls) -> None:
super().__init_subclass__()
cls._init_class()
@classmethod
def _init_class(cls) -> None:
cls._ensure_inv_cls()
cls._set_reversed()
__reversed__: t.ClassVar[t.Any]
@classmethod
def _set_reversed(cls) -> None:
"""Set __reversed__ for subclasses that do not set it explicitly
according to whether backing mappings are reversible.
"""
if cls is not BidictBase:
resolved = cls.__reversed__
overridden = resolved is not BidictBase.__reversed__
if overridden: # E.g. OrderedBidictBase, OrderedBidict
return
backing_reversible = all(issubclass(i, t.Reversible) for i in (cls._fwdm_cls, cls._invm_cls))
cls.__reversed__ = _fwdm_reversed if backing_reversible else None
@classmethod
def _ensure_inv_cls(cls) -> None:
"""Ensure :attr:`_inv_cls` is set, computing it dynamically if necessary.
All subclasses provided in :mod:`bidict` are their own inverse classes,
i.e., their backing forward and inverse mappings are both the same type,
but users may define subclasses where this is not the case.
This method ensures that the inverse class is computed correctly regardless.
See: :ref:`extending:Dynamic Inverse Class Generation`
(https://bidict.rtfd.io/extending.html#dynamic-inverse-class-generation)
"""
# This _ensure_inv_cls() method is (indirectly) corecursive with _make_inv_cls() below
# in the case that we need to dynamically generate the inverse class:
# 1. _ensure_inv_cls() calls cls._make_inv_cls()
# 2. cls._make_inv_cls() calls type(..., (cls, ...), ...) to dynamically generate inv_cls
# 3. Our __init_subclass__ hook (see above) is automatically called on inv_cls
# 4. inv_cls.__init_subclass__() calls inv_cls._ensure_inv_cls()
# 5. inv_cls._ensure_inv_cls() resolves to this implementation
# (inv_cls deliberately does not override this), so we're back where we started.
# But since the _make_inv_cls() call will have set inv_cls.__dict__._inv_cls,
# just check if it's already set before calling _make_inv_cls() to prevent infinite recursion.
if getattr(cls, '__dict__', {}).get('_inv_cls'): # Don't assume cls.__dict__ (e.g. mypyc native class)
return
cls._inv_cls = cls._make_inv_cls()
@classmethod
def _make_inv_cls(cls: type[BT]) -> type[BT]:
diff = cls._inv_cls_dict_diff()
cls_is_own_inv = all(getattr(cls, k, MISSING) == v for (k, v) in diff.items())
if cls_is_own_inv:
return cls
# Suppress auto-calculation of _inv_cls's _inv_cls since we know it already.
# Works with the guard in BidictBase._ensure_inv_cls() to prevent infinite recursion.
diff['_inv_cls'] = cls
inv_cls = type(f'{cls.__name__}Inv', (cls, GeneratedBidictInverse), diff)
inv_cls.__module__ = cls.__module__
return t.cast(t.Type[BT], inv_cls)
@classmethod
def _inv_cls_dict_diff(cls) -> dict[str, t.Any]:
return {
'_fwdm_cls': cls._invm_cls,
'_invm_cls': cls._fwdm_cls,
}
def __init__(self, arg: MapOrItems[KT, VT] = (), /, **kw: VT) -> None:
"""Make a new bidirectional mapping.
The signature behaves like that of :class:`dict`.
ktems passed via positional arg are processed first,
followed by any items passed via keyword argument.
Any duplication encountered along the way
is handled as per :attr:`on_dup`.
"""
self._fwdm = self._fwdm_cls()
self._invm = self._invm_cls()
self._update(arg, kw, rollback=False)
# If Python ever adds support for higher-kinded types, `inverse` could use them, e.g.
# def inverse(self: BT[KT, VT]) -> BT[VT, KT]:
# Ref: https://github.com/python/typing/issues/548#issuecomment-621571821
@property
def inverse(self) -> BidictBase[VT, KT]:
"""The inverse of this bidirectional mapping instance."""
# When `bi.inverse` is called for the first time, this method
# computes the inverse instance, stores it for subsequent use, and then
# returns it. It also stores a reference on `bi.inverse` back to `bi`,
# but uses a weakref to avoid creating a reference cycle. Strong references
# to inverse instances are stored in ._inv, and weak references are stored
# in ._invweak.
# First check if a strong reference is already stored.
inv: BidictBase[VT, KT] | None = getattr(self, '_inv', None)
if inv is not None:
return inv
# Next check if a weak reference is already stored.
invweak = getattr(self, '_invweak', None)
if invweak is not None:
inv = invweak() # Try to resolve a strong reference and return it.
if inv is not None:
return inv
# No luck. Compute the inverse reference and store it for subsequent use.
inv = self._make_inverse()
self._inv: BidictBase[VT, KT] | None = inv
self._invweak: weakref.ReferenceType[BidictBase[VT, KT]] | None = None
# Also store a weak reference back to `instance` on its inverse instance, so that
# the second `.inverse` access in `bi.inverse.inverse` hits the cached weakref.
inv._inv = None
inv._invweak = weakref.ref(self)
# In e.g. `bidict().inverse.inverse`, this design ensures that a strong reference
# back to the original instance is retained before its refcount drops to zero,
# avoiding an unintended potential deallocation.
return inv
def _make_inverse(self) -> BidictBase[VT, KT]:
inv: BidictBase[VT, KT] = self._inv_cls()
inv._fwdm = self._invm
inv._invm = self._fwdm
return inv
@property
def inv(self) -> BidictBase[VT, KT]:
"""Alias for :attr:`inverse`."""
return self.inverse
def __repr__(self) -> str:
"""See :func:`repr`."""
clsname = self.__class__.__name__
items = dict(self.items()) if self else ''
return f'{clsname}({items})'
def values(self) -> BidictKeysView[VT]:
"""A set-like object providing a view on the contained values.
Since the values of a bidict are equivalent to the keys of its inverse,
this method returns a set-like object for this bidict's values
rather than just a collections.abc.ValuesView.
This object supports set operations like union and difference,
and constant- rather than linear-time containment checks,
and is no more expensive to provide than the less capable
collections.abc.ValuesView would be.
See :meth:`keys` for more information.
"""
return t.cast(BidictKeysView[VT], self.inverse.keys())
def keys(self) -> t.KeysView[KT]:
"""A set-like object providing a view on the contained keys.
When *b._fwdm* is a :class:`dict`, *b.keys()* returns a
*dict_keys* object that behaves exactly the same as
*collections.abc.KeysView(b)*, except for
- offering better performance
- being reversible on Python 3.8+
- having a .mapping attribute in Python 3.10+
that exposes a mappingproxy to *b._fwdm*.
"""
fwdm, fwdm_cls = self._fwdm, self._fwdm_cls
return fwdm.keys() if fwdm_cls is dict else BidictKeysView(self)
def items(self) -> t.ItemsView[KT, VT]:
"""A set-like object providing a view on the contained items.
When *b._fwdm* is a :class:`dict`, *b.items()* returns a
*dict_items* object that behaves exactly the same as
*collections.abc.ItemsView(b)*, except for:
- offering better performance
- being reversible on Python 3.8+
- having a .mapping attribute in Python 3.10+
that exposes a mappingproxy to *b._fwdm*.
"""
return self._fwdm.items() if self._fwdm_cls is dict else super().items()
# The inherited collections.abc.Mapping.__contains__() method is implemented by doing a `try`
# `except KeyError` around `self[key]`. The following implementation is much faster,
# especially in the missing case.
def __contains__(self, key: t.Any) -> bool:
"""True if the mapping contains the specified key, else False."""
return key in self._fwdm
# The inherited collections.abc.Mapping.__eq__() method is implemented in terms of an inefficient
# `dict(self.items()) == dict(other.items())` comparison, so override it with a
# more efficient implementation.
def __eq__(self, other: object) -> bool:
"""*x.__eq__(other)  x == other*
Equivalent to *dict(x.items()) == dict(other.items())*
but more efficient.
Note that :meth:`bidict's __eq__() <bidict.BidictBase.__eq__>` implementation
is inherited by subclasses,
in particular by the ordered bidict subclasses,
so even with ordered bidicts,
:ref:`== comparison is order-insensitive <eq-order-insensitive>`
(https://bidict.rtfd.io/other-bidict-types.html#eq-is-order-insensitive).
*See also* :meth:`equals_order_sensitive`
"""
if isinstance(other, t.Mapping):
return self._fwdm.items() == other.items()
# Ref: https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
def equals_order_sensitive(self, other: object) -> bool:
"""Order-sensitive equality check.
*See also* :ref:`eq-order-insensitive`
(https://bidict.rtfd.io/other-bidict-types.html#eq-is-order-insensitive)
"""
if not isinstance(other, t.Mapping) or len(self) != len(other):
return False
return all(starmap(eq, zip(self.items(), other.items())))
def _dedup(self, key: KT, val: VT, on_dup: OnDup) -> DedupResult[KT, VT]:
"""Check *key* and *val* for any duplication in self.
Handle any duplication as per the passed in *on_dup*.
If (key, val) is already present, return None
since writing (key, val) would be a no-op.
If duplication is found and the corresponding :class:`~bidict.OnDupAction` is
:attr:`~bidict.DROP_NEW`, return None.
If duplication is found and the corresponding :class:`~bidict.OnDupAction` is
:attr:`~bidict.RAISE`, raise the appropriate exception.
If duplication is found and the corresponding :class:`~bidict.OnDupAction` is
:attr:`~bidict.DROP_OLD`, or if no duplication is found,
return *(oldkey, oldval)*.
"""
fwdm, invm = self._fwdm, self._invm
oldval: OVT[VT] = fwdm.get(key, MISSING)
oldkey: OKT[KT] = invm.get(val, MISSING)
isdupkey, isdupval = oldval is not MISSING, oldkey is not MISSING
if isdupkey and isdupval:
if key == oldkey:
assert val == oldval
# (key, val) duplicates an existing item -> no-op.
return None
# key and val each duplicate a different existing item.
if on_dup.val is RAISE:
raise KeyAndValueDuplicationError(key, val)
if on_dup.val is DROP_NEW:
return None
assert on_dup.val is DROP_OLD
# Fall through to the return statement on the last line.
elif isdupkey:
if on_dup.key is RAISE:
raise KeyDuplicationError(key)
if on_dup.key is DROP_NEW:
return None
assert on_dup.key is DROP_OLD
# Fall through to the return statement on the last line.
elif isdupval:
if on_dup.val is RAISE:
raise ValueDuplicationError(val)
if on_dup.val is DROP_NEW:
return None
assert on_dup.val is DROP_OLD
# Fall through to the return statement on the last line.
# else neither isdupkey nor isdupval.
return oldkey, oldval
def _write(self, newkey: KT, newval: VT, oldkey: OKT[KT], oldval: OVT[VT], unwrites: Unwrites | None) -> None:
"""Insert (newkey, newval), extending *unwrites* with associated inverse operations if provided.
*oldkey* and *oldval* are as returned by :meth:`_dedup`.
If *unwrites* is not None, it is extended with the inverse operations necessary to undo the write.
This design allows :meth:`_update` to roll back a partially applied update that fails part-way through
when necessary.
This design also allows subclasses that require additional operations to easily extend this implementation.
For example, :class:`bidict.OrderedBidictBase` calls this inherited implementation, and then extends *unwrites*
with additional operations needed to keep its internal linked list nodes consistent with its items' order
as changes are made.
"""
fwdm, invm = self._fwdm, self._invm
fwdm_set, invm_set = fwdm.__setitem__, invm.__setitem__
fwdm_del, invm_del = fwdm.__delitem__, invm.__delitem__
# Always perform the following writes regardless of duplication.
fwdm_set(newkey, newval)
invm_set(newval, newkey)
if oldval is MISSING and oldkey is MISSING: # no key or value duplication
# {0: 1, 2: 3} | {4: 5} => {0: 1, 2: 3, 4: 5}
if unwrites is not None:
unwrites.extend((
(fwdm_del, newkey),
(invm_del, newval),
))
elif oldval is not MISSING and oldkey is not MISSING: # key and value duplication across two different items
# {0: 1, 2: 3} | {0: 3} => {0: 3}
fwdm_del(oldkey)
invm_del(oldval)
if unwrites is not None:
unwrites.extend((
(fwdm_set, newkey, oldval),
(invm_set, oldval, newkey),
(fwdm_set, oldkey, newval),
(invm_set, newval, oldkey),
))
elif oldval is not MISSING: # just key duplication
# {0: 1, 2: 3} | {2: 4} => {0: 1, 2: 4}
invm_del(oldval)
if unwrites is not None:
unwrites.extend((
(fwdm_set, newkey, oldval),
(invm_set, oldval, newkey),
(invm_del, newval),
))
else:
assert oldkey is not MISSING # just value duplication
# {0: 1, 2: 3} | {4: 3} => {0: 1, 4: 3}
fwdm_del(oldkey)
if unwrites is not None:
unwrites.extend((
(fwdm_set, oldkey, newval),
(invm_set, newval, oldkey),
(fwdm_del, newkey),
))
def _update(
self,
arg: MapOrItems[KT, VT],
kw: t.Mapping[str, VT] = MappingProxyType({}),
*,
rollback: bool | None = None,
on_dup: OnDup | None = None,
) -> None:
"""Update with the items from *arg* and *kw*, maybe failing and rolling back as per *on_dup* and *rollback*."""
# Note: We must process input in a single pass, since arg may be a generator.
if not isinstance(arg, (t.Iterable, Maplike)):
raise TypeError(f"'{arg.__class__.__name__}' object is not iterable")
if not arg and not kw:
return
if on_dup is None:
on_dup = self.on_dup
if rollback is None:
rollback = RAISE in on_dup
# Fast path when we're empty and updating only from another bidict (i.e. no dup vals in new items).
if not self and not kw and isinstance(arg, BidictBase):
self._init_from(arg)
return
# Fast path when we're adding more items than we contain already and rollback is enabled:
# Update a copy of self with rollback disabled. Fail if that fails, otherwise become the copy.
if rollback and isinstance(arg, t.Sized) and len(arg) + len(kw) > len(self):
tmp = self.copy()
tmp._update(arg, kw, rollback=False, on_dup=on_dup)
self._init_from(tmp)
return
# In all other cases, benchmarking has indicated that the update is best implemented as follows:
# For each new item, perform a dup check (raising if necessary), and apply the associated writes we need to
# perform on our backing _fwdm and _invm mappings. If rollback is enabled, also compute the associated unwrites
# as we go. If the update results in a DuplicationError and rollback is enabled, apply the accumulated unwrites
# before raising, to ensure that we fail clean.
write = self._write
unwrites: Unwrites | None = [] if rollback else None
for key, val in iteritems(arg, **kw):
try:
dedup_result = self._dedup(key, val, on_dup)
except DuplicationError:
if unwrites is not None:
for fn, *args in reversed(unwrites):
fn(*args)
raise
if dedup_result is not None:
write(key, val, *dedup_result, unwrites=unwrites)
def __copy__(self: BT) -> BT:
"""Used for the copy protocol. See the :mod:`copy` module."""
return self.copy()
def copy(self: BT) -> BT:
"""Make a (shallow) copy of this bidict."""
# Could just `return self.__class__(self)` here, but the below is faster. The former
# would copy this bidict's items into a new instance one at a time (checking for duplication
# for each item), whereas the below copies from the backing mappings all at once, and foregoes
# item-by-item duplication checking since the backing mappings have been checked already.
return self._from_other(self.__class__, self)
@staticmethod
def _from_other(bt: type[BT], other: MapOrItems[KT, VT], inv: bool = False) -> BT:
"""Fast, private constructor based on :meth:`_init_from`.
If *inv* is true, return the inverse of the instance instead of the instance itself.
(Useful for pickling with dynamically-generated inverse classes -- see :meth:`__reduce__`.)
"""
inst = bt()
inst._init_from(other)
return t.cast(BT, inst.inverse) if inv else inst
def _init_from(self, other: MapOrItems[KT, VT]) -> None:
"""Fast init from *other*, bypassing item-by-item duplication checking."""
self._fwdm.clear()
self._invm.clear()
self._fwdm.update(other)
# If other is a bidict, use its existing backing inverse mapping, otherwise
# other could be a generator that's now exhausted, so invert self._fwdm on the fly.
inv = other.inverse if isinstance(other, BidictBase) else inverted(self._fwdm)
self._invm.update(inv)
# other's type is Mapping rather than Maplike since bidict() | SupportsKeysAndGetItem({})
# raises a TypeError, just like dict() | SupportsKeysAndGetItem({}) does.
def __or__(self: BT, other: t.Mapping[KT, VT]) -> BT:
"""Return self|other."""
if not isinstance(other, t.Mapping):
return NotImplemented
new = self.copy()
new._update(other, rollback=False)
return new
def __ror__(self: BT, other: t.Mapping[KT, VT]) -> BT:
"""Return other|self."""
if not isinstance(other, t.Mapping):
return NotImplemented
new = self.__class__(other)
new._update(self, rollback=False)
return new
def __len__(self) -> int:
"""The number of contained items."""
return len(self._fwdm)
def __iter__(self) -> t.Iterator[KT]:
"""Iterator over the contained keys."""
return iter(self._fwdm)
def __getitem__(self, key: KT) -> VT:
"""*x.__getitem__(key) ⟺ x[key]*"""
return self._fwdm[key]
def __reduce__(self) -> tuple[t.Any, ...]:
"""Return state information for pickling."""
cls = self.__class__
inst: t.Mapping[t.Any, t.Any] = self
# If this bidict's class is dynamically generated, pickle the inverse instead, whose (presumably not
# dynamically generated) class the caller is more likely to have a reference to somewhere in sys.modules
# that pickle can discover.
if should_invert := isinstance(self, GeneratedBidictInverse):
cls = self._inv_cls
inst = self.inverse
return self._from_other, (cls, dict(inst), should_invert)
# See BidictBase._set_reversed() above.
def _fwdm_reversed(self: BidictBase[KT, t.Any]) -> t.Iterator[KT]:
"""Iterator over the contained keys in reverse order."""
assert isinstance(self._fwdm, t.Reversible)
return reversed(self._fwdm)
BidictBase._init_class()
class GeneratedBidictInverse:
"""Base class for dynamically-generated inverse bidict classes."""
# * Code review nav *
# ============================================================================
# ← Prev: _abc.py Current: _base.py Next: _frozen.py →
# ============================================================================

View File

@ -0,0 +1,194 @@
# Copyright 2009-2024 Joshua Bronson. All rights reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# * Code review nav *
# (see comments in __init__.py)
# ============================================================================
# ← Prev: _frozen.py Current: _bidict.py Next: _orderedbase.py →
# ============================================================================
"""Provide :class:`MutableBidict` and :class:`bidict`."""
from __future__ import annotations
import typing as t
from ._abc import MutableBidirectionalMapping
from ._base import BidictBase
from ._dup import ON_DUP_DROP_OLD
from ._dup import ON_DUP_RAISE
from ._dup import OnDup
from ._typing import DT
from ._typing import KT
from ._typing import MISSING
from ._typing import ODT
from ._typing import VT
from ._typing import MapOrItems
class MutableBidict(BidictBase[KT, VT], MutableBidirectionalMapping[KT, VT]):
"""Base class for mutable bidirectional mappings."""
if t.TYPE_CHECKING:
@property
def inverse(self) -> MutableBidict[VT, KT]: ...
@property
def inv(self) -> MutableBidict[VT, KT]: ...
def _pop(self, key: KT) -> VT:
val = self._fwdm.pop(key)
del self._invm[val]
return val
def __delitem__(self, key: KT) -> None:
"""*x.__delitem__(y)  del x[y]*"""
self._pop(key)
def __setitem__(self, key: KT, val: VT) -> None:
"""Set the value for *key* to *val*.
If *key* is already associated with *val*, this is a no-op.
If *key* is already associated with a different value,
the old value will be replaced with *val*,
as with dict's :meth:`__setitem__`.
If *val* is already associated with a different key,
an exception is raised
to protect against accidental removal of the key
that's currently associated with *val*.
Use :meth:`put` instead if you want to specify different behavior in
the case that the provided key or value duplicates an existing one.
Or use :meth:`forceput` to unconditionally associate *key* with *val*,
replacing any existing items as necessary to preserve uniqueness.
:raises bidict.ValueDuplicationError: if *val* duplicates that of an
existing item.
:raises bidict.KeyAndValueDuplicationError: if *key* duplicates the key of an
existing item and *val* duplicates the value of a different
existing item.
"""
self.put(key, val, on_dup=self.on_dup)
def put(self, key: KT, val: VT, on_dup: OnDup = ON_DUP_RAISE) -> None:
"""Associate *key* with *val*, honoring the :class:`OnDup` given in *on_dup*.
For example, if *on_dup* is :attr:`~bidict.ON_DUP_RAISE`,
then *key* will be associated with *val* if and only if
*key* is not already associated with an existing value and
*val* is not already associated with an existing key,
otherwise an exception will be raised.
If *key* is already associated with *val*, this is a no-op.
:raises bidict.KeyDuplicationError: if attempting to insert an item
whose key only duplicates an existing item's, and *on_dup.key* is
:attr:`~bidict.RAISE`.
:raises bidict.ValueDuplicationError: if attempting to insert an item
whose value only duplicates an existing item's, and *on_dup.val* is
:attr:`~bidict.RAISE`.
:raises bidict.KeyAndValueDuplicationError: if attempting to insert an
item whose key duplicates one existing item's, and whose value
duplicates another existing item's, and *on_dup.val* is
:attr:`~bidict.RAISE`.
"""
self._update(((key, val),), on_dup=on_dup)
def forceput(self, key: KT, val: VT) -> None:
"""Associate *key* with *val* unconditionally.
Replace any existing mappings containing key *key* or value *val*
as necessary to preserve uniqueness.
"""
self.put(key, val, on_dup=ON_DUP_DROP_OLD)
def clear(self) -> None:
"""Remove all items."""
self._fwdm.clear()
self._invm.clear()
@t.overload
def pop(self, key: KT, /) -> VT: ...
@t.overload
def pop(self, key: KT, default: DT = ..., /) -> VT | DT: ...
def pop(self, key: KT, default: ODT[DT] = MISSING, /) -> VT | DT:
"""*x.pop(k[, d]) → v*
Remove specified key and return the corresponding value.
:raises KeyError: if *key* is not found and no *default* is provided.
"""
try:
return self._pop(key)
except KeyError:
if default is MISSING:
raise
return default
def popitem(self) -> tuple[KT, VT]:
"""*x.popitem() → (k, v)*
Remove and return some item as a (key, value) pair.
:raises KeyError: if *x* is empty.
"""
key, val = self._fwdm.popitem()
del self._invm[val]
return key, val
def update(self, arg: MapOrItems[KT, VT] = (), /, **kw: VT) -> None:
"""Like calling :meth:`putall` with *self.on_dup* passed for *on_dup*."""
self._update(arg, kw=kw)
def forceupdate(self, arg: MapOrItems[KT, VT] = (), /, **kw: VT) -> None:
"""Like a bulk :meth:`forceput`."""
self._update(arg, kw=kw, on_dup=ON_DUP_DROP_OLD)
def putall(self, items: MapOrItems[KT, VT], on_dup: OnDup = ON_DUP_RAISE) -> None:
"""Like a bulk :meth:`put`.
If one of the given items causes an exception to be raised,
none of the items is inserted.
"""
self._update(items, on_dup=on_dup)
# other's type is Mapping rather than Maplike since bidict() |= SupportsKeysAndGetItem({})
# raises a TypeError, just like dict() |= SupportsKeysAndGetItem({}) does.
def __ior__(self, other: t.Mapping[KT, VT]) -> MutableBidict[KT, VT]:
"""Return self|=other."""
self.update(other)
return self
class bidict(MutableBidict[KT, VT]):
"""The main bidirectional mapping type.
See :ref:`intro:Introduction` and :ref:`basic-usage:Basic Usage`
to get started (also available at https://bidict.rtfd.io).
"""
if t.TYPE_CHECKING:
@property
def inverse(self) -> bidict[VT, KT]: ...
@property
def inv(self) -> bidict[VT, KT]: ...
# * Code review nav *
# ============================================================================
# ← Prev: _frozen.py Current: _bidict.py Next: _orderedbase.py →
# ============================================================================

View File

@ -0,0 +1,61 @@
# Copyright 2009-2024 Joshua Bronson. All rights reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Provide :class:`OnDup` and related functionality."""
from __future__ import annotations
import typing as t
from enum import Enum
class OnDupAction(Enum):
"""An action to take to prevent duplication from occurring."""
#: Raise a :class:`~bidict.DuplicationError`.
RAISE = 'RAISE'
#: Overwrite existing items with new items.
DROP_OLD = 'DROP_OLD'
#: Keep existing items and drop new items.
DROP_NEW = 'DROP_NEW'
def __repr__(self) -> str:
return f'{self.__class__.__name__}.{self.name}'
RAISE: t.Final[OnDupAction] = OnDupAction.RAISE
DROP_OLD: t.Final[OnDupAction] = OnDupAction.DROP_OLD
DROP_NEW: t.Final[OnDupAction] = OnDupAction.DROP_NEW
class OnDup(t.NamedTuple):
r"""A combination of :class:`~bidict.OnDupAction`\s specifying how to handle various types of duplication.
The :attr:`~OnDup.key` field specifies what action to take when a duplicate key is encountered.
The :attr:`~OnDup.val` field specifies what action to take when a duplicate value is encountered.
In the case of both key and value duplication across two different items,
only :attr:`~OnDup.val` is used.
*See also* :ref:`basic-usage:Values Must Be Unique`
(https://bidict.rtfd.io/basic-usage.html#values-must-be-unique)
"""
key: OnDupAction = DROP_OLD
val: OnDupAction = RAISE
#: Default :class:`OnDup` used for the
#: :meth:`~bidict.bidict.__init__`,
#: :meth:`~bidict.bidict.__setitem__`, and
#: :meth:`~bidict.bidict.update` methods.
ON_DUP_DEFAULT: t.Final[OnDup] = OnDup(key=DROP_OLD, val=RAISE)
#: An :class:`OnDup` whose members are all :obj:`RAISE`.
ON_DUP_RAISE: t.Final[OnDup] = OnDup(key=RAISE, val=RAISE)
#: An :class:`OnDup` whose members are all :obj:`DROP_OLD`.
ON_DUP_DROP_OLD: t.Final[OnDup] = OnDup(key=DROP_OLD, val=DROP_OLD)

View File

@ -0,0 +1,36 @@
# Copyright 2009-2024 Joshua Bronson. All rights reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Provide all bidict exceptions."""
from __future__ import annotations
class BidictException(Exception):
"""Base class for bidict exceptions."""
class DuplicationError(BidictException):
"""Base class for exceptions raised when uniqueness is violated
as per the :attr:`~bidict.RAISE` :class:`~bidict.OnDupAction`.
"""
class KeyDuplicationError(DuplicationError):
"""Raised when a given key is not unique."""
class ValueDuplicationError(DuplicationError):
"""Raised when a given value is not unique."""
class KeyAndValueDuplicationError(KeyDuplicationError, ValueDuplicationError):
"""Raised when a given item's key and value are not unique.
That is, its key duplicates that of another item,
and its value duplicates that of a different other item.
"""

View File

@ -0,0 +1,50 @@
# Copyright 2009-2024 Joshua Bronson. All rights reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# * Code review nav *
# (see comments in __init__.py)
# ============================================================================
# ← Prev: _base.py Current: _frozen.py Next: _bidict.py →
# ============================================================================
"""Provide :class:`frozenbidict`, an immutable, hashable bidirectional mapping type."""
from __future__ import annotations
import typing as t
from ._base import BidictBase
from ._typing import KT
from ._typing import VT
class frozenbidict(BidictBase[KT, VT]):
"""Immutable, hashable bidict type."""
_hash: int
if t.TYPE_CHECKING:
@property
def inverse(self) -> frozenbidict[VT, KT]: ...
@property
def inv(self) -> frozenbidict[VT, KT]: ...
def __hash__(self) -> int:
"""The hash of this bidict as determined by its items."""
if getattr(self, '_hash', None) is None:
# The following is like hash(frozenset(self.items()))
# but more memory efficient. See also: https://bugs.python.org/issue46684
self._hash = t.ItemsView(self)._hash()
return self._hash
# * Code review nav *
# ============================================================================
# ← Prev: _base.py Current: _frozen.py Next: _bidict.py →
# ============================================================================

View File

@ -0,0 +1,51 @@
# Copyright 2009-2024 Joshua Bronson. All rights reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Functions for iterating over items in a mapping."""
from __future__ import annotations
import typing as t
from operator import itemgetter
from ._typing import KT
from ._typing import VT
from ._typing import ItemsIter
from ._typing import Maplike
from ._typing import MapOrItems
def iteritems(arg: MapOrItems[KT, VT] = (), /, **kw: VT) -> ItemsIter[KT, VT]:
"""Yield the items from *arg* and *kw* in the order given."""
if isinstance(arg, t.Mapping):
yield from arg.items()
elif isinstance(arg, Maplike):
yield from ((k, arg[k]) for k in arg.keys())
else:
yield from arg
yield from t.cast(ItemsIter[KT, VT], kw.items())
swap: t.Final = itemgetter(1, 0)
def inverted(arg: MapOrItems[KT, VT]) -> ItemsIter[VT, KT]:
"""Yield the inverse items of the provided object.
If *arg* has a :func:`callable` ``__inverted__`` attribute,
return the result of calling it.
Otherwise, return an iterator over the items in `arg`,
inverting each item on the fly.
*See also* :attr:`bidict.BidirectionalMapping.__inverted__`
"""
invattr = getattr(arg, '__inverted__', None)
if callable(invattr):
inv: ItemsIter[VT, KT] = invattr()
return inv
return map(swap, iteritems(arg))

View File

@ -0,0 +1,238 @@
# Copyright 2009-2024 Joshua Bronson. All rights reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# * Code review nav *
# (see comments in __init__.py)
# ============================================================================
# ← Prev: _bidict.py Current: _orderedbase.py Next: _orderedbidict.py →
# ============================================================================
"""Provide :class:`OrderedBidictBase`."""
from __future__ import annotations
import typing as t
from weakref import ref as weakref
from ._base import BidictBase
from ._base import Unwrites
from ._bidict import bidict
from ._iter import iteritems
from ._typing import KT
from ._typing import MISSING
from ._typing import OKT
from ._typing import OVT
from ._typing import VT
from ._typing import MapOrItems
AT = t.TypeVar('AT') # attr type
class WeakAttr(t.Generic[AT]):
"""Descriptor to automatically manage (de)referencing the given slot as a weakref.
See https://docs.python.org/3/howto/descriptor.html#managed-attributes
for an intro to using descriptors like this for managed attributes.
"""
def __init__(self, *, slot: str) -> None:
self.slot = slot
def __set__(self, instance: t.Any, value: AT) -> None:
setattr(instance, self.slot, weakref(value))
def __get__(self, instance: t.Any, __owner: t.Any = None) -> AT:
return t.cast(AT, getattr(instance, self.slot)())
class Node:
"""A node in a circular doubly-linked list
used to encode the order of items in an ordered bidict.
A weak reference to the previous node is stored
to avoid creating strong reference cycles.
Referencing/dereferencing the weakref is handled automatically by :class:`WeakAttr`.
"""
prv: WeakAttr[Node] = WeakAttr(slot='_prv_weak')
__slots__ = ('__weakref__', '_prv_weak', 'nxt')
nxt: Node | WeakAttr[Node] # Allow subclasses to use a WeakAttr for nxt too (see SentinelNode)
def __init__(self, prv: Node, nxt: Node) -> None:
self.prv = prv
self.nxt = nxt
def unlink(self) -> None:
"""Remove self from in between prv and nxt.
Self's references to prv and nxt are retained so it can be relinked (see below).
"""
self.prv.nxt = self.nxt
self.nxt.prv = self.prv
def relink(self) -> None:
"""Restore self between prv and nxt after unlinking (see above)."""
self.prv.nxt = self.nxt.prv = self
class SentinelNode(Node):
"""Special node in a circular doubly-linked list
that links the first node with the last node.
When its next and previous references point back to itself
it represents an empty list.
"""
nxt: WeakAttr[Node] = WeakAttr(slot='_nxt_weak')
__slots__ = ('_nxt_weak',)
def __init__(self) -> None:
super().__init__(self, self)
def iternodes(self, *, reverse: bool = False) -> t.Iterator[Node]:
"""Iterator yielding nodes in the requested order."""
attr = 'prv' if reverse else 'nxt'
node = getattr(self, attr)
while node is not self:
yield node
node = getattr(node, attr)
def new_last_node(self) -> Node:
"""Create and return a new terminal node."""
old_last = self.prv
new_last = Node(old_last, self)
old_last.nxt = self.prv = new_last
return new_last
class OrderedBidictBase(BidictBase[KT, VT]):
"""Base class implementing an ordered :class:`BidirectionalMapping`."""
_node_by_korv: bidict[t.Any, Node]
_bykey: bool
def __init__(self, arg: MapOrItems[KT, VT] = (), /, **kw: VT) -> None:
"""Make a new ordered bidirectional mapping.
The signature behaves like that of :class:`dict`.
Items passed in are added in the order they are passed,
respecting the :attr:`~bidict.BidictBase.on_dup`
class attribute in the process.
The order in which items are inserted is remembered,
similar to :class:`collections.OrderedDict`.
"""
self._sntl = SentinelNode()
self._node_by_korv = bidict()
self._bykey = True
super().__init__(arg, **kw)
if t.TYPE_CHECKING:
@property
def inverse(self) -> OrderedBidictBase[VT, KT]: ...
@property
def inv(self) -> OrderedBidictBase[VT, KT]: ...
def _make_inverse(self) -> OrderedBidictBase[VT, KT]:
inv = t.cast(OrderedBidictBase[VT, KT], super()._make_inverse())
inv._sntl = self._sntl
inv._node_by_korv = self._node_by_korv
inv._bykey = not self._bykey
return inv
def _assoc_node(self, node: Node, key: KT, val: VT) -> None:
korv = key if self._bykey else val
self._node_by_korv.forceput(korv, node)
def _dissoc_node(self, node: Node) -> None:
del self._node_by_korv.inverse[node]
node.unlink()
def _init_from(self, other: MapOrItems[KT, VT]) -> None:
"""See :meth:`BidictBase._init_from`."""
super()._init_from(other)
bykey = self._bykey
korv_by_node = self._node_by_korv.inverse
korv_by_node.clear()
korv_by_node_set = korv_by_node.__setitem__
self._sntl.nxt = self._sntl.prv = self._sntl
new_node = self._sntl.new_last_node
for k, v in iteritems(other):
korv_by_node_set(new_node(), k if bykey else v)
def _write(self, newkey: KT, newval: VT, oldkey: OKT[KT], oldval: OVT[VT], unwrites: Unwrites | None) -> None:
"""See :meth:`bidict.BidictBase._spec_write`."""
super()._write(newkey, newval, oldkey, oldval, unwrites)
assoc, dissoc = self._assoc_node, self._dissoc_node
node_by_korv, bykey = self._node_by_korv, self._bykey
if oldval is MISSING and oldkey is MISSING: # no key or value duplication
# {0: 1, 2: 3} | {4: 5} => {0: 1, 2: 3, 4: 5}
newnode = self._sntl.new_last_node()
assoc(newnode, newkey, newval)
if unwrites is not None:
unwrites.append((dissoc, newnode))
elif oldval is not MISSING and oldkey is not MISSING: # key and value duplication across two different items
# {0: 1, 2: 3} | {0: 3} => {0: 3}
# n1, n2 => n1 (collapse n1 and n2 into n1)
# oldkey: 2, oldval: 1, oldnode: n2, newkey: 0, newval: 3, newnode: n1
if bykey:
oldnode = node_by_korv[oldkey]
newnode = node_by_korv[newkey]
else:
oldnode = node_by_korv[newval]
newnode = node_by_korv[oldval]
dissoc(oldnode)
assoc(newnode, newkey, newval)
if unwrites is not None:
unwrites.extend((
(assoc, newnode, newkey, oldval),
(assoc, oldnode, oldkey, newval),
(oldnode.relink,),
))
elif oldval is not MISSING: # just key duplication
# {0: 1, 2: 3} | {2: 4} => {0: 1, 2: 4}
# oldkey: MISSING, oldval: 3, newkey: 2, newval: 4
node = node_by_korv[newkey if bykey else oldval]
assoc(node, newkey, newval)
if unwrites is not None:
unwrites.append((assoc, node, newkey, oldval))
else:
assert oldkey is not MISSING # just value duplication
# {0: 1, 2: 3} | {4: 3} => {0: 1, 4: 3}
# oldkey: 2, oldval: MISSING, newkey: 4, newval: 3
node = node_by_korv[oldkey if bykey else newval]
assoc(node, newkey, newval)
if unwrites is not None:
unwrites.append((assoc, node, oldkey, newval))
def __iter__(self) -> t.Iterator[KT]:
"""Iterator over the contained keys in insertion order."""
return self._iter(reverse=False)
def __reversed__(self) -> t.Iterator[KT]:
"""Iterator over the contained keys in reverse insertion order."""
return self._iter(reverse=True)
def _iter(self, *, reverse: bool = False) -> t.Iterator[KT]:
nodes = self._sntl.iternodes(reverse=reverse)
korv_by_node = self._node_by_korv.inverse
if self._bykey:
for node in nodes:
yield korv_by_node[node]
else:
key_by_val = self._invm
for node in nodes:
val = korv_by_node[node]
yield key_by_val[val]
# * Code review nav *
# ============================================================================
# ← Prev: _bidict.py Current: _orderedbase.py Next: _orderedbidict.py →
# ============================================================================

View File

@ -0,0 +1,172 @@
# Copyright 2009-2024 Joshua Bronson. All rights reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# * Code review nav *
# (see comments in __init__.py)
# ============================================================================
# ← Prev: _orderedbase.py Current: _orderedbidict.py <FIN>
# ============================================================================
"""Provide :class:`OrderedBidict`."""
from __future__ import annotations
import typing as t
from collections.abc import Set
from ._base import BidictKeysView
from ._bidict import MutableBidict
from ._orderedbase import OrderedBidictBase
from ._typing import KT
from ._typing import VT
class OrderedBidict(OrderedBidictBase[KT, VT], MutableBidict[KT, VT]):
"""Mutable bidict type that maintains items in insertion order."""
if t.TYPE_CHECKING:
@property
def inverse(self) -> OrderedBidict[VT, KT]: ...
@property
def inv(self) -> OrderedBidict[VT, KT]: ...
def clear(self) -> None:
"""Remove all items."""
super().clear()
self._node_by_korv.clear()
self._sntl.nxt = self._sntl.prv = self._sntl
def _pop(self, key: KT) -> VT:
val = super()._pop(key)
node = self._node_by_korv[key if self._bykey else val]
self._dissoc_node(node)
return val
def popitem(self, last: bool = True) -> tuple[KT, VT]:
"""*b.popitem() → (k, v)*
If *last* is true,
remove and return the most recently added item as a (key, value) pair.
Otherwise, remove and return the least recently added item.
:raises KeyError: if *b* is empty.
"""
if not self:
raise KeyError('OrderedBidict is empty')
node = getattr(self._sntl, 'prv' if last else 'nxt')
korv = self._node_by_korv.inverse[node]
if self._bykey:
return korv, self._pop(korv)
return self.inverse._pop(korv), korv
def move_to_end(self, key: KT, last: bool = True) -> None:
"""Move the item with the given key to the end if *last* is true, else to the beginning.
:raises KeyError: if *key* is missing
"""
korv = key if self._bykey else self._fwdm[key]
node = self._node_by_korv[korv]
node.prv.nxt = node.nxt
node.nxt.prv = node.prv
sntl = self._sntl
if last:
lastnode = sntl.prv
node.prv = lastnode
node.nxt = sntl
sntl.prv = lastnode.nxt = node
else:
firstnode = sntl.nxt
node.prv = sntl
node.nxt = firstnode
sntl.nxt = firstnode.prv = node
# Override the keys() and items() implementations inherited from BidictBase,
# which may delegate to the backing _fwdm dict, since this is a mutable ordered bidict,
# and therefore the ordering of items can get out of sync with the backing mappings
# after mutation. (Need not override values() because it delegates to .inverse.keys().)
def keys(self) -> t.KeysView[KT]:
"""A set-like object providing a view on the contained keys."""
return _OrderedBidictKeysView(self)
def items(self) -> t.ItemsView[KT, VT]:
"""A set-like object providing a view on the contained items."""
return _OrderedBidictItemsView(self)
# The following MappingView implementations use the __iter__ implementations
# inherited from their superclass counterparts in collections.abc, so they
# continue to yield items in the correct order even after an OrderedBidict
# is mutated. They also provide a __reversed__ implementation, which is not
# provided by the collections.abc superclasses.
class _OrderedBidictKeysView(BidictKeysView[KT]):
_mapping: OrderedBidict[KT, t.Any]
def __reversed__(self) -> t.Iterator[KT]:
return reversed(self._mapping)
class _OrderedBidictItemsView(t.ItemsView[KT, VT]):
_mapping: OrderedBidict[KT, VT]
def __reversed__(self) -> t.Iterator[tuple[KT, VT]]:
ob = self._mapping
for key in reversed(ob):
yield key, ob[key]
# For better performance, make _OrderedBidictKeysView and _OrderedBidictItemsView delegate
# to backing dicts for the methods they inherit from collections.abc.Set. (Cannot delegate
# for __iter__ and __reversed__ since they are order-sensitive.) See also: https://bugs.python.org/issue46713
_OView = t.Union[t.Type[_OrderedBidictKeysView[KT]], t.Type[_OrderedBidictItemsView[KT, t.Any]]]
_setmethodnames: t.Iterable[str] = (
'__lt__ __le__ __gt__ __ge__ __eq__ __ne__ __sub__ __rsub__ '
'__or__ __ror__ __xor__ __rxor__ __and__ __rand__ isdisjoint'
).split()
def _override_set_methods_to_use_backing_dict(cls: _OView[KT], viewname: str) -> None:
def make_proxy_method(methodname: str) -> t.Any:
def method(self: _OrderedBidictKeysView[KT] | _OrderedBidictItemsView[KT, t.Any], *args: t.Any) -> t.Any:
fwdm = self._mapping._fwdm
if not isinstance(fwdm, dict): # dict view speedup not available, fall back to Set's implementation.
return getattr(Set, methodname)(self, *args)
fwdm_dict_view = getattr(fwdm, viewname)()
fwdm_dict_view_method = getattr(fwdm_dict_view, methodname)
if (
len(args) != 1
or not isinstance((arg := args[0]), self.__class__)
or not isinstance(arg._mapping._fwdm, dict)
):
return fwdm_dict_view_method(*args)
# self and arg are both _OrderedBidictKeysViews or _OrderedBidictItemsViews whose bidicts are backed by
# a dict. Use arg's backing dict's corresponding view instead of arg. Otherwise, e.g. `ob1.keys()
# < ob2.keys()` would give "TypeError: '<' not supported between instances of '_OrderedBidictKeysView' and
# '_OrderedBidictKeysView'", because both `dict_keys(ob1).__lt__(ob2.keys()) is NotImplemented` and
# `dict_keys(ob2).__gt__(ob1.keys()) is NotImplemented`.
arg_dict = arg._mapping._fwdm
arg_dict_view = getattr(arg_dict, viewname)()
return fwdm_dict_view_method(arg_dict_view)
method.__name__ = methodname
method.__qualname__ = f'{cls.__qualname__}.{methodname}'
return method
for name in _setmethodnames:
setattr(cls, name, make_proxy_method(name))
_override_set_methods_to_use_backing_dict(_OrderedBidictKeysView, 'keys')
_override_set_methods_to_use_backing_dict(_OrderedBidictItemsView, 'items')
# * Code review nav *
# ============================================================================
# ← Prev: _orderedbase.py Current: _orderedbidict.py <FIN>
# ============================================================================

View File

@ -0,0 +1,49 @@
# Copyright 2009-2024 Joshua Bronson. All rights reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Provide typing-related objects."""
from __future__ import annotations
import typing as t
from enum import Enum
KT = t.TypeVar('KT')
VT = t.TypeVar('VT')
VT_co = t.TypeVar('VT_co', covariant=True)
Items = t.Iterable[t.Tuple[KT, VT]]
@t.runtime_checkable
class Maplike(t.Protocol[KT, VT_co]):
"""Like typeshed's SupportsKeysAndGetItem, but usable at runtime."""
def keys(self) -> t.Iterable[KT]: ...
def __getitem__(self, __key: KT) -> VT_co: ...
MapOrItems = t.Union[Maplike[KT, VT], Items[KT, VT]]
MappOrItems = t.Union[t.Mapping[KT, VT], Items[KT, VT]]
ItemsIter = t.Iterator[t.Tuple[KT, VT]]
class MissingT(Enum):
"""Sentinel used to represent none/missing when None itself can't be used."""
MISSING = 'MISSING'
MISSING: t.Final[t.Literal[MissingT.MISSING]] = MissingT.MISSING
OKT = t.Union[KT, MissingT] #: optional key type
OVT = t.Union[VT, MissingT] #: optional value type
DT = t.TypeVar('DT') #: for default arguments
ODT = t.Union[DT, MissingT] #: optional default arg type

View File

@ -0,0 +1,14 @@
# Copyright 2009-2024 Joshua Bronson. All rights reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Define bidict package metadata."""
__version__ = '0.23.1'
__author__ = {'name': 'Joshua Bronson', 'email': 'jabronson@gmail.com'}
__copyright__ = '© 2009-2024 Joshua Bronson'
__description__ = 'The bidirectional mapping library for Python.'
__license__ = 'MPL 2.0'
__url__ = 'https://bidict.readthedocs.io'

View File

@ -0,0 +1 @@
PEP-561 marker.

View File

@ -0,0 +1 @@
/home/davidson/Projects/evolution_client/python

View File

@ -0,0 +1,13 @@
from .client import Client
from .middleware import WSGIApp, Middleware
from .server import Server
from .async_server import AsyncServer
from .async_client import AsyncClient
from .async_drivers.asgi import ASGIApp
try:
from .async_drivers.tornado import get_tornado_handler
except ImportError: # pragma: no cover
get_tornado_handler = None
__all__ = ['Server', 'WSGIApp', 'Middleware', 'Client',
'AsyncServer', 'ASGIApp', 'get_tornado_handler', 'AsyncClient']

View File

@ -0,0 +1,680 @@
import asyncio
import signal
import ssl
import threading
try:
import aiohttp
except ImportError: # pragma: no cover
aiohttp = None
from . import base_client
from . import exceptions
from . import packet
from . import payload
async_signal_handler_set = False
# this set is used to keep references to background tasks to prevent them from
# being garbage collected mid-execution. Solution taken from
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
task_reference_holder = set()
def async_signal_handler():
"""SIGINT handler.
Disconnect all active async clients.
"""
async def _handler(): # pragma: no cover
for c in base_client.connected_clients[:]:
if c.is_asyncio_based():
await c.disconnect()
# cancel all running tasks
tasks = [task for task in asyncio.all_tasks() if task is not
asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
asyncio.get_running_loop().stop()
asyncio.ensure_future(_handler())
class AsyncClient(base_client.BaseClient):
"""An Engine.IO client for asyncio.
This class implements a fully compliant Engine.IO web client with support
for websocket and long-polling transports, compatible with the asyncio
framework on Python 3.5 or newer.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``. Note that fatal errors are logged even when
``logger`` is ``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param request_timeout: A timeout in seconds for requests. The default is
5 seconds.
:param http_session: an initialized ``aiohttp.ClientSession`` object to be
used when sending requests to the server. Use it if
you need to add special client options such as proxy
servers, SSL certificates, custom CA bundle, etc.
:param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to
skip SSL certificate verification, allowing
connections to servers with self signed certificates.
The default is ``True``.
:param handle_sigint: Set to ``True`` to automatically handle disconnection
when the process is interrupted, or to ``False`` to
leave interrupt handling to the calling application.
Interrupt handling can only be enabled when the
client instance is created in the main thread.
:param websocket_extra_options: Dictionary containing additional keyword
arguments passed to
``aiohttp.ws_connect()``.
:param timestamp_requests: If ``True`` a timestamp is added to the query
string of Socket.IO requests as a cache-busting
measure. Set to ``False`` to disable.
"""
def is_asyncio_based(self):
return True
async def connect(self, url, headers=None, transports=None,
engineio_path='engine.io'):
"""Connect to an Engine.IO server.
:param url: The URL of the Engine.IO server. It can include custom
query string parameters if required by the server.
:param headers: A dictionary with custom headers to send with the
connection request.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
then an upgrade to websocket is attempted.
:param engineio_path: The endpoint where the Engine.IO server is
installed. The default value is appropriate for
most cases.
Note: this method is a coroutine.
Example usage::
eio = engineio.Client()
await eio.connect('http://localhost:5000')
"""
global async_signal_handler_set
if self.handle_sigint and not async_signal_handler_set and \
threading.current_thread() == threading.main_thread():
try:
asyncio.get_running_loop().add_signal_handler(
signal.SIGINT, async_signal_handler)
except NotImplementedError: # pragma: no cover
self.logger.warning('Signal handler is unsupported')
async_signal_handler_set = True
if self.state != 'disconnected':
raise ValueError('Client is not in a disconnected state')
valid_transports = ['polling', 'websocket']
if transports is not None:
if isinstance(transports, str):
transports = [transports]
transports = [transport for transport in transports
if transport in valid_transports]
if not transports:
raise ValueError('No valid transports provided')
self.transports = transports or valid_transports
self.queue = self.create_queue()
return await getattr(self, '_connect_' + self.transports[0])(
url, headers or {}, engineio_path)
async def wait(self):
"""Wait until the connection with the server ends.
Client applications can use this function to block the main thread
during the life of the connection.
Note: this method is a coroutine.
"""
if self.read_loop_task:
await self.read_loop_task
async def send(self, data):
"""Send a message to the server.
:param data: The data to send to the server. Data can be of type
``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
or ``dict``, the data will be serialized as JSON.
Note: this method is a coroutine.
"""
await self._send_packet(packet.Packet(packet.MESSAGE, data=data))
async def disconnect(self, abort=False, reason=None):
"""Disconnect from the server.
:param abort: If set to ``True``, do not wait for background tasks
associated with the connection to end.
Note: this method is a coroutine.
"""
if self.state == 'connected':
await self._send_packet(packet.Packet(packet.CLOSE))
await self.queue.put(None)
self.state = 'disconnecting'
await self._trigger_event('disconnect',
reason or self.reason.CLIENT_DISCONNECT,
run_async=False)
if self.current_transport == 'websocket':
await self.ws.close()
if not abort:
await self.read_loop_task
self.state = 'disconnected'
try:
base_client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
await self._reset()
def start_background_task(self, target, *args, **kwargs):
"""Start a background task.
This is a utility function that applications can use to start a
background task.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
The return value is a ``asyncio.Task`` object.
"""
return asyncio.ensure_future(target(*args, **kwargs))
async def sleep(self, seconds=0):
"""Sleep for the requested amount of time.
Note: this method is a coroutine.
"""
return await asyncio.sleep(seconds)
def create_queue(self):
"""Create a queue object."""
q = asyncio.Queue()
q.Empty = asyncio.QueueEmpty
return q
def create_event(self):
"""Create an event object."""
return asyncio.Event()
async def _reset(self):
super()._reset()
if not self.external_http: # pragma: no cover
if self.http and not self.http.closed:
await self.http.close()
def __del__(self): # pragma: no cover
# try to close the aiohttp session if it is still open
if self.http and not self.http.closed:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.ensure_future(self.http.close())
else:
loop.run_until_complete(self.http.close())
except:
pass
async def _connect_polling(self, url, headers, engineio_path):
"""Establish a long-polling connection to the Engine.IO server."""
if aiohttp is None: # pragma: no cover
self.logger.error('aiohttp not installed -- cannot make HTTP '
'requests!')
return
self.base_url = self._get_engineio_url(url, engineio_path, 'polling')
self.logger.info('Attempting polling connection to ' + self.base_url)
r = await self._send_request(
'GET', self.base_url + self._get_url_timestamp(), headers=headers,
timeout=self.request_timeout)
if r is None or isinstance(r, str):
await self._reset()
raise exceptions.ConnectionError(
r or 'Connection refused by the server')
if r.status < 200 or r.status >= 300:
await self._reset()
try:
arg = await r.json()
except aiohttp.ClientError:
arg = None
raise exceptions.ConnectionError(
'Unexpected status code {} in server response'.format(
r.status), arg)
try:
p = payload.Payload(encoded_payload=(await r.read()).decode(
'utf-8'))
except ValueError:
raise exceptions.ConnectionError(
'Unexpected response from server') from None
open_packet = p.packets[0]
if open_packet.packet_type != packet.OPEN:
raise exceptions.ConnectionError(
'OPEN packet not returned by server')
self.logger.info(
'Polling connection accepted with ' + str(open_packet.data))
self.sid = open_packet.data['sid']
self.upgrades = open_packet.data['upgrades']
self.ping_interval = int(open_packet.data['pingInterval']) / 1000.0
self.ping_timeout = int(open_packet.data['pingTimeout']) / 1000.0
self.current_transport = 'polling'
self.base_url += '&sid=' + self.sid
self.state = 'connected'
base_client.connected_clients.append(self)
await self._trigger_event('connect', run_async=False)
for pkt in p.packets[1:]:
await self._receive_packet(pkt)
if 'websocket' in self.upgrades and 'websocket' in self.transports:
# attempt to upgrade to websocket
if await self._connect_websocket(url, headers, engineio_path):
# upgrade to websocket succeeded, we're done here
return
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_polling)
async def _connect_websocket(self, url, headers, engineio_path):
"""Establish or upgrade to a WebSocket connection with the server."""
if aiohttp is None: # pragma: no cover
self.logger.error('aiohttp package not installed')
return False
websocket_url = self._get_engineio_url(url, engineio_path,
'websocket')
if self.sid:
self.logger.info(
'Attempting WebSocket upgrade to ' + websocket_url)
upgrade = True
websocket_url += '&sid=' + self.sid
else:
upgrade = False
self.base_url = websocket_url
self.logger.info(
'Attempting WebSocket connection to ' + websocket_url)
if self.http is None or self.http.closed: # pragma: no cover
self.http = aiohttp.ClientSession()
# extract any new cookies passed in a header so that they can also be
# sent the the WebSocket route
cookies = {}
for header, value in headers.items():
if header.lower() == 'cookie':
cookies = dict(
[cookie.split('=', 1) for cookie in value.split('; ')])
del headers[header]
break
self.http.cookie_jar.update_cookies(cookies)
extra_options = {'timeout': self.request_timeout}
if not self.ssl_verify:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
extra_options['ssl'] = ssl_context
# combine internally generated options with the ones supplied by the
# caller. The caller's options take precedence.
headers.update(self.websocket_extra_options.pop('headers', {}))
extra_options['headers'] = headers
extra_options.update(self.websocket_extra_options)
try:
ws = await self.http.ws_connect(
websocket_url + self._get_url_timestamp(), **extra_options)
except (aiohttp.client_exceptions.WSServerHandshakeError,
aiohttp.client_exceptions.ServerConnectionError,
aiohttp.client_exceptions.ClientConnectionError):
if upgrade:
self.logger.warning(
'WebSocket upgrade failed: connection error')
return False
else:
raise exceptions.ConnectionError('Connection error')
if upgrade:
p = packet.Packet(packet.PING, data='probe').encode()
try:
await ws.send_str(p)
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected send exception: %s',
str(e))
return False
try:
p = (await ws.receive()).data
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected recv exception: %s',
str(e))
return False
pkt = packet.Packet(encoded_packet=p)
if pkt.packet_type != packet.PONG or pkt.data != 'probe':
self.logger.warning(
'WebSocket upgrade failed: no PONG packet')
return False
p = packet.Packet(packet.UPGRADE).encode()
try:
await ws.send_str(p)
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected send exception: %s',
str(e))
return False
self.current_transport = 'websocket'
self.logger.info('WebSocket upgrade was successful')
else:
try:
p = (await ws.receive()).data
except Exception as e: # pragma: no cover
raise exceptions.ConnectionError(
'Unexpected recv exception: ' + str(e))
open_packet = packet.Packet(encoded_packet=p)
if open_packet.packet_type != packet.OPEN:
raise exceptions.ConnectionError('no OPEN packet')
self.logger.info(
'WebSocket connection accepted with ' + str(open_packet.data))
self.sid = open_packet.data['sid']
self.upgrades = open_packet.data['upgrades']
self.ping_interval = int(open_packet.data['pingInterval']) / 1000.0
self.ping_timeout = int(open_packet.data['pingTimeout']) / 1000.0
self.current_transport = 'websocket'
self.state = 'connected'
base_client.connected_clients.append(self)
await self._trigger_event('connect', run_async=False)
self.ws = ws
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_websocket)
return True
async def _receive_packet(self, pkt):
"""Handle incoming packets from the server."""
packet_name = packet.packet_names[pkt.packet_type] \
if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
self.logger.info(
'Received packet %s data %s', packet_name,
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
if pkt.packet_type == packet.MESSAGE:
await self._trigger_event('message', pkt.data, run_async=True)
elif pkt.packet_type == packet.PING:
await self._send_packet(packet.Packet(packet.PONG, pkt.data))
elif pkt.packet_type == packet.CLOSE:
await self.disconnect(abort=True,
reason=self.reason.SERVER_DISCONNECT)
elif pkt.packet_type == packet.NOOP:
pass
else:
self.logger.error('Received unexpected packet of type %s',
pkt.packet_type)
async def _send_packet(self, pkt):
"""Queue a packet to be sent to the server."""
if self.state != 'connected':
return
await self.queue.put(pkt)
self.logger.info(
'Sending packet %s data %s',
packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
async def _send_request(
self, method, url, headers=None, body=None,
timeout=None): # pragma: no cover
if self.http is None or self.http.closed:
self.http = aiohttp.ClientSession()
http_method = getattr(self.http, method.lower())
try:
if not self.ssl_verify:
return await http_method(
url, headers=headers, data=body,
timeout=aiohttp.ClientTimeout(total=timeout), ssl=False)
else:
return await http_method(
url, headers=headers, data=body,
timeout=aiohttp.ClientTimeout(total=timeout))
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
self.logger.info('HTTP %s request to %s failed with error %s.',
method, url, exc)
return str(exc)
async def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
ret = None
if event in self.handlers:
if asyncio.iscoroutinefunction(self.handlers[event]) is True:
if run_async:
task = self.start_background_task(self.handlers[event],
*args)
task_reference_holder.add(task)
task.add_done_callback(task_reference_holder.discard)
return task
else:
try:
try:
ret = await self.handlers[event](*args)
except TypeError:
if event == 'disconnect' and \
len(args) == 1: # pragma: no branch
# legacy disconnect events do not have a reason
# argument
return await self.handlers[event]()
else: # pragma: no cover
raise
except asyncio.CancelledError: # pragma: no cover
pass
except:
self.logger.exception(event + ' async handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
else:
if run_async:
async def async_handler():
return self.handlers[event](*args)
task = self.start_background_task(async_handler)
task_reference_holder.add(task)
task.add_done_callback(task_reference_holder.discard)
return task
else:
try:
try:
ret = self.handlers[event](*args)
except TypeError:
if event == 'disconnect' and \
len(args) == 1: # pragma: no branch
# legacy disconnect events do not have a reason
# argument
ret = self.handlers[event]()
else: # pragma: no cover
raise
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
return ret
async def _read_loop_polling(self):
"""Read packets by polling the Engine.IO server."""
while self.state == 'connected' and self.write_loop_task:
self.logger.info(
'Sending polling GET request to ' + self.base_url)
r = await self._send_request(
'GET', self.base_url + self._get_url_timestamp(),
timeout=max(self.ping_interval, self.ping_timeout) + 5)
if r is None or isinstance(r, str):
self.logger.warning(
r or 'Connection refused by the server, aborting')
await self.queue.put(None)
break
if r.status < 200 or r.status >= 300:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status)
await self.queue.put(None)
break
try:
p = payload.Payload(encoded_payload=(await r.read()).decode(
'utf-8'))
except ValueError:
self.logger.warning(
'Unexpected packet from server, aborting')
await self.queue.put(None)
break
for pkt in p.packets:
await self._receive_packet(pkt)
if self.write_loop_task: # pragma: no branch
self.logger.info('Waiting for write loop task to end')
await self.write_loop_task
if self.state == 'connected':
await self._trigger_event(
'disconnect', self.reason.TRANSPORT_ERROR, run_async=False)
try:
base_client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
await self._reset()
self.logger.info('Exiting read loop task')
async def _read_loop_websocket(self):
"""Read packets from the Engine.IO WebSocket connection."""
while self.state == 'connected':
p = None
try:
p = await asyncio.wait_for(
self.ws.receive(),
timeout=self.ping_interval + self.ping_timeout)
if not isinstance(p.data, (str, bytes)): # pragma: no cover
self.logger.warning(
'Server sent %s packet data %s, aborting',
'close' if p.type in [aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING]
else str(p.type), str(p.data))
await self.queue.put(None)
break # the connection is broken
p = p.data
except asyncio.TimeoutError:
self.logger.warning(
'Server has stopped communicating, aborting')
await self.queue.put(None)
break
except aiohttp.client_exceptions.ServerDisconnectedError:
self.logger.info(
'Read loop: WebSocket connection was closed, aborting')
await self.queue.put(None)
break
except Exception as e:
self.logger.info(
'Unexpected error receiving packet: "%s", aborting',
str(e))
await self.queue.put(None)
break
try:
pkt = packet.Packet(encoded_packet=p)
except Exception as e: # pragma: no cover
self.logger.info(
'Unexpected error decoding packet: "%s", aborting', str(e))
await self.queue.put(None)
break
await self._receive_packet(pkt)
if self.write_loop_task: # pragma: no branch
self.logger.info('Waiting for write loop task to end')
await self.write_loop_task
if self.state == 'connected':
await self._trigger_event(
'disconnect', self.reason.TRANSPORT_ERROR, run_async=False)
try:
base_client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
await self._reset()
self.logger.info('Exiting read loop task')
async def _write_loop(self):
"""This background task sends packages to the server as they are
pushed to the send queue.
"""
while self.state == 'connected':
# to simplify the timeout handling, use the maximum of the
# ping interval and ping timeout as timeout, with an extra 5
# seconds grace period
timeout = max(self.ping_interval, self.ping_timeout) + 5
packets = None
try:
packets = [await asyncio.wait_for(self.queue.get(), timeout)]
except (self.queue.Empty, asyncio.TimeoutError):
self.logger.error('packet queue is empty, aborting')
break
except asyncio.CancelledError: # pragma: no cover
break
if packets == [None]:
self.queue.task_done()
packets = []
else:
while True:
try:
packets.append(self.queue.get_nowait())
except self.queue.Empty:
break
if packets[-1] is None:
packets = packets[:-1]
self.queue.task_done()
break
if not packets:
# empty packet list returned -> connection closed
break
if self.current_transport == 'polling':
p = payload.Payload(packets=packets)
r = await self._send_request(
'POST', self.base_url, body=p.encode(),
headers={'Content-Type': 'text/plain'},
timeout=self.request_timeout)
for pkt in packets:
self.queue.task_done()
if r is None or isinstance(r, str):
self.logger.warning(
r or 'Connection refused by the server, aborting')
break
if r.status < 200 or r.status >= 300:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status)
self.write_loop_task = None
break
else:
# websocket
try:
for pkt in packets:
if pkt.binary:
await self.ws.send_bytes(pkt.encode())
else:
await self.ws.send_str(pkt.encode())
self.queue.task_done()
except (aiohttp.client_exceptions.ServerDisconnectedError,
BrokenPipeError, OSError):
self.logger.info(
'Write loop: WebSocket connection was closed, '
'aborting')
break
self.logger.info('Exiting write loop task')

View File

@ -0,0 +1,34 @@
import simple_websocket
class SimpleWebSocketWSGI: # pragma: no cover
"""
This wrapper class provides a threading WebSocket interface that is
compatible with eventlet's implementation.
"""
def __init__(self, handler, server, **kwargs):
self.app = handler
self.server_args = kwargs
def __call__(self, environ, start_response):
self.ws = simple_websocket.Server(environ, **self.server_args)
ret = self.app(self)
if self.ws.mode == 'gunicorn':
raise StopIteration()
return ret
def close(self):
if self.ws.connected:
self.ws.close()
def send(self, message):
try:
return self.ws.send(message)
except simple_websocket.ConnectionClosed:
raise OSError()
def wait(self):
try:
return self.ws.receive()
except simple_websocket.ConnectionClosed:
return None

View File

@ -0,0 +1,127 @@
import asyncio
import sys
from urllib.parse import urlsplit
from aiohttp.web import Response, WebSocketResponse
def create_route(app, engineio_server, engineio_endpoint):
"""This function sets up the engine.io endpoint as a route for the
application.
Note that both GET and POST requests must be hooked up on the engine.io
endpoint.
"""
app.router.add_get(engineio_endpoint, engineio_server.handle_request)
app.router.add_post(engineio_endpoint, engineio_server.handle_request)
app.router.add_route('OPTIONS', engineio_endpoint,
engineio_server.handle_request)
def translate_request(request):
"""This function takes the arguments passed to the request handler and
uses them to generate a WSGI compatible environ dictionary.
"""
message = request._message
payload = request._payload
uri_parts = urlsplit(message.path)
environ = {
'wsgi.input': payload,
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'aiohttp',
'REQUEST_METHOD': message.method,
'QUERY_STRING': uri_parts.query or '',
'RAW_URI': message.path,
'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version,
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'aiohttp',
'SERVER_PORT': '0',
'aiohttp.request': request
}
for hdr_name, hdr_value in message.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = f'{environ[key]},{hdr_value}'
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
path_info = uri_parts.path
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = ''
return environ
def make_response(status, headers, payload, environ):
"""This function generates an appropriate response object for this async
mode.
"""
return Response(body=payload, status=int(status.split()[0]),
headers=headers)
class WebSocket: # pragma: no cover
"""
This wrapper class provides a aiohttp WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler, server):
self.handler = handler
self._sock = None
async def __call__(self, environ):
request = environ['aiohttp.request']
self._sock = WebSocketResponse(max_msg_size=0)
await self._sock.prepare(request)
self.environ = environ
await self.handler(self)
return self._sock
async def close(self):
await self._sock.close()
async def send(self, message):
if isinstance(message, bytes):
f = self._sock.send_bytes
else:
f = self._sock.send_str
if asyncio.iscoroutinefunction(f):
await f(message)
else:
f(message)
async def wait(self):
msg = await self._sock.receive()
if not isinstance(msg.data, bytes) and \
not isinstance(msg.data, str):
raise OSError()
return msg.data
_async = {
'asyncio': True,
'create_route': create_route,
'translate_request': translate_request,
'make_response': make_response,
'websocket': WebSocket,
}

View File

@ -0,0 +1,291 @@
import os
import sys
import asyncio
from engineio.static_files import get_static_file
class ASGIApp:
"""ASGI application middleware for Engine.IO.
This middleware dispatches traffic to an Engine.IO application. It can
also serve a list of static files to the client, or forward unrelated
HTTP traffic to another ASGI application.
:param engineio_server: The Engine.IO server. Must be an instance of the
``engineio.AsyncServer`` class.
:param static_files: A dictionary with static file mapping rules. See the
documentation for details on this argument.
:param other_asgi_app: A separate ASGI app that receives all other traffic.
:param engineio_path: The endpoint where the Engine.IO application should
be installed. The default value is appropriate for
most cases. With a value of ``None``, all incoming
traffic is directed to the Engine.IO server, with the
assumption that routing, if necessary, is handled by
a different layer. When this option is set to
``None``, ``static_files`` and ``other_asgi_app`` are
ignored.
:param on_startup: function to be called on application startup; can be
coroutine
:param on_shutdown: function to be called on application shutdown; can be
coroutine
Example usage::
import engineio
import uvicorn
eio = engineio.AsyncServer()
app = engineio.ASGIApp(eio, static_files={
'/': {'content_type': 'text/html', 'filename': 'index.html'},
'/index.html': {'content_type': 'text/html',
'filename': 'index.html'},
})
uvicorn.run(app, '127.0.0.1', 5000)
"""
def __init__(self, engineio_server, other_asgi_app=None,
static_files=None, engineio_path='engine.io',
on_startup=None, on_shutdown=None):
self.engineio_server = engineio_server
self.other_asgi_app = other_asgi_app
self.engineio_path = engineio_path
if self.engineio_path is not None:
if not self.engineio_path.startswith('/'):
self.engineio_path = '/' + self.engineio_path
if not self.engineio_path.endswith('/'):
self.engineio_path += '/'
self.static_files = static_files or {}
self.on_startup = on_startup
self.on_shutdown = on_shutdown
async def __call__(self, scope, receive, send):
if scope['type'] == 'lifespan':
await self.lifespan(scope, receive, send)
elif scope['type'] in ['http', 'websocket'] and (
self.engineio_path is None
or self._ensure_trailing_slash(scope['path']).startswith(
self.engineio_path)):
await self.engineio_server.handle_request(scope, receive, send)
else:
static_file = get_static_file(scope['path'], self.static_files) \
if scope['type'] == 'http' and self.static_files else None
if static_file and os.path.exists(static_file['filename']):
await self.serve_static_file(static_file, receive, send)
elif self.other_asgi_app is not None:
await self.other_asgi_app(scope, receive, send)
else:
await self.not_found(receive, send)
async def serve_static_file(self, static_file, receive,
send): # pragma: no cover
event = await receive()
if event['type'] == 'http.request':
with open(static_file['filename'], 'rb') as f:
payload = f.read()
await send({'type': 'http.response.start',
'status': 200,
'headers': [(b'Content-Type', static_file[
'content_type'].encode('utf-8'))]})
await send({'type': 'http.response.body',
'body': payload})
async def lifespan(self, scope, receive, send):
if self.other_asgi_app is not None and self.on_startup is None and \
self.on_shutdown is None:
# let the other ASGI app handle lifespan events
await self.other_asgi_app(scope, receive, send)
return
while True:
event = await receive()
if event['type'] == 'lifespan.startup':
if self.on_startup:
try:
await self.on_startup() \
if asyncio.iscoroutinefunction(self.on_startup) \
else self.on_startup()
except:
await send({'type': 'lifespan.startup.failed'})
return
await send({'type': 'lifespan.startup.complete'})
elif event['type'] == 'lifespan.shutdown':
if self.on_shutdown:
try:
await self.on_shutdown() \
if asyncio.iscoroutinefunction(self.on_shutdown) \
else self.on_shutdown()
except:
await send({'type': 'lifespan.shutdown.failed'})
return
await send({'type': 'lifespan.shutdown.complete'})
return
async def not_found(self, receive, send):
"""Return a 404 Not Found error to the client."""
await send({'type': 'http.response.start',
'status': 404,
'headers': [(b'Content-Type', b'text/plain')]})
await send({'type': 'http.response.body',
'body': b'Not Found'})
def _ensure_trailing_slash(self, path):
if not path.endswith('/'):
path += '/'
return path
async def translate_request(scope, receive, send):
class AwaitablePayload: # pragma: no cover
def __init__(self, payload):
self.payload = payload or b''
async def read(self, length=None):
if length is None:
r = self.payload
self.payload = b''
else:
r = self.payload[:length]
self.payload = self.payload[length:]
return r
event = await receive()
payload = b''
if event['type'] == 'http.request':
payload += event.get('body') or b''
while event.get('more_body'):
event = await receive()
if event['type'] == 'http.request':
payload += event.get('body') or b''
elif event['type'] == 'websocket.connect':
pass
else:
return {}
raw_uri = scope['path']
query_string = ''
if 'query_string' in scope and scope['query_string']:
try:
query_string = scope['query_string'].decode('utf-8')
except UnicodeDecodeError:
pass
else:
raw_uri += '?' + query_string
environ = {
'wsgi.input': AwaitablePayload(payload),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'asgi',
'REQUEST_METHOD': scope.get('method', 'GET'),
'PATH_INFO': scope['path'],
'QUERY_STRING': query_string,
'RAW_URI': raw_uri,
'SCRIPT_NAME': '',
'SERVER_PROTOCOL': 'HTTP/1.1',
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'asgi',
'SERVER_PORT': '0',
'asgi.receive': receive,
'asgi.send': send,
'asgi.scope': scope,
}
for hdr_name, hdr_value in scope['headers']:
try:
hdr_name = hdr_name.upper().decode('utf-8')
hdr_value = hdr_value.decode('utf-8')
except UnicodeDecodeError:
# skip header if it cannot be decoded
continue
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = f'{environ[key]},{hdr_value}'
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
return environ
async def make_response(status, headers, payload, environ):
headers = [(h[0].encode('utf-8'), h[1].encode('utf-8')) for h in headers]
if environ['asgi.scope']['type'] == 'websocket':
if status.startswith('200 '):
await environ['asgi.send']({'type': 'websocket.accept',
'headers': headers})
else:
if payload:
reason = payload.decode('utf-8') \
if isinstance(payload, bytes) else str(payload)
await environ['asgi.send']({'type': 'websocket.close',
'reason': reason})
else:
await environ['asgi.send']({'type': 'websocket.close'})
return
await environ['asgi.send']({'type': 'http.response.start',
'status': int(status.split(' ')[0]),
'headers': headers})
await environ['asgi.send']({'type': 'http.response.body',
'body': payload})
class WebSocket: # pragma: no cover
"""
This wrapper class provides an asgi WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler, server):
self.handler = handler
self.asgi_receive = None
self.asgi_send = None
async def __call__(self, environ):
self.asgi_receive = environ['asgi.receive']
self.asgi_send = environ['asgi.send']
await self.asgi_send({'type': 'websocket.accept'})
await self.handler(self)
return '' # send nothing as response
async def close(self):
try:
await self.asgi_send({'type': 'websocket.close'})
except Exception:
# if the socket is already close we don't care
pass
async def send(self, message):
msg_bytes = None
msg_text = None
if isinstance(message, bytes):
msg_bytes = message
else:
msg_text = message
await self.asgi_send({'type': 'websocket.send',
'bytes': msg_bytes,
'text': msg_text})
async def wait(self):
event = await self.asgi_receive()
if event['type'] != 'websocket.receive':
raise OSError()
return event.get('bytes') or event.get('text')
_async = {
'asyncio': True,
'translate_request': translate_request,
'make_response': make_response,
'websocket': WebSocket,
}

View File

@ -0,0 +1,52 @@
from eventlet.green.threading import Event
from eventlet import queue, sleep, spawn
from eventlet.websocket import WebSocketWSGI as _WebSocketWSGI
class EventletThread: # pragma: no cover
"""Thread class that uses eventlet green threads.
Eventlet's own Thread class has a strange bug that causes _DummyThread
objects to be created and leaked, since they are never garbage collected.
"""
def __init__(self, target, args=None, kwargs=None):
self.target = target
self.args = args or ()
self.kwargs = kwargs or {}
self.g = None
def start(self):
self.g = spawn(self.target, *self.args, **self.kwargs)
def join(self):
if self.g:
return self.g.wait()
class WebSocketWSGI(_WebSocketWSGI): # pragma: no cover
def __init__(self, handler, server):
try:
super().__init__(
handler, max_frame_length=int(server.max_http_buffer_size))
except TypeError: # pragma: no cover
# older versions of eventlet do not support a max frame size
super().__init__(handler)
self._sock = None
def __call__(self, environ, start_response):
if 'eventlet.input' not in environ:
raise RuntimeError('You need to use the eventlet server. '
'See the Deployment section of the '
'documentation for more information.')
self._sock = environ['eventlet.input'].get_socket()
return super().__call__(environ, start_response)
_async = {
'thread': EventletThread,
'queue': queue.Queue,
'queue_empty': queue.Empty,
'event': Event,
'websocket': WebSocketWSGI,
'sleep': sleep,
}

View File

@ -0,0 +1,83 @@
import gevent
from gevent import queue
from gevent.event import Event
try:
# use gevent-websocket if installed
import geventwebsocket # noqa
SimpleWebSocketWSGI = None
except ImportError: # pragma: no cover
# fallback to simple_websocket when gevent-websocket is not installed
from engineio.async_drivers._websocket_wsgi import SimpleWebSocketWSGI
class Thread(gevent.Greenlet): # pragma: no cover
"""
This wrapper class provides gevent Greenlet interface that is compatible
with the standard library's Thread class.
"""
def __init__(self, target, args=[], kwargs={}):
super().__init__(target, *args, **kwargs)
def _run(self):
return self.run()
if SimpleWebSocketWSGI is not None:
class WebSocketWSGI(SimpleWebSocketWSGI): # pragma: no cover
"""
This wrapper class provides a gevent WebSocket interface that is
compatible with eventlet's implementation, using the simple-websocket
package.
"""
def __init__(self, handler, server):
# to avoid the requirement that the standard library is
# monkey-patched, here we pass the gevent versions of the
# concurrency and networking classes required by simple-websocket
import gevent.event
import gevent.selectors
super().__init__(handler, server,
thread_class=Thread,
event_class=gevent.event.Event,
selector_class=gevent.selectors.DefaultSelector)
else:
class WebSocketWSGI: # pragma: no cover
"""
This wrapper class provides a gevent WebSocket interface that is
compatible with eventlet's implementation, using the gevent-websocket
package.
"""
def __init__(self, handler, server):
self.app = handler
def __call__(self, environ, start_response):
if 'wsgi.websocket' not in environ:
raise RuntimeError('The gevent-websocket server is not '
'configured appropriately. '
'See the Deployment section of the '
'documentation for more information.')
self._sock = environ['wsgi.websocket']
self.environ = environ
self.version = self._sock.version
self.path = self._sock.path
self.origin = self._sock.origin
self.protocol = self._sock.protocol
return self.app(self)
def close(self):
return self._sock.close()
def send(self, message):
return self._sock.send(message)
def wait(self):
return self._sock.receive()
_async = {
'thread': Thread,
'queue': queue.JoinableQueue,
'queue_empty': queue.Empty,
'event': Event,
'websocket': WebSocketWSGI,
'sleep': gevent.sleep,
}

View File

@ -0,0 +1,168 @@
import gevent
from gevent import queue
from gevent.event import Event
from gevent import selectors
import uwsgi
_websocket_available = hasattr(uwsgi, 'websocket_handshake')
class Thread(gevent.Greenlet): # pragma: no cover
"""
This wrapper class provides gevent Greenlet interface that is compatible
with the standard library's Thread class.
"""
def __init__(self, target, args=[], kwargs={}):
super().__init__(target, *args, **kwargs)
def _run(self):
return self.run()
class uWSGIWebSocket: # pragma: no cover
"""
This wrapper class provides a uWSGI WebSocket interface that is
compatible with eventlet's implementation.
"""
def __init__(self, handler, server):
self.app = handler
self._sock = None
self.received_messages = []
def __call__(self, environ, start_response):
self._sock = uwsgi.connection_fd()
self.environ = environ
uwsgi.websocket_handshake()
self._req_ctx = None
if hasattr(uwsgi, 'request_context'):
# uWSGI >= 2.1.x with support for api access across-greenlets
self._req_ctx = uwsgi.request_context()
else:
# use event and queue for sending messages
self._event = Event()
self._send_queue = queue.Queue()
# spawn a select greenlet
def select_greenlet_runner(fd, event):
"""Sets event when data becomes available to read on fd."""
sel = selectors.DefaultSelector()
sel.register(fd, selectors.EVENT_READ)
try:
while True:
sel.select()
event.set()
except gevent.GreenletExit:
sel.unregister(fd)
self._select_greenlet = gevent.spawn(
select_greenlet_runner,
self._sock,
self._event)
self.app(self)
uwsgi.disconnect()
return '' # send nothing as response
def close(self):
"""Disconnects uWSGI from the client."""
if self._req_ctx is None:
# better kill it here in case wait() is not called again
self._select_greenlet.kill()
self._event.set()
def _send(self, msg):
"""Transmits message either in binary or UTF-8 text mode,
depending on its type."""
if isinstance(msg, bytes):
method = uwsgi.websocket_send_binary
else:
method = uwsgi.websocket_send
if self._req_ctx is not None:
method(msg, request_context=self._req_ctx)
else:
method(msg)
def _decode_received(self, msg):
"""Returns either bytes or str, depending on message type."""
if not isinstance(msg, bytes):
# already decoded - do nothing
return msg
# only decode from utf-8 if message is not binary data
type = ord(msg[0:1])
if type >= 48: # no binary
return msg.decode('utf-8')
# binary message, don't try to decode
return msg
def send(self, msg):
"""Queues a message for sending. Real transmission is done in
wait method.
Sends directly if uWSGI version is new enough."""
if self._req_ctx is not None:
self._send(msg)
else:
self._send_queue.put(msg)
self._event.set()
def wait(self):
"""Waits and returns received messages.
If running in compatibility mode for older uWSGI versions,
it also sends messages that have been queued by send().
A return value of None means that connection was closed.
This must be called repeatedly. For uWSGI < 2.1.x it must
be called from the main greenlet."""
while True:
if self._req_ctx is not None:
try:
msg = uwsgi.websocket_recv(request_context=self._req_ctx)
except OSError: # connection closed
self.close()
return None
return self._decode_received(msg)
else:
if self.received_messages:
return self.received_messages.pop(0)
# we wake up at least every 3 seconds to let uWSGI
# do its ping/ponging
event_set = self._event.wait(timeout=3)
if event_set:
self._event.clear()
# maybe there is something to send
msgs = []
while True:
try:
msgs.append(self._send_queue.get(block=False))
except gevent.queue.Empty:
break
for msg in msgs:
try:
self._send(msg)
except OSError:
self.close()
return None
# maybe there is something to receive, if not, at least
# ensure uWSGI does its ping/ponging
while True:
try:
msg = uwsgi.websocket_recv_nb()
except OSError: # connection closed
self.close()
return None
if msg: # message available
self.received_messages.append(
self._decode_received(msg))
else:
break
if self.received_messages:
return self.received_messages.pop(0)
_async = {
'thread': Thread,
'queue': queue.JoinableQueue,
'queue_empty': queue.Empty,
'event': Event,
'websocket': uWSGIWebSocket if _websocket_available else None,
'sleep': gevent.sleep,
}

View File

@ -0,0 +1,148 @@
import sys
from urllib.parse import urlsplit
try: # pragma: no cover
from sanic.response import HTTPResponse
try:
from sanic.server.protocols.websocket_protocol import WebSocketProtocol
except ImportError:
from sanic.websocket import WebSocketProtocol
except ImportError:
HTTPResponse = None
WebSocketProtocol = None
def create_route(app, engineio_server, engineio_endpoint): # pragma: no cover
"""This function sets up the engine.io endpoint as a route for the
application.
Note that both GET and POST requests must be hooked up on the engine.io
endpoint.
"""
app.add_route(engineio_server.handle_request, engineio_endpoint,
methods=['GET', 'POST', 'OPTIONS'])
try:
app.enable_websocket()
except AttributeError:
# ignore, this version does not support websocket
pass
def translate_request(request): # pragma: no cover
"""This function takes the arguments passed to the request handler and
uses them to generate a WSGI compatible environ dictionary.
"""
class AwaitablePayload:
def __init__(self, payload):
self.payload = payload or b''
async def read(self, length=None):
if length is None:
r = self.payload
self.payload = b''
else:
r = self.payload[:length]
self.payload = self.payload[length:]
return r
uri_parts = urlsplit(request.url)
environ = {
'wsgi.input': AwaitablePayload(request.body),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'sanic',
'REQUEST_METHOD': request.method,
'QUERY_STRING': uri_parts.query or '',
'RAW_URI': request.url,
'SERVER_PROTOCOL': 'HTTP/' + request.version,
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'sanic',
'SERVER_PORT': '0',
'sanic.request': request
}
for hdr_name, hdr_value in request.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = f'{environ[key]},{hdr_value}'
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
path_info = uri_parts.path
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = ''
return environ
def make_response(status, headers, payload, environ): # pragma: no cover
"""This function generates an appropriate response object for this async
mode.
"""
headers_dict = {}
content_type = None
for h in headers:
if h[0].lower() == 'content-type':
content_type = h[1]
else:
headers_dict[h[0]] = h[1]
return HTTPResponse(body=payload, content_type=content_type,
status=int(status.split()[0]), headers=headers_dict)
class WebSocket: # pragma: no cover
"""
This wrapper class provides a sanic WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler, server):
self.handler = handler
self.server = server
self._sock = None
async def __call__(self, environ):
request = environ['sanic.request']
protocol = request.transport.get_protocol()
self._sock = await protocol.websocket_handshake(request)
self.environ = environ
await self.handler(self)
return self.server._ok()
async def close(self):
await self._sock.close()
async def send(self, message):
await self._sock.send(message)
async def wait(self):
data = await self._sock.recv()
if not isinstance(data, bytes) and \
not isinstance(data, str):
raise OSError()
return data
_async = {
'asyncio': True,
'create_route': create_route,
'translate_request': translate_request,
'make_response': make_response,
'websocket': WebSocket if WebSocketProtocol else None,
}

View File

@ -0,0 +1,19 @@
import queue
import threading
import time
from engineio.async_drivers._websocket_wsgi import SimpleWebSocketWSGI
class DaemonThread(threading.Thread): # pragma: no cover
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, daemon=True)
_async = {
'thread': DaemonThread,
'queue': queue.Queue,
'queue_empty': queue.Empty,
'event': threading.Event,
'websocket': SimpleWebSocketWSGI,
'sleep': time.sleep,
}

View File

@ -0,0 +1,182 @@
import asyncio
import sys
from urllib.parse import urlsplit
from .. import exceptions
import tornado.web
import tornado.websocket
def get_tornado_handler(engineio_server):
class Handler(tornado.websocket.WebSocketHandler): # pragma: no cover
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if isinstance(engineio_server.cors_allowed_origins, str):
if engineio_server.cors_allowed_origins == '*':
self.allowed_origins = None
else:
self.allowed_origins = [
engineio_server.cors_allowed_origins]
else:
self.allowed_origins = engineio_server.cors_allowed_origins
self.receive_queue = asyncio.Queue()
async def get(self, *args, **kwargs):
if self.request.headers.get('Upgrade', '').lower() == 'websocket':
ret = super().get(*args, **kwargs)
if asyncio.iscoroutine(ret):
await ret
else:
await engineio_server.handle_request(self)
async def open(self, *args, **kwargs):
# this is the handler for the websocket request
asyncio.ensure_future(engineio_server.handle_request(self))
async def post(self, *args, **kwargs):
await engineio_server.handle_request(self)
async def options(self, *args, **kwargs):
await engineio_server.handle_request(self)
async def on_message(self, message):
await self.receive_queue.put(message)
async def get_next_message(self):
return await self.receive_queue.get()
def on_close(self):
self.receive_queue.put_nowait(None)
def check_origin(self, origin):
if self.allowed_origins is None or origin in self.allowed_origins:
return True
return super().check_origin(origin)
def get_compression_options(self):
# enable compression
return {}
return Handler
def translate_request(handler):
"""This function takes the arguments passed to the request handler and
uses them to generate a WSGI compatible environ dictionary.
"""
class AwaitablePayload:
def __init__(self, payload):
self.payload = payload or b''
async def read(self, length=None):
if length is None:
r = self.payload
self.payload = b''
else:
r = self.payload[:length]
self.payload = self.payload[length:]
return r
payload = handler.request.body
uri_parts = urlsplit(handler.request.path)
full_uri = handler.request.path
if handler.request.query: # pragma: no cover
full_uri += '?' + handler.request.query
environ = {
'wsgi.input': AwaitablePayload(payload),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'aiohttp',
'REQUEST_METHOD': handler.request.method,
'QUERY_STRING': handler.request.query or '',
'RAW_URI': full_uri,
'SERVER_PROTOCOL': 'HTTP/%s' % handler.request.version,
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'aiohttp',
'SERVER_PORT': '0',
'tornado.handler': handler
}
for hdr_name, hdr_value in handler.request.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
path_info = uri_parts.path
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = ''
return environ
def make_response(status, headers, payload, environ):
"""This function generates an appropriate response object for this async
mode.
"""
tornado_handler = environ['tornado.handler']
try:
tornado_handler.set_status(int(status.split()[0]))
except RuntimeError: # pragma: no cover
# for websocket connections Tornado does not accept a response, since
# it already emitted the 101 status code
return
for header, value in headers:
tornado_handler.set_header(header, value)
tornado_handler.write(payload)
tornado_handler.finish()
class WebSocket: # pragma: no cover
"""
This wrapper class provides a tornado WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler, server):
self.handler = handler
self.tornado_handler = None
async def __call__(self, environ):
self.tornado_handler = environ['tornado.handler']
self.environ = environ
await self.handler(self)
async def close(self):
self.tornado_handler.close()
async def send(self, message):
try:
self.tornado_handler.write_message(
message, binary=isinstance(message, bytes))
except tornado.websocket.WebSocketClosedError:
raise exceptions.EngineIOError()
async def wait(self):
msg = await self.tornado_handler.get_next_message()
if not isinstance(msg, bytes) and \
not isinstance(msg, str):
raise OSError()
return msg
_async = {
'asyncio': True,
'translate_request': translate_request,
'make_response': make_response,
'websocket': WebSocket,
}

View File

@ -0,0 +1,611 @@
import asyncio
import urllib
from . import base_server
from . import exceptions
from . import packet
from . import async_socket
# this set is used to keep references to background tasks to prevent them from
# being garbage collected mid-execution. Solution taken from
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
task_reference_holder = set()
class AsyncServer(base_server.BaseServer):
"""An Engine.IO server for asyncio.
This class implements a fully compliant Engine.IO web server with support
for websocket and long-polling transports, compatible with the asyncio
framework on Python 3.5 or newer.
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are "aiohttp",
"sanic", "tornado" and "asgi". If this argument is not
given, "aiohttp" is tried first, followed by "sanic",
"tornado", and finally "asgi". The first async mode that
has all its dependencies installed is the one that is
chosen.
:param ping_interval: The interval in seconds at which the server pings
the client. The default is 25 seconds. For advanced
control, a two element tuple can be given, where
the first number is the ping interval and the second
is a grace period added by the server.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting. The default
is 20 seconds.
:param max_http_buffer_size: The maximum size that is accepted for incoming
messages. The default is 1,000,000 bytes. In
spite of its name, the value set in this
argument is enforced for HTTP long-polling and
WebSocket connections.
:param allow_upgrades: Whether to allow transport upgrades or not.
:param http_compression: Whether to compress packages when using the
polling transport.
:param compression_threshold: Only compress messages when their byte size
is greater than this value.
:param cookie: If set to a string, it is the name of the HTTP cookie the
server sends back tot he client containing the client
session id. If set to a dictionary, the ``'name'`` key
contains the cookie name and other keys define cookie
attributes, where the value of each attribute can be a
string, a callable with no arguments, or a boolean. If set
to ``None`` (the default), a cookie is not sent to the
client.
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. Only the same origin
is allowed by default. Set this argument to
``'*'`` to allow all origins, or to ``[]`` to
disable CORS handling.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. Note that fatal
errors are logged even when ``logger`` is ``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param async_handlers: If set to ``True``, run message event handlers in
non-blocking threads. To run handlers synchronously,
set to ``False``. The default is ``True``.
:param monitor_clients: If set to ``True``, a background task will ensure
inactive clients are closed. Set to ``False`` to
disable the monitoring task (not recommended). The
default is ``True``.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. Defaults to
``['polling', 'websocket']``.
:param kwargs: Reserved for future extensions, any additional parameters
given as keyword arguments will be silently ignored.
"""
def is_asyncio_based(self):
return True
def async_modes(self):
return ['aiohttp', 'sanic', 'tornado', 'asgi']
def attach(self, app, engineio_path='engine.io'):
"""Attach the Engine.IO server to an application."""
engineio_path = engineio_path.strip('/')
self._async['create_route'](app, self, f'/{engineio_path}/')
async def send(self, sid, data):
"""Send a message to a client.
:param sid: The session id of the recipient client.
:param data: The data to send to the client. Data can be of type
``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
or ``dict``, the data will be serialized as JSON.
Note: this method is a coroutine.
"""
await self.send_packet(sid, packet.Packet(packet.MESSAGE, data=data))
async def send_packet(self, sid, pkt):
"""Send a raw packet to a client.
:param sid: The session id of the recipient client.
:param pkt: The packet to send to the client.
Note: this method is a coroutine.
"""
try:
socket = self._get_socket(sid)
except KeyError:
# the socket is not available
self.logger.warning('Cannot send to sid %s', sid)
return
await socket.send(pkt)
async def get_session(self, sid):
"""Return the user session for a client.
:param sid: The session id of the client.
The return value is a dictionary. Modifications made to this
dictionary are not guaranteed to be preserved. If you want to modify
the user session, use the ``session`` context manager instead.
"""
socket = self._get_socket(sid)
return socket.session
async def save_session(self, sid, session):
"""Store the user session for a client.
:param sid: The session id of the client.
:param session: The session dictionary.
"""
socket = self._get_socket(sid)
socket.session = session
def session(self, sid):
"""Return the user session for a client with context manager syntax.
:param sid: The session id of the client.
This is a context manager that returns the user session dictionary for
the client. Any changes that are made to this dictionary inside the
context manager block are saved back to the session. Example usage::
@eio.on('connect')
def on_connect(sid, environ):
username = authenticate_user(environ)
if not username:
return False
with eio.session(sid) as session:
session['username'] = username
@eio.on('message')
def on_message(sid, msg):
async with eio.session(sid) as session:
print('received message from ', session['username'])
"""
class _session_context_manager:
def __init__(self, server, sid):
self.server = server
self.sid = sid
self.session = None
async def __aenter__(self):
self.session = await self.server.get_session(sid)
return self.session
async def __aexit__(self, *args):
await self.server.save_session(sid, self.session)
return _session_context_manager(self, sid)
async def disconnect(self, sid=None):
"""Disconnect a client.
:param sid: The session id of the client to close. If this parameter
is not given, then all clients are closed.
Note: this method is a coroutine.
"""
if sid is not None:
try:
socket = self._get_socket(sid)
except KeyError: # pragma: no cover
# the socket was already closed or gone
pass
else:
await socket.close(reason=self.reason.SERVER_DISCONNECT)
if sid in self.sockets: # pragma: no cover
del self.sockets[sid]
else:
await asyncio.wait([
asyncio.create_task(client.close(
reason=self.reason.SERVER_DISCONNECT))
for client in self.sockets.values()
])
self.sockets = {}
async def handle_request(self, *args, **kwargs):
"""Handle an HTTP request from the client.
This is the entry point of the Engine.IO application. This function
returns the HTTP response to deliver to the client.
Note: this method is a coroutine.
"""
translate_request = self._async['translate_request']
if asyncio.iscoroutinefunction(translate_request):
environ = await translate_request(*args, **kwargs)
else:
environ = translate_request(*args, **kwargs)
if self.cors_allowed_origins != []:
# Validate the origin header if present
# This is important for WebSocket more than for HTTP, since
# browsers only apply CORS controls to HTTP.
origin = environ.get('HTTP_ORIGIN')
if origin:
allowed_origins = self._cors_allowed_origins(environ)
if allowed_origins is not None and origin not in \
allowed_origins:
self._log_error_once(
origin + ' is not an accepted origin.', 'bad-origin')
return await self._make_response(
self._bad_request(
origin + ' is not an accepted origin.'),
environ)
method = environ['REQUEST_METHOD']
query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))
sid = query['sid'][0] if 'sid' in query else None
jsonp = False
jsonp_index = None
# make sure the client uses an allowed transport
transport = query.get('transport', ['polling'])[0]
if transport not in self.transports:
self._log_error_once('Invalid transport', 'bad-transport')
return await self._make_response(
self._bad_request('Invalid transport'), environ)
# make sure the client speaks a compatible Engine.IO version
sid = query['sid'][0] if 'sid' in query else None
if sid is None and query.get('EIO') != ['4']:
self._log_error_once(
'The client is using an unsupported version of the Socket.IO '
'or Engine.IO protocols', 'bad-version'
)
return await self._make_response(self._bad_request(
'The client is using an unsupported version of the Socket.IO '
'or Engine.IO protocols'
), environ)
if 'j' in query:
jsonp = True
try:
jsonp_index = int(query['j'][0])
except (ValueError, KeyError, IndexError):
# Invalid JSONP index number
pass
if jsonp and jsonp_index is None:
self._log_error_once('Invalid JSONP index number',
'bad-jsonp-index')
r = self._bad_request('Invalid JSONP index number')
elif method == 'GET':
upgrade_header = environ.get('HTTP_UPGRADE').lower() \
if 'HTTP_UPGRADE' in environ else None
if sid is None:
# transport must be one of 'polling' or 'websocket'.
# if 'websocket', the HTTP_UPGRADE header must match.
if transport == 'polling' \
or transport == upgrade_header == 'websocket':
r = await self._handle_connect(environ, transport,
jsonp_index)
else:
self._log_error_once('Invalid websocket upgrade',
'bad-upgrade')
r = self._bad_request('Invalid websocket upgrade')
else:
if sid not in self.sockets:
self._log_error_once(f'Invalid session {sid}', 'bad-sid')
r = self._bad_request(f'Invalid session {sid}')
else:
try:
socket = self._get_socket(sid)
except KeyError as e: # pragma: no cover
self._log_error_once(f'{e} {sid}', 'bad-sid')
r = self._bad_request(f'{e} {sid}')
else:
if self.transport(sid) != transport and \
transport != upgrade_header:
self._log_error_once(
f'Invalid transport for session {sid}',
'bad-transport')
r = self._bad_request('Invalid transport')
else:
try:
packets = await socket.handle_get_request(
environ)
if isinstance(packets, list):
r = self._ok(packets,
jsonp_index=jsonp_index)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and \
self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
self._log_error_once(f'Invalid session {sid}', 'bad-sid')
r = self._bad_request(f'Invalid session {sid}')
else:
socket = self._get_socket(sid)
try:
await socket.handle_post_request(environ)
r = self._ok(jsonp_index=jsonp_index)
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
except: # pragma: no cover
# for any other unexpected errors, we log the error
# and keep going
self.logger.exception('post request handler error')
r = self._ok(jsonp_index=jsonp_index)
elif method == 'OPTIONS':
r = self._ok()
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
if not isinstance(r, dict):
return r
if self.http_compression and \
len(r['response']) >= self.compression_threshold:
encodings = [e.split(';')[0].strip() for e in
environ.get('HTTP_ACCEPT_ENCODING', '').split(',')]
for encoding in encodings:
if encoding in self.compression_methods:
r['response'] = \
getattr(self, '_' + encoding)(r['response'])
r['headers'] += [('Content-Encoding', encoding)]
break
return await self._make_response(r, environ)
async def shutdown(self):
"""Stop Socket.IO background tasks.
This method stops background activity initiated by the Socket.IO
server. It must be called before shutting down the web server.
"""
self.logger.info('Socket.IO is shutting down')
if self.service_task_event: # pragma: no cover
self.service_task_event.set()
await self.service_task_handle
self.service_task_handle = None
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
The return value is a ``asyncio.Task`` object.
"""
return asyncio.ensure_future(target(*args, **kwargs))
async def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
Note: this method is a coroutine.
"""
return await asyncio.sleep(seconds)
def create_queue(self, *args, **kwargs):
"""Create a queue object using the appropriate async model.
This is a utility function that applications can use to create a queue
without having to worry about using the correct call for the selected
async mode. For asyncio based async modes, this returns an instance of
``asyncio.Queue``.
"""
return asyncio.Queue(*args, **kwargs)
def get_queue_empty_exception(self):
"""Return the queue empty exception for the appropriate async model.
This is a utility function that applications can use to work with a
queue without having to worry about using the correct call for the
selected async mode. For asyncio based async modes, this returns an
instance of ``asyncio.QueueEmpty``.
"""
return asyncio.QueueEmpty
def create_event(self, *args, **kwargs):
"""Create an event object using the appropriate async model.
This is a utility function that applications can use to create an
event without having to worry about using the correct call for the
selected async mode. For asyncio based async modes, this returns
an instance of ``asyncio.Event``.
"""
return asyncio.Event(*args, **kwargs)
async def _make_response(self, response_dict, environ):
cors_headers = self._cors_headers(environ)
make_response = self._async['make_response']
if asyncio.iscoroutinefunction(make_response):
response = await make_response(
response_dict['status'],
response_dict['headers'] + cors_headers,
response_dict['response'], environ)
else:
response = make_response(
response_dict['status'],
response_dict['headers'] + cors_headers,
response_dict['response'], environ)
return response
async def _handle_connect(self, environ, transport, jsonp_index=None):
"""Handle a client connection request."""
if self.start_service_task:
# start the service task to monitor connected clients
self.start_service_task = False
self.service_task_handle = self.start_background_task(
self._service_task)
sid = self.generate_id()
s = async_socket.AsyncSocket(self, sid)
self.sockets[sid] = s
pkt = packet.Packet(packet.OPEN, {
'sid': sid,
'upgrades': self._upgrades(sid, transport),
'pingTimeout': int(self.ping_timeout * 1000),
'pingInterval': int(
self.ping_interval + self.ping_interval_grace_period) * 1000,
'maxPayload': self.max_http_buffer_size,
})
await s.send(pkt)
s.schedule_ping()
ret = await self._trigger_event('connect', sid, environ,
run_async=False)
if ret is not None and ret is not True:
del self.sockets[sid]
self.logger.warning('Application rejected connection')
return self._unauthorized(ret or None)
if transport == 'websocket':
ret = await s.handle_get_request(environ)
if s.closed and sid in self.sockets:
# websocket connection ended, so we are done
del self.sockets[sid]
return ret
else:
s.connected = True
headers = None
if self.cookie:
if isinstance(self.cookie, dict):
headers = [(
'Set-Cookie',
self._generate_sid_cookie(sid, self.cookie)
)]
else:
headers = [(
'Set-Cookie',
self._generate_sid_cookie(sid, {
'name': self.cookie, 'path': '/', 'SameSite': 'Lax'
})
)]
try:
return self._ok(await s.poll(), headers=headers,
jsonp_index=jsonp_index)
except exceptions.QueueEmpty:
return self._bad_request()
async def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
ret = None
if event in self.handlers:
if asyncio.iscoroutinefunction(self.handlers[event]):
async def run_async_handler():
try:
try:
return await self.handlers[event](*args)
except TypeError:
if event == 'disconnect' and \
len(args) == 2: # pragma: no branch
# legacy disconnect events do not have a reason
# argument
return await self.handlers[event](args[0])
else: # pragma: no cover
raise
except asyncio.CancelledError: # pragma: no cover
pass
except:
self.logger.exception(event + ' async handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
if run_async:
ret = self.start_background_task(run_async_handler)
task_reference_holder.add(ret)
ret.add_done_callback(task_reference_holder.discard)
else:
ret = await run_async_handler()
else:
async def run_sync_handler():
try:
try:
return self.handlers[event](*args)
except TypeError:
if event == 'disconnect' and \
len(args) == 2: # pragma: no branch
# legacy disconnect events do not have a reason
# argument
return self.handlers[event](args[0])
else: # pragma: no cover
raise
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
if run_async:
ret = self.start_background_task(run_sync_handler)
task_reference_holder.add(ret)
ret.add_done_callback(task_reference_holder.discard)
else:
ret = await run_sync_handler()
return ret
async def _service_task(self): # pragma: no cover
"""Monitor connected clients and clean up those that time out."""
loop = asyncio.get_running_loop()
self.service_task_event = self.create_event()
while not self.service_task_event.is_set():
if len(self.sockets) == 0:
# nothing to do
try:
await asyncio.wait_for(self.service_task_event.wait(),
timeout=self.ping_timeout)
break
except asyncio.TimeoutError:
continue
# go through the entire client list in a ping interval cycle
sleep_interval = self.ping_timeout / len(self.sockets)
try:
# iterate over the current clients
for s in self.sockets.copy().values():
if s.closed:
try:
del self.sockets[s.sid]
except KeyError:
# the socket could have also been removed by
# the _get_socket() method from another thread
pass
elif not s.closing:
await s.check_ping_timeout()
try:
await asyncio.wait_for(self.service_task_event.wait(),
timeout=sleep_interval)
raise KeyboardInterrupt()
except asyncio.TimeoutError:
continue
except (
SystemExit,
KeyboardInterrupt,
asyncio.CancelledError,
GeneratorExit,
):
self.logger.info('service task canceled')
break
except:
if loop.is_closed():
self.logger.info('event loop is closed, exiting service '
'task')
break
# an unexpected exception has occurred, log it and continue
self.logger.exception('service task exception')

View File

@ -0,0 +1,261 @@
import asyncio
import sys
import time
from . import base_socket
from . import exceptions
from . import packet
from . import payload
class AsyncSocket(base_socket.BaseSocket):
async def poll(self):
"""Wait for packets to send to the client."""
try:
packets = [await asyncio.wait_for(
self.queue.get(),
self.server.ping_interval + self.server.ping_timeout)]
self.queue.task_done()
except (asyncio.TimeoutError, asyncio.CancelledError):
raise exceptions.QueueEmpty()
if packets == [None]:
return []
while True:
try:
pkt = self.queue.get_nowait()
self.queue.task_done()
if pkt is None:
self.queue.put_nowait(None)
break
packets.append(pkt)
except asyncio.QueueEmpty:
break
return packets
async def receive(self, pkt):
"""Receive packet from the client."""
self.server.logger.info('%s: Received packet %s data %s',
self.sid, packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes)
else '<binary>')
if pkt.packet_type == packet.PONG:
self.schedule_ping()
elif pkt.packet_type == packet.MESSAGE:
await self.server._trigger_event(
'message', self.sid, pkt.data,
run_async=self.server.async_handlers)
elif pkt.packet_type == packet.UPGRADE:
await self.send(packet.Packet(packet.NOOP))
elif pkt.packet_type == packet.CLOSE:
await self.close(wait=False, abort=True,
reason=self.server.reason.CLIENT_DISCONNECT)
else:
raise exceptions.UnknownPacketError()
async def check_ping_timeout(self):
"""Make sure the client is still sending pings."""
if self.closed:
raise exceptions.SocketIsClosedError()
if self.last_ping and \
time.time() - self.last_ping > self.server.ping_timeout:
self.server.logger.info('%s: Client is gone, closing socket',
self.sid)
# Passing abort=False here will cause close() to write a
# CLOSE packet. This has the effect of updating half-open sockets
# to their correct state of disconnected
await self.close(wait=False, abort=False,
reason=self.server.reason.PING_TIMEOUT)
return False
return True
async def send(self, pkt):
"""Send a packet to the client."""
if not await self.check_ping_timeout():
return
else:
await self.queue.put(pkt)
self.server.logger.info('%s: Sending packet %s data %s',
self.sid, packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes)
else '<binary>')
async def handle_get_request(self, environ):
"""Handle a long-polling GET request from the client."""
connections = [
s.strip()
for s in environ.get('HTTP_CONNECTION', '').lower().split(',')]
transport = environ.get('HTTP_UPGRADE', '').lower()
if 'upgrade' in connections and transport in self.upgrade_protocols:
self.server.logger.info('%s: Received request to upgrade to %s',
self.sid, transport)
return await getattr(self, '_upgrade_' + transport)(environ)
if self.upgrading or self.upgraded:
# we are upgrading to WebSocket, do not return any more packets
# through the polling endpoint
return [packet.Packet(packet.NOOP)]
try:
packets = await self.poll()
except exceptions.QueueEmpty:
exc = sys.exc_info()
await self.close(wait=False,
reason=self.server.reason.TRANSPORT_ERROR)
raise exc[1].with_traceback(exc[2])
return packets
async def handle_post_request(self, environ):
"""Handle a long-polling POST request from the client."""
length = int(environ.get('CONTENT_LENGTH', '0'))
if length > self.server.max_http_buffer_size:
raise exceptions.ContentTooLongError()
else:
body = (await environ['wsgi.input'].read(length)).decode('utf-8')
p = payload.Payload(encoded_payload=body)
for pkt in p.packets:
await self.receive(pkt)
async def close(self, wait=True, abort=False, reason=None):
"""Close the socket connection."""
if not self.closed and not self.closing:
self.closing = True
await self.server._trigger_event(
'disconnect', self.sid,
reason or self.server.reason.SERVER_DISCONNECT,
run_async=False)
if not abort:
await self.send(packet.Packet(packet.CLOSE))
self.closed = True
if wait:
await self.queue.join()
def schedule_ping(self):
self.server.start_background_task(self._send_ping)
async def _send_ping(self):
self.last_ping = None
await asyncio.sleep(self.server.ping_interval)
if not self.closing and not self.closed:
self.last_ping = time.time()
await self.send(packet.Packet(packet.PING))
async def _upgrade_websocket(self, environ):
"""Upgrade the connection from polling to websocket."""
if self.upgraded:
raise OSError('Socket has been upgraded already')
if self.server._async['websocket'] is None:
# the selected async mode does not support websocket
return self.server._bad_request()
ws = self.server._async['websocket'](
self._websocket_handler, self.server)
return await ws(environ)
async def _websocket_handler(self, ws):
"""Engine.IO handler for websocket transport."""
async def websocket_wait():
data = await ws.wait()
if data and len(data) > self.server.max_http_buffer_size:
raise ValueError('packet is too large')
return data
if self.connected:
# the socket was already connected, so this is an upgrade
self.upgrading = True # hold packet sends during the upgrade
try:
pkt = await websocket_wait()
except OSError: # pragma: no cover
return
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.PING or \
decoded_pkt.data != 'probe':
self.server.logger.info(
'%s: Failed websocket upgrade, no PING packet', self.sid)
self.upgrading = False
return
await ws.send(packet.Packet(packet.PONG, data='probe').encode())
await self.queue.put(packet.Packet(packet.NOOP)) # end poll
try:
pkt = await websocket_wait()
except OSError: # pragma: no cover
self.upgrading = False
return
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.UPGRADE:
self.upgraded = False
self.server.logger.info(
('%s: Failed websocket upgrade, expected UPGRADE packet, '
'received %s instead.'),
self.sid, pkt)
self.upgrading = False
return
self.upgraded = True
self.upgrading = False
else:
self.connected = True
self.upgraded = True
# start separate writer thread
async def writer():
while True:
packets = None
try:
packets = await self.poll()
except exceptions.QueueEmpty:
break
if not packets:
# empty packet list returned -> connection closed
break
try:
for pkt in packets:
await ws.send(pkt.encode())
except:
break
await ws.close()
writer_task = asyncio.ensure_future(writer())
self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)
while True:
p = None
wait_task = asyncio.ensure_future(websocket_wait())
try:
p = await asyncio.wait_for(
wait_task,
self.server.ping_interval + self.server.ping_timeout)
except asyncio.CancelledError: # pragma: no cover
# there is a bug (https://bugs.python.org/issue30508) in
# asyncio that causes a "Task exception never retrieved" error
# to appear when wait_task raises an exception before it gets
# cancelled. Calling wait_task.exception() prevents the error
# from being issued in Python 3.6, but causes other errors in
# other versions, so we run it with all errors suppressed and
# hope for the best.
try:
wait_task.exception()
except:
pass
break
except:
break
if p is None:
# connection closed by client
break
pkt = packet.Packet(encoded_packet=p)
try:
await self.receive(pkt)
except exceptions.UnknownPacketError: # pragma: no cover
pass
except exceptions.SocketIsClosedError: # pragma: no cover
self.server.logger.info('Receive error -- socket is closed')
break
except: # pragma: no cover
# if we get an unexpected exception we log the error and exit
# the connection properly
self.server.logger.exception('Unknown receive error')
await self.queue.put(None) # unlock the writer task so it can exit
await asyncio.wait_for(writer_task, timeout=None)
await self.close(wait=False, abort=True,
reason=self.server.reason.TRANSPORT_CLOSE)

View File

@ -0,0 +1,158 @@
import logging
import signal
import threading
import time
import urllib
from . import packet
default_logger = logging.getLogger('engineio.client')
connected_clients = []
def signal_handler(sig, frame):
"""SIGINT handler.
Disconnect all active clients and then invoke the original signal handler.
"""
for client in connected_clients[:]:
if not client.is_asyncio_based():
client.disconnect()
if callable(original_signal_handler):
return original_signal_handler(sig, frame)
else: # pragma: no cover
# Handle case where no original SIGINT handler was present.
return signal.default_int_handler(sig, frame)
original_signal_handler = None
class BaseClient:
event_names = ['connect', 'disconnect', 'message']
class reason:
"""Disconnection reasons."""
#: Client-initiated disconnection.
CLIENT_DISCONNECT = 'client disconnect'
#: Server-initiated disconnection.
SERVER_DISCONNECT = 'server disconnect'
#: Transport error.
TRANSPORT_ERROR = 'transport error'
def __init__(self, logger=False, json=None, request_timeout=5,
http_session=None, ssl_verify=True, handle_sigint=True,
websocket_extra_options=None, timestamp_requests=True):
global original_signal_handler
if handle_sigint and original_signal_handler is None and \
threading.current_thread() == threading.main_thread():
original_signal_handler = signal.signal(signal.SIGINT,
signal_handler)
self.handlers = {}
self.base_url = None
self.transports = None
self.current_transport = None
self.sid = None
self.upgrades = None
self.ping_interval = None
self.ping_timeout = None
self.http = http_session
self.external_http = http_session is not None
self.handle_sigint = handle_sigint
self.ws = None
self.read_loop_task = None
self.write_loop_task = None
self.queue = None
self.state = 'disconnected'
self.ssl_verify = ssl_verify
self.websocket_extra_options = websocket_extra_options or {}
self.timestamp_requests = timestamp_requests
if json is not None:
packet.Packet.json = json
if not isinstance(logger, bool):
self.logger = logger
else:
self.logger = default_logger
if self.logger.level == logging.NOTSET:
if logger:
self.logger.setLevel(logging.INFO)
else:
self.logger.setLevel(logging.ERROR)
self.logger.addHandler(logging.StreamHandler())
self.request_timeout = request_timeout
def is_asyncio_based(self):
return False
def on(self, event, handler=None):
"""Register an event handler.
:param event: The event name. Can be ``'connect'``, ``'message'`` or
``'disconnect'``.
:param handler: The function that should be invoked to handle the
event. When this parameter is not given, the method
acts as a decorator for the handler function.
Example usage::
# as a decorator:
@eio.on('connect')
def connect_handler():
print('Connection request')
# as a method:
def message_handler(msg):
print('Received message: ', msg)
eio.send('response')
eio.on('message', message_handler)
"""
if event not in self.event_names:
raise ValueError('Invalid event')
def set_handler(handler):
self.handlers[event] = handler
return handler
if handler is None:
return set_handler
set_handler(handler)
def transport(self):
"""Return the name of the transport currently in use.
The possible values returned by this function are ``'polling'`` and
``'websocket'``.
"""
return self.current_transport
def _reset(self):
self.state = 'disconnected'
self.sid = None
def _get_engineio_url(self, url, engineio_path, transport):
"""Generate the Engine.IO connection URL."""
engineio_path = engineio_path.strip('/')
parsed_url = urllib.parse.urlparse(url)
if transport == 'polling':
scheme = 'http'
elif transport == 'websocket':
scheme = 'ws'
else: # pragma: no cover
raise ValueError('invalid transport')
if parsed_url.scheme in ['https', 'wss']:
scheme += 's'
return ('{scheme}://{netloc}/{path}/?{query}'
'{sep}transport={transport}&EIO=4').format(
scheme=scheme, netloc=parsed_url.netloc,
path=engineio_path, query=parsed_url.query,
sep='&' if parsed_url.query else '',
transport=transport)
def _get_url_timestamp(self):
"""Generate the Engine.IO query string timestamp."""
if not self.timestamp_requests:
return ''
return '&t=' + str(time.time())

View File

@ -0,0 +1,351 @@
import base64
import gzip
import importlib
import io
import logging
import secrets
import zlib
from . import packet
from . import payload
default_logger = logging.getLogger('engineio.server')
class BaseServer:
compression_methods = ['gzip', 'deflate']
event_names = ['connect', 'disconnect', 'message']
valid_transports = ['polling', 'websocket']
_default_monitor_clients = True
sequence_number = 0
class reason:
"""Disconnection reasons."""
#: Server-initiated disconnection.
SERVER_DISCONNECT = 'server disconnect'
#: Client-initiated disconnection.
CLIENT_DISCONNECT = 'client disconnect'
#: Ping timeout.
PING_TIMEOUT = 'ping timeout'
#: Transport close.
TRANSPORT_CLOSE = 'transport close'
#: Transport error.
TRANSPORT_ERROR = 'transport error'
def __init__(self, async_mode=None, ping_interval=25, ping_timeout=20,
max_http_buffer_size=1000000, allow_upgrades=True,
http_compression=True, compression_threshold=1024,
cookie=None, cors_allowed_origins=None,
cors_credentials=True, logger=False, json=None,
async_handlers=True, monitor_clients=None, transports=None,
**kwargs):
self.ping_timeout = ping_timeout
if isinstance(ping_interval, tuple):
self.ping_interval = ping_interval[0]
self.ping_interval_grace_period = ping_interval[1]
else:
self.ping_interval = ping_interval
self.ping_interval_grace_period = 0
self.max_http_buffer_size = max_http_buffer_size
self.allow_upgrades = allow_upgrades
self.http_compression = http_compression
self.compression_threshold = compression_threshold
self.cookie = cookie
self.cors_allowed_origins = cors_allowed_origins
self.cors_credentials = cors_credentials
self.async_handlers = async_handlers
self.sockets = {}
self.handlers = {}
self.log_message_keys = set()
self.start_service_task = monitor_clients \
if monitor_clients is not None else self._default_monitor_clients
self.service_task_handle = None
self.service_task_event = None
if json is not None:
packet.Packet.json = json
if not isinstance(logger, bool):
self.logger = logger
else:
self.logger = default_logger
if self.logger.level == logging.NOTSET:
if logger:
self.logger.setLevel(logging.INFO)
else:
self.logger.setLevel(logging.ERROR)
self.logger.addHandler(logging.StreamHandler())
modes = self.async_modes()
if async_mode is not None:
modes = [async_mode] if async_mode in modes else []
self._async = None
self.async_mode = None
for mode in modes:
try:
self._async = importlib.import_module(
'engineio.async_drivers.' + mode)._async
asyncio_based = self._async['asyncio'] \
if 'asyncio' in self._async else False
if asyncio_based != self.is_asyncio_based():
continue # pragma: no cover
self.async_mode = mode
break
except ImportError:
pass
if self.async_mode is None:
raise ValueError('Invalid async_mode specified')
if self.is_asyncio_based() and \
('asyncio' not in self._async or not
self._async['asyncio']): # pragma: no cover
raise ValueError('The selected async_mode is not asyncio '
'compatible')
if not self.is_asyncio_based() and 'asyncio' in self._async and \
self._async['asyncio']: # pragma: no cover
raise ValueError('The selected async_mode requires asyncio and '
'must use the AsyncServer class')
if transports is not None:
if isinstance(transports, str):
transports = [transports]
transports = [transport for transport in transports
if transport in self.valid_transports]
if not transports:
raise ValueError('No valid transports provided')
self.transports = transports or self.valid_transports
self.logger.info('Server initialized for %s.', self.async_mode)
def is_asyncio_based(self):
return False
def async_modes(self):
return ['eventlet', 'gevent_uwsgi', 'gevent', 'threading']
def on(self, event, handler=None):
"""Register an event handler.
:param event: The event name. Can be ``'connect'``, ``'message'`` or
``'disconnect'``.
:param handler: The function that should be invoked to handle the
event. When this parameter is not given, the method
acts as a decorator for the handler function.
Example usage::
# as a decorator:
@eio.on('connect')
def connect_handler(sid, environ):
print('Connection request')
if environ['REMOTE_ADDR'] in blacklisted:
return False # reject
# as a method:
def message_handler(sid, msg):
print('Received message: ', msg)
eio.send(sid, 'response')
eio.on('message', message_handler)
The handler function receives the ``sid`` (session ID) for the
client as first argument. The ``'connect'`` event handler receives the
WSGI environment as a second argument, and can return ``False`` to
reject the connection. The ``'message'`` handler receives the message
payload as a second argument. The ``'disconnect'`` handler does not
take a second argument.
"""
if event not in self.event_names:
raise ValueError('Invalid event')
def set_handler(handler):
self.handlers[event] = handler
return handler
if handler is None:
return set_handler
set_handler(handler)
def transport(self, sid):
"""Return the name of the transport used by the client.
The two possible values returned by this function are ``'polling'``
and ``'websocket'``.
:param sid: The session of the client.
"""
return 'websocket' if self._get_socket(sid).upgraded else 'polling'
def create_queue(self, *args, **kwargs):
"""Create a queue object using the appropriate async model.
This is a utility function that applications can use to create a queue
without having to worry about using the correct call for the selected
async mode.
"""
return self._async['queue'](*args, **kwargs)
def get_queue_empty_exception(self):
"""Return the queue empty exception for the appropriate async model.
This is a utility function that applications can use to work with a
queue without having to worry about using the correct call for the
selected async mode.
"""
return self._async['queue_empty']
def create_event(self, *args, **kwargs):
"""Create an event object using the appropriate async model.
This is a utility function that applications can use to create an
event without having to worry about using the correct call for the
selected async mode.
"""
return self._async['event'](*args, **kwargs)
def generate_id(self):
"""Generate a unique session id."""
id = base64.b64encode(
secrets.token_bytes(12) + self.sequence_number.to_bytes(3, 'big'))
self.sequence_number = (self.sequence_number + 1) & 0xffffff
return id.decode('utf-8').replace('/', '_').replace('+', '-')
def _generate_sid_cookie(self, sid, attributes):
"""Generate the sid cookie."""
cookie = attributes.get('name', 'io') + '=' + sid
for attribute, value in attributes.items():
if attribute == 'name':
continue
if callable(value):
value = value()
if value is True:
cookie += '; ' + attribute
else:
cookie += '; ' + attribute + '=' + value
return cookie
def _upgrades(self, sid, transport):
"""Return the list of possible upgrades for a client connection."""
if not self.allow_upgrades or self._get_socket(sid).upgraded or \
transport == 'websocket':
return []
if self._async['websocket'] is None: # pragma: no cover
self._log_error_once(
'The WebSocket transport is not available, you must install a '
'WebSocket server that is compatible with your async mode to '
'enable it. See the documentation for details.',
'no-websocket')
return []
return ['websocket']
def _get_socket(self, sid):
"""Return the socket object for a given session."""
try:
s = self.sockets[sid]
except KeyError:
raise KeyError('Session not found')
if s.closed:
del self.sockets[sid]
raise KeyError('Session is disconnected')
return s
def _ok(self, packets=None, headers=None, jsonp_index=None):
"""Generate a successful HTTP response."""
if packets is not None:
if headers is None:
headers = []
headers += [('Content-Type', 'text/plain; charset=UTF-8')]
return {'status': '200 OK',
'headers': headers,
'response': payload.Payload(packets=packets).encode(
jsonp_index=jsonp_index).encode('utf-8')}
else:
return {'status': '200 OK',
'headers': [('Content-Type', 'text/plain')],
'response': b'OK'}
def _bad_request(self, message=None):
"""Generate a bad request HTTP error response."""
if message is None:
message = 'Bad Request'
message = packet.Packet.json.dumps(message)
return {'status': '400 BAD REQUEST',
'headers': [('Content-Type', 'text/plain')],
'response': message.encode('utf-8')}
def _method_not_found(self):
"""Generate a method not found HTTP error response."""
return {'status': '405 METHOD NOT FOUND',
'headers': [('Content-Type', 'text/plain')],
'response': b'Method Not Found'}
def _unauthorized(self, message=None):
"""Generate a unauthorized HTTP error response."""
if message is None:
message = 'Unauthorized'
message = packet.Packet.json.dumps(message)
return {'status': '401 UNAUTHORIZED',
'headers': [('Content-Type', 'application/json')],
'response': message.encode('utf-8')}
def _cors_allowed_origins(self, environ):
default_origins = []
if 'wsgi.url_scheme' in environ and 'HTTP_HOST' in environ:
default_origins.append('{scheme}://{host}'.format(
scheme=environ['wsgi.url_scheme'], host=environ['HTTP_HOST']))
if 'HTTP_X_FORWARDED_PROTO' in environ or \
'HTTP_X_FORWARDED_HOST' in environ:
scheme = environ.get(
'HTTP_X_FORWARDED_PROTO',
environ['wsgi.url_scheme']).split(',')[0].strip()
default_origins.append('{scheme}://{host}'.format(
scheme=scheme, host=environ.get(
'HTTP_X_FORWARDED_HOST', environ['HTTP_HOST']).split(
',')[0].strip()))
if self.cors_allowed_origins is None:
allowed_origins = default_origins
elif self.cors_allowed_origins == '*':
allowed_origins = None
elif isinstance(self.cors_allowed_origins, str):
allowed_origins = [self.cors_allowed_origins]
elif callable(self.cors_allowed_origins):
origin = environ.get('HTTP_ORIGIN')
allowed_origins = [origin] \
if self.cors_allowed_origins(origin) else []
else:
allowed_origins = self.cors_allowed_origins
return allowed_origins
def _cors_headers(self, environ):
"""Return the cross-origin-resource-sharing headers."""
if self.cors_allowed_origins == []:
# special case, CORS handling is completely disabled
return []
headers = []
allowed_origins = self._cors_allowed_origins(environ)
if 'HTTP_ORIGIN' in environ and \
(allowed_origins is None or environ['HTTP_ORIGIN'] in
allowed_origins):
headers = [('Access-Control-Allow-Origin', environ['HTTP_ORIGIN'])]
if environ['REQUEST_METHOD'] == 'OPTIONS':
headers += [('Access-Control-Allow-Methods', 'OPTIONS, GET, POST')]
if 'HTTP_ACCESS_CONTROL_REQUEST_HEADERS' in environ:
headers += [('Access-Control-Allow-Headers',
environ['HTTP_ACCESS_CONTROL_REQUEST_HEADERS'])]
if self.cors_credentials:
headers += [('Access-Control-Allow-Credentials', 'true')]
return headers
def _gzip(self, response):
"""Apply gzip compression to a response."""
bytesio = io.BytesIO()
with gzip.GzipFile(fileobj=bytesio, mode='w') as gz:
gz.write(response)
return bytesio.getvalue()
def _deflate(self, response):
"""Apply deflate compression to a response."""
return zlib.compress(response)
def _log_error_once(self, message, message_key):
"""Log message with logging.ERROR level the first time, then log
with given level."""
if message_key not in self.log_message_keys:
self.logger.error(message + ' (further occurrences of this error '
'will be logged with level INFO)')
self.log_message_keys.add(message_key)
else:
self.logger.info(message)

View File

@ -0,0 +1,14 @@
class BaseSocket:
upgrade_protocols = ['websocket']
def __init__(self, server, sid):
self.server = server
self.sid = sid
self.queue = self.server.create_queue()
self.last_ping = None
self.connected = False
self.upgrading = False
self.upgraded = False
self.closing = False
self.closed = False
self.session = {}

View File

@ -0,0 +1,620 @@
from base64 import b64encode
from engineio.json import JSONDecodeError
import logging
import queue
import ssl
import threading
import time
import urllib
try:
import requests
except ImportError: # pragma: no cover
requests = None
try:
import websocket
except ImportError: # pragma: no cover
websocket = None
from . import base_client
from . import exceptions
from . import packet
from . import payload
default_logger = logging.getLogger('engineio.client')
class Client(base_client.BaseClient):
"""An Engine.IO client.
This class implements a fully compliant Engine.IO web client with support
for websocket and long-polling transports.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``. Note that fatal errors are logged even when
``logger`` is ``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param request_timeout: A timeout in seconds for requests. The default is
5 seconds.
:param http_session: an initialized ``requests.Session`` object to be used
when sending requests to the server. Use it if you
need to add special client options such as proxy
servers, SSL certificates, custom CA bundle, etc.
:param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to
skip SSL certificate verification, allowing
connections to servers with self signed certificates.
The default is ``True``.
:param handle_sigint: Set to ``True`` to automatically handle disconnection
when the process is interrupted, or to ``False`` to
leave interrupt handling to the calling application.
Interrupt handling can only be enabled when the
client instance is created in the main thread.
:param websocket_extra_options: Dictionary containing additional keyword
arguments passed to
``websocket.create_connection()``.
:param timestamp_requests: If ``True`` a timestamp is added to the query
string of Socket.IO requests as a cache-busting
measure. Set to ``False`` to disable.
"""
def connect(self, url, headers=None, transports=None,
engineio_path='engine.io'):
"""Connect to an Engine.IO server.
:param url: The URL of the Engine.IO server. It can include custom
query string parameters if required by the server.
:param headers: A dictionary with custom headers to send with the
connection request.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
then an upgrade to websocket is attempted.
:param engineio_path: The endpoint where the Engine.IO server is
installed. The default value is appropriate for
most cases.
Example usage::
eio = engineio.Client()
eio.connect('http://localhost:5000')
"""
if self.state != 'disconnected':
raise ValueError('Client is not in a disconnected state')
valid_transports = ['polling', 'websocket']
if transports is not None:
if isinstance(transports, str):
transports = [transports]
transports = [transport for transport in transports
if transport in valid_transports]
if not transports:
raise ValueError('No valid transports provided')
self.transports = transports or valid_transports
self.queue = self.create_queue()
return getattr(self, '_connect_' + self.transports[0])(
url, headers or {}, engineio_path)
def wait(self):
"""Wait until the connection with the server ends.
Client applications can use this function to block the main thread
during the life of the connection.
"""
if self.read_loop_task:
self.read_loop_task.join()
def send(self, data):
"""Send a message to the server.
:param data: The data to send to the server. Data can be of type
``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
or ``dict``, the data will be serialized as JSON.
"""
self._send_packet(packet.Packet(packet.MESSAGE, data=data))
def disconnect(self, abort=False, reason=None):
"""Disconnect from the server.
:param abort: If set to ``True``, do not wait for background tasks
associated with the connection to end.
"""
if self.state == 'connected':
self._send_packet(packet.Packet(packet.CLOSE))
self.queue.put(None)
self.state = 'disconnecting'
self._trigger_event('disconnect',
reason or self.reason.CLIENT_DISCONNECT,
run_async=False)
if self.current_transport == 'websocket':
self.ws.close()
if not abort:
self.read_loop_task.join()
self.state = 'disconnected'
try:
base_client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
def start_background_task(self, target, *args, **kwargs):
"""Start a background task.
This is a utility function that applications can use to start a
background task.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object that represents the background task,
on which the ``join()`` method can be invoked to wait for the task to
complete.
"""
th = threading.Thread(target=target, args=args, kwargs=kwargs,
daemon=True)
th.start()
return th
def sleep(self, seconds=0):
"""Sleep for the requested amount of time."""
return time.sleep(seconds)
def create_queue(self, *args, **kwargs):
"""Create a queue object."""
q = queue.Queue(*args, **kwargs)
q.Empty = queue.Empty
return q
def create_event(self, *args, **kwargs):
"""Create an event object."""
return threading.Event(*args, **kwargs)
def _connect_polling(self, url, headers, engineio_path):
"""Establish a long-polling connection to the Engine.IO server."""
if requests is None: # pragma: no cover
# not installed
self.logger.error('requests package is not installed -- cannot '
'send HTTP requests!')
return
self.base_url = self._get_engineio_url(url, engineio_path, 'polling')
self.logger.info('Attempting polling connection to ' + self.base_url)
r = self._send_request(
'GET', self.base_url + self._get_url_timestamp(), headers=headers,
timeout=self.request_timeout)
if r is None or isinstance(r, str):
self._reset()
raise exceptions.ConnectionError(
r or 'Connection refused by the server')
if r.status_code < 200 or r.status_code >= 300:
self._reset()
try:
arg = r.json()
except JSONDecodeError:
arg = None
raise exceptions.ConnectionError(
'Unexpected status code {} in server response'.format(
r.status_code), arg)
try:
p = payload.Payload(encoded_payload=r.content.decode('utf-8'))
except ValueError:
raise exceptions.ConnectionError(
'Unexpected response from server') from None
open_packet = p.packets[0]
if open_packet.packet_type != packet.OPEN:
raise exceptions.ConnectionError(
'OPEN packet not returned by server')
self.logger.info(
'Polling connection accepted with ' + str(open_packet.data))
self.sid = open_packet.data['sid']
self.upgrades = open_packet.data['upgrades']
self.ping_interval = int(open_packet.data['pingInterval']) / 1000.0
self.ping_timeout = int(open_packet.data['pingTimeout']) / 1000.0
self.current_transport = 'polling'
self.base_url += '&sid=' + self.sid
self.state = 'connected'
base_client.connected_clients.append(self)
self._trigger_event('connect', run_async=False)
for pkt in p.packets[1:]:
self._receive_packet(pkt)
if 'websocket' in self.upgrades and 'websocket' in self.transports:
# attempt to upgrade to websocket
if self._connect_websocket(url, headers, engineio_path):
# upgrade to websocket succeeded, we're done here
return
# start background tasks associated with this client
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_polling)
def _connect_websocket(self, url, headers, engineio_path):
"""Establish or upgrade to a WebSocket connection with the server."""
if websocket is None: # pragma: no cover
# not installed
self.logger.error('websocket-client package not installed, only '
'polling transport is available')
return False
websocket_url = self._get_engineio_url(url, engineio_path, 'websocket')
if self.sid:
self.logger.info(
'Attempting WebSocket upgrade to ' + websocket_url)
upgrade = True
websocket_url += '&sid=' + self.sid
else:
upgrade = False
self.base_url = websocket_url
self.logger.info(
'Attempting WebSocket connection to ' + websocket_url)
# get cookies and other settings from the long-polling connection
# so that they are preserved when connecting to the WebSocket route
cookies = None
extra_options = {}
if self.http:
# cookies
cookies = '; '.join([f"{cookie.name}={cookie.value}"
for cookie in self.http.cookies])
for header, value in headers.items():
if header.lower() == 'cookie':
if cookies:
cookies += '; '
cookies += value
del headers[header]
break
# auth
if 'Authorization' not in headers and self.http.auth is not None:
if not isinstance(self.http.auth, tuple): # pragma: no cover
raise ValueError('Only basic authentication is supported')
basic_auth = '{}:{}'.format(
self.http.auth[0], self.http.auth[1]).encode('utf-8')
basic_auth = b64encode(basic_auth).decode('utf-8')
headers['Authorization'] = 'Basic ' + basic_auth
# cert
# this can be given as ('certfile', 'keyfile') or just 'certfile'
if isinstance(self.http.cert, tuple):
extra_options['sslopt'] = {
'certfile': self.http.cert[0],
'keyfile': self.http.cert[1]}
elif self.http.cert:
extra_options['sslopt'] = {'certfile': self.http.cert}
# proxies
if self.http.proxies:
proxy_url = None
if websocket_url.startswith('ws://'):
proxy_url = self.http.proxies.get(
'ws', self.http.proxies.get('http'))
else: # wss://
proxy_url = self.http.proxies.get(
'wss', self.http.proxies.get('https'))
if proxy_url:
parsed_url = urllib.parse.urlparse(
proxy_url if '://' in proxy_url
else 'scheme://' + proxy_url)
extra_options['http_proxy_host'] = parsed_url.hostname
extra_options['http_proxy_port'] = parsed_url.port
extra_options['http_proxy_auth'] = (
(parsed_url.username, parsed_url.password)
if parsed_url.username or parsed_url.password
else None)
# verify
if isinstance(self.http.verify, str):
if 'sslopt' in extra_options:
extra_options['sslopt']['ca_certs'] = self.http.verify
else:
extra_options['sslopt'] = {'ca_certs': self.http.verify}
elif not self.http.verify:
self.ssl_verify = False
if not self.ssl_verify:
if 'sslopt' in extra_options:
extra_options['sslopt'].update({"cert_reqs": ssl.CERT_NONE})
else:
extra_options['sslopt'] = {"cert_reqs": ssl.CERT_NONE}
# combine internally generated options with the ones supplied by the
# caller. The caller's options take precedence.
headers.update(self.websocket_extra_options.pop('header', {}))
extra_options['header'] = headers
extra_options['cookie'] = cookies
extra_options['enable_multithread'] = True
extra_options['timeout'] = self.request_timeout
extra_options.update(self.websocket_extra_options)
try:
ws = websocket.create_connection(
websocket_url + self._get_url_timestamp(), **extra_options)
except (ConnectionError, OSError, websocket.WebSocketException):
if upgrade:
self.logger.warning(
'WebSocket upgrade failed: connection error')
return False
else:
raise exceptions.ConnectionError('Connection error')
if upgrade:
p = packet.Packet(packet.PING, data='probe').encode()
try:
ws.send(p)
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected send exception: %s',
str(e))
return False
try:
p = ws.recv()
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected recv exception: %s',
str(e))
return False
pkt = packet.Packet(encoded_packet=p)
if pkt.packet_type != packet.PONG or pkt.data != 'probe':
self.logger.warning(
'WebSocket upgrade failed: no PONG packet')
return False
p = packet.Packet(packet.UPGRADE).encode()
try:
ws.send(p)
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected send exception: %s',
str(e))
return False
self.current_transport = 'websocket'
self.logger.info('WebSocket upgrade was successful')
else:
try:
p = ws.recv()
except Exception as e: # pragma: no cover
raise exceptions.ConnectionError(
'Unexpected recv exception: ' + str(e))
open_packet = packet.Packet(encoded_packet=p)
if open_packet.packet_type != packet.OPEN:
raise exceptions.ConnectionError('no OPEN packet')
self.logger.info(
'WebSocket connection accepted with ' + str(open_packet.data))
self.sid = open_packet.data['sid']
self.upgrades = open_packet.data['upgrades']
self.ping_interval = int(open_packet.data['pingInterval']) / 1000.0
self.ping_timeout = int(open_packet.data['pingTimeout']) / 1000.0
self.current_transport = 'websocket'
self.state = 'connected'
base_client.connected_clients.append(self)
self._trigger_event('connect', run_async=False)
self.ws = ws
self.ws.settimeout(self.ping_interval + self.ping_timeout)
# start background tasks associated with this client
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_websocket)
return True
def _receive_packet(self, pkt):
"""Handle incoming packets from the server."""
packet_name = packet.packet_names[pkt.packet_type] \
if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
self.logger.info(
'Received packet %s data %s', packet_name,
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
if pkt.packet_type == packet.MESSAGE:
self._trigger_event('message', pkt.data, run_async=True)
elif pkt.packet_type == packet.PING:
self._send_packet(packet.Packet(packet.PONG, pkt.data))
elif pkt.packet_type == packet.CLOSE:
self.disconnect(abort=True, reason=self.reason.SERVER_DISCONNECT)
elif pkt.packet_type == packet.NOOP:
pass
else:
self.logger.error('Received unexpected packet of type %s',
pkt.packet_type)
def _send_packet(self, pkt):
"""Queue a packet to be sent to the server."""
if self.state != 'connected':
return
self.queue.put(pkt)
self.logger.info(
'Sending packet %s data %s',
packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
def _send_request(
self, method, url, headers=None, body=None,
timeout=None): # pragma: no cover
if self.http is None:
self.http = requests.Session()
if not self.ssl_verify:
self.http.verify = False
try:
return self.http.request(method, url, headers=headers, data=body,
timeout=timeout)
except requests.exceptions.RequestException as exc:
self.logger.info('HTTP %s request to %s failed with error %s.',
method, url, exc)
return str(exc)
def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
if event in self.handlers:
if run_async:
return self.start_background_task(self.handlers[event], *args)
else:
try:
try:
return self.handlers[event](*args)
except TypeError:
if event == 'disconnect' and \
len(args) == 1: # pragma: no branch
# legacy disconnect events do not have a reason
# argument
return self.handlers[event]()
else: # pragma: no cover
raise
except:
self.logger.exception(event + ' handler error')
def _read_loop_polling(self):
"""Read packets by polling the Engine.IO server."""
while self.state == 'connected' and self.write_loop_task:
self.logger.info(
'Sending polling GET request to ' + self.base_url)
r = self._send_request(
'GET', self.base_url + self._get_url_timestamp(),
timeout=max(self.ping_interval, self.ping_timeout) + 5)
if r is None or isinstance(r, str):
self.logger.warning(
r or 'Connection refused by the server, aborting')
self.queue.put(None)
break
if r.status_code < 200 or r.status_code >= 300:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status_code)
self.queue.put(None)
break
try:
p = payload.Payload(encoded_payload=r.content.decode('utf-8'))
except ValueError:
self.logger.warning(
'Unexpected packet from server, aborting')
self.queue.put(None)
break
for pkt in p.packets:
self._receive_packet(pkt)
if self.write_loop_task: # pragma: no branch
self.logger.info('Waiting for write loop task to end')
self.write_loop_task.join()
if self.state == 'connected':
self._trigger_event('disconnect', self.reason.TRANSPORT_ERROR,
run_async=False)
try:
base_client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')
def _read_loop_websocket(self):
"""Read packets from the Engine.IO WebSocket connection."""
while self.state == 'connected':
p = None
try:
p = self.ws.recv()
if len(p) == 0 and not self.ws.connected: # pragma: no cover
# websocket client can return an empty string after close
raise websocket.WebSocketConnectionClosedException()
except websocket.WebSocketTimeoutException:
self.logger.warning(
'Server has stopped communicating, aborting')
self.queue.put(None)
break
except websocket.WebSocketConnectionClosedException:
self.logger.warning(
'WebSocket connection was closed, aborting')
self.queue.put(None)
break
except Exception as e: # pragma: no cover
if type(e) is OSError and e.errno == 9:
self.logger.info(
'WebSocket connection is closing, aborting')
else:
self.logger.info(
'Unexpected error receiving packet: "%s", aborting',
str(e))
self.queue.put(None)
break
try:
pkt = packet.Packet(encoded_packet=p)
except Exception as e: # pragma: no cover
self.logger.info(
'Unexpected error decoding packet: "%s", aborting', str(e))
self.queue.put(None)
break
self._receive_packet(pkt)
if self.write_loop_task: # pragma: no branch
self.logger.info('Waiting for write loop task to end')
self.write_loop_task.join()
if self.state == 'connected':
self._trigger_event('disconnect', self.reason.TRANSPORT_ERROR,
run_async=False)
try:
base_client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')
def _write_loop(self):
"""This background task sends packages to the server as they are
pushed to the send queue.
"""
while self.state == 'connected':
# to simplify the timeout handling, use the maximum of the
# ping interval and ping timeout as timeout, with an extra 5
# seconds grace period
timeout = max(self.ping_interval, self.ping_timeout) + 5
packets = None
try:
packets = [self.queue.get(timeout=timeout)]
except self.queue.Empty:
self.logger.error('packet queue is empty, aborting')
break
if packets == [None]:
self.queue.task_done()
packets = []
else:
while True:
try:
packets.append(self.queue.get(block=False))
except self.queue.Empty:
break
if packets[-1] is None:
packets = packets[:-1]
self.queue.task_done()
break
if not packets:
# empty packet list returned -> connection closed
break
if self.current_transport == 'polling':
p = payload.Payload(packets=packets)
r = self._send_request(
'POST', self.base_url, body=p.encode(),
headers={'Content-Type': 'text/plain'},
timeout=self.request_timeout)
for pkt in packets:
self.queue.task_done()
if r is None or isinstance(r, str):
self.logger.warning(
r or 'Connection refused by the server, aborting')
break
if r.status_code < 200 or r.status_code >= 300:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status_code)
self.write_loop_task = None
break
else:
# websocket
try:
for pkt in packets:
encoded_packet = pkt.encode()
if pkt.binary:
self.ws.send_binary(encoded_packet)
else:
self.ws.send(encoded_packet)
self.queue.task_done()
except (websocket.WebSocketConnectionClosedException,
BrokenPipeError, OSError):
self.logger.warning(
'WebSocket connection was closed, aborting')
break
self.logger.info('Exiting write loop task')

View File

@ -0,0 +1,22 @@
class EngineIOError(Exception):
pass
class ContentTooLongError(EngineIOError):
pass
class UnknownPacketError(EngineIOError):
pass
class QueueEmpty(EngineIOError):
pass
class SocketIsClosedError(EngineIOError):
pass
class ConnectionError(EngineIOError):
pass

View File

@ -0,0 +1,16 @@
"""JSON-compatible module with sane defaults."""
from json import * # noqa: F401, F403
from json import loads as original_loads
def _safe_int(s):
if len(s) > 100:
raise ValueError('Integer is too large')
return int(s)
def loads(*args, **kwargs):
if 'parse_int' not in kwargs: # pragma: no cover
kwargs['parse_int'] = _safe_int
return original_loads(*args, **kwargs)

View File

@ -0,0 +1,86 @@
import os
from engineio.static_files import get_static_file
class WSGIApp:
"""WSGI application middleware for Engine.IO.
This middleware dispatches traffic to an Engine.IO application. It can
also serve a list of static files to the client, or forward unrelated
HTTP traffic to another WSGI application.
:param engineio_app: The Engine.IO server. Must be an instance of the
``engineio.Server`` class.
:param wsgi_app: The WSGI app that receives all other traffic.
:param static_files: A dictionary with static file mapping rules. See the
documentation for details on this argument.
:param engineio_path: The endpoint where the Engine.IO application should
be installed. The default value is appropriate for
most cases.
Example usage::
import engineio
import eventlet
eio = engineio.Server()
app = engineio.WSGIApp(eio, static_files={
'/': {'content_type': 'text/html', 'filename': 'index.html'},
'/index.html': {'content_type': 'text/html',
'filename': 'index.html'},
})
eventlet.wsgi.server(eventlet.listen(('', 8000)), app)
"""
def __init__(self, engineio_app, wsgi_app=None, static_files=None,
engineio_path='engine.io'):
self.engineio_app = engineio_app
self.wsgi_app = wsgi_app
self.engineio_path = engineio_path
if not self.engineio_path.startswith('/'):
self.engineio_path = '/' + self.engineio_path
if not self.engineio_path.endswith('/'):
self.engineio_path += '/'
self.static_files = static_files or {}
def __call__(self, environ, start_response):
if 'gunicorn.socket' in environ:
# gunicorn saves the socket under environ['gunicorn.socket'], while
# eventlet saves it under environ['eventlet.input']. Eventlet also
# stores the socket inside a wrapper class, while gunicon writes it
# directly into the environment. To give eventlet's WebSocket
# module access to this socket when running under gunicorn, here we
# copy the socket to the eventlet format.
class Input:
def __init__(self, socket):
self.socket = socket
def get_socket(self):
return self.socket
environ['eventlet.input'] = Input(environ['gunicorn.socket'])
path = environ['PATH_INFO']
if path is not None and path.startswith(self.engineio_path):
return self.engineio_app.handle_request(environ, start_response)
else:
static_file = get_static_file(path, self.static_files) \
if self.static_files else None
if static_file and os.path.exists(static_file['filename']):
start_response(
'200 OK',
[('Content-Type', static_file['content_type'])])
with open(static_file['filename'], 'rb') as f:
return [f.read()]
elif self.wsgi_app is not None:
return self.wsgi_app(environ, start_response)
return self.not_found(start_response)
def not_found(self, start_response):
start_response("404 Not Found", [('Content-Type', 'text/plain')])
return [b'Not Found']
class Middleware(WSGIApp):
"""This class has been renamed to ``WSGIApp`` and is now deprecated."""
def __init__(self, engineio_app, wsgi_app=None,
engineio_path='engine.io'):
super().__init__(engineio_app, wsgi_app, engineio_path=engineio_path)

View File

@ -0,0 +1,82 @@
import base64
from engineio import json as _json
(OPEN, CLOSE, PING, PONG, MESSAGE, UPGRADE, NOOP) = (0, 1, 2, 3, 4, 5, 6)
packet_names = ['OPEN', 'CLOSE', 'PING', 'PONG', 'MESSAGE', 'UPGRADE', 'NOOP']
binary_types = (bytes, bytearray)
class Packet:
"""Engine.IO packet."""
json = _json
def __init__(self, packet_type=NOOP, data=None, encoded_packet=None):
self.packet_type = packet_type
self.data = data
self.encode_cache = None
if isinstance(data, str):
self.binary = False
elif isinstance(data, binary_types):
self.binary = True
else:
self.binary = False
if self.binary and self.packet_type != MESSAGE:
raise ValueError('Binary packets can only be of type MESSAGE')
if encoded_packet is not None:
self.decode(encoded_packet)
def encode(self, b64=False):
"""Encode the packet for transmission.
Note: as a performance optimization, subsequent calls to this method
will return a cached encoded packet, even if the data has changed.
"""
if self.encode_cache:
return self.encode_cache
if self.binary:
if b64:
encoded_packet = 'b' + base64.b64encode(self.data).decode(
'utf-8')
else:
encoded_packet = self.data
else:
encoded_packet = str(self.packet_type)
if isinstance(self.data, str):
encoded_packet += self.data
elif isinstance(self.data, dict) or isinstance(self.data, list):
encoded_packet += self.json.dumps(self.data,
separators=(',', ':'))
elif self.data is not None:
encoded_packet += str(self.data)
self.encode_cache = encoded_packet
return encoded_packet
def decode(self, encoded_packet):
"""Decode a transmitted package."""
self.binary = isinstance(encoded_packet, binary_types)
if not self.binary and len(encoded_packet) == 0:
raise ValueError('Invalid empty packet received')
b64 = not self.binary and encoded_packet[0] == 'b'
if b64:
self.binary = True
self.packet_type = MESSAGE
self.data = base64.b64decode(encoded_packet[1:])
else:
if self.binary and not isinstance(encoded_packet, bytes):
encoded_packet = bytes(encoded_packet)
if self.binary:
self.packet_type = MESSAGE
self.data = encoded_packet
else:
self.packet_type = int(encoded_packet[0])
try:
self.data = self.json.loads(encoded_packet[1:])
if isinstance(self.data, int):
# do not allow integer payloads, see
# github.com/miguelgrinberg/python-engineio/issues/75
# for background on this decision
raise ValueError
except ValueError:
self.data = encoded_packet[1:]

View File

@ -0,0 +1,46 @@
import urllib
from . import packet
class Payload:
"""Engine.IO payload."""
max_decode_packets = 16
def __init__(self, packets=None, encoded_payload=None):
self.packets = packets or []
if encoded_payload is not None:
self.decode(encoded_payload)
def encode(self, jsonp_index=None):
"""Encode the payload for transmission."""
encoded_payload = ''
for pkt in self.packets:
if encoded_payload:
encoded_payload += '\x1e'
encoded_payload += pkt.encode(b64=True)
if jsonp_index is not None:
encoded_payload = '___eio[' + \
str(jsonp_index) + \
']("' + \
encoded_payload.replace('"', '\\"') + \
'");'
return encoded_payload
def decode(self, encoded_payload):
"""Decode a transmitted payload."""
self.packets = []
if len(encoded_payload) == 0:
return
# JSONP POST payload starts with 'd='
if encoded_payload.startswith('d='):
encoded_payload = urllib.parse.parse_qs(
encoded_payload)['d'][0]
encoded_packets = encoded_payload.split('\x1e')
if len(encoded_packets) > self.max_decode_packets:
raise ValueError('Too many packets in payload')
self.packets = [packet.Packet(encoded_packet=encoded_packet)
for encoded_packet in encoded_packets]

View File

@ -0,0 +1,503 @@
import logging
import urllib
from . import base_server
from . import exceptions
from . import packet
from . import socket
default_logger = logging.getLogger('engineio.server')
class Server(base_server.BaseServer):
"""An Engine.IO server.
This class implements a fully compliant Engine.IO web server with support
for websocket and long-polling transports.
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are "threading",
"eventlet", "gevent" and "gevent_uwsgi". If this
argument is not given, "eventlet" is tried first, then
"gevent_uwsgi", then "gevent", and finally "threading".
The first async mode that has all its dependencies
installed is the one that is chosen.
:param ping_interval: The interval in seconds at which the server pings
the client. The default is 25 seconds. For advanced
control, a two element tuple can be given, where
the first number is the ping interval and the second
is a grace period added by the server.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting. The default
is 20 seconds.
:param max_http_buffer_size: The maximum size that is accepted for incoming
messages. The default is 1,000,000 bytes. In
spite of its name, the value set in this
argument is enforced for HTTP long-polling and
WebSocket connections.
:param allow_upgrades: Whether to allow transport upgrades or not. The
default is ``True``.
:param http_compression: Whether to compress packages when using the
polling transport. The default is ``True``.
:param compression_threshold: Only compress messages when their byte size
is greater than this value. The default is
1024 bytes.
:param cookie: If set to a string, it is the name of the HTTP cookie the
server sends back tot he client containing the client
session id. If set to a dictionary, the ``'name'`` key
contains the cookie name and other keys define cookie
attributes, where the value of each attribute can be a
string, a callable with no arguments, or a boolean. If set
to ``None`` (the default), a cookie is not sent to the
client.
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. Only the same origin
is allowed by default. Set this argument to
``'*'`` to allow all origins, or to ``[]`` to
disable CORS handling.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server. The default
is ``True``.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``. Note that fatal errors are logged even when
``logger`` is ``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param async_handlers: If set to ``True``, run message event handlers in
non-blocking threads. To run handlers synchronously,
set to ``False``. The default is ``True``.
:param monitor_clients: If set to ``True``, a background task will ensure
inactive clients are closed. Set to ``False`` to
disable the monitoring task (not recommended). The
default is ``True``.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. Defaults to
``['polling', 'websocket']``.
:param kwargs: Reserved for future extensions, any additional parameters
given as keyword arguments will be silently ignored.
"""
def send(self, sid, data):
"""Send a message to a client.
:param sid: The session id of the recipient client.
:param data: The data to send to the client. Data can be of type
``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
or ``dict``, the data will be serialized as JSON.
"""
self.send_packet(sid, packet.Packet(packet.MESSAGE, data=data))
def send_packet(self, sid, pkt):
"""Send a raw packet to a client.
:param sid: The session id of the recipient client.
:param pkt: The packet to send to the client.
"""
try:
socket = self._get_socket(sid)
except KeyError:
# the socket is not available
self.logger.warning('Cannot send to sid %s', sid)
return
socket.send(pkt)
def get_session(self, sid):
"""Return the user session for a client.
:param sid: The session id of the client.
The return value is a dictionary. Modifications made to this
dictionary are not guaranteed to be preserved unless
``save_session()`` is called, or when the ``session`` context manager
is used.
"""
socket = self._get_socket(sid)
return socket.session
def save_session(self, sid, session):
"""Store the user session for a client.
:param sid: The session id of the client.
:param session: The session dictionary.
"""
socket = self._get_socket(sid)
socket.session = session
def session(self, sid):
"""Return the user session for a client with context manager syntax.
:param sid: The session id of the client.
This is a context manager that returns the user session dictionary for
the client. Any changes that are made to this dictionary inside the
context manager block are saved back to the session. Example usage::
@eio.on('connect')
def on_connect(sid, environ):
username = authenticate_user(environ)
if not username:
return False
with eio.session(sid) as session:
session['username'] = username
@eio.on('message')
def on_message(sid, msg):
with eio.session(sid) as session:
print('received message from ', session['username'])
"""
class _session_context_manager:
def __init__(self, server, sid):
self.server = server
self.sid = sid
self.session = None
def __enter__(self):
self.session = self.server.get_session(sid)
return self.session
def __exit__(self, *args):
self.server.save_session(sid, self.session)
return _session_context_manager(self, sid)
def disconnect(self, sid=None):
"""Disconnect a client.
:param sid: The session id of the client to close. If this parameter
is not given, then all clients are closed.
"""
if sid is not None:
try:
socket = self._get_socket(sid)
except KeyError: # pragma: no cover
# the socket was already closed or gone
pass
else:
socket.close(reason=self.reason.SERVER_DISCONNECT)
if sid in self.sockets: # pragma: no cover
del self.sockets[sid]
else:
for client in self.sockets.copy().values():
client.close(reason=self.reason.SERVER_DISCONNECT)
self.sockets = {}
def handle_request(self, environ, start_response):
"""Handle an HTTP request from the client.
This is the entry point of the Engine.IO application, using the same
interface as a WSGI application. For the typical usage, this function
is invoked by the :class:`Middleware` instance, but it can be invoked
directly when the middleware is not used.
:param environ: The WSGI environment.
:param start_response: The WSGI ``start_response`` function.
This function returns the HTTP response body to deliver to the client
as a byte sequence.
"""
if self.cors_allowed_origins != []:
# Validate the origin header if present
# This is important for WebSocket more than for HTTP, since
# browsers only apply CORS controls to HTTP.
origin = environ.get('HTTP_ORIGIN')
if origin:
allowed_origins = self._cors_allowed_origins(environ)
if allowed_origins is not None and origin not in \
allowed_origins:
self._log_error_once(
origin + ' is not an accepted origin.', 'bad-origin')
r = self._bad_request('Not an accepted origin.')
start_response(r['status'], r['headers'])
return [r['response']]
method = environ['REQUEST_METHOD']
query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))
jsonp = False
jsonp_index = None
# make sure the client uses an allowed transport
transport = query.get('transport', ['polling'])[0]
if transport not in self.transports:
self._log_error_once('Invalid transport', 'bad-transport')
r = self._bad_request('Invalid transport')
start_response(r['status'], r['headers'])
return [r['response']]
# make sure the client speaks a compatible Engine.IO version
sid = query['sid'][0] if 'sid' in query else None
if sid is None and query.get('EIO') != ['4']:
self._log_error_once(
'The client is using an unsupported version of the Socket.IO '
'or Engine.IO protocols', 'bad-version')
r = self._bad_request(
'The client is using an unsupported version of the Socket.IO '
'or Engine.IO protocols')
start_response(r['status'], r['headers'])
return [r['response']]
if 'j' in query:
jsonp = True
try:
jsonp_index = int(query['j'][0])
except (ValueError, KeyError, IndexError):
# Invalid JSONP index number
pass
if jsonp and jsonp_index is None:
self._log_error_once('Invalid JSONP index number',
'bad-jsonp-index')
r = self._bad_request('Invalid JSONP index number')
elif method == 'GET':
upgrade_header = environ.get('HTTP_UPGRADE').lower() \
if 'HTTP_UPGRADE' in environ else None
if sid is None:
# transport must be one of 'polling' or 'websocket'.
# if 'websocket', the HTTP_UPGRADE header must match.
if transport == 'polling' \
or transport == upgrade_header == 'websocket':
r = self._handle_connect(environ, start_response,
transport, jsonp_index)
else:
self._log_error_once('Invalid websocket upgrade',
'bad-upgrade')
r = self._bad_request('Invalid websocket upgrade')
else:
if sid not in self.sockets:
self._log_error_once(f'Invalid session {sid}', 'bad-sid')
r = self._bad_request(f'Invalid session {sid}')
else:
try:
socket = self._get_socket(sid)
except KeyError as e: # pragma: no cover
self._log_error_once(f'{e} {sid}', 'bad-sid')
r = self._bad_request(f'{e} {sid}')
else:
if self.transport(sid) != transport and \
transport != upgrade_header:
self._log_error_once(
f'Invalid transport for session {sid}',
'bad-transport')
r = self._bad_request('Invalid transport')
else:
try:
packets = socket.handle_get_request(
environ, start_response)
if isinstance(packets, list):
r = self._ok(packets,
jsonp_index=jsonp_index)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and \
self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
self._log_error_once(f'Invalid session {sid}', 'bad-sid')
r = self._bad_request(f'Invalid session {sid}')
else:
socket = self._get_socket(sid)
try:
socket.handle_post_request(environ)
r = self._ok(jsonp_index=jsonp_index)
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
except: # pragma: no cover
# for any other unexpected errors, we log the error
# and keep going
self.logger.exception('post request handler error')
r = self._ok(jsonp_index=jsonp_index)
elif method == 'OPTIONS':
r = self._ok()
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
if not isinstance(r, dict):
return r
if self.http_compression and \
len(r['response']) >= self.compression_threshold:
encodings = [e.split(';')[0].strip() for e in
environ.get('HTTP_ACCEPT_ENCODING', '').split(',')]
for encoding in encodings:
if encoding in self.compression_methods:
r['response'] = \
getattr(self, '_' + encoding)(r['response'])
r['headers'] += [('Content-Encoding', encoding)]
break
cors_headers = self._cors_headers(environ)
start_response(r['status'], r['headers'] + cors_headers)
return [r['response']]
def shutdown(self):
"""Stop Socket.IO background tasks.
This method stops background activity initiated by the Socket.IO
server. It must be called before shutting down the web server.
"""
self.logger.info('Socket.IO is shutting down')
if self.service_task_event: # pragma: no cover
self.service_task_event.set()
self.service_task_handle.join()
self.service_task_handle = None
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object that represents the background task,
on which the ``join()`` methond can be invoked to wait for the task to
complete.
"""
th = self._async['thread'](target=target, args=args, kwargs=kwargs)
th.start()
return th # pragma: no cover
def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
"""
return self._async['sleep'](seconds)
def _handle_connect(self, environ, start_response, transport,
jsonp_index=None):
"""Handle a client connection request."""
if self.start_service_task:
# start the service task to monitor connected clients
self.start_service_task = False
self.service_task_handle = self.start_background_task(
self._service_task)
sid = self.generate_id()
s = socket.Socket(self, sid)
self.sockets[sid] = s
pkt = packet.Packet(packet.OPEN, {
'sid': sid,
'upgrades': self._upgrades(sid, transport),
'pingTimeout': int(self.ping_timeout * 1000),
'pingInterval': int(
self.ping_interval + self.ping_interval_grace_period) * 1000,
'maxPayload': self.max_http_buffer_size,
})
s.send(pkt)
s.schedule_ping()
# NOTE: some sections below are marked as "no cover" to workaround
# what seems to be a bug in the coverage package. All the lines below
# are covered by tests, but some are not reported as such for some
# reason
ret = self._trigger_event('connect', sid, environ, run_async=False)
if ret is not None and ret is not True: # pragma: no cover
del self.sockets[sid]
self.logger.warning('Application rejected connection')
return self._unauthorized(ret or None)
if transport == 'websocket': # pragma: no cover
ret = s.handle_get_request(environ, start_response)
if s.closed and sid in self.sockets:
# websocket connection ended, so we are done
del self.sockets[sid]
return ret
else: # pragma: no cover
s.connected = True
headers = None
if self.cookie:
if isinstance(self.cookie, dict):
headers = [(
'Set-Cookie',
self._generate_sid_cookie(sid, self.cookie)
)]
else:
headers = [(
'Set-Cookie',
self._generate_sid_cookie(sid, {
'name': self.cookie, 'path': '/', 'SameSite': 'Lax'
})
)]
try:
return self._ok(s.poll(), headers=headers,
jsonp_index=jsonp_index)
except exceptions.QueueEmpty:
return self._bad_request()
def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
if event in self.handlers:
def run_handler():
try:
try:
return self.handlers[event](*args)
except TypeError:
if event == 'disconnect' and \
len(args) == 2: # pragma: no branch
# legacy disconnect events do not have a reason
# argument
return self.handlers[event](args[0])
else: # pragma: no cover
raise
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
if run_async:
return self.start_background_task(run_handler)
else:
return run_handler()
def _service_task(self): # pragma: no cover
"""Monitor connected clients and clean up those that time out."""
self.service_task_event = self.create_event()
while not self.service_task_event.is_set():
if len(self.sockets) == 0:
# nothing to do
if self.service_task_event.wait(timeout=self.ping_timeout):
break
continue
# go through the entire client list in a ping interval cycle
sleep_interval = float(self.ping_timeout) / len(self.sockets)
try:
# iterate over the current clients
for s in self.sockets.copy().values():
if s.closed:
try:
del self.sockets[s.sid]
except KeyError:
# the socket could have also been removed by
# the _get_socket() method from another thread
pass
elif not s.closing:
s.check_ping_timeout()
if self.service_task_event.wait(timeout=sleep_interval):
raise KeyboardInterrupt()
except (SystemExit, KeyboardInterrupt):
self.logger.info('service task canceled')
break
except:
# an unexpected exception has occurred, log it and continue
self.logger.exception('service task exception')

View File

@ -0,0 +1,256 @@
import sys
import time
from . import base_socket
from . import exceptions
from . import packet
from . import payload
class Socket(base_socket.BaseSocket):
"""An Engine.IO socket."""
def poll(self):
"""Wait for packets to send to the client."""
queue_empty = self.server.get_queue_empty_exception()
try:
packets = [self.queue.get(
timeout=self.server.ping_interval + self.server.ping_timeout)]
self.queue.task_done()
except queue_empty:
raise exceptions.QueueEmpty()
if packets == [None]:
return []
while True:
try:
pkt = self.queue.get(block=False)
self.queue.task_done()
if pkt is None:
self.queue.put(None)
break
packets.append(pkt)
except queue_empty:
break
return packets
def receive(self, pkt):
"""Receive packet from the client."""
packet_name = packet.packet_names[pkt.packet_type] \
if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
self.server.logger.info('%s: Received packet %s data %s',
self.sid, packet_name,
pkt.data if not isinstance(pkt.data, bytes)
else '<binary>')
if pkt.packet_type == packet.PONG:
self.schedule_ping()
elif pkt.packet_type == packet.MESSAGE:
self.server._trigger_event('message', self.sid, pkt.data,
run_async=self.server.async_handlers)
elif pkt.packet_type == packet.UPGRADE:
self.send(packet.Packet(packet.NOOP))
elif pkt.packet_type == packet.CLOSE:
self.close(wait=False, abort=True,
reason=self.server.reason.CLIENT_DISCONNECT)
else:
raise exceptions.UnknownPacketError()
def check_ping_timeout(self):
"""Make sure the client is still responding to pings."""
if self.closed:
raise exceptions.SocketIsClosedError()
if self.last_ping and \
time.time() - self.last_ping > self.server.ping_timeout:
self.server.logger.info('%s: Client is gone, closing socket',
self.sid)
# Passing abort=False here will cause close() to write a
# CLOSE packet. This has the effect of updating half-open sockets
# to their correct state of disconnected
self.close(wait=False, abort=False,
reason=self.server.reason.PING_TIMEOUT)
return False
return True
def send(self, pkt):
"""Send a packet to the client."""
if not self.check_ping_timeout():
return
else:
self.queue.put(pkt)
self.server.logger.info('%s: Sending packet %s data %s',
self.sid, packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes)
else '<binary>')
def handle_get_request(self, environ, start_response):
"""Handle a long-polling GET request from the client."""
connections = [
s.strip()
for s in environ.get('HTTP_CONNECTION', '').lower().split(',')]
transport = environ.get('HTTP_UPGRADE', '').lower()
if 'upgrade' in connections and transport in self.upgrade_protocols:
self.server.logger.info('%s: Received request to upgrade to %s',
self.sid, transport)
return getattr(self, '_upgrade_' + transport)(environ,
start_response)
if self.upgrading or self.upgraded:
# we are upgrading to WebSocket, do not return any more packets
# through the polling endpoint
return [packet.Packet(packet.NOOP)]
try:
packets = self.poll()
except exceptions.QueueEmpty:
exc = sys.exc_info()
self.close(wait=False, reason=self.server.reason.TRANSPORT_ERROR)
raise exc[1].with_traceback(exc[2])
return packets
def handle_post_request(self, environ):
"""Handle a long-polling POST request from the client."""
length = int(environ.get('CONTENT_LENGTH', '0'))
if length > self.server.max_http_buffer_size:
raise exceptions.ContentTooLongError()
else:
body = environ['wsgi.input'].read(length).decode('utf-8')
p = payload.Payload(encoded_payload=body)
for pkt in p.packets:
self.receive(pkt)
def close(self, wait=True, abort=False, reason=None):
"""Close the socket connection."""
if not self.closed and not self.closing:
self.closing = True
self.server._trigger_event(
'disconnect', self.sid,
reason or self.server.reason.SERVER_DISCONNECT,
run_async=False)
if not abort:
self.send(packet.Packet(packet.CLOSE))
self.closed = True
self.queue.put(None)
if wait:
self.queue.join()
def schedule_ping(self):
self.server.start_background_task(self._send_ping)
def _send_ping(self):
self.last_ping = None
self.server.sleep(self.server.ping_interval)
if not self.closing and not self.closed:
self.last_ping = time.time()
self.send(packet.Packet(packet.PING))
def _upgrade_websocket(self, environ, start_response):
"""Upgrade the connection from polling to websocket."""
if self.upgraded:
raise OSError('Socket has been upgraded already')
if self.server._async['websocket'] is None:
# the selected async mode does not support websocket
return self.server._bad_request()
ws = self.server._async['websocket'](
self._websocket_handler, self.server)
return ws(environ, start_response)
def _websocket_handler(self, ws):
"""Engine.IO handler for websocket transport."""
def websocket_wait():
data = ws.wait()
if data and len(data) > self.server.max_http_buffer_size:
raise ValueError('packet is too large')
return data
# try to set a socket timeout matching the configured ping interval
# and timeout
for attr in ['_sock', 'socket']: # pragma: no cover
if hasattr(ws, attr) and hasattr(getattr(ws, attr), 'settimeout'):
getattr(ws, attr).settimeout(
self.server.ping_interval + self.server.ping_timeout)
if self.connected:
# the socket was already connected, so this is an upgrade
self.upgrading = True # hold packet sends during the upgrade
pkt = websocket_wait()
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.PING or \
decoded_pkt.data != 'probe':
self.server.logger.info(
'%s: Failed websocket upgrade, no PING packet', self.sid)
self.upgrading = False
return []
ws.send(packet.Packet(packet.PONG, data='probe').encode())
self.queue.put(packet.Packet(packet.NOOP)) # end poll
pkt = websocket_wait()
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.UPGRADE:
self.upgraded = False
self.server.logger.info(
('%s: Failed websocket upgrade, expected UPGRADE packet, '
'received %s instead.'),
self.sid, pkt)
self.upgrading = False
return []
self.upgraded = True
self.upgrading = False
else:
self.connected = True
self.upgraded = True
# start separate writer thread
def writer():
while True:
packets = None
try:
packets = self.poll()
except exceptions.QueueEmpty:
break
if not packets:
# empty packet list returned -> connection closed
break
try:
for pkt in packets:
ws.send(pkt.encode())
except:
break
ws.close()
writer_task = self.server.start_background_task(writer)
self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)
while True:
p = None
try:
p = websocket_wait()
except Exception as e:
# if the socket is already closed, we can assume this is a
# downstream error of that
if not self.closed: # pragma: no cover
self.server.logger.info(
'%s: Unexpected error "%s", closing connection',
self.sid, str(e))
break
if p is None:
# connection closed by client
break
pkt = packet.Packet(encoded_packet=p)
try:
self.receive(pkt)
except exceptions.UnknownPacketError: # pragma: no cover
pass
except exceptions.SocketIsClosedError: # pragma: no cover
self.server.logger.info('Receive error -- socket is closed')
break
except: # pragma: no cover
# if we get an unexpected exception we log the error and exit
# the connection properly
self.server.logger.exception('Unknown receive error')
break
self.queue.put(None) # unlock the writer task so that it can exit
writer_task.join()
self.close(wait=False, abort=True,
reason=self.server.reason.TRANSPORT_CLOSE)
return []

View File

@ -0,0 +1,60 @@
content_types = {
'css': 'text/css',
'gif': 'image/gif',
'html': 'text/html',
'jpg': 'image/jpeg',
'js': 'application/javascript',
'json': 'application/json',
'png': 'image/png',
'txt': 'text/plain',
}
def get_static_file(path, static_files):
"""Return the local filename and content type for the requested static
file URL.
:param path: the path portion of the requested URL.
:param static_files: a static file configuration dictionary.
This function returns a dictionary with two keys, "filename" and
"content_type". If the requested URL does not match any static file, the
return value is None.
"""
extra_path = ''
if path in static_files:
f = static_files[path]
else:
f = None
while path != '':
path, last = path.rsplit('/', 1)
extra_path = '/' + last + extra_path
if path in static_files:
f = static_files[path]
break
elif path + '/' in static_files:
f = static_files[path + '/']
break
if f:
if isinstance(f, str):
f = {'filename': f}
else:
f = f.copy() # in case it is mutated below
if f['filename'].endswith('/') and extra_path.startswith('/'):
extra_path = extra_path[1:]
f['filename'] += extra_path
if f['filename'].endswith('/'):
if '' in static_files:
if isinstance(static_files[''], str):
f['filename'] += static_files['']
else:
f['filename'] += static_files['']['filename']
if 'content_type' in static_files['']:
f['content_type'] = static_files['']['content_type']
else:
f['filename'] += 'index.html'
if 'content_type' not in f:
ext = f['filename'].rsplit('.')[-1]
f['content_type'] = content_types.get(
ext, 'application/octet-stream')
return f

View File

@ -0,0 +1,2 @@
/home/davidson/Projects/evolution_client/python
.

View File

@ -0,0 +1 @@
pip

View File

@ -0,0 +1,22 @@
The MIT License (MIT)
Copyright (c) 2016 Nathaniel J. Smith <njs@pobox.com> and other contributors
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -0,0 +1,193 @@
Metadata-Version: 2.1
Name: h11
Version: 0.14.0
Summary: A pure-Python, bring-your-own-I/O implementation of HTTP/1.1
Home-page: https://github.com/python-hyper/h11
Author: Nathaniel J. Smith
Author-email: njs@pobox.com
License: MIT
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: Implementation :: CPython
Classifier: Programming Language :: Python :: Implementation :: PyPy
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Topic :: Internet :: WWW/HTTP
Classifier: Topic :: System :: Networking
Requires-Python: >=3.7
License-File: LICENSE.txt
Requires-Dist: typing-extensions ; python_version < "3.8"
h11
===
.. image:: https://travis-ci.org/python-hyper/h11.svg?branch=master
:target: https://travis-ci.org/python-hyper/h11
:alt: Automated test status
.. image:: https://codecov.io/gh/python-hyper/h11/branch/master/graph/badge.svg
:target: https://codecov.io/gh/python-hyper/h11
:alt: Test coverage
.. image:: https://readthedocs.org/projects/h11/badge/?version=latest
:target: http://h11.readthedocs.io/en/latest/?badge=latest
:alt: Documentation Status
This is a little HTTP/1.1 library written from scratch in Python,
heavily inspired by `hyper-h2 <https://hyper-h2.readthedocs.io/>`_.
It's a "bring-your-own-I/O" library; h11 contains no IO code
whatsoever. This means you can hook h11 up to your favorite network
API, and that could be anything you want: synchronous, threaded,
asynchronous, or your own implementation of `RFC 6214
<https://tools.ietf.org/html/rfc6214>`_ -- h11 won't judge you.
(Compare this to the current state of the art, where every time a `new
network API <https://trio.readthedocs.io/>`_ comes along then someone
gets to start over reimplementing the entire HTTP protocol from
scratch.) Cory Benfield made an `excellent blog post describing the
benefits of this approach
<https://lukasa.co.uk/2015/10/The_New_Hyper/>`_, or if you like video
then here's his `PyCon 2016 talk on the same theme
<https://www.youtube.com/watch?v=7cC3_jGwl_U>`_.
This also means that h11 is not immediately useful out of the box:
it's a toolkit for building programs that speak HTTP, not something
that could directly replace ``requests`` or ``twisted.web`` or
whatever. But h11 makes it much easier to implement something like
``requests`` or ``twisted.web``.
At a high level, working with h11 goes like this:
1) First, create an ``h11.Connection`` object to track the state of a
single HTTP/1.1 connection.
2) When you read data off the network, pass it to
``conn.receive_data(...)``; you'll get back a list of objects
representing high-level HTTP "events".
3) When you want to send a high-level HTTP event, create the
corresponding "event" object and pass it to ``conn.send(...)``;
this will give you back some bytes that you can then push out
through the network.
For example, a client might instantiate and then send a
``h11.Request`` object, then zero or more ``h11.Data`` objects for the
request body (e.g., if this is a POST), and then a
``h11.EndOfMessage`` to indicate the end of the message. Then the
server would then send back a ``h11.Response``, some ``h11.Data``, and
its own ``h11.EndOfMessage``. If either side violates the protocol,
you'll get a ``h11.ProtocolError`` exception.
h11 is suitable for implementing both servers and clients, and has a
pleasantly symmetric API: the events you send as a client are exactly
the ones that you receive as a server and vice-versa.
`Here's an example of a tiny HTTP client
<https://github.com/python-hyper/h11/blob/master/examples/basic-client.py>`_
It also has `a fine manual <https://h11.readthedocs.io/>`_.
FAQ
---
*Whyyyyy?*
I wanted to play with HTTP in `Curio
<https://curio.readthedocs.io/en/latest/tutorial.html>`__ and `Trio
<https://trio.readthedocs.io>`__, which at the time didn't have any
HTTP libraries. So I thought, no big deal, Python has, like, a dozen
different implementations of HTTP, surely I can find one that's
reusable. I didn't find one, but I did find Cory's call-to-arms
blog-post. So I figured, well, fine, if I have to implement HTTP from
scratch, at least I can make sure no-one *else* has to ever again.
*Should I use it?*
Maybe. You should be aware that it's a very young project. But, it's
feature complete and has an exhaustive test-suite and complete docs,
so the next step is for people to try using it and see how it goes
:-). If you do then please let us know -- if nothing else we'll want
to talk to you before making any incompatible changes!
*What are the features/limitations?*
Roughly speaking, it's trying to be a robust, complete, and non-hacky
implementation of the first "chapter" of the HTTP/1.1 spec: `RFC 7230:
HTTP/1.1 Message Syntax and Routing
<https://tools.ietf.org/html/rfc7230>`_. That is, it mostly focuses on
implementing HTTP at the level of taking bytes on and off the wire,
and the headers related to that, and tries to be anal about spec
conformance. It doesn't know about higher-level concerns like URL
routing, conditional GETs, cross-origin cookie policies, or content
negotiation. But it does know how to take care of framing,
cross-version differences in keep-alive handling, and the "obsolete
line folding" rule, so you can focus your energies on the hard /
interesting parts for your application, and it tries to support the
full specification in the sense that any useful HTTP/1.1 conformant
application should be able to use h11.
It's pure Python, and has no dependencies outside of the standard
library.
It has a test suite with 100.0% coverage for both statements and
branches.
Currently it supports Python 3 (testing on 3.7-3.10) and PyPy 3.
The last Python 2-compatible version was h11 0.11.x.
(Originally it had a Cython wrapper for `http-parser
<https://github.com/nodejs/http-parser>`_ and a beautiful nested state
machine implemented with ``yield from`` to postprocess the output. But
I had to take these out -- the new *parser* needs fewer lines-of-code
than the old *parser wrapper*, is written in pure Python, uses no
exotic language syntax, and has more features. It's sad, really; that
old state machine was really slick. I just need a few sentences here
to mourn that.)
I don't know how fast it is. I haven't benchmarked or profiled it yet,
so it's probably got a few pointless hot spots, and I've been trying
to err on the side of simplicity and robustness instead of
micro-optimization. But at the architectural level I tried hard to
avoid fundamentally bad decisions, e.g., I believe that all the
parsing algorithms remain linear-time even in the face of pathological
input like slowloris, and there are no byte-by-byte loops. (I also
believe that it maintains bounded memory usage in the face of
arbitrary/pathological input.)
The whole library is ~800 lines-of-code. You can read and understand
the whole thing in less than an hour. Most of the energy invested in
this so far has been spent on trying to keep things simple by
minimizing special-cases and ad hoc state manipulation; even though it
is now quite small and simple, I'm still annoyed that I haven't
figured out how to make it even smaller and simpler. (Unfortunately,
HTTP does not lend itself to simplicity.)
The API is ~feature complete and I don't expect the general outlines
to change much, but you can't judge an API's ergonomics until you
actually document and use it, so I'd expect some changes in the
details.
*How do I try it?*
.. code-block:: sh
$ pip install h11
$ git clone git@github.com:python-hyper/h11
$ cd h11/examples
$ python basic-client.py
and go from there.
*License?*
MIT
*Code of conduct?*
Contributors are requested to follow our `code of conduct
<https://github.com/python-hyper/h11/blob/master/CODE_OF_CONDUCT.md>`_ in
all project spaces.

View File

@ -0,0 +1,52 @@
h11-0.14.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
h11-0.14.0.dist-info/LICENSE.txt,sha256=N9tbuFkm2yikJ6JYZ_ELEjIAOuob5pzLhRE4rbjm82E,1124
h11-0.14.0.dist-info/METADATA,sha256=B7pZ0m7WBXNs17vl6hUH9bJTL9s37DaGvY31w7jNxSg,8175
h11-0.14.0.dist-info/RECORD,,
h11-0.14.0.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92
h11-0.14.0.dist-info/top_level.txt,sha256=F7dC4jl3zeh8TGHEPaWJrMbeuoWbS379Gwdi-Yvdcis,4
h11/__init__.py,sha256=iO1KzkSO42yZ6ffg-VMgbx_ZVTWGUY00nRYEWn-s3kY,1507
h11/__pycache__/__init__.cpython-310.pyc,,
h11/__pycache__/_abnf.cpython-310.pyc,,
h11/__pycache__/_connection.cpython-310.pyc,,
h11/__pycache__/_events.cpython-310.pyc,,
h11/__pycache__/_headers.cpython-310.pyc,,
h11/__pycache__/_readers.cpython-310.pyc,,
h11/__pycache__/_receivebuffer.cpython-310.pyc,,
h11/__pycache__/_state.cpython-310.pyc,,
h11/__pycache__/_util.cpython-310.pyc,,
h11/__pycache__/_version.cpython-310.pyc,,
h11/__pycache__/_writers.cpython-310.pyc,,
h11/_abnf.py,sha256=ybixr0xsupnkA6GFAyMubuXF6Tc1lb_hF890NgCsfNc,4815
h11/_connection.py,sha256=eS2sorMD0zKLCFiB9lW9W9F_Nzny2tjHa4e6s1ujr1c,26539
h11/_events.py,sha256=LEfuvg1AbhHaVRwxCd0I-pFn9-ezUOaoL8o2Kvy1PBA,11816
h11/_headers.py,sha256=RqB8cd8CN0blYPzcLe5qeCh-phv6D1U_CHj4hs67lgQ,10230
h11/_readers.py,sha256=EbSed0jzwVUiD1nOPAeUcVE4Flf3wXkxfb8c06-OTBM,8383
h11/_receivebuffer.py,sha256=xrspsdsNgWFxRfQcTXxR8RrdjRXXTK0Io5cQYWpJ1Ws,5252
h11/_state.py,sha256=k1VL6SDbaPkSrZ-49ewCXDpuiUS69_46YhbWjuV1qEY,13300
h11/_util.py,sha256=LWkkjXyJaFlAy6Lt39w73UStklFT5ovcvo0TkY7RYuk,4888
h11/_version.py,sha256=LVyTdiZRzIIEv79UyOgbM5iUrJUllEzlCWaJEYBY1zc,686
h11/_writers.py,sha256=oFKm6PtjeHfbj4RLX7VB7KDc1gIY53gXG3_HR9ltmTA,5081
h11/py.typed,sha256=sow9soTwP9T_gEAQSVh7Gb8855h04Nwmhs2We-JRgZM,7
h11/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
h11/tests/__pycache__/__init__.cpython-310.pyc,,
h11/tests/__pycache__/helpers.cpython-310.pyc,,
h11/tests/__pycache__/test_against_stdlib_http.cpython-310.pyc,,
h11/tests/__pycache__/test_connection.cpython-310.pyc,,
h11/tests/__pycache__/test_events.cpython-310.pyc,,
h11/tests/__pycache__/test_headers.cpython-310.pyc,,
h11/tests/__pycache__/test_helpers.cpython-310.pyc,,
h11/tests/__pycache__/test_io.cpython-310.pyc,,
h11/tests/__pycache__/test_receivebuffer.cpython-310.pyc,,
h11/tests/__pycache__/test_state.cpython-310.pyc,,
h11/tests/__pycache__/test_util.cpython-310.pyc,,
h11/tests/data/test-file,sha256=ZJ03Rqs98oJw29OHzJg7LlMzyGQaRAY0r3AqBeM2wVU,65
h11/tests/helpers.py,sha256=a1EVG_p7xU4wRsa3tMPTRxuaKCmretok9sxXWvqfmQA,3355
h11/tests/test_against_stdlib_http.py,sha256=cojCHgHXFQ8gWhNlEEwl3trmOpN-5uDukRoHnElqo3A,3995
h11/tests/test_connection.py,sha256=ZbPLDPclKvjgjAhgk-WlCPBaf17c4XUIV2tpaW08jOI,38720
h11/tests/test_events.py,sha256=LPVLbcV-NvPNK9fW3rraR6Bdpz1hAlsWubMtNaJ5gHg,4657
h11/tests/test_headers.py,sha256=qd8T1Zenuz5GbD6wklSJ5G8VS7trrYgMV0jT-SMvqg8,5612
h11/tests/test_helpers.py,sha256=kAo0CEM4LGqmyyP2ZFmhsyq3UFJqoFfAbzu3hbWreRM,794
h11/tests/test_io.py,sha256=uCZVnjarkRBkudfC1ij-KSCQ71XWJhnkgkgWWkKgYPQ,16386
h11/tests/test_receivebuffer.py,sha256=3jGbeJM36Akqg_pAhPb7XzIn2NS6RhPg-Ryg8Eu6ytk,3454
h11/tests/test_state.py,sha256=rqll9WqFsJPE0zSrtCn9LH659mPKsDeXZ-DwXwleuBQ,8928
h11/tests/test_util.py,sha256=VO5L4nSFe4pgtSwKuv6u_6l0H7UeizF5WKuHTWreg70,2970

View File

@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.37.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@ -0,0 +1 @@
h11

View File

@ -0,0 +1,62 @@
# A highish-level implementation of the HTTP/1.1 wire protocol (RFC 7230),
# containing no networking code at all, loosely modelled on hyper-h2's generic
# implementation of HTTP/2 (and in particular the h2.connection.H2Connection
# class). There's still a bunch of subtle details you need to get right if you
# want to make this actually useful, because it doesn't implement all the
# semantics to check that what you're asking to write to the wire is sensible,
# but at least it gets you out of dealing with the wire itself.
from h11._connection import Connection, NEED_DATA, PAUSED
from h11._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from h11._state import (
CLIENT,
CLOSED,
DONE,
ERROR,
IDLE,
MIGHT_SWITCH_PROTOCOL,
MUST_CLOSE,
SEND_BODY,
SEND_RESPONSE,
SERVER,
SWITCHED_PROTOCOL,
)
from h11._util import LocalProtocolError, ProtocolError, RemoteProtocolError
from h11._version import __version__
PRODUCT_ID = "python-h11/" + __version__
__all__ = (
"Connection",
"NEED_DATA",
"PAUSED",
"ConnectionClosed",
"Data",
"EndOfMessage",
"Event",
"InformationalResponse",
"Request",
"Response",
"CLIENT",
"CLOSED",
"DONE",
"ERROR",
"IDLE",
"MUST_CLOSE",
"SEND_BODY",
"SEND_RESPONSE",
"SERVER",
"SWITCHED_PROTOCOL",
"ProtocolError",
"LocalProtocolError",
"RemoteProtocolError",
)

View File

@ -0,0 +1,132 @@
# We use native strings for all the re patterns, to take advantage of string
# formatting, and then convert to bytestrings when compiling the final re
# objects.
# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#whitespace
# OWS = *( SP / HTAB )
# ; optional whitespace
OWS = r"[ \t]*"
# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#rule.token.separators
# token = 1*tchar
#
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
# / DIGIT / ALPHA
# ; any VCHAR, except delimiters
token = r"[-!#$%&'*+.^_`|~0-9a-zA-Z]+"
# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#header.fields
# field-name = token
field_name = token
# The standard says:
#
# field-value = *( field-content / obs-fold )
# field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ]
# field-vchar = VCHAR / obs-text
# obs-fold = CRLF 1*( SP / HTAB )
# ; obsolete line folding
# ; see Section 3.2.4
#
# https://tools.ietf.org/html/rfc5234#appendix-B.1
#
# VCHAR = %x21-7E
# ; visible (printing) characters
#
# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#rule.quoted-string
# obs-text = %x80-FF
#
# However, the standard definition of field-content is WRONG! It disallows
# fields containing a single visible character surrounded by whitespace,
# e.g. "foo a bar".
#
# See: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
#
# So our definition of field_content attempts to fix it up...
#
# Also, we allow lots of control characters, because apparently people assume
# that they're legal in practice (e.g., google analytics makes cookies with
# \x01 in them!):
# https://github.com/python-hyper/h11/issues/57
# We still don't allow NUL or whitespace, because those are often treated as
# meta-characters and letting them through can lead to nasty issues like SSRF.
vchar = r"[\x21-\x7e]"
vchar_or_obs_text = r"[^\x00\s]"
field_vchar = vchar_or_obs_text
field_content = r"{field_vchar}+(?:[ \t]+{field_vchar}+)*".format(**globals())
# We handle obs-fold at a different level, and our fixed-up field_content
# already grows to swallow the whole value, so ? instead of *
field_value = r"({field_content})?".format(**globals())
# header-field = field-name ":" OWS field-value OWS
header_field = (
r"(?P<field_name>{field_name})"
r":"
r"{OWS}"
r"(?P<field_value>{field_value})"
r"{OWS}".format(**globals())
)
# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#request.line
#
# request-line = method SP request-target SP HTTP-version CRLF
# method = token
# HTTP-version = HTTP-name "/" DIGIT "." DIGIT
# HTTP-name = %x48.54.54.50 ; "HTTP", case-sensitive
#
# request-target is complicated (see RFC 7230 sec 5.3) -- could be path, full
# URL, host+port (for connect), or even "*", but in any case we are guaranteed
# that it contists of the visible printing characters.
method = token
request_target = r"{vchar}+".format(**globals())
http_version = r"HTTP/(?P<http_version>[0-9]\.[0-9])"
request_line = (
r"(?P<method>{method})"
r" "
r"(?P<target>{request_target})"
r" "
r"{http_version}".format(**globals())
)
# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#status.line
#
# status-line = HTTP-version SP status-code SP reason-phrase CRLF
# status-code = 3DIGIT
# reason-phrase = *( HTAB / SP / VCHAR / obs-text )
status_code = r"[0-9]{3}"
reason_phrase = r"([ \t]|{vchar_or_obs_text})*".format(**globals())
status_line = (
r"{http_version}"
r" "
r"(?P<status_code>{status_code})"
# However, there are apparently a few too many servers out there that just
# leave out the reason phrase:
# https://github.com/scrapy/scrapy/issues/345#issuecomment-281756036
# https://github.com/seanmonstar/httparse/issues/29
# so make it optional. ?: is a non-capturing group.
r"(?: (?P<reason>{reason_phrase}))?".format(**globals())
)
HEXDIG = r"[0-9A-Fa-f]"
# Actually
#
# chunk-size = 1*HEXDIG
#
# but we impose an upper-limit to avoid ridiculosity. len(str(2**64)) == 20
chunk_size = r"({HEXDIG}){{1,20}}".format(**globals())
# Actually
#
# chunk-ext = *( ";" chunk-ext-name [ "=" chunk-ext-val ] )
#
# but we aren't parsing the things so we don't really care.
chunk_ext = r";.*"
chunk_header = (
r"(?P<chunk_size>{chunk_size})"
r"(?P<chunk_ext>{chunk_ext})?"
r"{OWS}\r\n".format(
**globals()
) # Even though the specification does not allow for extra whitespaces,
# we are lenient with trailing whitespaces because some servers on the wild use it.
)

View File

@ -0,0 +1,633 @@
# This contains the main Connection class. Everything in h11 revolves around
# this.
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
from ._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from ._headers import get_comma_header, has_expect_100_continue, set_comma_header
from ._readers import READERS, ReadersType
from ._receivebuffer import ReceiveBuffer
from ._state import (
_SWITCH_CONNECT,
_SWITCH_UPGRADE,
CLIENT,
ConnectionState,
DONE,
ERROR,
MIGHT_SWITCH_PROTOCOL,
SEND_BODY,
SERVER,
SWITCHED_PROTOCOL,
)
from ._util import ( # Import the internal things we need
LocalProtocolError,
RemoteProtocolError,
Sentinel,
)
from ._writers import WRITERS, WritersType
# Everything in __all__ gets re-exported as part of the h11 public API.
__all__ = ["Connection", "NEED_DATA", "PAUSED"]
class NEED_DATA(Sentinel, metaclass=Sentinel):
pass
class PAUSED(Sentinel, metaclass=Sentinel):
pass
# If we ever have this much buffered without it making a complete parseable
# event, we error out. The only time we really buffer is when reading the
# request/response line + headers together, so this is effectively the limit on
# the size of that.
#
# Some precedents for defaults:
# - node.js: 80 * 1024
# - tomcat: 8 * 1024
# - IIS: 16 * 1024
# - Apache: <8 KiB per line>
DEFAULT_MAX_INCOMPLETE_EVENT_SIZE = 16 * 1024
# RFC 7230's rules for connection lifecycles:
# - If either side says they want to close the connection, then the connection
# must close.
# - HTTP/1.1 defaults to keep-alive unless someone says Connection: close
# - HTTP/1.0 defaults to close unless both sides say Connection: keep-alive
# (and even this is a mess -- e.g. if you're implementing a proxy then
# sending Connection: keep-alive is forbidden).
#
# We simplify life by simply not supporting keep-alive with HTTP/1.0 peers. So
# our rule is:
# - If someone says Connection: close, we will close
# - If someone uses HTTP/1.0, we will close.
def _keep_alive(event: Union[Request, Response]) -> bool:
connection = get_comma_header(event.headers, b"connection")
if b"close" in connection:
return False
if getattr(event, "http_version", b"1.1") < b"1.1":
return False
return True
def _body_framing(
request_method: bytes, event: Union[Request, Response]
) -> Tuple[str, Union[Tuple[()], Tuple[int]]]:
# Called when we enter SEND_BODY to figure out framing information for
# this body.
#
# These are the only two events that can trigger a SEND_BODY state:
assert type(event) in (Request, Response)
# Returns one of:
#
# ("content-length", count)
# ("chunked", ())
# ("http/1.0", ())
#
# which are (lookup key, *args) for constructing body reader/writer
# objects.
#
# Reference: https://tools.ietf.org/html/rfc7230#section-3.3.3
#
# Step 1: some responses always have an empty body, regardless of what the
# headers say.
if type(event) is Response:
if (
event.status_code in (204, 304)
or request_method == b"HEAD"
or (request_method == b"CONNECT" and 200 <= event.status_code < 300)
):
return ("content-length", (0,))
# Section 3.3.3 also lists another case -- responses with status_code
# < 200. For us these are InformationalResponses, not Responses, so
# they can't get into this function in the first place.
assert event.status_code >= 200
# Step 2: check for Transfer-Encoding (T-E beats C-L):
transfer_encodings = get_comma_header(event.headers, b"transfer-encoding")
if transfer_encodings:
assert transfer_encodings == [b"chunked"]
return ("chunked", ())
# Step 3: check for Content-Length
content_lengths = get_comma_header(event.headers, b"content-length")
if content_lengths:
return ("content-length", (int(content_lengths[0]),))
# Step 4: no applicable headers; fallback/default depends on type
if type(event) is Request:
return ("content-length", (0,))
else:
return ("http/1.0", ())
################################################################
#
# The main Connection class
#
################################################################
class Connection:
"""An object encapsulating the state of an HTTP connection.
Args:
our_role: If you're implementing a client, pass :data:`h11.CLIENT`. If
you're implementing a server, pass :data:`h11.SERVER`.
max_incomplete_event_size (int):
The maximum number of bytes we're willing to buffer of an
incomplete event. In practice this mostly sets a limit on the
maximum size of the request/response line + headers. If this is
exceeded, then :meth:`next_event` will raise
:exc:`RemoteProtocolError`.
"""
def __init__(
self,
our_role: Type[Sentinel],
max_incomplete_event_size: int = DEFAULT_MAX_INCOMPLETE_EVENT_SIZE,
) -> None:
self._max_incomplete_event_size = max_incomplete_event_size
# State and role tracking
if our_role not in (CLIENT, SERVER):
raise ValueError("expected CLIENT or SERVER, not {!r}".format(our_role))
self.our_role = our_role
self.their_role: Type[Sentinel]
if our_role is CLIENT:
self.their_role = SERVER
else:
self.their_role = CLIENT
self._cstate = ConnectionState()
# Callables for converting data->events or vice-versa given the
# current state
self._writer = self._get_io_object(self.our_role, None, WRITERS)
self._reader = self._get_io_object(self.their_role, None, READERS)
# Holds any unprocessed received data
self._receive_buffer = ReceiveBuffer()
# If this is true, then it indicates that the incoming connection was
# closed *after* the end of whatever's in self._receive_buffer:
self._receive_buffer_closed = False
# Extra bits of state that don't fit into the state machine.
#
# These two are only used to interpret framing headers for figuring
# out how to read/write response bodies. their_http_version is also
# made available as a convenient public API.
self.their_http_version: Optional[bytes] = None
self._request_method: Optional[bytes] = None
# This is pure flow-control and doesn't at all affect the set of legal
# transitions, so no need to bother ConnectionState with it:
self.client_is_waiting_for_100_continue = False
@property
def states(self) -> Dict[Type[Sentinel], Type[Sentinel]]:
"""A dictionary like::
{CLIENT: <client state>, SERVER: <server state>}
See :ref:`state-machine` for details.
"""
return dict(self._cstate.states)
@property
def our_state(self) -> Type[Sentinel]:
"""The current state of whichever role we are playing. See
:ref:`state-machine` for details.
"""
return self._cstate.states[self.our_role]
@property
def their_state(self) -> Type[Sentinel]:
"""The current state of whichever role we are NOT playing. See
:ref:`state-machine` for details.
"""
return self._cstate.states[self.their_role]
@property
def they_are_waiting_for_100_continue(self) -> bool:
return self.their_role is CLIENT and self.client_is_waiting_for_100_continue
def start_next_cycle(self) -> None:
"""Attempt to reset our connection state for a new request/response
cycle.
If both client and server are in :data:`DONE` state, then resets them
both to :data:`IDLE` state in preparation for a new request/response
cycle on this same connection. Otherwise, raises a
:exc:`LocalProtocolError`.
See :ref:`keepalive-and-pipelining`.
"""
old_states = dict(self._cstate.states)
self._cstate.start_next_cycle()
self._request_method = None
# self.their_http_version gets left alone, since it presumably lasts
# beyond a single request/response cycle
assert not self.client_is_waiting_for_100_continue
self._respond_to_state_changes(old_states)
def _process_error(self, role: Type[Sentinel]) -> None:
old_states = dict(self._cstate.states)
self._cstate.process_error(role)
self._respond_to_state_changes(old_states)
def _server_switch_event(self, event: Event) -> Optional[Type[Sentinel]]:
if type(event) is InformationalResponse and event.status_code == 101:
return _SWITCH_UPGRADE
if type(event) is Response:
if (
_SWITCH_CONNECT in self._cstate.pending_switch_proposals
and 200 <= event.status_code < 300
):
return _SWITCH_CONNECT
return None
# All events go through here
def _process_event(self, role: Type[Sentinel], event: Event) -> None:
# First, pass the event through the state machine to make sure it
# succeeds.
old_states = dict(self._cstate.states)
if role is CLIENT and type(event) is Request:
if event.method == b"CONNECT":
self._cstate.process_client_switch_proposal(_SWITCH_CONNECT)
if get_comma_header(event.headers, b"upgrade"):
self._cstate.process_client_switch_proposal(_SWITCH_UPGRADE)
server_switch_event = None
if role is SERVER:
server_switch_event = self._server_switch_event(event)
self._cstate.process_event(role, type(event), server_switch_event)
# Then perform the updates triggered by it.
if type(event) is Request:
self._request_method = event.method
if role is self.their_role and type(event) in (
Request,
Response,
InformationalResponse,
):
event = cast(Union[Request, Response, InformationalResponse], event)
self.their_http_version = event.http_version
# Keep alive handling
#
# RFC 7230 doesn't really say what one should do if Connection: close
# shows up on a 1xx InformationalResponse. I think the idea is that
# this is not supposed to happen. In any case, if it does happen, we
# ignore it.
if type(event) in (Request, Response) and not _keep_alive(
cast(Union[Request, Response], event)
):
self._cstate.process_keep_alive_disabled()
# 100-continue
if type(event) is Request and has_expect_100_continue(event):
self.client_is_waiting_for_100_continue = True
if type(event) in (InformationalResponse, Response):
self.client_is_waiting_for_100_continue = False
if role is CLIENT and type(event) in (Data, EndOfMessage):
self.client_is_waiting_for_100_continue = False
self._respond_to_state_changes(old_states, event)
def _get_io_object(
self,
role: Type[Sentinel],
event: Optional[Event],
io_dict: Union[ReadersType, WritersType],
) -> Optional[Callable[..., Any]]:
# event may be None; it's only used when entering SEND_BODY
state = self._cstate.states[role]
if state is SEND_BODY:
# Special case: the io_dict has a dict of reader/writer factories
# that depend on the request/response framing.
framing_type, args = _body_framing(
cast(bytes, self._request_method), cast(Union[Request, Response], event)
)
return io_dict[SEND_BODY][framing_type](*args) # type: ignore[index]
else:
# General case: the io_dict just has the appropriate reader/writer
# for this state
return io_dict.get((role, state)) # type: ignore[return-value]
# This must be called after any action that might have caused
# self._cstate.states to change.
def _respond_to_state_changes(
self,
old_states: Dict[Type[Sentinel], Type[Sentinel]],
event: Optional[Event] = None,
) -> None:
# Update reader/writer
if self.our_state != old_states[self.our_role]:
self._writer = self._get_io_object(self.our_role, event, WRITERS)
if self.their_state != old_states[self.their_role]:
self._reader = self._get_io_object(self.their_role, event, READERS)
@property
def trailing_data(self) -> Tuple[bytes, bool]:
"""Data that has been received, but not yet processed, represented as
a tuple with two elements, where the first is a byte-string containing
the unprocessed data itself, and the second is a bool that is True if
the receive connection was closed.
See :ref:`switching-protocols` for discussion of why you'd want this.
"""
return (bytes(self._receive_buffer), self._receive_buffer_closed)
def receive_data(self, data: bytes) -> None:
"""Add data to our internal receive buffer.
This does not actually do any processing on the data, just stores
it. To trigger processing, you have to call :meth:`next_event`.
Args:
data (:term:`bytes-like object`):
The new data that was just received.
Special case: If *data* is an empty byte-string like ``b""``,
then this indicates that the remote side has closed the
connection (end of file). Normally this is convenient, because
standard Python APIs like :meth:`file.read` or
:meth:`socket.recv` use ``b""`` to indicate end-of-file, while
other failures to read are indicated using other mechanisms
like raising :exc:`TimeoutError`. When using such an API you
can just blindly pass through whatever you get from ``read``
to :meth:`receive_data`, and everything will work.
But, if you have an API where reading an empty string is a
valid non-EOF condition, then you need to be aware of this and
make sure to check for such strings and avoid passing them to
:meth:`receive_data`.
Returns:
Nothing, but after calling this you should call :meth:`next_event`
to parse the newly received data.
Raises:
RuntimeError:
Raised if you pass an empty *data*, indicating EOF, and then
pass a non-empty *data*, indicating more data that somehow
arrived after the EOF.
(Calling ``receive_data(b"")`` multiple times is fine,
and equivalent to calling it once.)
"""
if data:
if self._receive_buffer_closed:
raise RuntimeError("received close, then received more data?")
self._receive_buffer += data
else:
self._receive_buffer_closed = True
def _extract_next_receive_event(
self,
) -> Union[Event, Type[NEED_DATA], Type[PAUSED]]:
state = self.their_state
# We don't pause immediately when they enter DONE, because even in
# DONE state we can still process a ConnectionClosed() event. But
# if we have data in our buffer, then we definitely aren't getting
# a ConnectionClosed() immediately and we need to pause.
if state is DONE and self._receive_buffer:
return PAUSED
if state is MIGHT_SWITCH_PROTOCOL or state is SWITCHED_PROTOCOL:
return PAUSED
assert self._reader is not None
event = self._reader(self._receive_buffer)
if event is None:
if not self._receive_buffer and self._receive_buffer_closed:
# In some unusual cases (basically just HTTP/1.0 bodies), EOF
# triggers an actual protocol event; in that case, we want to
# return that event, and then the state will change and we'll
# get called again to generate the actual ConnectionClosed().
if hasattr(self._reader, "read_eof"):
event = self._reader.read_eof() # type: ignore[attr-defined]
else:
event = ConnectionClosed()
if event is None:
event = NEED_DATA
return event # type: ignore[no-any-return]
def next_event(self) -> Union[Event, Type[NEED_DATA], Type[PAUSED]]:
"""Parse the next event out of our receive buffer, update our internal
state, and return it.
This is a mutating operation -- think of it like calling :func:`next`
on an iterator.
Returns:
: One of three things:
1) An event object -- see :ref:`events`.
2) The special constant :data:`NEED_DATA`, which indicates that
you need to read more data from your socket and pass it to
:meth:`receive_data` before this method will be able to return
any more events.
3) The special constant :data:`PAUSED`, which indicates that we
are not in a state where we can process incoming data (usually
because the peer has finished their part of the current
request/response cycle, and you have not yet called
:meth:`start_next_cycle`). See :ref:`flow-control` for details.
Raises:
RemoteProtocolError:
The peer has misbehaved. You should close the connection
(possibly after sending some kind of 4xx response).
Once this method returns :class:`ConnectionClosed` once, then all
subsequent calls will also return :class:`ConnectionClosed`.
If this method raises any exception besides :exc:`RemoteProtocolError`
then that's a bug -- if it happens please file a bug report!
If this method raises any exception then it also sets
:attr:`Connection.their_state` to :data:`ERROR` -- see
:ref:`error-handling` for discussion.
"""
if self.their_state is ERROR:
raise RemoteProtocolError("Can't receive data when peer state is ERROR")
try:
event = self._extract_next_receive_event()
if event not in [NEED_DATA, PAUSED]:
self._process_event(self.their_role, cast(Event, event))
if event is NEED_DATA:
if len(self._receive_buffer) > self._max_incomplete_event_size:
# 431 is "Request header fields too large" which is pretty
# much the only situation where we can get here
raise RemoteProtocolError(
"Receive buffer too long", error_status_hint=431
)
if self._receive_buffer_closed:
# We're still trying to complete some event, but that's
# never going to happen because no more data is coming
raise RemoteProtocolError("peer unexpectedly closed connection")
return event
except BaseException as exc:
self._process_error(self.their_role)
if isinstance(exc, LocalProtocolError):
exc._reraise_as_remote_protocol_error()
else:
raise
def send(self, event: Event) -> Optional[bytes]:
"""Convert a high-level event into bytes that can be sent to the peer,
while updating our internal state machine.
Args:
event: The :ref:`event <events>` to send.
Returns:
If ``type(event) is ConnectionClosed``, then returns
``None``. Otherwise, returns a :term:`bytes-like object`.
Raises:
LocalProtocolError:
Sending this event at this time would violate our
understanding of the HTTP/1.1 protocol.
If this method raises any exception then it also sets
:attr:`Connection.our_state` to :data:`ERROR` -- see
:ref:`error-handling` for discussion.
"""
data_list = self.send_with_data_passthrough(event)
if data_list is None:
return None
else:
return b"".join(data_list)
def send_with_data_passthrough(self, event: Event) -> Optional[List[bytes]]:
"""Identical to :meth:`send`, except that in situations where
:meth:`send` returns a single :term:`bytes-like object`, this instead
returns a list of them -- and when sending a :class:`Data` event, this
list is guaranteed to contain the exact object you passed in as
:attr:`Data.data`. See :ref:`sendfile` for discussion.
"""
if self.our_state is ERROR:
raise LocalProtocolError("Can't send data when our state is ERROR")
try:
if type(event) is Response:
event = self._clean_up_response_headers_for_sending(event)
# We want to call _process_event before calling the writer,
# because if someone tries to do something invalid then this will
# give a sensible error message, while our writers all just assume
# they will only receive valid events. But, _process_event might
# change self._writer. So we have to do a little dance:
writer = self._writer
self._process_event(self.our_role, event)
if type(event) is ConnectionClosed:
return None
else:
# In any situation where writer is None, process_event should
# have raised ProtocolError
assert writer is not None
data_list: List[bytes] = []
writer(event, data_list.append)
return data_list
except:
self._process_error(self.our_role)
raise
def send_failed(self) -> None:
"""Notify the state machine that we failed to send the data it gave
us.
This causes :attr:`Connection.our_state` to immediately become
:data:`ERROR` -- see :ref:`error-handling` for discussion.
"""
self._process_error(self.our_role)
# When sending a Response, we take responsibility for a few things:
#
# - Sometimes you MUST set Connection: close. We take care of those
# times. (You can also set it yourself if you want, and if you do then
# we'll respect that and close the connection at the right time. But you
# don't have to worry about that unless you want to.)
#
# - The user has to set Content-Length if they want it. Otherwise, for
# responses that have bodies (e.g. not HEAD), then we will automatically
# select the right mechanism for streaming a body of unknown length,
# which depends on depending on the peer's HTTP version.
#
# This function's *only* responsibility is making sure headers are set up
# right -- everything downstream just looks at the headers. There are no
# side channels.
def _clean_up_response_headers_for_sending(self, response: Response) -> Response:
assert type(response) is Response
headers = response.headers
need_close = False
# HEAD requests need some special handling: they always act like they
# have Content-Length: 0, and that's how _body_framing treats
# them. But their headers are supposed to match what we would send if
# the request was a GET. (Technically there is one deviation allowed:
# we're allowed to leave out the framing headers -- see
# https://tools.ietf.org/html/rfc7231#section-4.3.2 . But it's just as
# easy to get them right.)
method_for_choosing_headers = cast(bytes, self._request_method)
if method_for_choosing_headers == b"HEAD":
method_for_choosing_headers = b"GET"
framing_type, _ = _body_framing(method_for_choosing_headers, response)
if framing_type in ("chunked", "http/1.0"):
# This response has a body of unknown length.
# If our peer is HTTP/1.1, we use Transfer-Encoding: chunked
# If our peer is HTTP/1.0, we use no framing headers, and close the
# connection afterwards.
#
# Make sure to clear Content-Length (in principle user could have
# set both and then we ignored Content-Length b/c
# Transfer-Encoding overwrote it -- this would be naughty of them,
# but the HTTP spec says that if our peer does this then we have
# to fix it instead of erroring out, so we'll accord the user the
# same respect).
headers = set_comma_header(headers, b"content-length", [])
if self.their_http_version is None or self.their_http_version < b"1.1":
# Either we never got a valid request and are sending back an
# error (their_http_version is None), so we assume the worst;
# or else we did get a valid HTTP/1.0 request, so we know that
# they don't understand chunked encoding.
headers = set_comma_header(headers, b"transfer-encoding", [])
# This is actually redundant ATM, since currently we
# unconditionally disable keep-alive when talking to HTTP/1.0
# peers. But let's be defensive just in case we add
# Connection: keep-alive support later:
if self._request_method != b"HEAD":
need_close = True
else:
headers = set_comma_header(headers, b"transfer-encoding", [b"chunked"])
if not self._cstate.keep_alive or need_close:
# Make sure Connection: close is set
connection = set(get_comma_header(headers, b"connection"))
connection.discard(b"keep-alive")
connection.add(b"close")
headers = set_comma_header(headers, b"connection", sorted(connection))
return Response(
headers=headers,
status_code=response.status_code,
http_version=response.http_version,
reason=response.reason,
)

View File

@ -0,0 +1,369 @@
# High level events that make up HTTP/1.1 conversations. Loosely inspired by
# the corresponding events in hyper-h2:
#
# http://python-hyper.org/h2/en/stable/api.html#events
#
# Don't subclass these. Stuff will break.
import re
from abc import ABC
from dataclasses import dataclass, field
from typing import Any, cast, Dict, List, Tuple, Union
from ._abnf import method, request_target
from ._headers import Headers, normalize_and_validate
from ._util import bytesify, LocalProtocolError, validate
# Everything in __all__ gets re-exported as part of the h11 public API.
__all__ = [
"Event",
"Request",
"InformationalResponse",
"Response",
"Data",
"EndOfMessage",
"ConnectionClosed",
]
method_re = re.compile(method.encode("ascii"))
request_target_re = re.compile(request_target.encode("ascii"))
class Event(ABC):
"""
Base class for h11 events.
"""
__slots__ = ()
@dataclass(init=False, frozen=True)
class Request(Event):
"""The beginning of an HTTP request.
Fields:
.. attribute:: method
An HTTP method, e.g. ``b"GET"`` or ``b"POST"``. Always a byte
string. :term:`Bytes-like objects <bytes-like object>` and native
strings containing only ascii characters will be automatically
converted to byte strings.
.. attribute:: target
The target of an HTTP request, e.g. ``b"/index.html"``, or one of the
more exotic formats described in `RFC 7320, section 5.3
<https://tools.ietf.org/html/rfc7230#section-5.3>`_. Always a byte
string. :term:`Bytes-like objects <bytes-like object>` and native
strings containing only ascii characters will be automatically
converted to byte strings.
.. attribute:: headers
Request headers, represented as a list of (name, value) pairs. See
:ref:`the header normalization rules <headers-format>` for details.
.. attribute:: http_version
The HTTP protocol version, represented as a byte string like
``b"1.1"``. See :ref:`the HTTP version normalization rules
<http_version-format>` for details.
"""
__slots__ = ("method", "headers", "target", "http_version")
method: bytes
headers: Headers
target: bytes
http_version: bytes
def __init__(
self,
*,
method: Union[bytes, str],
headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]],
target: Union[bytes, str],
http_version: Union[bytes, str] = b"1.1",
_parsed: bool = False,
) -> None:
super().__init__()
if isinstance(headers, Headers):
object.__setattr__(self, "headers", headers)
else:
object.__setattr__(
self, "headers", normalize_and_validate(headers, _parsed=_parsed)
)
if not _parsed:
object.__setattr__(self, "method", bytesify(method))
object.__setattr__(self, "target", bytesify(target))
object.__setattr__(self, "http_version", bytesify(http_version))
else:
object.__setattr__(self, "method", method)
object.__setattr__(self, "target", target)
object.__setattr__(self, "http_version", http_version)
# "A server MUST respond with a 400 (Bad Request) status code to any
# HTTP/1.1 request message that lacks a Host header field and to any
# request message that contains more than one Host header field or a
# Host header field with an invalid field-value."
# -- https://tools.ietf.org/html/rfc7230#section-5.4
host_count = 0
for name, value in self.headers:
if name == b"host":
host_count += 1
if self.http_version == b"1.1" and host_count == 0:
raise LocalProtocolError("Missing mandatory Host: header")
if host_count > 1:
raise LocalProtocolError("Found multiple Host: headers")
validate(method_re, self.method, "Illegal method characters")
validate(request_target_re, self.target, "Illegal target characters")
# This is an unhashable type.
__hash__ = None # type: ignore
@dataclass(init=False, frozen=True)
class _ResponseBase(Event):
__slots__ = ("headers", "http_version", "reason", "status_code")
headers: Headers
http_version: bytes
reason: bytes
status_code: int
def __init__(
self,
*,
headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]],
status_code: int,
http_version: Union[bytes, str] = b"1.1",
reason: Union[bytes, str] = b"",
_parsed: bool = False,
) -> None:
super().__init__()
if isinstance(headers, Headers):
object.__setattr__(self, "headers", headers)
else:
object.__setattr__(
self, "headers", normalize_and_validate(headers, _parsed=_parsed)
)
if not _parsed:
object.__setattr__(self, "reason", bytesify(reason))
object.__setattr__(self, "http_version", bytesify(http_version))
if not isinstance(status_code, int):
raise LocalProtocolError("status code must be integer")
# Because IntEnum objects are instances of int, but aren't
# duck-compatible (sigh), see gh-72.
object.__setattr__(self, "status_code", int(status_code))
else:
object.__setattr__(self, "reason", reason)
object.__setattr__(self, "http_version", http_version)
object.__setattr__(self, "status_code", status_code)
self.__post_init__()
def __post_init__(self) -> None:
pass
# This is an unhashable type.
__hash__ = None # type: ignore
@dataclass(init=False, frozen=True)
class InformationalResponse(_ResponseBase):
"""An HTTP informational response.
Fields:
.. attribute:: status_code
The status code of this response, as an integer. For an
:class:`InformationalResponse`, this is always in the range [100,
200).
.. attribute:: headers
Request headers, represented as a list of (name, value) pairs. See
:ref:`the header normalization rules <headers-format>` for
details.
.. attribute:: http_version
The HTTP protocol version, represented as a byte string like
``b"1.1"``. See :ref:`the HTTP version normalization rules
<http_version-format>` for details.
.. attribute:: reason
The reason phrase of this response, as a byte string. For example:
``b"OK"``, or ``b"Not Found"``.
"""
def __post_init__(self) -> None:
if not (100 <= self.status_code < 200):
raise LocalProtocolError(
"InformationalResponse status_code should be in range "
"[100, 200), not {}".format(self.status_code)
)
# This is an unhashable type.
__hash__ = None # type: ignore
@dataclass(init=False, frozen=True)
class Response(_ResponseBase):
"""The beginning of an HTTP response.
Fields:
.. attribute:: status_code
The status code of this response, as an integer. For an
:class:`Response`, this is always in the range [200,
1000).
.. attribute:: headers
Request headers, represented as a list of (name, value) pairs. See
:ref:`the header normalization rules <headers-format>` for details.
.. attribute:: http_version
The HTTP protocol version, represented as a byte string like
``b"1.1"``. See :ref:`the HTTP version normalization rules
<http_version-format>` for details.
.. attribute:: reason
The reason phrase of this response, as a byte string. For example:
``b"OK"``, or ``b"Not Found"``.
"""
def __post_init__(self) -> None:
if not (200 <= self.status_code < 1000):
raise LocalProtocolError(
"Response status_code should be in range [200, 1000), not {}".format(
self.status_code
)
)
# This is an unhashable type.
__hash__ = None # type: ignore
@dataclass(init=False, frozen=True)
class Data(Event):
"""Part of an HTTP message body.
Fields:
.. attribute:: data
A :term:`bytes-like object` containing part of a message body. Or, if
using the ``combine=False`` argument to :meth:`Connection.send`, then
any object that your socket writing code knows what to do with, and for
which calling :func:`len` returns the number of bytes that will be
written -- see :ref:`sendfile` for details.
.. attribute:: chunk_start
A marker that indicates whether this data object is from the start of a
chunked transfer encoding chunk. This field is ignored when when a Data
event is provided to :meth:`Connection.send`: it is only valid on
events emitted from :meth:`Connection.next_event`. You probably
shouldn't use this attribute at all; see
:ref:`chunk-delimiters-are-bad` for details.
.. attribute:: chunk_end
A marker that indicates whether this data object is the last for a
given chunked transfer encoding chunk. This field is ignored when when
a Data event is provided to :meth:`Connection.send`: it is only valid
on events emitted from :meth:`Connection.next_event`. You probably
shouldn't use this attribute at all; see
:ref:`chunk-delimiters-are-bad` for details.
"""
__slots__ = ("data", "chunk_start", "chunk_end")
data: bytes
chunk_start: bool
chunk_end: bool
def __init__(
self, data: bytes, chunk_start: bool = False, chunk_end: bool = False
) -> None:
object.__setattr__(self, "data", data)
object.__setattr__(self, "chunk_start", chunk_start)
object.__setattr__(self, "chunk_end", chunk_end)
# This is an unhashable type.
__hash__ = None # type: ignore
# XX FIXME: "A recipient MUST ignore (or consider as an error) any fields that
# are forbidden to be sent in a trailer, since processing them as if they were
# present in the header section might bypass external security filters."
# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#chunked.trailer.part
# Unfortunately, the list of forbidden fields is long and vague :-/
@dataclass(init=False, frozen=True)
class EndOfMessage(Event):
"""The end of an HTTP message.
Fields:
.. attribute:: headers
Default value: ``[]``
Any trailing headers attached to this message, represented as a list of
(name, value) pairs. See :ref:`the header normalization rules
<headers-format>` for details.
Must be empty unless ``Transfer-Encoding: chunked`` is in use.
"""
__slots__ = ("headers",)
headers: Headers
def __init__(
self,
*,
headers: Union[
Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]], None
] = None,
_parsed: bool = False,
) -> None:
super().__init__()
if headers is None:
headers = Headers([])
elif not isinstance(headers, Headers):
headers = normalize_and_validate(headers, _parsed=_parsed)
object.__setattr__(self, "headers", headers)
# This is an unhashable type.
__hash__ = None # type: ignore
@dataclass(frozen=True)
class ConnectionClosed(Event):
"""This event indicates that the sender has closed their outgoing
connection.
Note that this does not necessarily mean that they can't *receive* further
data, because TCP connections are composed to two one-way channels which
can be closed independently. See :ref:`closing` for details.
No fields.
"""
pass

View File

@ -0,0 +1,278 @@
import re
from typing import AnyStr, cast, List, overload, Sequence, Tuple, TYPE_CHECKING, Union
from ._abnf import field_name, field_value
from ._util import bytesify, LocalProtocolError, validate
if TYPE_CHECKING:
from ._events import Request
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal # type: ignore
# Facts
# -----
#
# Headers are:
# keys: case-insensitive ascii
# values: mixture of ascii and raw bytes
#
# "Historically, HTTP has allowed field content with text in the ISO-8859-1
# charset [ISO-8859-1], supporting other charsets only through use of
# [RFC2047] encoding. In practice, most HTTP header field values use only a
# subset of the US-ASCII charset [USASCII]. Newly defined header fields SHOULD
# limit their field values to US-ASCII octets. A recipient SHOULD treat other
# octets in field content (obs-text) as opaque data."
# And it deprecates all non-ascii values
#
# Leading/trailing whitespace in header names is forbidden
#
# Values get leading/trailing whitespace stripped
#
# Content-Disposition actually needs to contain unicode semantically; to
# accomplish this it has a terrifically weird way of encoding the filename
# itself as ascii (and even this still has lots of cross-browser
# incompatibilities)
#
# Order is important:
# "a proxy MUST NOT change the order of these field values when forwarding a
# message"
# (and there are several headers where the order indicates a preference)
#
# Multiple occurences of the same header:
# "A sender MUST NOT generate multiple header fields with the same field name
# in a message unless either the entire field value for that header field is
# defined as a comma-separated list [or the header is Set-Cookie which gets a
# special exception]" - RFC 7230. (cookies are in RFC 6265)
#
# So every header aside from Set-Cookie can be merged by b", ".join if it
# occurs repeatedly. But, of course, they can't necessarily be split by
# .split(b","), because quoting.
#
# Given all this mess (case insensitive, duplicates allowed, order is
# important, ...), there doesn't appear to be any standard way to handle
# headers in Python -- they're almost like dicts, but... actually just
# aren't. For now we punt and just use a super simple representation: headers
# are a list of pairs
#
# [(name1, value1), (name2, value2), ...]
#
# where all entries are bytestrings, names are lowercase and have no
# leading/trailing whitespace, and values are bytestrings with no
# leading/trailing whitespace. Searching and updating are done via naive O(n)
# methods.
#
# Maybe a dict-of-lists would be better?
_content_length_re = re.compile(rb"[0-9]+")
_field_name_re = re.compile(field_name.encode("ascii"))
_field_value_re = re.compile(field_value.encode("ascii"))
class Headers(Sequence[Tuple[bytes, bytes]]):
"""
A list-like interface that allows iterating over headers as byte-pairs
of (lowercased-name, value).
Internally we actually store the representation as three-tuples,
including both the raw original casing, in order to preserve casing
over-the-wire, and the lowercased name, for case-insensitive comparisions.
r = Request(
method="GET",
target="/",
headers=[("Host", "example.org"), ("Connection", "keep-alive")],
http_version="1.1",
)
assert r.headers == [
(b"host", b"example.org"),
(b"connection", b"keep-alive")
]
assert r.headers.raw_items() == [
(b"Host", b"example.org"),
(b"Connection", b"keep-alive")
]
"""
__slots__ = "_full_items"
def __init__(self, full_items: List[Tuple[bytes, bytes, bytes]]) -> None:
self._full_items = full_items
def __bool__(self) -> bool:
return bool(self._full_items)
def __eq__(self, other: object) -> bool:
return list(self) == list(other) # type: ignore
def __len__(self) -> int:
return len(self._full_items)
def __repr__(self) -> str:
return "<Headers(%s)>" % repr(list(self))
def __getitem__(self, idx: int) -> Tuple[bytes, bytes]: # type: ignore[override]
_, name, value = self._full_items[idx]
return (name, value)
def raw_items(self) -> List[Tuple[bytes, bytes]]:
return [(raw_name, value) for raw_name, _, value in self._full_items]
HeaderTypes = Union[
List[Tuple[bytes, bytes]],
List[Tuple[bytes, str]],
List[Tuple[str, bytes]],
List[Tuple[str, str]],
]
@overload
def normalize_and_validate(headers: Headers, _parsed: Literal[True]) -> Headers:
...
@overload
def normalize_and_validate(headers: HeaderTypes, _parsed: Literal[False]) -> Headers:
...
@overload
def normalize_and_validate(
headers: Union[Headers, HeaderTypes], _parsed: bool = False
) -> Headers:
...
def normalize_and_validate(
headers: Union[Headers, HeaderTypes], _parsed: bool = False
) -> Headers:
new_headers = []
seen_content_length = None
saw_transfer_encoding = False
for name, value in headers:
# For headers coming out of the parser, we can safely skip some steps,
# because it always returns bytes and has already run these regexes
# over the data:
if not _parsed:
name = bytesify(name)
value = bytesify(value)
validate(_field_name_re, name, "Illegal header name {!r}", name)
validate(_field_value_re, value, "Illegal header value {!r}", value)
assert isinstance(name, bytes)
assert isinstance(value, bytes)
raw_name = name
name = name.lower()
if name == b"content-length":
lengths = {length.strip() for length in value.split(b",")}
if len(lengths) != 1:
raise LocalProtocolError("conflicting Content-Length headers")
value = lengths.pop()
validate(_content_length_re, value, "bad Content-Length")
if seen_content_length is None:
seen_content_length = value
new_headers.append((raw_name, name, value))
elif seen_content_length != value:
raise LocalProtocolError("conflicting Content-Length headers")
elif name == b"transfer-encoding":
# "A server that receives a request message with a transfer coding
# it does not understand SHOULD respond with 501 (Not
# Implemented)."
# https://tools.ietf.org/html/rfc7230#section-3.3.1
if saw_transfer_encoding:
raise LocalProtocolError(
"multiple Transfer-Encoding headers", error_status_hint=501
)
# "All transfer-coding names are case-insensitive"
# -- https://tools.ietf.org/html/rfc7230#section-4
value = value.lower()
if value != b"chunked":
raise LocalProtocolError(
"Only Transfer-Encoding: chunked is supported",
error_status_hint=501,
)
saw_transfer_encoding = True
new_headers.append((raw_name, name, value))
else:
new_headers.append((raw_name, name, value))
return Headers(new_headers)
def get_comma_header(headers: Headers, name: bytes) -> List[bytes]:
# Should only be used for headers whose value is a list of
# comma-separated, case-insensitive values.
#
# The header name `name` is expected to be lower-case bytes.
#
# Connection: meets these criteria (including cast insensitivity).
#
# Content-Length: technically is just a single value (1*DIGIT), but the
# standard makes reference to implementations that do multiple values, and
# using this doesn't hurt. Ditto, case insensitivity doesn't things either
# way.
#
# Transfer-Encoding: is more complex (allows for quoted strings), so
# splitting on , is actually wrong. For example, this is legal:
#
# Transfer-Encoding: foo; options="1,2", chunked
#
# and should be parsed as
#
# foo; options="1,2"
# chunked
#
# but this naive function will parse it as
#
# foo; options="1
# 2"
# chunked
#
# However, this is okay because the only thing we are going to do with
# any Transfer-Encoding is reject ones that aren't just "chunked", so
# both of these will be treated the same anyway.
#
# Expect: the only legal value is the literal string
# "100-continue". Splitting on commas is harmless. Case insensitive.
#
out: List[bytes] = []
for _, found_name, found_raw_value in headers._full_items:
if found_name == name:
found_raw_value = found_raw_value.lower()
for found_split_value in found_raw_value.split(b","):
found_split_value = found_split_value.strip()
if found_split_value:
out.append(found_split_value)
return out
def set_comma_header(headers: Headers, name: bytes, new_values: List[bytes]) -> Headers:
# The header name `name` is expected to be lower-case bytes.
#
# Note that when we store the header we use title casing for the header
# names, in order to match the conventional HTTP header style.
#
# Simply calling `.title()` is a blunt approach, but it's correct
# here given the cases where we're using `set_comma_header`...
#
# Connection, Content-Length, Transfer-Encoding.
new_headers: List[Tuple[bytes, bytes]] = []
for found_raw_name, found_name, found_raw_value in headers._full_items:
if found_name != name:
new_headers.append((found_raw_name, found_raw_value))
for new_value in new_values:
new_headers.append((name.title(), new_value))
return normalize_and_validate(new_headers)
def has_expect_100_continue(request: "Request") -> bool:
# https://tools.ietf.org/html/rfc7231#section-5.1.1
# "A server that receives a 100-continue expectation in an HTTP/1.0 request
# MUST ignore that expectation."
if request.http_version < b"1.1":
return False
expect = get_comma_header(request.headers, b"expect")
return b"100-continue" in expect

View File

@ -0,0 +1,247 @@
# Code to read HTTP data
#
# Strategy: each reader is a callable which takes a ReceiveBuffer object, and
# either:
# 1) consumes some of it and returns an Event
# 2) raises a LocalProtocolError (for consistency -- e.g. we call validate()
# and it might raise a LocalProtocolError, so simpler just to always use
# this)
# 3) returns None, meaning "I need more data"
#
# If they have a .read_eof attribute, then this will be called if an EOF is
# received -- but this is optional. Either way, the actual ConnectionClosed
# event will be generated afterwards.
#
# READERS is a dict describing how to pick a reader. It maps states to either:
# - a reader
# - or, for body readers, a dict of per-framing reader factories
import re
from typing import Any, Callable, Dict, Iterable, NoReturn, Optional, Tuple, Type, Union
from ._abnf import chunk_header, header_field, request_line, status_line
from ._events import Data, EndOfMessage, InformationalResponse, Request, Response
from ._receivebuffer import ReceiveBuffer
from ._state import (
CLIENT,
CLOSED,
DONE,
IDLE,
MUST_CLOSE,
SEND_BODY,
SEND_RESPONSE,
SERVER,
)
from ._util import LocalProtocolError, RemoteProtocolError, Sentinel, validate
__all__ = ["READERS"]
header_field_re = re.compile(header_field.encode("ascii"))
obs_fold_re = re.compile(rb"[ \t]+")
def _obsolete_line_fold(lines: Iterable[bytes]) -> Iterable[bytes]:
it = iter(lines)
last: Optional[bytes] = None
for line in it:
match = obs_fold_re.match(line)
if match:
if last is None:
raise LocalProtocolError("continuation line at start of headers")
if not isinstance(last, bytearray):
# Cast to a mutable type, avoiding copy on append to ensure O(n) time
last = bytearray(last)
last += b" "
last += line[match.end() :]
else:
if last is not None:
yield last
last = line
if last is not None:
yield last
def _decode_header_lines(
lines: Iterable[bytes],
) -> Iterable[Tuple[bytes, bytes]]:
for line in _obsolete_line_fold(lines):
matches = validate(header_field_re, line, "illegal header line: {!r}", line)
yield (matches["field_name"], matches["field_value"])
request_line_re = re.compile(request_line.encode("ascii"))
def maybe_read_from_IDLE_client(buf: ReceiveBuffer) -> Optional[Request]:
lines = buf.maybe_extract_lines()
if lines is None:
if buf.is_next_line_obviously_invalid_request_line():
raise LocalProtocolError("illegal request line")
return None
if not lines:
raise LocalProtocolError("no request line received")
matches = validate(
request_line_re, lines[0], "illegal request line: {!r}", lines[0]
)
return Request(
headers=list(_decode_header_lines(lines[1:])), _parsed=True, **matches
)
status_line_re = re.compile(status_line.encode("ascii"))
def maybe_read_from_SEND_RESPONSE_server(
buf: ReceiveBuffer,
) -> Union[InformationalResponse, Response, None]:
lines = buf.maybe_extract_lines()
if lines is None:
if buf.is_next_line_obviously_invalid_request_line():
raise LocalProtocolError("illegal request line")
return None
if not lines:
raise LocalProtocolError("no response line received")
matches = validate(status_line_re, lines[0], "illegal status line: {!r}", lines[0])
http_version = (
b"1.1" if matches["http_version"] is None else matches["http_version"]
)
reason = b"" if matches["reason"] is None else matches["reason"]
status_code = int(matches["status_code"])
class_: Union[Type[InformationalResponse], Type[Response]] = (
InformationalResponse if status_code < 200 else Response
)
return class_(
headers=list(_decode_header_lines(lines[1:])),
_parsed=True,
status_code=status_code,
reason=reason,
http_version=http_version,
)
class ContentLengthReader:
def __init__(self, length: int) -> None:
self._length = length
self._remaining = length
def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]:
if self._remaining == 0:
return EndOfMessage()
data = buf.maybe_extract_at_most(self._remaining)
if data is None:
return None
self._remaining -= len(data)
return Data(data=data)
def read_eof(self) -> NoReturn:
raise RemoteProtocolError(
"peer closed connection without sending complete message body "
"(received {} bytes, expected {})".format(
self._length - self._remaining, self._length
)
)
chunk_header_re = re.compile(chunk_header.encode("ascii"))
class ChunkedReader:
def __init__(self) -> None:
self._bytes_in_chunk = 0
# After reading a chunk, we have to throw away the trailing \r\n; if
# this is >0 then we discard that many bytes before resuming regular
# de-chunkification.
self._bytes_to_discard = 0
self._reading_trailer = False
def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]:
if self._reading_trailer:
lines = buf.maybe_extract_lines()
if lines is None:
return None
return EndOfMessage(headers=list(_decode_header_lines(lines)))
if self._bytes_to_discard > 0:
data = buf.maybe_extract_at_most(self._bytes_to_discard)
if data is None:
return None
self._bytes_to_discard -= len(data)
if self._bytes_to_discard > 0:
return None
# else, fall through and read some more
assert self._bytes_to_discard == 0
if self._bytes_in_chunk == 0:
# We need to refill our chunk count
chunk_header = buf.maybe_extract_next_line()
if chunk_header is None:
return None
matches = validate(
chunk_header_re,
chunk_header,
"illegal chunk header: {!r}",
chunk_header,
)
# XX FIXME: we discard chunk extensions. Does anyone care?
self._bytes_in_chunk = int(matches["chunk_size"], base=16)
if self._bytes_in_chunk == 0:
self._reading_trailer = True
return self(buf)
chunk_start = True
else:
chunk_start = False
assert self._bytes_in_chunk > 0
data = buf.maybe_extract_at_most(self._bytes_in_chunk)
if data is None:
return None
self._bytes_in_chunk -= len(data)
if self._bytes_in_chunk == 0:
self._bytes_to_discard = 2
chunk_end = True
else:
chunk_end = False
return Data(data=data, chunk_start=chunk_start, chunk_end=chunk_end)
def read_eof(self) -> NoReturn:
raise RemoteProtocolError(
"peer closed connection without sending complete message body "
"(incomplete chunked read)"
)
class Http10Reader:
def __call__(self, buf: ReceiveBuffer) -> Optional[Data]:
data = buf.maybe_extract_at_most(999999999)
if data is None:
return None
return Data(data=data)
def read_eof(self) -> EndOfMessage:
return EndOfMessage()
def expect_nothing(buf: ReceiveBuffer) -> None:
if buf:
raise LocalProtocolError("Got data when expecting EOF")
return None
ReadersType = Dict[
Union[Type[Sentinel], Tuple[Type[Sentinel], Type[Sentinel]]],
Union[Callable[..., Any], Dict[str, Callable[..., Any]]],
]
READERS: ReadersType = {
(CLIENT, IDLE): maybe_read_from_IDLE_client,
(SERVER, IDLE): maybe_read_from_SEND_RESPONSE_server,
(SERVER, SEND_RESPONSE): maybe_read_from_SEND_RESPONSE_server,
(CLIENT, DONE): expect_nothing,
(CLIENT, MUST_CLOSE): expect_nothing,
(CLIENT, CLOSED): expect_nothing,
(SERVER, DONE): expect_nothing,
(SERVER, MUST_CLOSE): expect_nothing,
(SERVER, CLOSED): expect_nothing,
SEND_BODY: {
"chunked": ChunkedReader,
"content-length": ContentLengthReader,
"http/1.0": Http10Reader,
},
}

View File

@ -0,0 +1,153 @@
import re
import sys
from typing import List, Optional, Union
__all__ = ["ReceiveBuffer"]
# Operations we want to support:
# - find next \r\n or \r\n\r\n (\n or \n\n are also acceptable),
# or wait until there is one
# - read at-most-N bytes
# Goals:
# - on average, do this fast
# - worst case, do this in O(n) where n is the number of bytes processed
# Plan:
# - store bytearray, offset, how far we've searched for a separator token
# - use the how-far-we've-searched data to avoid rescanning
# - while doing a stream of uninterrupted processing, advance offset instead
# of constantly copying
# WARNING:
# - I haven't benchmarked or profiled any of this yet.
#
# Note that starting in Python 3.4, deleting the initial n bytes from a
# bytearray is amortized O(n), thanks to some excellent work by Antoine
# Martin:
#
# https://bugs.python.org/issue19087
#
# This means that if we only supported 3.4+, we could get rid of the code here
# involving self._start and self.compress, because it's doing exactly the same
# thing that bytearray now does internally.
#
# BUT unfortunately, we still support 2.7, and reading short segments out of a
# long buffer MUST be O(bytes read) to avoid DoS issues, so we can't actually
# delete this code. Yet:
#
# https://pythonclock.org/
#
# (Two things to double-check first though: make sure PyPy also has the
# optimization, and benchmark to make sure it's a win, since we do have a
# slightly clever thing where we delay calling compress() until we've
# processed a whole event, which could in theory be slightly more efficient
# than the internal bytearray support.)
blank_line_regex = re.compile(b"\n\r?\n", re.MULTILINE)
class ReceiveBuffer:
def __init__(self) -> None:
self._data = bytearray()
self._next_line_search = 0
self._multiple_lines_search = 0
def __iadd__(self, byteslike: Union[bytes, bytearray]) -> "ReceiveBuffer":
self._data += byteslike
return self
def __bool__(self) -> bool:
return bool(len(self))
def __len__(self) -> int:
return len(self._data)
# for @property unprocessed_data
def __bytes__(self) -> bytes:
return bytes(self._data)
def _extract(self, count: int) -> bytearray:
# extracting an initial slice of the data buffer and return it
out = self._data[:count]
del self._data[:count]
self._next_line_search = 0
self._multiple_lines_search = 0
return out
def maybe_extract_at_most(self, count: int) -> Optional[bytearray]:
"""
Extract a fixed number of bytes from the buffer.
"""
out = self._data[:count]
if not out:
return None
return self._extract(count)
def maybe_extract_next_line(self) -> Optional[bytearray]:
"""
Extract the first line, if it is completed in the buffer.
"""
# Only search in buffer space that we've not already looked at.
search_start_index = max(0, self._next_line_search - 1)
partial_idx = self._data.find(b"\r\n", search_start_index)
if partial_idx == -1:
self._next_line_search = len(self._data)
return None
# + 2 is to compensate len(b"\r\n")
idx = partial_idx + 2
return self._extract(idx)
def maybe_extract_lines(self) -> Optional[List[bytearray]]:
"""
Extract everything up to the first blank line, and return a list of lines.
"""
# Handle the case where we have an immediate empty line.
if self._data[:1] == b"\n":
self._extract(1)
return []
if self._data[:2] == b"\r\n":
self._extract(2)
return []
# Only search in buffer space that we've not already looked at.
match = blank_line_regex.search(self._data, self._multiple_lines_search)
if match is None:
self._multiple_lines_search = max(0, len(self._data) - 2)
return None
# Truncate the buffer and return it.
idx = match.span(0)[-1]
out = self._extract(idx)
lines = out.split(b"\n")
for line in lines:
if line.endswith(b"\r"):
del line[-1]
assert lines[-2] == lines[-1] == b""
del lines[-2:]
return lines
# In theory we should wait until `\r\n` before starting to validate
# incoming data. However it's interesting to detect (very) invalid data
# early given they might not even contain `\r\n` at all (hence only
# timeout will get rid of them).
# This is not a 100% effective detection but more of a cheap sanity check
# allowing for early abort in some useful cases.
# This is especially interesting when peer is messing up with HTTPS and
# sent us a TLS stream where we were expecting plain HTTP given all
# versions of TLS so far start handshake with a 0x16 message type code.
def is_next_line_obviously_invalid_request_line(self) -> bool:
try:
# HTTP header line must not contain non-printable characters
# and should not start with a space
return self._data[0] < 0x21
except IndexError:
return False

View File

@ -0,0 +1,367 @@
################################################################
# The core state machine
################################################################
#
# Rule 1: everything that affects the state machine and state transitions must
# live here in this file. As much as possible goes into the table-based
# representation, but for the bits that don't quite fit, the actual code and
# state must nonetheless live here.
#
# Rule 2: this file does not know about what role we're playing; it only knows
# about HTTP request/response cycles in the abstract. This ensures that we
# don't cheat and apply different rules to local and remote parties.
#
#
# Theory of operation
# ===================
#
# Possibly the simplest way to think about this is that we actually have 5
# different state machines here. Yes, 5. These are:
#
# 1) The client state, with its complicated automaton (see the docs)
# 2) The server state, with its complicated automaton (see the docs)
# 3) The keep-alive state, with possible states {True, False}
# 4) The SWITCH_CONNECT state, with possible states {False, True}
# 5) The SWITCH_UPGRADE state, with possible states {False, True}
#
# For (3)-(5), the first state listed is the initial state.
#
# (1)-(3) are stored explicitly in member variables. The last
# two are stored implicitly in the pending_switch_proposals set as:
# (state of 4) == (_SWITCH_CONNECT in pending_switch_proposals)
# (state of 5) == (_SWITCH_UPGRADE in pending_switch_proposals)
#
# And each of these machines has two different kinds of transitions:
#
# a) Event-triggered
# b) State-triggered
#
# Event triggered is the obvious thing that you'd think it is: some event
# happens, and if it's the right event at the right time then a transition
# happens. But there are somewhat complicated rules for which machines can
# "see" which events. (As a rule of thumb, if a machine "sees" an event, this
# means two things: the event can affect the machine, and if the machine is
# not in a state where it expects that event then it's an error.) These rules
# are:
#
# 1) The client machine sees all h11.events objects emitted by the client.
#
# 2) The server machine sees all h11.events objects emitted by the server.
#
# It also sees the client's Request event.
#
# And sometimes, server events are annotated with a _SWITCH_* event. For
# example, we can have a (Response, _SWITCH_CONNECT) event, which is
# different from a regular Response event.
#
# 3) The keep-alive machine sees the process_keep_alive_disabled() event
# (which is derived from Request/Response events), and this event
# transitions it from True -> False, or from False -> False. There's no way
# to transition back.
#
# 4&5) The _SWITCH_* machines transition from False->True when we get a
# Request that proposes the relevant type of switch (via
# process_client_switch_proposals), and they go from True->False when we
# get a Response that has no _SWITCH_* annotation.
#
# So that's event-triggered transitions.
#
# State-triggered transitions are less standard. What they do here is couple
# the machines together. The way this works is, when certain *joint*
# configurations of states are achieved, then we automatically transition to a
# new *joint* state. So, for example, if we're ever in a joint state with
#
# client: DONE
# keep-alive: False
#
# then the client state immediately transitions to:
#
# client: MUST_CLOSE
#
# This is fundamentally different from an event-based transition, because it
# doesn't matter how we arrived at the {client: DONE, keep-alive: False} state
# -- maybe the client transitioned SEND_BODY -> DONE, or keep-alive
# transitioned True -> False. Either way, once this precondition is satisfied,
# this transition is immediately triggered.
#
# What if two conflicting state-based transitions get enabled at the same
# time? In practice there's only one case where this arises (client DONE ->
# MIGHT_SWITCH_PROTOCOL versus DONE -> MUST_CLOSE), and we resolve it by
# explicitly prioritizing the DONE -> MIGHT_SWITCH_PROTOCOL transition.
#
# Implementation
# --------------
#
# The event-triggered transitions for the server and client machines are all
# stored explicitly in a table. Ditto for the state-triggered transitions that
# involve just the server and client state.
#
# The transitions for the other machines, and the state-triggered transitions
# that involve the other machines, are written out as explicit Python code.
#
# It'd be nice if there were some cleaner way to do all this. This isn't
# *too* terrible, but I feel like it could probably be better.
#
# WARNING
# -------
#
# The script that generates the state machine diagrams for the docs knows how
# to read out the EVENT_TRIGGERED_TRANSITIONS and STATE_TRIGGERED_TRANSITIONS
# tables. But it can't automatically read the transitions that are written
# directly in Python code. So if you touch those, you need to also update the
# script to keep it in sync!
from typing import cast, Dict, Optional, Set, Tuple, Type, Union
from ._events import *
from ._util import LocalProtocolError, Sentinel
# Everything in __all__ gets re-exported as part of the h11 public API.
__all__ = [
"CLIENT",
"SERVER",
"IDLE",
"SEND_RESPONSE",
"SEND_BODY",
"DONE",
"MUST_CLOSE",
"CLOSED",
"MIGHT_SWITCH_PROTOCOL",
"SWITCHED_PROTOCOL",
"ERROR",
]
class CLIENT(Sentinel, metaclass=Sentinel):
pass
class SERVER(Sentinel, metaclass=Sentinel):
pass
# States
class IDLE(Sentinel, metaclass=Sentinel):
pass
class SEND_RESPONSE(Sentinel, metaclass=Sentinel):
pass
class SEND_BODY(Sentinel, metaclass=Sentinel):
pass
class DONE(Sentinel, metaclass=Sentinel):
pass
class MUST_CLOSE(Sentinel, metaclass=Sentinel):
pass
class CLOSED(Sentinel, metaclass=Sentinel):
pass
class ERROR(Sentinel, metaclass=Sentinel):
pass
# Switch types
class MIGHT_SWITCH_PROTOCOL(Sentinel, metaclass=Sentinel):
pass
class SWITCHED_PROTOCOL(Sentinel, metaclass=Sentinel):
pass
class _SWITCH_UPGRADE(Sentinel, metaclass=Sentinel):
pass
class _SWITCH_CONNECT(Sentinel, metaclass=Sentinel):
pass
EventTransitionType = Dict[
Type[Sentinel],
Dict[
Type[Sentinel],
Dict[Union[Type[Event], Tuple[Type[Event], Type[Sentinel]]], Type[Sentinel]],
],
]
EVENT_TRIGGERED_TRANSITIONS: EventTransitionType = {
CLIENT: {
IDLE: {Request: SEND_BODY, ConnectionClosed: CLOSED},
SEND_BODY: {Data: SEND_BODY, EndOfMessage: DONE},
DONE: {ConnectionClosed: CLOSED},
MUST_CLOSE: {ConnectionClosed: CLOSED},
CLOSED: {ConnectionClosed: CLOSED},
MIGHT_SWITCH_PROTOCOL: {},
SWITCHED_PROTOCOL: {},
ERROR: {},
},
SERVER: {
IDLE: {
ConnectionClosed: CLOSED,
Response: SEND_BODY,
# Special case: server sees client Request events, in this form
(Request, CLIENT): SEND_RESPONSE,
},
SEND_RESPONSE: {
InformationalResponse: SEND_RESPONSE,
Response: SEND_BODY,
(InformationalResponse, _SWITCH_UPGRADE): SWITCHED_PROTOCOL,
(Response, _SWITCH_CONNECT): SWITCHED_PROTOCOL,
},
SEND_BODY: {Data: SEND_BODY, EndOfMessage: DONE},
DONE: {ConnectionClosed: CLOSED},
MUST_CLOSE: {ConnectionClosed: CLOSED},
CLOSED: {ConnectionClosed: CLOSED},
SWITCHED_PROTOCOL: {},
ERROR: {},
},
}
StateTransitionType = Dict[
Tuple[Type[Sentinel], Type[Sentinel]], Dict[Type[Sentinel], Type[Sentinel]]
]
# NB: there are also some special-case state-triggered transitions hard-coded
# into _fire_state_triggered_transitions below.
STATE_TRIGGERED_TRANSITIONS: StateTransitionType = {
# (Client state, Server state) -> new states
# Protocol negotiation
(MIGHT_SWITCH_PROTOCOL, SWITCHED_PROTOCOL): {CLIENT: SWITCHED_PROTOCOL},
# Socket shutdown
(CLOSED, DONE): {SERVER: MUST_CLOSE},
(CLOSED, IDLE): {SERVER: MUST_CLOSE},
(ERROR, DONE): {SERVER: MUST_CLOSE},
(DONE, CLOSED): {CLIENT: MUST_CLOSE},
(IDLE, CLOSED): {CLIENT: MUST_CLOSE},
(DONE, ERROR): {CLIENT: MUST_CLOSE},
}
class ConnectionState:
def __init__(self) -> None:
# Extra bits of state that don't quite fit into the state model.
# If this is False then it enables the automatic DONE -> MUST_CLOSE
# transition. Don't set this directly; call .keep_alive_disabled()
self.keep_alive = True
# This is a subset of {UPGRADE, CONNECT}, containing the proposals
# made by the client for switching protocols.
self.pending_switch_proposals: Set[Type[Sentinel]] = set()
self.states: Dict[Type[Sentinel], Type[Sentinel]] = {CLIENT: IDLE, SERVER: IDLE}
def process_error(self, role: Type[Sentinel]) -> None:
self.states[role] = ERROR
self._fire_state_triggered_transitions()
def process_keep_alive_disabled(self) -> None:
self.keep_alive = False
self._fire_state_triggered_transitions()
def process_client_switch_proposal(self, switch_event: Type[Sentinel]) -> None:
self.pending_switch_proposals.add(switch_event)
self._fire_state_triggered_transitions()
def process_event(
self,
role: Type[Sentinel],
event_type: Type[Event],
server_switch_event: Optional[Type[Sentinel]] = None,
) -> None:
_event_type: Union[Type[Event], Tuple[Type[Event], Type[Sentinel]]] = event_type
if server_switch_event is not None:
assert role is SERVER
if server_switch_event not in self.pending_switch_proposals:
raise LocalProtocolError(
"Received server {} event without a pending proposal".format(
server_switch_event
)
)
_event_type = (event_type, server_switch_event)
if server_switch_event is None and _event_type is Response:
self.pending_switch_proposals = set()
self._fire_event_triggered_transitions(role, _event_type)
# Special case: the server state does get to see Request
# events.
if _event_type is Request:
assert role is CLIENT
self._fire_event_triggered_transitions(SERVER, (Request, CLIENT))
self._fire_state_triggered_transitions()
def _fire_event_triggered_transitions(
self,
role: Type[Sentinel],
event_type: Union[Type[Event], Tuple[Type[Event], Type[Sentinel]]],
) -> None:
state = self.states[role]
try:
new_state = EVENT_TRIGGERED_TRANSITIONS[role][state][event_type]
except KeyError:
event_type = cast(Type[Event], event_type)
raise LocalProtocolError(
"can't handle event type {} when role={} and state={}".format(
event_type.__name__, role, self.states[role]
)
) from None
self.states[role] = new_state
def _fire_state_triggered_transitions(self) -> None:
# We apply these rules repeatedly until converging on a fixed point
while True:
start_states = dict(self.states)
# It could happen that both these special-case transitions are
# enabled at the same time:
#
# DONE -> MIGHT_SWITCH_PROTOCOL
# DONE -> MUST_CLOSE
#
# For example, this will always be true of a HTTP/1.0 client
# requesting CONNECT. If this happens, the protocol switch takes
# priority. From there the client will either go to
# SWITCHED_PROTOCOL, in which case it's none of our business when
# they close the connection, or else the server will deny the
# request, in which case the client will go back to DONE and then
# from there to MUST_CLOSE.
if self.pending_switch_proposals:
if self.states[CLIENT] is DONE:
self.states[CLIENT] = MIGHT_SWITCH_PROTOCOL
if not self.pending_switch_proposals:
if self.states[CLIENT] is MIGHT_SWITCH_PROTOCOL:
self.states[CLIENT] = DONE
if not self.keep_alive:
for role in (CLIENT, SERVER):
if self.states[role] is DONE:
self.states[role] = MUST_CLOSE
# Tabular state-triggered transitions
joint_state = (self.states[CLIENT], self.states[SERVER])
changes = STATE_TRIGGERED_TRANSITIONS.get(joint_state, {})
self.states.update(changes)
if self.states == start_states:
# Fixed point reached
return
def start_next_cycle(self) -> None:
if self.states != {CLIENT: DONE, SERVER: DONE}:
raise LocalProtocolError(
"not in a reusable state. self.states={}".format(self.states)
)
# Can't reach DONE/DONE with any of these active, but still, let's be
# sure.
assert self.keep_alive
assert not self.pending_switch_proposals
self.states = {CLIENT: IDLE, SERVER: IDLE}

View File

@ -0,0 +1,135 @@
from typing import Any, Dict, NoReturn, Pattern, Tuple, Type, TypeVar, Union
__all__ = [
"ProtocolError",
"LocalProtocolError",
"RemoteProtocolError",
"validate",
"bytesify",
]
class ProtocolError(Exception):
"""Exception indicating a violation of the HTTP/1.1 protocol.
This as an abstract base class, with two concrete base classes:
:exc:`LocalProtocolError`, which indicates that you tried to do something
that HTTP/1.1 says is illegal, and :exc:`RemoteProtocolError`, which
indicates that the remote peer tried to do something that HTTP/1.1 says is
illegal. See :ref:`error-handling` for details.
In addition to the normal :exc:`Exception` features, it has one attribute:
.. attribute:: error_status_hint
This gives a suggestion as to what status code a server might use if
this error occurred as part of a request.
For a :exc:`RemoteProtocolError`, this is useful as a suggestion for
how you might want to respond to a misbehaving peer, if you're
implementing a server.
For a :exc:`LocalProtocolError`, this can be taken as a suggestion for
how your peer might have responded to *you* if h11 had allowed you to
continue.
The default is 400 Bad Request, a generic catch-all for protocol
violations.
"""
def __init__(self, msg: str, error_status_hint: int = 400) -> None:
if type(self) is ProtocolError:
raise TypeError("tried to directly instantiate ProtocolError")
Exception.__init__(self, msg)
self.error_status_hint = error_status_hint
# Strategy: there are a number of public APIs where a LocalProtocolError can
# be raised (send(), all the different event constructors, ...), and only one
# public API where RemoteProtocolError can be raised
# (receive_data()). Therefore we always raise LocalProtocolError internally,
# and then receive_data will translate this into a RemoteProtocolError.
#
# Internally:
# LocalProtocolError is the generic "ProtocolError".
# Externally:
# LocalProtocolError is for local errors and RemoteProtocolError is for
# remote errors.
class LocalProtocolError(ProtocolError):
def _reraise_as_remote_protocol_error(self) -> NoReturn:
# After catching a LocalProtocolError, use this method to re-raise it
# as a RemoteProtocolError. This method must be called from inside an
# except: block.
#
# An easy way to get an equivalent RemoteProtocolError is just to
# modify 'self' in place.
self.__class__ = RemoteProtocolError # type: ignore
# But the re-raising is somewhat non-trivial -- you might think that
# now that we've modified the in-flight exception object, that just
# doing 'raise' to re-raise it would be enough. But it turns out that
# this doesn't work, because Python tracks the exception type
# (exc_info[0]) separately from the exception object (exc_info[1]),
# and we only modified the latter. So we really do need to re-raise
# the new type explicitly.
# On py3, the traceback is part of the exception object, so our
# in-place modification preserved it and we can just re-raise:
raise self
class RemoteProtocolError(ProtocolError):
pass
def validate(
regex: Pattern[bytes], data: bytes, msg: str = "malformed data", *format_args: Any
) -> Dict[str, bytes]:
match = regex.fullmatch(data)
if not match:
if format_args:
msg = msg.format(*format_args)
raise LocalProtocolError(msg)
return match.groupdict()
# Sentinel values
#
# - Inherit identity-based comparison and hashing from object
# - Have a nice repr
# - Have a *bonus property*: type(sentinel) is sentinel
#
# The bonus property is useful if you want to take the return value from
# next_event() and do some sort of dispatch based on type(event).
_T_Sentinel = TypeVar("_T_Sentinel", bound="Sentinel")
class Sentinel(type):
def __new__(
cls: Type[_T_Sentinel],
name: str,
bases: Tuple[type, ...],
namespace: Dict[str, Any],
**kwds: Any
) -> _T_Sentinel:
assert bases == (Sentinel,)
v = super().__new__(cls, name, bases, namespace, **kwds)
v.__class__ = v # type: ignore
return v
def __repr__(self) -> str:
return self.__name__
# Used for methods, request targets, HTTP versions, header names, and header
# values. Accepts ascii-strings, or bytes/bytearray/memoryview/..., and always
# returns bytes.
def bytesify(s: Union[bytes, bytearray, memoryview, int, str]) -> bytes:
# Fast-path:
if type(s) is bytes:
return s
if isinstance(s, str):
s = s.encode("ascii")
if isinstance(s, int):
raise TypeError("expected bytes-like object, not int")
return bytes(s)

View File

@ -0,0 +1,16 @@
# This file must be kept very simple, because it is consumed from several
# places -- it is imported by h11/__init__.py, execfile'd by setup.py, etc.
# We use a simple scheme:
# 1.0.0 -> 1.0.0+dev -> 1.1.0 -> 1.1.0+dev
# where the +dev versions are never released into the wild, they're just what
# we stick into the VCS in between releases.
#
# This is compatible with PEP 440:
# http://legacy.python.org/dev/peps/pep-0440/
# via the use of the "local suffix" "+dev", which is disallowed on index
# servers and causes 1.0.0+dev to sort after plain 1.0.0, which is what we
# want. (Contrast with the special suffix 1.0.0.dev, which sorts *before*
# 1.0.0.)
__version__ = "0.14.0"

View File

@ -0,0 +1,145 @@
# Code to read HTTP data
#
# Strategy: each writer takes an event + a write-some-bytes function, which is
# calls.
#
# WRITERS is a dict describing how to pick a reader. It maps states to either:
# - a writer
# - or, for body writers, a dict of framin-dependent writer factories
from typing import Any, Callable, Dict, List, Tuple, Type, Union
from ._events import Data, EndOfMessage, Event, InformationalResponse, Request, Response
from ._headers import Headers
from ._state import CLIENT, IDLE, SEND_BODY, SEND_RESPONSE, SERVER
from ._util import LocalProtocolError, Sentinel
__all__ = ["WRITERS"]
Writer = Callable[[bytes], Any]
def write_headers(headers: Headers, write: Writer) -> None:
# "Since the Host field-value is critical information for handling a
# request, a user agent SHOULD generate Host as the first header field
# following the request-line." - RFC 7230
raw_items = headers._full_items
for raw_name, name, value in raw_items:
if name == b"host":
write(b"%s: %s\r\n" % (raw_name, value))
for raw_name, name, value in raw_items:
if name != b"host":
write(b"%s: %s\r\n" % (raw_name, value))
write(b"\r\n")
def write_request(request: Request, write: Writer) -> None:
if request.http_version != b"1.1":
raise LocalProtocolError("I only send HTTP/1.1")
write(b"%s %s HTTP/1.1\r\n" % (request.method, request.target))
write_headers(request.headers, write)
# Shared between InformationalResponse and Response
def write_any_response(
response: Union[InformationalResponse, Response], write: Writer
) -> None:
if response.http_version != b"1.1":
raise LocalProtocolError("I only send HTTP/1.1")
status_bytes = str(response.status_code).encode("ascii")
# We don't bother sending ascii status messages like "OK"; they're
# optional and ignored by the protocol. (But the space after the numeric
# status code is mandatory.)
#
# XX FIXME: could at least make an effort to pull out the status message
# from stdlib's http.HTTPStatus table. Or maybe just steal their enums
# (either by import or copy/paste). We already accept them as status codes
# since they're of type IntEnum < int.
write(b"HTTP/1.1 %s %s\r\n" % (status_bytes, response.reason))
write_headers(response.headers, write)
class BodyWriter:
def __call__(self, event: Event, write: Writer) -> None:
if type(event) is Data:
self.send_data(event.data, write)
elif type(event) is EndOfMessage:
self.send_eom(event.headers, write)
else: # pragma: no cover
assert False
def send_data(self, data: bytes, write: Writer) -> None:
pass
def send_eom(self, headers: Headers, write: Writer) -> None:
pass
#
# These are all careful not to do anything to 'data' except call len(data) and
# write(data). This allows us to transparently pass-through funny objects,
# like placeholder objects referring to files on disk that will be sent via
# sendfile(2).
#
class ContentLengthWriter(BodyWriter):
def __init__(self, length: int) -> None:
self._length = length
def send_data(self, data: bytes, write: Writer) -> None:
self._length -= len(data)
if self._length < 0:
raise LocalProtocolError("Too much data for declared Content-Length")
write(data)
def send_eom(self, headers: Headers, write: Writer) -> None:
if self._length != 0:
raise LocalProtocolError("Too little data for declared Content-Length")
if headers:
raise LocalProtocolError("Content-Length and trailers don't mix")
class ChunkedWriter(BodyWriter):
def send_data(self, data: bytes, write: Writer) -> None:
# if we encoded 0-length data in the naive way, it would look like an
# end-of-message.
if not data:
return
write(b"%x\r\n" % len(data))
write(data)
write(b"\r\n")
def send_eom(self, headers: Headers, write: Writer) -> None:
write(b"0\r\n")
write_headers(headers, write)
class Http10Writer(BodyWriter):
def send_data(self, data: bytes, write: Writer) -> None:
write(data)
def send_eom(self, headers: Headers, write: Writer) -> None:
if headers:
raise LocalProtocolError("can't send trailers to HTTP/1.0 client")
# no need to close the socket ourselves, that will be taken care of by
# Connection: close machinery
WritersType = Dict[
Union[Tuple[Type[Sentinel], Type[Sentinel]], Type[Sentinel]],
Union[
Dict[str, Type[BodyWriter]],
Callable[[Union[InformationalResponse, Response], Writer], None],
Callable[[Request, Writer], None],
],
]
WRITERS: WritersType = {
(CLIENT, IDLE): write_request,
(SERVER, IDLE): write_any_response,
(SERVER, SEND_RESPONSE): write_any_response,
SEND_BODY: {
"chunked": ChunkedWriter,
"content-length": ContentLengthWriter,
"http/1.0": Http10Writer,
},
}

View File

@ -0,0 +1 @@
Marker

View File

View File

@ -0,0 +1 @@
92b12bc045050b55b848d37167a1a63947c364579889ce1d39788e45e9fac9e5

View File

@ -0,0 +1,101 @@
from typing import cast, List, Type, Union, ValuesView
from .._connection import Connection, NEED_DATA, PAUSED
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .._state import CLIENT, CLOSED, DONE, MUST_CLOSE, SERVER
from .._util import Sentinel
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal # type: ignore
def get_all_events(conn: Connection) -> List[Event]:
got_events = []
while True:
event = conn.next_event()
if event in (NEED_DATA, PAUSED):
break
event = cast(Event, event)
got_events.append(event)
if type(event) is ConnectionClosed:
break
return got_events
def receive_and_get(conn: Connection, data: bytes) -> List[Event]:
conn.receive_data(data)
return get_all_events(conn)
# Merges adjacent Data events, converts payloads to bytestrings, and removes
# chunk boundaries.
def normalize_data_events(in_events: List[Event]) -> List[Event]:
out_events: List[Event] = []
for event in in_events:
if type(event) is Data:
event = Data(data=bytes(event.data), chunk_start=False, chunk_end=False)
if out_events and type(out_events[-1]) is type(event) is Data:
out_events[-1] = Data(
data=out_events[-1].data + event.data,
chunk_start=out_events[-1].chunk_start,
chunk_end=out_events[-1].chunk_end,
)
else:
out_events.append(event)
return out_events
# Given that we want to write tests that push some events through a Connection
# and check that its state updates appropriately... we might as make a habit
# of pushing them through two Connections with a fake network link in
# between.
class ConnectionPair:
def __init__(self) -> None:
self.conn = {CLIENT: Connection(CLIENT), SERVER: Connection(SERVER)}
self.other = {CLIENT: SERVER, SERVER: CLIENT}
@property
def conns(self) -> ValuesView[Connection]:
return self.conn.values()
# expect="match" if expect=send_events; expect=[...] to say what expected
def send(
self,
role: Type[Sentinel],
send_events: Union[List[Event], Event],
expect: Union[List[Event], Event, Literal["match"]] = "match",
) -> bytes:
if not isinstance(send_events, list):
send_events = [send_events]
data = b""
closed = False
for send_event in send_events:
new_data = self.conn[role].send(send_event)
if new_data is None:
closed = True
else:
data += new_data
# send uses b"" to mean b"", and None to mean closed
# receive uses b"" to mean closed, and None to mean "try again"
# so we have to translate between the two conventions
if data:
self.conn[self.other[role]].receive_data(data)
if closed:
self.conn[self.other[role]].receive_data(b"")
got_events = get_all_events(self.conn[self.other[role]])
if expect == "match":
expect = send_events
if not isinstance(expect, list):
expect = [expect]
assert got_events == expect
return data

View File

@ -0,0 +1,115 @@
import json
import os.path
import socket
import socketserver
import threading
from contextlib import closing, contextmanager
from http.server import SimpleHTTPRequestHandler
from typing import Callable, Generator
from urllib.request import urlopen
import h11
@contextmanager
def socket_server(
handler: Callable[..., socketserver.BaseRequestHandler]
) -> Generator[socketserver.TCPServer, None, None]:
httpd = socketserver.TCPServer(("127.0.0.1", 0), handler)
thread = threading.Thread(
target=httpd.serve_forever, kwargs={"poll_interval": 0.01}
)
thread.daemon = True
try:
thread.start()
yield httpd
finally:
httpd.shutdown()
test_file_path = os.path.join(os.path.dirname(__file__), "data/test-file")
with open(test_file_path, "rb") as f:
test_file_data = f.read()
class SingleMindedRequestHandler(SimpleHTTPRequestHandler):
def translate_path(self, path: str) -> str:
return test_file_path
def test_h11_as_client() -> None:
with socket_server(SingleMindedRequestHandler) as httpd:
with closing(socket.create_connection(httpd.server_address)) as s:
c = h11.Connection(h11.CLIENT)
s.sendall(
c.send( # type: ignore[arg-type]
h11.Request(
method="GET", target="/foo", headers=[("Host", "localhost")]
)
)
)
s.sendall(c.send(h11.EndOfMessage())) # type: ignore[arg-type]
data = bytearray()
while True:
event = c.next_event()
print(event)
if event is h11.NEED_DATA:
# Use a small read buffer to make things more challenging
# and exercise more paths :-)
c.receive_data(s.recv(10))
continue
if type(event) is h11.Response:
assert event.status_code == 200
if type(event) is h11.Data:
data += event.data
if type(event) is h11.EndOfMessage:
break
assert bytes(data) == test_file_data
class H11RequestHandler(socketserver.BaseRequestHandler):
def handle(self) -> None:
with closing(self.request) as s:
c = h11.Connection(h11.SERVER)
request = None
while True:
event = c.next_event()
if event is h11.NEED_DATA:
# Use a small read buffer to make things more challenging
# and exercise more paths :-)
c.receive_data(s.recv(10))
continue
if type(event) is h11.Request:
request = event
if type(event) is h11.EndOfMessage:
break
assert request is not None
info = json.dumps(
{
"method": request.method.decode("ascii"),
"target": request.target.decode("ascii"),
"headers": {
name.decode("ascii"): value.decode("ascii")
for (name, value) in request.headers
},
}
)
s.sendall(c.send(h11.Response(status_code=200, headers=[]))) # type: ignore[arg-type]
s.sendall(c.send(h11.Data(data=info.encode("ascii"))))
s.sendall(c.send(h11.EndOfMessage()))
def test_h11_as_server() -> None:
with socket_server(H11RequestHandler) as httpd:
host, port = httpd.server_address
url = "http://{}:{}/some-path".format(host, port)
with closing(urlopen(url)) as f:
assert f.getcode() == 200
data = f.read()
info = json.loads(data.decode("ascii"))
print(info)
assert info["method"] == "GET"
assert info["target"] == "/some-path"
assert "urllib" in info["headers"]["user-agent"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,150 @@
from http import HTTPStatus
import pytest
from .. import _events
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .._util import LocalProtocolError
def test_events() -> None:
with pytest.raises(LocalProtocolError):
# Missing Host:
req = Request(
method="GET", target="/", headers=[("a", "b")], http_version="1.1"
)
# But this is okay (HTTP/1.0)
req = Request(method="GET", target="/", headers=[("a", "b")], http_version="1.0")
# fields are normalized
assert req.method == b"GET"
assert req.target == b"/"
assert req.headers == [(b"a", b"b")]
assert req.http_version == b"1.0"
# This is also okay -- has a Host (with weird capitalization, which is ok)
req = Request(
method="GET",
target="/",
headers=[("a", "b"), ("hOSt", "example.com")],
http_version="1.1",
)
# we normalize header capitalization
assert req.headers == [(b"a", b"b"), (b"host", b"example.com")]
# Multiple host is bad too
with pytest.raises(LocalProtocolError):
req = Request(
method="GET",
target="/",
headers=[("Host", "a"), ("Host", "a")],
http_version="1.1",
)
# Even for HTTP/1.0
with pytest.raises(LocalProtocolError):
req = Request(
method="GET",
target="/",
headers=[("Host", "a"), ("Host", "a")],
http_version="1.0",
)
# Header values are validated
for bad_char in "\x00\r\n\f\v":
with pytest.raises(LocalProtocolError):
req = Request(
method="GET",
target="/",
headers=[("Host", "a"), ("Foo", "asd" + bad_char)],
http_version="1.0",
)
# But for compatibility we allow non-whitespace control characters, even
# though they're forbidden by the spec.
Request(
method="GET",
target="/",
headers=[("Host", "a"), ("Foo", "asd\x01\x02\x7f")],
http_version="1.0",
)
# Request target is validated
for bad_byte in b"\x00\x20\x7f\xee":
target = bytearray(b"/")
target.append(bad_byte)
with pytest.raises(LocalProtocolError):
Request(
method="GET", target=target, headers=[("Host", "a")], http_version="1.1"
)
# Request method is validated
with pytest.raises(LocalProtocolError):
Request(
method="GET / HTTP/1.1",
target=target,
headers=[("Host", "a")],
http_version="1.1",
)
ir = InformationalResponse(status_code=100, headers=[("Host", "a")])
assert ir.status_code == 100
assert ir.headers == [(b"host", b"a")]
assert ir.http_version == b"1.1"
with pytest.raises(LocalProtocolError):
InformationalResponse(status_code=200, headers=[("Host", "a")])
resp = Response(status_code=204, headers=[], http_version="1.0") # type: ignore[arg-type]
assert resp.status_code == 204
assert resp.headers == []
assert resp.http_version == b"1.0"
with pytest.raises(LocalProtocolError):
resp = Response(status_code=100, headers=[], http_version="1.0") # type: ignore[arg-type]
with pytest.raises(LocalProtocolError):
Response(status_code="100", headers=[], http_version="1.0") # type: ignore[arg-type]
with pytest.raises(LocalProtocolError):
InformationalResponse(status_code=b"100", headers=[], http_version="1.0") # type: ignore[arg-type]
d = Data(data=b"asdf")
assert d.data == b"asdf"
eom = EndOfMessage()
assert eom.headers == []
cc = ConnectionClosed()
assert repr(cc) == "ConnectionClosed()"
def test_intenum_status_code() -> None:
# https://github.com/python-hyper/h11/issues/72
r = Response(status_code=HTTPStatus.OK, headers=[], http_version="1.0") # type: ignore[arg-type]
assert r.status_code == HTTPStatus.OK
assert type(r.status_code) is not type(HTTPStatus.OK)
assert type(r.status_code) is int
def test_header_casing() -> None:
r = Request(
method="GET",
target="/",
headers=[("Host", "example.org"), ("Connection", "keep-alive")],
http_version="1.1",
)
assert len(r.headers) == 2
assert r.headers[0] == (b"host", b"example.org")
assert r.headers == [(b"host", b"example.org"), (b"connection", b"keep-alive")]
assert r.headers.raw_items() == [
(b"Host", b"example.org"),
(b"Connection", b"keep-alive"),
]

View File

@ -0,0 +1,157 @@
import pytest
from .._events import Request
from .._headers import (
get_comma_header,
has_expect_100_continue,
Headers,
normalize_and_validate,
set_comma_header,
)
from .._util import LocalProtocolError
def test_normalize_and_validate() -> None:
assert normalize_and_validate([("foo", "bar")]) == [(b"foo", b"bar")]
assert normalize_and_validate([(b"foo", b"bar")]) == [(b"foo", b"bar")]
# no leading/trailing whitespace in names
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b"foo ", "bar")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b" foo", "bar")])
# no weird characters in names
with pytest.raises(LocalProtocolError) as excinfo:
normalize_and_validate([(b"foo bar", b"baz")])
assert "foo bar" in str(excinfo.value)
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b"foo\x00bar", b"baz")])
# Not even 8-bit characters:
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b"foo\xffbar", b"baz")])
# And not even the control characters we allow in values:
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b"foo\x01bar", b"baz")])
# no return or NUL characters in values
with pytest.raises(LocalProtocolError) as excinfo:
normalize_and_validate([("foo", "bar\rbaz")])
assert "bar\\rbaz" in str(excinfo.value)
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "bar\nbaz")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "bar\x00baz")])
# no leading/trailing whitespace
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "barbaz ")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", " barbaz")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "barbaz\t")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "\tbarbaz")])
# content-length
assert normalize_and_validate([("Content-Length", "1")]) == [
(b"content-length", b"1")
]
with pytest.raises(LocalProtocolError):
normalize_and_validate([("Content-Length", "asdf")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("Content-Length", "1x")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("Content-Length", "1"), ("Content-Length", "2")])
assert normalize_and_validate(
[("Content-Length", "0"), ("Content-Length", "0")]
) == [(b"content-length", b"0")]
assert normalize_and_validate([("Content-Length", "0 , 0")]) == [
(b"content-length", b"0")
]
with pytest.raises(LocalProtocolError):
normalize_and_validate(
[("Content-Length", "1"), ("Content-Length", "1"), ("Content-Length", "2")]
)
with pytest.raises(LocalProtocolError):
normalize_and_validate([("Content-Length", "1 , 1,2")])
# transfer-encoding
assert normalize_and_validate([("Transfer-Encoding", "chunked")]) == [
(b"transfer-encoding", b"chunked")
]
assert normalize_and_validate([("Transfer-Encoding", "cHuNkEd")]) == [
(b"transfer-encoding", b"chunked")
]
with pytest.raises(LocalProtocolError) as excinfo:
normalize_and_validate([("Transfer-Encoding", "gzip")])
assert excinfo.value.error_status_hint == 501 # Not Implemented
with pytest.raises(LocalProtocolError) as excinfo:
normalize_and_validate(
[("Transfer-Encoding", "chunked"), ("Transfer-Encoding", "gzip")]
)
assert excinfo.value.error_status_hint == 501 # Not Implemented
def test_get_set_comma_header() -> None:
headers = normalize_and_validate(
[
("Connection", "close"),
("whatever", "something"),
("connectiON", "fOo,, , BAR"),
]
)
assert get_comma_header(headers, b"connection") == [b"close", b"foo", b"bar"]
headers = set_comma_header(headers, b"newthing", ["a", "b"]) # type: ignore
with pytest.raises(LocalProtocolError):
set_comma_header(headers, b"newthing", [" a", "b"]) # type: ignore
assert headers == [
(b"connection", b"close"),
(b"whatever", b"something"),
(b"connection", b"fOo,, , BAR"),
(b"newthing", b"a"),
(b"newthing", b"b"),
]
headers = set_comma_header(headers, b"whatever", ["different thing"]) # type: ignore
assert headers == [
(b"connection", b"close"),
(b"connection", b"fOo,, , BAR"),
(b"newthing", b"a"),
(b"newthing", b"b"),
(b"whatever", b"different thing"),
]
def test_has_100_continue() -> None:
assert has_expect_100_continue(
Request(
method="GET",
target="/",
headers=[("Host", "example.com"), ("Expect", "100-continue")],
)
)
assert not has_expect_100_continue(
Request(method="GET", target="/", headers=[("Host", "example.com")])
)
# Case insensitive
assert has_expect_100_continue(
Request(
method="GET",
target="/",
headers=[("Host", "example.com"), ("Expect", "100-Continue")],
)
)
# Doesn't work in HTTP/1.0
assert not has_expect_100_continue(
Request(
method="GET",
target="/",
headers=[("Host", "example.com"), ("Expect", "100-continue")],
http_version="1.0",
)
)

View File

@ -0,0 +1,32 @@
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .helpers import normalize_data_events
def test_normalize_data_events() -> None:
assert normalize_data_events(
[
Data(data=bytearray(b"1")),
Data(data=b"2"),
Response(status_code=200, headers=[]), # type: ignore[arg-type]
Data(data=b"3"),
Data(data=b"4"),
EndOfMessage(),
Data(data=b"5"),
Data(data=b"6"),
Data(data=b"7"),
]
) == [
Data(data=b"12"),
Response(status_code=200, headers=[]), # type: ignore[arg-type]
Data(data=b"34"),
EndOfMessage(),
Data(data=b"567"),
]

View File

@ -0,0 +1,572 @@
from typing import Any, Callable, Generator, List
import pytest
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .._headers import Headers, normalize_and_validate
from .._readers import (
_obsolete_line_fold,
ChunkedReader,
ContentLengthReader,
Http10Reader,
READERS,
)
from .._receivebuffer import ReceiveBuffer
from .._state import (
CLIENT,
CLOSED,
DONE,
IDLE,
MIGHT_SWITCH_PROTOCOL,
MUST_CLOSE,
SEND_BODY,
SEND_RESPONSE,
SERVER,
SWITCHED_PROTOCOL,
)
from .._util import LocalProtocolError
from .._writers import (
ChunkedWriter,
ContentLengthWriter,
Http10Writer,
write_any_response,
write_headers,
write_request,
WRITERS,
)
from .helpers import normalize_data_events
SIMPLE_CASES = [
(
(CLIENT, IDLE),
Request(
method="GET",
target="/a",
headers=[("Host", "foo"), ("Connection", "close")],
),
b"GET /a HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
Response(status_code=200, headers=[("Connection", "close")], reason=b"OK"),
b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
Response(status_code=200, headers=[], reason=b"OK"), # type: ignore[arg-type]
b"HTTP/1.1 200 OK\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
InformationalResponse(
status_code=101, headers=[("Upgrade", "websocket")], reason=b"Upgrade"
),
b"HTTP/1.1 101 Upgrade\r\nUpgrade: websocket\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
InformationalResponse(status_code=101, headers=[], reason=b"Upgrade"), # type: ignore[arg-type]
b"HTTP/1.1 101 Upgrade\r\n\r\n",
),
]
def dowrite(writer: Callable[..., None], obj: Any) -> bytes:
got_list: List[bytes] = []
writer(obj, got_list.append)
return b"".join(got_list)
def tw(writer: Any, obj: Any, expected: Any) -> None:
got = dowrite(writer, obj)
assert got == expected
def makebuf(data: bytes) -> ReceiveBuffer:
buf = ReceiveBuffer()
buf += data
return buf
def tr(reader: Any, data: bytes, expected: Any) -> None:
def check(got: Any) -> None:
assert got == expected
# Headers should always be returned as bytes, not e.g. bytearray
# https://github.com/python-hyper/wsproto/pull/54#issuecomment-377709478
for name, value in getattr(got, "headers", []):
assert type(name) is bytes
assert type(value) is bytes
# Simple: consume whole thing
buf = makebuf(data)
check(reader(buf))
assert not buf
# Incrementally growing buffer
buf = ReceiveBuffer()
for i in range(len(data)):
assert reader(buf) is None
buf += data[i : i + 1]
check(reader(buf))
# Trailing data
buf = makebuf(data)
buf += b"trailing"
check(reader(buf))
assert bytes(buf) == b"trailing"
def test_writers_simple() -> None:
for ((role, state), event, binary) in SIMPLE_CASES:
tw(WRITERS[role, state], event, binary)
def test_readers_simple() -> None:
for ((role, state), event, binary) in SIMPLE_CASES:
tr(READERS[role, state], binary, event)
def test_writers_unusual() -> None:
# Simple test of the write_headers utility routine
tw(
write_headers,
normalize_and_validate([("foo", "bar"), ("baz", "quux")]),
b"foo: bar\r\nbaz: quux\r\n\r\n",
)
tw(write_headers, Headers([]), b"\r\n")
# We understand HTTP/1.0, but we don't speak it
with pytest.raises(LocalProtocolError):
tw(
write_request,
Request(
method="GET",
target="/",
headers=[("Host", "foo"), ("Connection", "close")],
http_version="1.0",
),
None,
)
with pytest.raises(LocalProtocolError):
tw(
write_any_response,
Response(
status_code=200, headers=[("Connection", "close")], http_version="1.0"
),
None,
)
def test_readers_unusual() -> None:
# Reading HTTP/1.0
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.0\r\nSome: header\r\n\r\n",
Request(
method="HEAD",
target="/foo",
headers=[("Some", "header")],
http_version="1.0",
),
)
# check no-headers, since it's only legal with HTTP/1.0
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.0\r\n\r\n",
Request(method="HEAD", target="/foo", headers=[], http_version="1.0"), # type: ignore[arg-type]
)
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200 OK\r\nSome: header\r\n\r\n",
Response(
status_code=200,
headers=[("Some", "header")],
http_version="1.0",
reason=b"OK",
),
)
# single-character header values (actually disallowed by the ABNF in RFC
# 7230 -- this is a bug in the standard that we originally copied...)
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200 OK\r\n" b"Foo: a a a a a \r\n\r\n",
Response(
status_code=200,
headers=[("Foo", "a a a a a")],
http_version="1.0",
reason=b"OK",
),
)
# Empty headers -- also legal
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200 OK\r\n" b"Foo:\r\n\r\n",
Response(
status_code=200, headers=[("Foo", "")], http_version="1.0", reason=b"OK"
),
)
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200 OK\r\n" b"Foo: \t \t \r\n\r\n",
Response(
status_code=200, headers=[("Foo", "")], http_version="1.0", reason=b"OK"
),
)
# Tolerate broken servers that leave off the response code
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200\r\n" b"Foo: bar\r\n\r\n",
Response(
status_code=200, headers=[("Foo", "bar")], http_version="1.0", reason=b""
),
)
# Tolerate headers line endings (\r\n and \n)
# \n\r\b between headers and body
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.1 200 OK\r\nSomeHeader: val\n\r\n",
Response(
status_code=200,
headers=[("SomeHeader", "val")],
http_version="1.1",
reason="OK",
),
)
# delimited only with \n
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.1 200 OK\nSomeHeader1: val1\nSomeHeader2: val2\n\n",
Response(
status_code=200,
headers=[("SomeHeader1", "val1"), ("SomeHeader2", "val2")],
http_version="1.1",
reason="OK",
),
)
# mixed \r\n and \n
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.1 200 OK\r\nSomeHeader1: val1\nSomeHeader2: val2\n\r\n",
Response(
status_code=200,
headers=[("SomeHeader1", "val1"), ("SomeHeader2", "val2")],
http_version="1.1",
reason="OK",
),
)
# obsolete line folding
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Some: multi-line\r\n"
b" header\r\n"
b"\tnonsense\r\n"
b" \t \t\tI guess\r\n"
b"Connection: close\r\n"
b"More-nonsense: in the\r\n"
b" last header \r\n\r\n",
Request(
method="HEAD",
target="/foo",
headers=[
("Host", "example.com"),
("Some", "multi-line header nonsense I guess"),
("Connection", "close"),
("More-nonsense", "in the last header"),
],
),
)
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b" folded: line\r\n\r\n",
None,
)
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b"foo : line\r\n\r\n",
None,
)
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b"foo\t: line\r\n\r\n",
None,
)
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b"foo\t: line\r\n\r\n",
None,
)
with pytest.raises(LocalProtocolError):
tr(READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.1\r\n" b": line\r\n\r\n", None)
def test__obsolete_line_fold_bytes() -> None:
# _obsolete_line_fold has a defensive cast to bytearray, which is
# necessary to protect against O(n^2) behavior in case anyone ever passes
# in regular bytestrings... but right now we never pass in regular
# bytestrings. so this test just exists to get some coverage on that
# defensive cast.
assert list(_obsolete_line_fold([b"aaa", b"bbb", b" ccc", b"ddd"])) == [
b"aaa",
bytearray(b"bbb ccc"),
b"ddd",
]
def _run_reader_iter(
reader: Any, buf: bytes, do_eof: bool
) -> Generator[Any, None, None]:
while True:
event = reader(buf)
if event is None:
break
yield event
# body readers have undefined behavior after returning EndOfMessage,
# because this changes the state so they don't get called again
if type(event) is EndOfMessage:
break
if do_eof:
assert not buf
yield reader.read_eof()
def _run_reader(*args: Any) -> List[Event]:
events = list(_run_reader_iter(*args))
return normalize_data_events(events)
def t_body_reader(thunk: Any, data: bytes, expected: Any, do_eof: bool = False) -> None:
# Simple: consume whole thing
print("Test 1")
buf = makebuf(data)
assert _run_reader(thunk(), buf, do_eof) == expected
# Incrementally growing buffer
print("Test 2")
reader = thunk()
buf = ReceiveBuffer()
events = []
for i in range(len(data)):
events += _run_reader(reader, buf, False)
buf += data[i : i + 1]
events += _run_reader(reader, buf, do_eof)
assert normalize_data_events(events) == expected
is_complete = any(type(event) is EndOfMessage for event in expected)
if is_complete and not do_eof:
buf = makebuf(data + b"trailing")
assert _run_reader(thunk(), buf, False) == expected
def test_ContentLengthReader() -> None:
t_body_reader(lambda: ContentLengthReader(0), b"", [EndOfMessage()])
t_body_reader(
lambda: ContentLengthReader(10),
b"0123456789",
[Data(data=b"0123456789"), EndOfMessage()],
)
def test_Http10Reader() -> None:
t_body_reader(Http10Reader, b"", [EndOfMessage()], do_eof=True)
t_body_reader(Http10Reader, b"asdf", [Data(data=b"asdf")], do_eof=False)
t_body_reader(
Http10Reader, b"asdf", [Data(data=b"asdf"), EndOfMessage()], do_eof=True
)
def test_ChunkedReader() -> None:
t_body_reader(ChunkedReader, b"0\r\n\r\n", [EndOfMessage()])
t_body_reader(
ChunkedReader,
b"0\r\nSome: header\r\n\r\n",
[EndOfMessage(headers=[("Some", "header")])],
)
t_body_reader(
ChunkedReader,
b"5\r\n01234\r\n"
+ b"10\r\n0123456789abcdef\r\n"
+ b"0\r\n"
+ b"Some: header\r\n\r\n",
[
Data(data=b"012340123456789abcdef"),
EndOfMessage(headers=[("Some", "header")]),
],
)
t_body_reader(
ChunkedReader,
b"5\r\n01234\r\n" + b"10\r\n0123456789abcdef\r\n" + b"0\r\n\r\n",
[Data(data=b"012340123456789abcdef"), EndOfMessage()],
)
# handles upper and lowercase hex
t_body_reader(
ChunkedReader,
b"aA\r\n" + b"x" * 0xAA + b"\r\n" + b"0\r\n\r\n",
[Data(data=b"x" * 0xAA), EndOfMessage()],
)
# refuses arbitrarily long chunk integers
with pytest.raises(LocalProtocolError):
# Technically this is legal HTTP/1.1, but we refuse to process chunk
# sizes that don't fit into 20 characters of hex
t_body_reader(ChunkedReader, b"9" * 100 + b"\r\nxxx", [Data(data=b"xxx")])
# refuses garbage in the chunk count
with pytest.raises(LocalProtocolError):
t_body_reader(ChunkedReader, b"10\x00\r\nxxx", None)
# handles (and discards) "chunk extensions" omg wtf
t_body_reader(
ChunkedReader,
b"5; hello=there\r\n"
+ b"xxxxx"
+ b"\r\n"
+ b'0; random="junk"; some=more; canbe=lonnnnngg\r\n\r\n',
[Data(data=b"xxxxx"), EndOfMessage()],
)
t_body_reader(
ChunkedReader,
b"5 \r\n01234\r\n" + b"0\r\n\r\n",
[Data(data=b"01234"), EndOfMessage()],
)
def test_ContentLengthWriter() -> None:
w = ContentLengthWriter(5)
assert dowrite(w, Data(data=b"123")) == b"123"
assert dowrite(w, Data(data=b"45")) == b"45"
assert dowrite(w, EndOfMessage()) == b""
w = ContentLengthWriter(5)
with pytest.raises(LocalProtocolError):
dowrite(w, Data(data=b"123456"))
w = ContentLengthWriter(5)
dowrite(w, Data(data=b"123"))
with pytest.raises(LocalProtocolError):
dowrite(w, Data(data=b"456"))
w = ContentLengthWriter(5)
dowrite(w, Data(data=b"123"))
with pytest.raises(LocalProtocolError):
dowrite(w, EndOfMessage())
w = ContentLengthWriter(5)
dowrite(w, Data(data=b"123")) == b"123"
dowrite(w, Data(data=b"45")) == b"45"
with pytest.raises(LocalProtocolError):
dowrite(w, EndOfMessage(headers=[("Etag", "asdf")]))
def test_ChunkedWriter() -> None:
w = ChunkedWriter()
assert dowrite(w, Data(data=b"aaa")) == b"3\r\naaa\r\n"
assert dowrite(w, Data(data=b"a" * 20)) == b"14\r\n" + b"a" * 20 + b"\r\n"
assert dowrite(w, Data(data=b"")) == b""
assert dowrite(w, EndOfMessage()) == b"0\r\n\r\n"
assert (
dowrite(w, EndOfMessage(headers=[("Etag", "asdf"), ("a", "b")]))
== b"0\r\nEtag: asdf\r\na: b\r\n\r\n"
)
def test_Http10Writer() -> None:
w = Http10Writer()
assert dowrite(w, Data(data=b"1234")) == b"1234"
assert dowrite(w, EndOfMessage()) == b""
with pytest.raises(LocalProtocolError):
dowrite(w, EndOfMessage(headers=[("Etag", "asdf")]))
def test_reject_garbage_after_request_line() -> None:
with pytest.raises(LocalProtocolError):
tr(READERS[SERVER, SEND_RESPONSE], b"HTTP/1.0 200 OK\x00xxxx\r\n\r\n", None)
def test_reject_garbage_after_response_line() -> None:
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1 xxxxxx\r\n" b"Host: a\r\n\r\n",
None,
)
def test_reject_garbage_in_header_line() -> None:
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b"Host: foo\x00bar\r\n\r\n",
None,
)
def test_reject_non_vchar_in_path() -> None:
for bad_char in b"\x00\x20\x7f\xee":
message = bytearray(b"HEAD /")
message.append(bad_char)
message.extend(b" HTTP/1.1\r\nHost: foobar\r\n\r\n")
with pytest.raises(LocalProtocolError):
tr(READERS[CLIENT, IDLE], message, None)
# https://github.com/python-hyper/h11/issues/57
def test_allow_some_garbage_in_cookies() -> None:
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n"
b"Host: foo\r\n"
b"Set-Cookie: ___utmvafIumyLc=kUd\x01UpAt; path=/; Max-Age=900\r\n"
b"\r\n",
Request(
method="HEAD",
target="/foo",
headers=[
("Host", "foo"),
("Set-Cookie", "___utmvafIumyLc=kUd\x01UpAt; path=/; Max-Age=900"),
],
),
)
def test_host_comes_first() -> None:
tw(
write_headers,
normalize_and_validate([("foo", "bar"), ("Host", "example.com")]),
b"Host: example.com\r\nfoo: bar\r\n\r\n",
)

View File

@ -0,0 +1,135 @@
import re
from typing import Tuple
import pytest
from .._receivebuffer import ReceiveBuffer
def test_receivebuffer() -> None:
b = ReceiveBuffer()
assert not b
assert len(b) == 0
assert bytes(b) == b""
b += b"123"
assert b
assert len(b) == 3
assert bytes(b) == b"123"
assert bytes(b) == b"123"
assert b.maybe_extract_at_most(2) == b"12"
assert b
assert len(b) == 1
assert bytes(b) == b"3"
assert bytes(b) == b"3"
assert b.maybe_extract_at_most(10) == b"3"
assert bytes(b) == b""
assert b.maybe_extract_at_most(10) is None
assert not b
################################################################
# maybe_extract_until_next
################################################################
b += b"123\n456\r\n789\r\n"
assert b.maybe_extract_next_line() == b"123\n456\r\n"
assert bytes(b) == b"789\r\n"
assert b.maybe_extract_next_line() == b"789\r\n"
assert bytes(b) == b""
b += b"12\r"
assert b.maybe_extract_next_line() is None
assert bytes(b) == b"12\r"
b += b"345\n\r"
assert b.maybe_extract_next_line() is None
assert bytes(b) == b"12\r345\n\r"
# here we stopped at the middle of b"\r\n" delimiter
b += b"\n6789aaa123\r\n"
assert b.maybe_extract_next_line() == b"12\r345\n\r\n"
assert b.maybe_extract_next_line() == b"6789aaa123\r\n"
assert b.maybe_extract_next_line() is None
assert bytes(b) == b""
################################################################
# maybe_extract_lines
################################################################
b += b"123\r\na: b\r\nfoo:bar\r\n\r\ntrailing"
lines = b.maybe_extract_lines()
assert lines == [b"123", b"a: b", b"foo:bar"]
assert bytes(b) == b"trailing"
assert b.maybe_extract_lines() is None
b += b"\r\n\r"
assert b.maybe_extract_lines() is None
assert b.maybe_extract_at_most(100) == b"trailing\r\n\r"
assert not b
# Empty body case (as happens at the end of chunked encoding if there are
# no trailing headers, e.g.)
b += b"\r\ntrailing"
assert b.maybe_extract_lines() == []
assert bytes(b) == b"trailing"
@pytest.mark.parametrize(
"data",
[
pytest.param(
(
b"HTTP/1.1 200 OK\r\n",
b"Content-type: text/plain\r\n",
b"Connection: close\r\n",
b"\r\n",
b"Some body",
),
id="with_crlf_delimiter",
),
pytest.param(
(
b"HTTP/1.1 200 OK\n",
b"Content-type: text/plain\n",
b"Connection: close\n",
b"\n",
b"Some body",
),
id="with_lf_only_delimiter",
),
pytest.param(
(
b"HTTP/1.1 200 OK\n",
b"Content-type: text/plain\r\n",
b"Connection: close\n",
b"\n",
b"Some body",
),
id="with_mixed_crlf_and_lf",
),
],
)
def test_receivebuffer_for_invalid_delimiter(data: Tuple[bytes]) -> None:
b = ReceiveBuffer()
for line in data:
b += line
lines = b.maybe_extract_lines()
assert lines == [
b"HTTP/1.1 200 OK",
b"Content-type: text/plain",
b"Connection: close",
]
assert bytes(b) == b"Some body"

View File

@ -0,0 +1,271 @@
import pytest
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .._state import (
_SWITCH_CONNECT,
_SWITCH_UPGRADE,
CLIENT,
CLOSED,
ConnectionState,
DONE,
IDLE,
MIGHT_SWITCH_PROTOCOL,
MUST_CLOSE,
SEND_BODY,
SEND_RESPONSE,
SERVER,
SWITCHED_PROTOCOL,
)
from .._util import LocalProtocolError
def test_ConnectionState() -> None:
cs = ConnectionState()
# Basic event-triggered transitions
assert cs.states == {CLIENT: IDLE, SERVER: IDLE}
cs.process_event(CLIENT, Request)
# The SERVER-Request special case:
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
# Illegal transitions raise an error and nothing happens
with pytest.raises(LocalProtocolError):
cs.process_event(CLIENT, Request)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, InformationalResponse)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, Response)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_BODY}
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(SERVER, EndOfMessage)
assert cs.states == {CLIENT: DONE, SERVER: DONE}
# State-triggered transition
cs.process_event(SERVER, ConnectionClosed)
assert cs.states == {CLIENT: MUST_CLOSE, SERVER: CLOSED}
def test_ConnectionState_keep_alive() -> None:
# keep_alive = False
cs = ConnectionState()
cs.process_event(CLIENT, Request)
cs.process_keep_alive_disabled()
cs.process_event(CLIENT, EndOfMessage)
assert cs.states == {CLIENT: MUST_CLOSE, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
assert cs.states == {CLIENT: MUST_CLOSE, SERVER: MUST_CLOSE}
def test_ConnectionState_keep_alive_in_DONE() -> None:
# Check that if keep_alive is disabled when the CLIENT is already in DONE,
# then this is sufficient to immediately trigger the DONE -> MUST_CLOSE
# transition
cs = ConnectionState()
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
assert cs.states[CLIENT] is DONE
cs.process_keep_alive_disabled()
assert cs.states[CLIENT] is MUST_CLOSE
def test_ConnectionState_switch_denied() -> None:
for switch_type in (_SWITCH_CONNECT, _SWITCH_UPGRADE):
for deny_early in (True, False):
cs = ConnectionState()
cs.process_client_switch_proposal(switch_type)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, Data)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
assert switch_type in cs.pending_switch_proposals
if deny_early:
# before client reaches DONE
cs.process_event(SERVER, Response)
assert not cs.pending_switch_proposals
cs.process_event(CLIENT, EndOfMessage)
if deny_early:
assert cs.states == {CLIENT: DONE, SERVER: SEND_BODY}
else:
assert cs.states == {
CLIENT: MIGHT_SWITCH_PROTOCOL,
SERVER: SEND_RESPONSE,
}
cs.process_event(SERVER, InformationalResponse)
assert cs.states == {
CLIENT: MIGHT_SWITCH_PROTOCOL,
SERVER: SEND_RESPONSE,
}
cs.process_event(SERVER, Response)
assert cs.states == {CLIENT: DONE, SERVER: SEND_BODY}
assert not cs.pending_switch_proposals
_response_type_for_switch = {
_SWITCH_UPGRADE: InformationalResponse,
_SWITCH_CONNECT: Response,
None: Response,
}
def test_ConnectionState_protocol_switch_accepted() -> None:
for switch_event in [_SWITCH_UPGRADE, _SWITCH_CONNECT]:
cs = ConnectionState()
cs.process_client_switch_proposal(switch_event)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, Data)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
cs.process_event(CLIENT, EndOfMessage)
assert cs.states == {CLIENT: MIGHT_SWITCH_PROTOCOL, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, InformationalResponse)
assert cs.states == {CLIENT: MIGHT_SWITCH_PROTOCOL, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, _response_type_for_switch[switch_event], switch_event)
assert cs.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL}
def test_ConnectionState_double_protocol_switch() -> None:
# CONNECT + Upgrade is legal! Very silly, but legal. So we support
# it. Because sometimes doing the silly thing is easier than not.
for server_switch in [None, _SWITCH_UPGRADE, _SWITCH_CONNECT]:
cs = ConnectionState()
cs.process_client_switch_proposal(_SWITCH_UPGRADE)
cs.process_client_switch_proposal(_SWITCH_CONNECT)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
assert cs.states == {CLIENT: MIGHT_SWITCH_PROTOCOL, SERVER: SEND_RESPONSE}
cs.process_event(
SERVER, _response_type_for_switch[server_switch], server_switch
)
if server_switch is None:
assert cs.states == {CLIENT: DONE, SERVER: SEND_BODY}
else:
assert cs.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL}
def test_ConnectionState_inconsistent_protocol_switch() -> None:
for client_switches, server_switch in [
([], _SWITCH_CONNECT),
([], _SWITCH_UPGRADE),
([_SWITCH_UPGRADE], _SWITCH_CONNECT),
([_SWITCH_CONNECT], _SWITCH_UPGRADE),
]:
cs = ConnectionState()
for client_switch in client_switches: # type: ignore[attr-defined]
cs.process_client_switch_proposal(client_switch)
cs.process_event(CLIENT, Request)
with pytest.raises(LocalProtocolError):
cs.process_event(SERVER, Response, server_switch)
def test_ConnectionState_keepalive_protocol_switch_interaction() -> None:
# keep_alive=False + pending_switch_proposals
cs = ConnectionState()
cs.process_client_switch_proposal(_SWITCH_UPGRADE)
cs.process_event(CLIENT, Request)
cs.process_keep_alive_disabled()
cs.process_event(CLIENT, Data)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
# the protocol switch "wins"
cs.process_event(CLIENT, EndOfMessage)
assert cs.states == {CLIENT: MIGHT_SWITCH_PROTOCOL, SERVER: SEND_RESPONSE}
# but when the server denies the request, keep_alive comes back into play
cs.process_event(SERVER, Response)
assert cs.states == {CLIENT: MUST_CLOSE, SERVER: SEND_BODY}
def test_ConnectionState_reuse() -> None:
cs = ConnectionState()
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
cs.start_next_cycle()
assert cs.states == {CLIENT: IDLE, SERVER: IDLE}
# No keepalive
cs.process_event(CLIENT, Request)
cs.process_keep_alive_disabled()
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
# One side closed
cs = ConnectionState()
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(CLIENT, ConnectionClosed)
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
# Succesful protocol switch
cs = ConnectionState()
cs.process_client_switch_proposal(_SWITCH_UPGRADE)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(SERVER, InformationalResponse, _SWITCH_UPGRADE)
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
# Failed protocol switch
cs = ConnectionState()
cs.process_client_switch_proposal(_SWITCH_UPGRADE)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
cs.start_next_cycle()
assert cs.states == {CLIENT: IDLE, SERVER: IDLE}
def test_server_request_is_illegal() -> None:
# There used to be a bug in how we handled the Request special case that
# made this allowed...
cs = ConnectionState()
with pytest.raises(LocalProtocolError):
cs.process_event(SERVER, Request)

View File

@ -0,0 +1,112 @@
import re
import sys
import traceback
from typing import NoReturn
import pytest
from .._util import (
bytesify,
LocalProtocolError,
ProtocolError,
RemoteProtocolError,
Sentinel,
validate,
)
def test_ProtocolError() -> None:
with pytest.raises(TypeError):
ProtocolError("abstract base class")
def test_LocalProtocolError() -> None:
try:
raise LocalProtocolError("foo")
except LocalProtocolError as e:
assert str(e) == "foo"
assert e.error_status_hint == 400
try:
raise LocalProtocolError("foo", error_status_hint=418)
except LocalProtocolError as e:
assert str(e) == "foo"
assert e.error_status_hint == 418
def thunk() -> NoReturn:
raise LocalProtocolError("a", error_status_hint=420)
try:
try:
thunk()
except LocalProtocolError as exc1:
orig_traceback = "".join(traceback.format_tb(sys.exc_info()[2]))
exc1._reraise_as_remote_protocol_error()
except RemoteProtocolError as exc2:
assert type(exc2) is RemoteProtocolError
assert exc2.args == ("a",)
assert exc2.error_status_hint == 420
new_traceback = "".join(traceback.format_tb(sys.exc_info()[2]))
assert new_traceback.endswith(orig_traceback)
def test_validate() -> None:
my_re = re.compile(rb"(?P<group1>[0-9]+)\.(?P<group2>[0-9]+)")
with pytest.raises(LocalProtocolError):
validate(my_re, b"0.")
groups = validate(my_re, b"0.1")
assert groups == {"group1": b"0", "group2": b"1"}
# successful partial matches are an error - must match whole string
with pytest.raises(LocalProtocolError):
validate(my_re, b"0.1xx")
with pytest.raises(LocalProtocolError):
validate(my_re, b"0.1\n")
def test_validate_formatting() -> None:
my_re = re.compile(rb"foo")
with pytest.raises(LocalProtocolError) as excinfo:
validate(my_re, b"", "oops")
assert "oops" in str(excinfo.value)
with pytest.raises(LocalProtocolError) as excinfo:
validate(my_re, b"", "oops {}")
assert "oops {}" in str(excinfo.value)
with pytest.raises(LocalProtocolError) as excinfo:
validate(my_re, b"", "oops {} xx", 10)
assert "oops 10 xx" in str(excinfo.value)
def test_make_sentinel() -> None:
class S(Sentinel, metaclass=Sentinel):
pass
assert repr(S) == "S"
assert S == S
assert type(S).__name__ == "S"
assert S in {S}
assert type(S) is S
class S2(Sentinel, metaclass=Sentinel):
pass
assert repr(S2) == "S2"
assert S != S2
assert S not in {S2}
assert type(S) is not type(S2)
def test_bytesify() -> None:
assert bytesify(b"123") == b"123"
assert bytesify(bytearray(b"123")) == b"123"
assert bytesify("123") == b"123"
with pytest.raises(UnicodeEncodeError):
bytesify("\u1234")
with pytest.raises(TypeError):
bytesify(10)

View File

@ -0,0 +1 @@
pip

View File

@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2015 Miguel Grinberg
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -0,0 +1,48 @@
Metadata-Version: 2.1
Name: python-engineio
Version: 4.11.2
Summary: Engine.IO server and client for Python
Author-email: Miguel Grinberg <miguel.grinberg@gmail.com>
Project-URL: Homepage, https://github.com/miguelgrinberg/python-engineio
Project-URL: Bug Tracker, https://github.com/miguelgrinberg/python-engineio/issues
Classifier: Environment :: Web Environment
Classifier: Intended Audience :: Developers
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.6
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: simple-websocket>=0.10.0
Provides-Extra: client
Requires-Dist: requests>=2.21.0; extra == "client"
Requires-Dist: websocket-client>=0.54.0; extra == "client"
Provides-Extra: asyncio-client
Requires-Dist: aiohttp>=3.4; extra == "asyncio-client"
Provides-Extra: docs
Requires-Dist: sphinx; extra == "docs"
python-engineio
===============
[![Build status](https://github.com/miguelgrinberg/python-engineio/workflows/build/badge.svg)](https://github.com/miguelgrinberg/python-engineio/actions) [![codecov](https://codecov.io/gh/miguelgrinberg/python-engineio/branch/main/graph/badge.svg)](https://codecov.io/gh/miguelgrinberg/python-engineio)
Python implementation of the `Engine.IO` realtime client and server.
Sponsors
--------
The following organizations are funding this project:
![Socket.IO](https://images.opencollective.com/socketio/050e5eb/logo/64.png)<br>[Socket.IO](https://socket.io) | [Add your company here!](https://github.com/sponsors/miguelgrinberg)|
-|-
Many individual sponsors also support this project through small ongoing contributions. Why not [join them](https://github.com/sponsors/miguelgrinberg)?
Resources
---------
- [Documentation](https://python-engineio.readthedocs.io/)
- [PyPI](https://pypi.python.org/pypi/python-engineio)
- [Change Log](https://github.com/miguelgrinberg/python-engineio/blob/main/CHANGES.md)
- Questions? See the [questions](https://stackoverflow.com/questions/tagged/python-socketio) others have asked on Stack Overflow, or [ask](https://stackoverflow.com/questions/ask?tags=python+python-socketio) your own question.

View File

@ -0,0 +1,58 @@
engineio/__init__.py,sha256=0R2PY1EXu3sicP7mkA0_QxEVGRlFlgvsxfhByqREE1A,481
engineio/__pycache__/__init__.cpython-310.pyc,,
engineio/__pycache__/async_client.cpython-310.pyc,,
engineio/__pycache__/async_server.cpython-310.pyc,,
engineio/__pycache__/async_socket.cpython-310.pyc,,
engineio/__pycache__/base_client.cpython-310.pyc,,
engineio/__pycache__/base_server.cpython-310.pyc,,
engineio/__pycache__/base_socket.cpython-310.pyc,,
engineio/__pycache__/client.cpython-310.pyc,,
engineio/__pycache__/exceptions.cpython-310.pyc,,
engineio/__pycache__/json.cpython-310.pyc,,
engineio/__pycache__/middleware.cpython-310.pyc,,
engineio/__pycache__/packet.cpython-310.pyc,,
engineio/__pycache__/payload.cpython-310.pyc,,
engineio/__pycache__/server.cpython-310.pyc,,
engineio/__pycache__/socket.cpython-310.pyc,,
engineio/__pycache__/static_files.cpython-310.pyc,,
engineio/async_client.py,sha256=ZrK9j_sUNKRwjqeT6W26d1TtjoGAp2nULhIbX_1iivs,29446
engineio/async_drivers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
engineio/async_drivers/__pycache__/__init__.cpython-310.pyc,,
engineio/async_drivers/__pycache__/_websocket_wsgi.cpython-310.pyc,,
engineio/async_drivers/__pycache__/aiohttp.cpython-310.pyc,,
engineio/async_drivers/__pycache__/asgi.cpython-310.pyc,,
engineio/async_drivers/__pycache__/eventlet.cpython-310.pyc,,
engineio/async_drivers/__pycache__/gevent.cpython-310.pyc,,
engineio/async_drivers/__pycache__/gevent_uwsgi.cpython-310.pyc,,
engineio/async_drivers/__pycache__/sanic.cpython-310.pyc,,
engineio/async_drivers/__pycache__/threading.cpython-310.pyc,,
engineio/async_drivers/__pycache__/tornado.cpython-310.pyc,,
engineio/async_drivers/_websocket_wsgi.py,sha256=LuOEfKhbAw8SplB5PMpYKIUqfCPEadQEpqeiq_leOIA,949
engineio/async_drivers/aiohttp.py,sha256=OBDGhaNXWHxQkwhzZT2vlTAOqWReGS6Sjk9u3BEh_Mc,3754
engineio/async_drivers/asgi.py,sha256=mBu109j7R6esH-wI62jvTnfVAEjWvw2YNp0dZ74NyHg,11210
engineio/async_drivers/eventlet.py,sha256=n1y4OjPdj4J2GIep5N56O29oa5NQgFJVcTBjyO1C-Gs,1735
engineio/async_drivers/gevent.py,sha256=hnJHeWdDQE2jfoLCP5DnwVPzsQlcTLJUMA5EVf1UL-k,2962
engineio/async_drivers/gevent_uwsgi.py,sha256=m6ay5dov9FDQl0fbeiKeE-Orh5LiF6zLlYQ64Oa3T5g,5954
engineio/async_drivers/sanic.py,sha256=GYX8YWR1GbRm-GkMTAQkfkWbY12MOT1IV2DzH0Xx8Ns,4495
engineio/async_drivers/threading.py,sha256=ywmG59d4H6OHZjKarBN97-9BHEsRxFEz9YN-E9QAu_I,463
engineio/async_drivers/tornado.py,sha256=mbVHs1mECfzFSNv33uigkpTBtNPT0u49k5zaybewdIo,5893
engineio/async_server.py,sha256=8Af_uwf8mKOCJGqVKa3QN0kz7_B2K1cdpSmsvkAoBR4,27412
engineio/async_socket.py,sha256=nHY0DPPk0FtI9djUnQWtzZ3ce2OD184Tu-Dop7JLg9I,10715
engineio/base_client.py,sha256=uEI6OglvhdIGD1DAZuptbn0KXQMk_o3w0AH4wJvPwZ0,5322
engineio/base_server.py,sha256=H-aKQ9hMQjiIrjDaqbw_edw9yUVIbCQMzyIy9egLv-g,14458
engineio/base_socket.py,sha256=sQqbNSfGhMQG3xzwar6IXMal28C7Q5TIAQRGp74Wt2o,399
engineio/client.py,sha256=wfgB_MrpEh-v4dBt6W2eW6i0CRNEP0CIJ6d4Ap4CnZE,27227
engineio/exceptions.py,sha256=FyuMb5qhX9CUYP3fEoe1m-faU96ApdQTSbblaaoo8LA,292
engineio/json.py,sha256=SG5FTojqd1ix6u0dKXJsZVqqdYioZLO4S2GPL7BKl3U,405
engineio/middleware.py,sha256=5NKBXz-ftuFErUB_V9IDvRHaSOsjhtW-NnuJtquB1nc,3750
engineio/packet.py,sha256=Tejm9U5JcYs5LwZ_n_Xh0PIRv-U_JbHwGEivNXQN4eg,3181
engineio/payload.py,sha256=GIWu0Vnay4WNZlDxHqVgP34tKTBXX58OArJ-mO5zD3E,1539
engineio/server.py,sha256=4_xdtH0tyusigQurwOeDFrHtodthFTOIbOVZ-_ckh5U,22957
engineio/socket.py,sha256=Oaw1E7ZDyOCaS7KV151g1u9rqOf1JJHh5gttEU0cSeA,10342
engineio/static_files.py,sha256=pwez9LQFaSQXMbtI0vLyD6UDiokQ4rNfmRYgVLKOthc,2064
python_engineio-4.11.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
python_engineio-4.11.2.dist-info/LICENSE,sha256=yel9Pbwfu82094CLKCzWRtuIev9PUxP-a76NTDFAWpw,1082
python_engineio-4.11.2.dist-info/METADATA,sha256=WfrX2DOOQdxr6FoiFDIuNe8lorFgUdilJYBfWuD2-Kw,2237
python_engineio-4.11.2.dist-info/RECORD,,
python_engineio-4.11.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
python_engineio-4.11.2.dist-info/top_level.txt,sha256=u8PmNisCZLwRYcWrNLe9wutQ2tt4zNi8IH362c-HWuA,9

View File

@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (75.6.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@ -0,0 +1 @@
engineio

View File

@ -0,0 +1 @@
pip

View File

@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2015 Miguel Grinberg
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -0,0 +1,71 @@
Metadata-Version: 2.1
Name: python-socketio
Version: 5.12.1
Summary: Socket.IO server and client for Python
Author-email: Miguel Grinberg <miguel.grinberg@gmail.com>
Project-URL: Homepage, https://github.com/miguelgrinberg/python-socketio
Project-URL: Bug Tracker, https://github.com/miguelgrinberg/python-socketio/issues
Classifier: Environment :: Web Environment
Classifier: Intended Audience :: Developers
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: bidict>=0.21.0
Requires-Dist: python-engineio>=4.11.0
Provides-Extra: client
Requires-Dist: requests>=2.21.0; extra == "client"
Requires-Dist: websocket-client>=0.54.0; extra == "client"
Provides-Extra: asyncio-client
Requires-Dist: aiohttp>=3.4; extra == "asyncio-client"
Provides-Extra: docs
Requires-Dist: sphinx; extra == "docs"
python-socketio
===============
[![Build status](https://github.com/miguelgrinberg/python-socketio/workflows/build/badge.svg)](https://github.com/miguelgrinberg/python-socketio/actions) [![codecov](https://codecov.io/gh/miguelgrinberg/python-socketio/branch/main/graph/badge.svg)](https://codecov.io/gh/miguelgrinberg/python-socketio)
Python implementation of the `Socket.IO` realtime client and server.
Sponsors
--------
The following organizations are funding this project:
![Socket.IO](https://images.opencollective.com/socketio/050e5eb/logo/64.png)<br>[Socket.IO](https://socket.io) | [Add your company here!](https://github.com/sponsors/miguelgrinberg)|
-|-
Many individual sponsors also support this project through small ongoing contributions. Why not [join them](https://github.com/sponsors/miguelgrinberg)?
Version compatibility
---------------------
The Socket.IO protocol has been through a number of revisions, and some of these
introduced backward incompatible changes, which means that the client and the
server must use compatible versions for everything to work.
If you are using the Python client and server, the easiest way to ensure compatibility
is to use the same version of this package for the client and the server. If you are
using this package with a different client or server, then you must ensure the
versions are compatible.
The version compatibility chart below maps versions of this package to versions
of the JavaScript reference implementation and the versions of the Socket.IO and
Engine.IO protocols.
JavaScript Socket.IO version | Socket.IO protocol revision | Engine.IO protocol revision | python-socketio version
-|-|-|-
0.9.x | 1, 2 | 1, 2 | Not supported
1.x and 2.x | 3, 4 | 3 | 4.x
3.x and 4.x | 5 | 4 | 5.x
Resources
---------
- [Documentation](http://python-socketio.readthedocs.io/)
- [PyPI](https://pypi.python.org/pypi/python-socketio)
- [Change Log](https://github.com/miguelgrinberg/python-socketio/blob/main/CHANGES.md)
- Questions? See the [questions](https://stackoverflow.com/questions/tagged/python-socketio) others have asked on Stack Overflow, or [ask](https://stackoverflow.com/questions/ask?tags=python+python-socketio) your own question.

View File

@ -0,0 +1,68 @@
python_socketio-5.12.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
python_socketio-5.12.1.dist-info/LICENSE,sha256=yel9Pbwfu82094CLKCzWRtuIev9PUxP-a76NTDFAWpw,1082
python_socketio-5.12.1.dist-info/METADATA,sha256=Fv_RBJ7M_Ob-O245jQ0Z4TSY7kDEEoww7do3s6BfFqY,3205
python_socketio-5.12.1.dist-info/RECORD,,
python_socketio-5.12.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
python_socketio-5.12.1.dist-info/top_level.txt,sha256=xWd-HVUanhys_VzQQTRTRZBX8W448ayFytYf1Zffivs,9
socketio/__init__.py,sha256=DXxtwPIqHFIqV4BGTgJ86OvCXD6Mth3PxBYhFoJ1_7g,1269
socketio/__pycache__/__init__.cpython-310.pyc,,
socketio/__pycache__/admin.cpython-310.pyc,,
socketio/__pycache__/asgi.cpython-310.pyc,,
socketio/__pycache__/async_admin.cpython-310.pyc,,
socketio/__pycache__/async_aiopika_manager.cpython-310.pyc,,
socketio/__pycache__/async_client.cpython-310.pyc,,
socketio/__pycache__/async_manager.cpython-310.pyc,,
socketio/__pycache__/async_namespace.cpython-310.pyc,,
socketio/__pycache__/async_pubsub_manager.cpython-310.pyc,,
socketio/__pycache__/async_redis_manager.cpython-310.pyc,,
socketio/__pycache__/async_server.cpython-310.pyc,,
socketio/__pycache__/async_simple_client.cpython-310.pyc,,
socketio/__pycache__/base_client.cpython-310.pyc,,
socketio/__pycache__/base_manager.cpython-310.pyc,,
socketio/__pycache__/base_namespace.cpython-310.pyc,,
socketio/__pycache__/base_server.cpython-310.pyc,,
socketio/__pycache__/client.cpython-310.pyc,,
socketio/__pycache__/exceptions.cpython-310.pyc,,
socketio/__pycache__/kafka_manager.cpython-310.pyc,,
socketio/__pycache__/kombu_manager.cpython-310.pyc,,
socketio/__pycache__/manager.cpython-310.pyc,,
socketio/__pycache__/middleware.cpython-310.pyc,,
socketio/__pycache__/msgpack_packet.cpython-310.pyc,,
socketio/__pycache__/namespace.cpython-310.pyc,,
socketio/__pycache__/packet.cpython-310.pyc,,
socketio/__pycache__/pubsub_manager.cpython-310.pyc,,
socketio/__pycache__/redis_manager.cpython-310.pyc,,
socketio/__pycache__/server.cpython-310.pyc,,
socketio/__pycache__/simple_client.cpython-310.pyc,,
socketio/__pycache__/tornado.cpython-310.pyc,,
socketio/__pycache__/zmq_manager.cpython-310.pyc,,
socketio/admin.py,sha256=pfZ7ZtcZ9-aeaFZkOR4mFhsNPcy9WjZs4_5Os6xc9tA,15966
socketio/asgi.py,sha256=NaJtYhOswVVcwHU0zcMM5H5TrSzXq9K-CAYaeSNTZRY,2192
socketio/async_admin.py,sha256=opwgGfkREXb_T25FL7At6hkC3hTfY33bDooyNi1Dgvw,16317
socketio/async_aiopika_manager.py,sha256=DaBUjGRYaNIsOsk2xNjWylUsz2egmTAFFUiQkV6mNmk,5193
socketio/async_client.py,sha256=iVXDsHiU9aohwE2QkSwUOtU8GYivCZRapolJMCWeCPY,27810
socketio/async_manager.py,sha256=dSD2XVtWYwKHDWxAXSu4Xgqw6dXyy9P_6C8rwlguybM,4503
socketio/async_namespace.py,sha256=pSyJjIekWgydsmQHxmJvuc_NdI8SMGjGTAatLUtRvAk,12028
socketio/async_pubsub_manager.py,sha256=Dzt34zwWgxqGsB_61_hegSlTSZucciHX6aJrEPSuKos,11141
socketio/async_redis_manager.py,sha256=UZXKunvbSk8neRVhGqigQF5S0WwLYTKV0BKondnV_yY,4299
socketio/async_server.py,sha256=YrZ69AN1i8hK-TMZGtRiD6UnoQk_zwl2amHYaKk_1uI,36382
socketio/async_simple_client.py,sha256=Dj2h0iRR1qZ4BhOV6gpzvDM0K5XO4f-vdxmISiREzhQ,8908
socketio/base_client.py,sha256=AKwZprl7qwgdOaQwV2drBNx9bB3PBCyABm6HKton-w4,11637
socketio/base_manager.py,sha256=vmHGHlIUDJTCdp9MIFppqFJJuoN2M1MmEWTTyV35FeY,5727
socketio/base_namespace.py,sha256=mXECdZZ7jPLphU9yH4U4yOayqjMh6OyWgZ71mOJzl5A,970
socketio/base_server.py,sha256=JtHtmxFjtclcdORg7FIBoMtMxiaCFnuwulXrpLUSjUE,10637
socketio/client.py,sha256=gE8NH3oZrdwTMQN1j-D3J_opGZlYCxPyMO6m3rjFDC0,26040
socketio/exceptions.py,sha256=c8yKss_oJl-fkL52X_AagyJecL-9Mxlgb5xDRqSz5tA,975
socketio/kafka_manager.py,sha256=BbpNbEus0DCFXaohBAXlKoV2IHU8RhbGzpkL9QcqQNM,2388
socketio/kombu_manager.py,sha256=MhDhnbZoncW5_Y02Ojhu8qFUFdT7STZDnLPsMUARuik,5748
socketio/manager.py,sha256=RPYPcVBFAjN-fEtLfcsPlk6SOW_SBATvw0Tkq_PkGZw,3861
socketio/middleware.py,sha256=P8wOgSzy3YKOcRVI-r3KNKsEejBz_f5p2wdV8ZqW12E,1591
socketio/msgpack_packet.py,sha256=0K_XXM-OF3SdqOaLN_O5B4a1xHE6N_UhhiaRhQdseNw,514
socketio/namespace.py,sha256=80y8BN2FFlHK8JKF1TirWvvE4pn9FkGKk14IVFkCLEs,9488
socketio/packet.py,sha256=nYvjUEIEUMHThZj--xrmRCZX9jN1V9BwFB2GzRpDLWU,7069
socketio/pubsub_manager.py,sha256=JCB9aaEBbEw8Or6XaosoSpO-f6p5iF_BnNJOCul7ps4,10442
socketio/redis_manager.py,sha256=DIvqRXjsSsmvXYBwuRvEap70IFyJILLaicj1X2Hssug,4403
socketio/server.py,sha256=laukqFmlQK24bCmGtMP9KGGGUP8CqebTO4_SJeZrMGY,34788
socketio/simple_client.py,sha256=tZiX2sAPY66OJTIJPk-PIGQjmnmUxu3RnpgJ0nc1-y8,8326
socketio/tornado.py,sha256=R82JCqz-E1ibZAQX708h7FX3sguCHQ1OLYpnMag-LY8,295
socketio/zmq_manager.py,sha256=PVlx175_MqKQ6j0sqGpqbqN2vW5zf4BzviotbBQpdEE,3544

View File

@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (75.6.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@ -0,0 +1 @@
socketio

View File

@ -0,0 +1 @@
pip

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021 Miguel Grinberg
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,37 @@
Metadata-Version: 2.1
Name: simple-websocket
Version: 1.1.0
Summary: Simple WebSocket server and client for Python
Author-email: Miguel Grinberg <miguel.grinberg@gmail.com>
Project-URL: Homepage, https://github.com/miguelgrinberg/simple-websocket
Project-URL: Bug Tracker, https://github.com/miguelgrinberg/simple-websocket/issues
Classifier: Environment :: Web Environment
Classifier: Intended Audience :: Developers
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.6
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: wsproto
Provides-Extra: dev
Requires-Dist: tox ; extra == 'dev'
Requires-Dist: flake8 ; extra == 'dev'
Requires-Dist: pytest ; extra == 'dev'
Requires-Dist: pytest-cov ; extra == 'dev'
Provides-Extra: docs
Requires-Dist: sphinx ; extra == 'docs'
simple-websocket
================
[![Build status](https://github.com/miguelgrinberg/simple-websocket/workflows/build/badge.svg)](https://github.com/miguelgrinberg/simple-websocket/actions) [![codecov](https://codecov.io/gh/miguelgrinberg/simple-websocket/branch/main/graph/badge.svg)](https://codecov.io/gh/miguelgrinberg/simple-websocket)
Simple WebSocket server and client for Python.
## Resources
- [Documentation](http://simple-websocket.readthedocs.io/en/latest/)
- [PyPI](https://pypi.python.org/pypi/simple-websocket)
- [Change Log](https://github.com/miguelgrinberg/simple-websocket/blob/main/CHANGES.md)

View File

@ -0,0 +1,16 @@
simple_websocket-1.1.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
simple_websocket-1.1.0.dist-info/LICENSE,sha256=S4q63MXj3SnHGQW4SVKUVpnwp7pB5q-Z6rpG-qvpW7c,1072
simple_websocket-1.1.0.dist-info/METADATA,sha256=jIZUFRCbg8Ae1BNwEipAlSMvuPJ05_bXpptlk2DHiNQ,1530
simple_websocket-1.1.0.dist-info/RECORD,,
simple_websocket-1.1.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
simple_websocket-1.1.0.dist-info/top_level.txt,sha256=gslMtkYd2H3exn9JQxdAgsKBCESZDyTxmukAF9Iz5aA,17
simple_websocket/__init__.py,sha256=EKakMkVO9vg5WlXjHEJiTwI2emAqs9q22ZxJz9vJ4co,167
simple_websocket/__pycache__/__init__.cpython-310.pyc,,
simple_websocket/__pycache__/aiows.cpython-310.pyc,,
simple_websocket/__pycache__/asgi.cpython-310.pyc,,
simple_websocket/__pycache__/errors.cpython-310.pyc,,
simple_websocket/__pycache__/ws.cpython-310.pyc,,
simple_websocket/aiows.py,sha256=CHIBIAN2cz004S4tPeTLAcQuT9iBgw6-hA0QD_JZD1A,20978
simple_websocket/asgi.py,sha256=ic2tmrUI-u9vjMNzjqIORc8g7pAsGwFd9YJIjppHHVU,1823
simple_websocket/errors.py,sha256=BtR8B4OI-FL2O_VSIi9cmLMobHqcJ2FhvQnRtvvMlSo,652
simple_websocket/ws.py,sha256=Nj7DSMnUhOXGYI9j5wvJMpm5X_c7iNDg0H7EpJQPb9o,22789

View File

@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (75.1.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@ -0,0 +1 @@
simple_websocket

View File

@ -0,0 +1,3 @@
from .ws import Server, Client # noqa: F401
from .aiows import AioServer, AioClient # noqa: F401
from .errors import ConnectionError, ConnectionClosed # noqa: F401

View File

@ -0,0 +1,467 @@
import asyncio
import ssl
from time import time
from urllib.parse import urlsplit
from wsproto import ConnectionType, WSConnection
from wsproto.events import (
AcceptConnection,
RejectConnection,
CloseConnection,
Message,
Request,
Ping,
Pong,
TextMessage,
BytesMessage,
)
from wsproto.extensions import PerMessageDeflate
from wsproto.frame_protocol import CloseReason
from wsproto.utilities import LocalProtocolError
from .errors import ConnectionError, ConnectionClosed
class AioBase:
def __init__(self, connection_type=None, receive_bytes=4096,
ping_interval=None, max_message_size=None):
#: The name of the subprotocol chosen for the WebSocket connection.
self.subprotocol = None
self.connection_type = connection_type
self.receive_bytes = receive_bytes
self.ping_interval = ping_interval
self.max_message_size = max_message_size
self.pong_received = True
self.input_buffer = []
self.incoming_message = None
self.incoming_message_len = 0
self.connected = False
self.is_server = (connection_type == ConnectionType.SERVER)
self.close_reason = CloseReason.NO_STATUS_RCVD
self.close_message = None
self.rsock = None
self.wsock = None
self.event = asyncio.Event()
self.ws = None
self.task = None
async def connect(self):
self.ws = WSConnection(self.connection_type)
await self.handshake()
if not self.connected: # pragma: no cover
raise ConnectionError()
self.task = asyncio.create_task(self._task())
async def handshake(self): # pragma: no cover
# to be implemented by subclasses
pass
async def send(self, data):
"""Send data over the WebSocket connection.
:param data: The data to send. If ``data`` is of type ``bytes``, then
a binary message is sent. Else, the message is sent in
text format.
"""
if not self.connected:
raise ConnectionClosed(self.close_reason, self.close_message)
if isinstance(data, bytes):
out_data = self.ws.send(Message(data=data))
else:
out_data = self.ws.send(TextMessage(data=str(data)))
self.wsock.write(out_data)
async def receive(self, timeout=None):
"""Receive data over the WebSocket connection.
:param timeout: Amount of time to wait for the data, in seconds. Set
to ``None`` (the default) to wait indefinitely. Set
to 0 to read without blocking.
The data received is returned, as ``bytes`` or ``str``, depending on
the type of the incoming message.
"""
while self.connected and not self.input_buffer:
try:
await asyncio.wait_for(self.event.wait(), timeout=timeout)
except asyncio.TimeoutError:
return None
self.event.clear() # pragma: no cover
try:
return self.input_buffer.pop(0)
except IndexError:
pass
if not self.connected: # pragma: no cover
raise ConnectionClosed(self.close_reason, self.close_message)
async def close(self, reason=None, message=None):
"""Close the WebSocket connection.
:param reason: A numeric status code indicating the reason of the
closure, as defined by the WebSocket specification. The
default is 1000 (normal closure).
:param message: A text message to be sent to the other side.
"""
if not self.connected:
raise ConnectionClosed(self.close_reason, self.close_message)
out_data = self.ws.send(CloseConnection(
reason or CloseReason.NORMAL_CLOSURE, message))
try:
self.wsock.write(out_data)
except BrokenPipeError: # pragma: no cover
pass
self.connected = False
def choose_subprotocol(self, request): # pragma: no cover
# The method should return the subprotocol to use, or ``None`` if no
# subprotocol is chosen. Can be overridden by subclasses that implement
# the server-side of the WebSocket protocol.
return None
async def _task(self):
next_ping = None
if self.ping_interval:
next_ping = time() + self.ping_interval
while self.connected:
try:
in_data = b''
if next_ping:
now = time()
timed_out = True
if next_ping > now:
timed_out = False
try:
in_data = await asyncio.wait_for(
self.rsock.read(self.receive_bytes),
timeout=next_ping - now)
except asyncio.TimeoutError:
timed_out = True
if timed_out:
# we reached the timeout, we have to send a ping
if not self.pong_received:
await self.close(
reason=CloseReason.POLICY_VIOLATION,
message='Ping/Pong timeout')
break
self.pong_received = False
self.wsock.write(self.ws.send(Ping()))
next_ping = max(now, next_ping) + self.ping_interval
continue
else:
in_data = await self.rsock.read(self.receive_bytes)
if len(in_data) == 0:
raise OSError()
except (OSError, ConnectionResetError): # pragma: no cover
self.connected = False
self.event.set()
break
self.ws.receive_data(in_data)
self.connected = await self._handle_events()
self.wsock.close()
async def _handle_events(self):
keep_going = True
out_data = b''
for event in self.ws.events():
try:
if isinstance(event, Request):
self.subprotocol = self.choose_subprotocol(event)
out_data += self.ws.send(AcceptConnection(
subprotocol=self.subprotocol,
extensions=[PerMessageDeflate()]))
elif isinstance(event, CloseConnection):
if self.is_server:
out_data += self.ws.send(event.response())
self.close_reason = event.code
self.close_message = event.reason
self.connected = False
self.event.set()
keep_going = False
elif isinstance(event, Ping):
out_data += self.ws.send(event.response())
elif isinstance(event, Pong):
self.pong_received = True
elif isinstance(event, (TextMessage, BytesMessage)):
self.incoming_message_len += len(event.data)
if self.max_message_size and \
self.incoming_message_len > self.max_message_size:
out_data += self.ws.send(CloseConnection(
CloseReason.MESSAGE_TOO_BIG, 'Message is too big'))
self.event.set()
keep_going = False
break
if self.incoming_message is None:
# store message as is first
# if it is the first of a group, the message will be
# converted to bytearray on arrival of the second
# part, since bytearrays are mutable and can be
# concatenated more efficiently
self.incoming_message = event.data
elif isinstance(event, TextMessage):
if not isinstance(self.incoming_message, bytearray):
# convert to bytearray and append
self.incoming_message = bytearray(
(self.incoming_message + event.data).encode())
else:
# append to bytearray
self.incoming_message += event.data.encode()
else:
if not isinstance(self.incoming_message, bytearray):
# convert to mutable bytearray and append
self.incoming_message = bytearray(
self.incoming_message + event.data)
else:
# append to bytearray
self.incoming_message += event.data
if not event.message_finished:
continue
if isinstance(self.incoming_message, (str, bytes)):
# single part message
self.input_buffer.append(self.incoming_message)
elif isinstance(event, TextMessage):
# convert multi-part message back to text
self.input_buffer.append(
self.incoming_message.decode())
else:
# convert multi-part message back to bytes
self.input_buffer.append(bytes(self.incoming_message))
self.incoming_message = None
self.incoming_message_len = 0
self.event.set()
else: # pragma: no cover
pass
except LocalProtocolError: # pragma: no cover
out_data = b''
self.event.set()
keep_going = False
if out_data:
self.wsock.write(out_data)
return keep_going
class AioServer(AioBase):
"""This class implements a WebSocket server.
Instead of creating an instance of this class directly, use the
``accept()`` class method to create individual instances of the server,
each bound to a client request.
"""
def __init__(self, request, subprotocols=None, receive_bytes=4096,
ping_interval=None, max_message_size=None):
super().__init__(connection_type=ConnectionType.SERVER,
receive_bytes=receive_bytes,
ping_interval=ping_interval,
max_message_size=max_message_size)
self.request = request
self.headers = {}
self.subprotocols = subprotocols or []
if isinstance(self.subprotocols, str):
self.subprotocols = [self.subprotocols]
self.mode = 'unknown'
@classmethod
async def accept(cls, aiohttp=None, asgi=None, sock=None, headers=None,
subprotocols=None, receive_bytes=4096, ping_interval=None,
max_message_size=None):
"""Accept a WebSocket connection from a client.
:param aiohttp: The request object from aiohttp. If this argument is
provided, ``asgi``, ``sock`` and ``headers`` must not
be set.
:param asgi: A (scope, receive, send) tuple from an ASGI request. If
this argument is provided, ``aiohttp``, ``sock`` and
``headers`` must not be set.
:param sock: A connected socket to use. If this argument is provided,
``aiohttp`` and ``asgi`` must not be set. The ``headers``
argument must be set with the incoming request headers.
:param headers: A dictionary with the incoming request headers, when
``sock`` is used.
:param subprotocols: A list of supported subprotocols, or ``None`` (the
default) to disable subprotocol negotiation.
:param receive_bytes: The size of the receive buffer, in bytes. The
default is 4096.
:param ping_interval: Send ping packets to clients at the requested
interval in seconds. Set to ``None`` (the
default) to disable ping/pong logic. Enable to
prevent disconnections when the line is idle for
a certain amount of time, or to detect
unresponsive clients and disconnect them. A
recommended interval is 25 seconds.
:param max_message_size: The maximum size allowed for a message, in
bytes, or ``None`` for no limit. The default
is ``None``.
"""
if aiohttp and (asgi or sock):
raise ValueError('aiohttp argument cannot be used with asgi or '
'sock')
if asgi and (aiohttp or sock):
raise ValueError('asgi argument cannot be used with aiohttp or '
'sock')
if asgi: # pragma: no cover
from .asgi import WebSocketASGI
return await WebSocketASGI.accept(asgi[0], asgi[1], asgi[2],
subprotocols=subprotocols)
ws = cls({'aiohttp': aiohttp, 'sock': sock, 'headers': headers},
subprotocols=subprotocols, receive_bytes=receive_bytes,
ping_interval=ping_interval,
max_message_size=max_message_size)
await ws._accept()
return ws
async def _accept(self):
if self.request['sock']: # pragma: no cover
# custom integration, request is a tuple with (socket, headers)
sock = self.request['sock']
self.headers = self.request['headers']
self.mode = 'custom'
elif self.request['aiohttp']:
# default implementation, request is an aiohttp request object
sock = self.request['aiohttp'].transport.get_extra_info(
'socket').dup()
self.headers = self.request['aiohttp'].headers
self.mode = 'aiohttp'
else: # pragma: no cover
raise ValueError('Invalid request')
self.rsock, self.wsock = await asyncio.open_connection(sock=sock)
await super().connect()
async def handshake(self):
in_data = b'GET / HTTP/1.1\r\n'
for header, value in self.headers.items():
in_data += f'{header}: {value}\r\n'.encode()
in_data += b'\r\n'
self.ws.receive_data(in_data)
self.connected = await self._handle_events()
def choose_subprotocol(self, request):
"""Choose a subprotocol to use for the WebSocket connection.
The default implementation selects the first protocol requested by the
client that is accepted by the server. Subclasses can override this
method to implement a different subprotocol negotiation algorithm.
:param request: A ``Request`` object.
The method should return the subprotocol to use, or ``None`` if no
subprotocol is chosen.
"""
for subprotocol in request.subprotocols:
if subprotocol in self.subprotocols:
return subprotocol
return None
class AioClient(AioBase):
"""This class implements a WebSocket client.
Instead of creating an instance of this class directly, use the
``connect()`` class method to create an instance that is connected to a
server.
"""
def __init__(self, url, subprotocols=None, headers=None,
receive_bytes=4096, ping_interval=None, max_message_size=None,
ssl_context=None):
super().__init__(connection_type=ConnectionType.CLIENT,
receive_bytes=receive_bytes,
ping_interval=ping_interval,
max_message_size=max_message_size)
self.url = url
self.ssl_context = ssl_context
parsed_url = urlsplit(url)
self.is_secure = parsed_url.scheme in ['https', 'wss']
self.host = parsed_url.hostname
self.port = parsed_url.port or (443 if self.is_secure else 80)
self.path = parsed_url.path
if parsed_url.query:
self.path += '?' + parsed_url.query
self.subprotocols = subprotocols or []
if isinstance(self.subprotocols, str):
self.subprotocols = [self.subprotocols]
self.extra_headeers = []
if isinstance(headers, dict):
for key, value in headers.items():
self.extra_headeers.append((key, value))
elif isinstance(headers, list):
self.extra_headeers = headers
@classmethod
async def connect(cls, url, subprotocols=None, headers=None,
receive_bytes=4096, ping_interval=None,
max_message_size=None, ssl_context=None,
thread_class=None, event_class=None):
"""Returns a WebSocket client connection.
:param url: The connection URL. Both ``ws://`` and ``wss://`` URLs are
accepted.
:param subprotocols: The name of the subprotocol to use, or a list of
subprotocol names in order of preference. Set to
``None`` (the default) to not use a subprotocol.
:param headers: A dictionary or list of tuples with additional HTTP
headers to send with the connection request. Note that
custom headers are not supported by the WebSocket
protocol, so the use of this parameter is not
recommended.
:param receive_bytes: The size of the receive buffer, in bytes. The
default is 4096.
:param ping_interval: Send ping packets to the server at the requested
interval in seconds. Set to ``None`` (the
default) to disable ping/pong logic. Enable to
prevent disconnections when the line is idle for
a certain amount of time, or to detect an
unresponsive server and disconnect. A recommended
interval is 25 seconds. In general it is
preferred to enable ping/pong on the server, and
let the client respond with pong (which it does
regardless of this setting).
:param max_message_size: The maximum size allowed for a message, in
bytes, or ``None`` for no limit. The default
is ``None``.
:param ssl_context: An ``SSLContext`` instance, if a default SSL
context isn't sufficient.
"""
ws = cls(url, subprotocols=subprotocols, headers=headers,
receive_bytes=receive_bytes, ping_interval=ping_interval,
max_message_size=max_message_size, ssl_context=ssl_context)
await ws._connect()
return ws
async def _connect(self):
if self.is_secure: # pragma: no cover
if self.ssl_context is None:
self.ssl_context = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH)
self.rsock, self.wsock = await asyncio.open_connection(
self.host, self.port, ssl=self.ssl_context)
await super().connect()
async def handshake(self):
out_data = self.ws.send(Request(host=self.host, target=self.path,
subprotocols=self.subprotocols,
extra_headers=self.extra_headeers))
self.wsock.write(out_data)
while True:
in_data = await self.rsock.read(self.receive_bytes)
self.ws.receive_data(in_data)
try:
event = next(self.ws.events())
except StopIteration: # pragma: no cover
pass
else: # pragma: no cover
break
if isinstance(event, RejectConnection): # pragma: no cover
raise ConnectionError(event.status_code)
elif not isinstance(event, AcceptConnection): # pragma: no cover
raise ConnectionError(400)
self.subprotocol = event.subprotocol
self.connected = True
async def close(self, reason=None, message=None):
await super().close(reason=reason, message=message)
self.wsock.close()

View File

@ -0,0 +1,50 @@
from .errors import ConnectionClosed # pragma: no cover
class WebSocketASGI: # pragma: no cover
def __init__(self, scope, receive, send, subprotocols=None):
self._scope = scope
self._receive = receive
self._send = send
self.subprotocols = subprotocols or []
self.subprotocol = None
self.connected = False
@classmethod
async def accept(cls, scope, receive, send, subprotocols=None):
ws = WebSocketASGI(scope, receive, send, subprotocols=subprotocols)
await ws._accept()
return ws
async def _accept(self):
connect = await self._receive()
if connect['type'] != 'websocket.connect':
raise ValueError('Expected websocket.connect')
for subprotocol in self._scope['subprotocols']:
if subprotocol in self.subprotocols:
self.subprotocol = subprotocol
break
await self._send({'type': 'websocket.accept',
'subprotocol': self.subprotocol})
async def receive(self):
message = await self._receive()
if message['type'] == 'websocket.disconnect':
raise ConnectionClosed()
elif message['type'] != 'websocket.receive':
raise OSError(32, 'Websocket message type not supported')
return message.get('text', message.get('bytes'))
async def send(self, data):
if isinstance(data, str):
await self._send({'type': 'websocket.send', 'text': data})
else:
await self._send({'type': 'websocket.send', 'bytes': data})
async def close(self):
if not self.connected:
self.conncted = False
try:
await self._send({'type': 'websocket.close'})
except Exception:
pass

Some files were not shown because too many files have changed in this diff Show More