Sync: devchat[main](ef593429) Merge pull request #414 from devchat-ai/fix-not-found-custom_git_urls
This commit is contained in:
parent
6ea48f76d1
commit
f16f4ec2d5
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: anyio
|
||||
Version: 4.4.0
|
||||
Version: 4.5.0
|
||||
Summary: High level compatibility layer for multiple asynchronous event loop implementations
|
||||
Author-email: Alex Grönholm <alex.gronholm@nextday.fi>
|
||||
License: MIT
|
||||
@ -20,6 +20,7 @@ Classifier: Programming Language :: Python :: 3.9
|
||||
Classifier: Programming Language :: Python :: 3.10
|
||||
Classifier: Programming Language :: Python :: 3.11
|
||||
Classifier: Programming Language :: Python :: 3.12
|
||||
Classifier: Programming Language :: Python :: 3.13
|
||||
Requires-Python: >=3.8
|
||||
Description-Content-Type: text/x-rst
|
||||
License-File: LICENSE
|
||||
@ -29,7 +30,7 @@ Requires-Dist: exceptiongroup >=1.0.2 ; python_version < "3.11"
|
||||
Requires-Dist: typing-extensions >=4.1 ; python_version < "3.11"
|
||||
Provides-Extra: doc
|
||||
Requires-Dist: packaging ; extra == 'doc'
|
||||
Requires-Dist: Sphinx >=7 ; extra == 'doc'
|
||||
Requires-Dist: Sphinx ~=7.4 ; extra == 'doc'
|
||||
Requires-Dist: sphinx-rtd-theme ; extra == 'doc'
|
||||
Requires-Dist: sphinx-autodoc-typehints >=1.2.0 ; extra == 'doc'
|
||||
Provides-Extra: test
|
||||
@ -41,9 +42,9 @@ Requires-Dist: psutil >=5.9 ; extra == 'test'
|
||||
Requires-Dist: pytest >=7.0 ; extra == 'test'
|
||||
Requires-Dist: pytest-mock >=3.6.1 ; extra == 'test'
|
||||
Requires-Dist: trustme ; extra == 'test'
|
||||
Requires-Dist: uvloop >=0.17 ; (platform_python_implementation == "CPython" and platform_system != "Windows") and extra == 'test'
|
||||
Requires-Dist: uvloop >=0.21.0b1 ; (platform_python_implementation == "CPython" and platform_system != "Windows") and extra == 'test'
|
||||
Provides-Extra: trio
|
||||
Requires-Dist: trio >=0.23 ; extra == 'trio'
|
||||
Requires-Dist: trio >=0.26.1 ; extra == 'trio'
|
||||
|
||||
.. image:: https://github.com/agronholm/anyio/actions/workflows/test.yml/badge.svg
|
||||
:target: https://github.com/agronholm/anyio/actions/workflows/test.yml
|
@ -1,11 +1,11 @@
|
||||
anyio-4.4.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
anyio-4.4.0.dist-info/LICENSE,sha256=U2GsncWPLvX9LpsJxoKXwX8ElQkJu8gCO9uC6s8iwrA,1081
|
||||
anyio-4.4.0.dist-info/METADATA,sha256=sbJaOJ_Ilka4D0U6yKtfOtVrYef7XRFzGjoBEZnRpes,4599
|
||||
anyio-4.4.0.dist-info/RECORD,,
|
||||
anyio-4.4.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
||||
anyio-4.4.0.dist-info/entry_points.txt,sha256=_d6Yu6uiaZmNe0CydowirE9Cmg7zUL2g08tQpoS3Qvc,39
|
||||
anyio-4.4.0.dist-info/top_level.txt,sha256=QglSMiWX8_5dpoVAEIHdEYzvqFMdSYWmCj6tYw2ITkQ,6
|
||||
anyio/__init__.py,sha256=CxUxIHOIONI3KpsDLCg-dI6lQaDkW_4Zhtu5jWt1XO8,4344
|
||||
anyio-4.5.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
anyio-4.5.0.dist-info/LICENSE,sha256=U2GsncWPLvX9LpsJxoKXwX8ElQkJu8gCO9uC6s8iwrA,1081
|
||||
anyio-4.5.0.dist-info/METADATA,sha256=gveTB0gvT7MwEWHWcT9CUmYzS4qgywQPW4gweoJsESs,4658
|
||||
anyio-4.5.0.dist-info/RECORD,,
|
||||
anyio-4.5.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
||||
anyio-4.5.0.dist-info/entry_points.txt,sha256=_d6Yu6uiaZmNe0CydowirE9Cmg7zUL2g08tQpoS3Qvc,39
|
||||
anyio-4.5.0.dist-info/top_level.txt,sha256=QglSMiWX8_5dpoVAEIHdEYzvqFMdSYWmCj6tYw2ITkQ,6
|
||||
anyio/__init__.py,sha256=myTIdg75VPwA-9L7BpislRQplJUPMeleUBHa4MyIruw,4315
|
||||
anyio/__pycache__/__init__.cpython-38.pyc,,
|
||||
anyio/__pycache__/from_thread.cpython-38.pyc,,
|
||||
anyio/__pycache__/lowlevel.cpython-38.pyc,,
|
||||
@ -16,8 +16,8 @@ anyio/_backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
anyio/_backends/__pycache__/__init__.cpython-38.pyc,,
|
||||
anyio/_backends/__pycache__/_asyncio.cpython-38.pyc,,
|
||||
anyio/_backends/__pycache__/_trio.cpython-38.pyc,,
|
||||
anyio/_backends/_asyncio.py,sha256=CVy87WpTh1URmEjlE-AKTrBoPwsqH_nRxbGkLnTrfeg,83244
|
||||
anyio/_backends/_trio.py,sha256=8gdA930WJFn4xrMNpMH6LrHC9MeyTdZBHsB_W-8HFBw,35909
|
||||
anyio/_backends/_asyncio.py,sha256=eu1H_onDHe_Oa6_wiHxsQnOmEnxme0F9Dlmf7ykbkT8,88344
|
||||
anyio/_backends/_trio.py,sha256=L2XiIaDruGztveS8SBPMJMjHHEnZDjXHhsbmu00II6Q,39735
|
||||
anyio/_core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
anyio/_core/__pycache__/__init__.cpython-38.pyc,,
|
||||
anyio/_core/__pycache__/_eventloop.cpython-38.pyc,,
|
||||
@ -33,14 +33,14 @@ anyio/_core/__pycache__/_tasks.cpython-38.pyc,,
|
||||
anyio/_core/__pycache__/_testing.cpython-38.pyc,,
|
||||
anyio/_core/__pycache__/_typedattr.cpython-38.pyc,,
|
||||
anyio/_core/_eventloop.py,sha256=t_tAwBFPjF8jrZGjlJ6bbYy6KA3bjsbZxV9mvh9t1i0,4695
|
||||
anyio/_core/_exceptions.py,sha256=wUmhDu80qEB7z9EdCqUwVEhNUlNEok4_W2-rC6sCAUQ,2078
|
||||
anyio/_core/_fileio.py,sha256=fC6H6DcueA-2AUaDkP91kmVeqoU7DlWj6O0CCAdnsdM,19456
|
||||
anyio/_core/_exceptions.py,sha256=NPxECdXkG4nk3NOCUeFmBEAgPhmj7Bzs4vFAKaW_vqw,2481
|
||||
anyio/_core/_fileio.py,sha256=a4Ab0y8UU0vkYYExM_7pLfBcjeqY0GM-4AyQ4EaRn6s,20067
|
||||
anyio/_core/_resources.py,sha256=NbmU5O5UX3xEyACnkmYX28Fmwdl-f-ny0tHym26e0w0,435
|
||||
anyio/_core/_signals.py,sha256=rDOVxtugZDgC5AhfW3lrwsre2n9Pj_adoRUidBiF6dA,878
|
||||
anyio/_core/_sockets.py,sha256=2jOzi4bXQQYTLr9PSrCaeTgwaU_N7mt0yjHGmX4LvA8,24028
|
||||
anyio/_core/_sockets.py,sha256=iM3UeMU68n0PlQjl2U9HyiOpV26rnjqV4KBr_Fo2z1I,24293
|
||||
anyio/_core/_streams.py,sha256=Z8ZlTY6xom5EszrMsgCT3TphiT4JIlQG-y33CrD0NQY,1811
|
||||
anyio/_core/_subprocesses.py,sha256=ZLLNXAtlRGfbyC4sOIltYB1k3NJa3tqk_x_Fsnbcs1M,5272
|
||||
anyio/_core/_synchronization.py,sha256=h3o6dWWbzVrcNmi7i2mQjEgRtnIxkGtjmYK7KMpdlaE,18444
|
||||
anyio/_core/_subprocesses.py,sha256=DysVq3ZEKayWtjKF3gtyteR_q-Q-K7CmpYTjhIEV_CQ,8314
|
||||
anyio/_core/_synchronization.py,sha256=UDsbG5f8jWsWkRxYUOKp_WOBWCI9-vBO6wBrsR6WNjA,20121
|
||||
anyio/_core/_tasks.py,sha256=pvVEX2Fw159sf0ypAPerukKsZgRRwvFFedVW52nR2Vk,4764
|
||||
anyio/_core/_testing.py,sha256=YUGwA5cgFFbUTv4WFd7cv_BSVr4ryTtPp8owQA3JdWE,2118
|
||||
anyio/_core/_typedattr.py,sha256=P4ozZikn3-DbpoYcvyghS_FOYAgbmUxeoU8-L_07pZM,2508
|
||||
@ -53,17 +53,17 @@ anyio/abc/__pycache__/_streams.cpython-38.pyc,,
|
||||
anyio/abc/__pycache__/_subprocesses.cpython-38.pyc,,
|
||||
anyio/abc/__pycache__/_tasks.cpython-38.pyc,,
|
||||
anyio/abc/__pycache__/_testing.cpython-38.pyc,,
|
||||
anyio/abc/_eventloop.py,sha256=r9pldSu6p-ZsvO6D_brc0EIi1JgZRDbfgVLuy7Q7R6o,10085
|
||||
anyio/abc/_eventloop.py,sha256=L0axMx7JdbH5efXSbkMaCYp0fMdk-zeWm-u4jx_nFW4,9593
|
||||
anyio/abc/_resources.py,sha256=DrYvkNN1hH6Uvv5_5uKySvDsnknGVDe8FCKfko0VtN8,783
|
||||
anyio/abc/_sockets.py,sha256=XdZ42TQ1omZN9Ec3HUfTMWG_i-21yMjXQ_FFslAZtzQ,6269
|
||||
anyio/abc/_streams.py,sha256=GzST5Q2zQmxVzdrAqtbSyHNxkPlIC9AzeZJg_YyPAXw,6598
|
||||
anyio/abc/_subprocesses.py,sha256=cumAPJTktOQtw63IqG0lDpyZqu_l1EElvQHMiwJgL08,2067
|
||||
anyio/abc/_tasks.py,sha256=0Jc6oIwUjMIVReehF6knOZyAqlgwDt4TP1NQkx4IQGw,2731
|
||||
anyio/abc/_testing.py,sha256=tBJUzkSfOXJw23fe8qSJ03kJlShOYjjaEyFB6k6MYT8,1821
|
||||
anyio/from_thread.py,sha256=HtgJ7yZ6RLfRe0l0yyyhpg2mnoax0mqXXgKv8TORUlA,17700
|
||||
anyio/from_thread.py,sha256=2tEb3LZeqVLl-WyLm9siIej4mKwcKK5zQIZEXyafC1A,17439
|
||||
anyio/lowlevel.py,sha256=nkgmW--SdxGVp0cmLUYazjkigveRm5HY7-gW8Bpp9oY,4169
|
||||
anyio/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
anyio/pytest_plugin.py,sha256=TBgRAfT-Oxy6efhO1Tziq54NND3Jy4dRmwkMmQXSvhI,5386
|
||||
anyio/pytest_plugin.py,sha256=le89r6YzzM85skagv04dzDJiuPvSbxJVnwkoJZwaoY8,5852
|
||||
anyio/streams/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
anyio/streams/__pycache__/__init__.cpython-38.pyc,,
|
||||
anyio/streams/__pycache__/buffered.cpython-38.pyc,,
|
||||
@ -74,9 +74,9 @@ anyio/streams/__pycache__/text.cpython-38.pyc,,
|
||||
anyio/streams/__pycache__/tls.cpython-38.pyc,,
|
||||
anyio/streams/buffered.py,sha256=UCldKC168YuLvT7n3HtNPnQ2iWAMSTYQWbZvzLwMwkM,4500
|
||||
anyio/streams/file.py,sha256=6uoTNb5KbMoj-6gS3_xrrL8uZN8Q4iIvOS1WtGyFfKw,4383
|
||||
anyio/streams/memory.py,sha256=Y286x16omNSSGONQx5CBLLNiB3vAJb_vVKt5vb3go-Q,10190
|
||||
anyio/streams/memory.py,sha256=j8AyOExK4-UPaon_Xbhwax25Vqs0DwFg3ZXc-EIiHjY,10550
|
||||
anyio/streams/stapled.py,sha256=U09pCrmOw9kkNhe6tKopsm1QIMT1lFTFvtb-A7SIe4k,4302
|
||||
anyio/streams/text.py,sha256=6x8w8xlfCZKTUWQoJiMPoMhSSJFUBRKgoBNSBtbd9yg,5094
|
||||
anyio/streams/tls.py,sha256=ev-6yNOGcIkziIkcIfKj8VmLqQJW-iDBJttaKgKDsF4,12752
|
||||
anyio/to_process.py,sha256=lx_bt0CUJsS1eSlraw662OpCjRgGXowoyf1Q-i-kOxo,9535
|
||||
anyio/to_process.py,sha256=cR4n7TssbbJowE_9cWme49zaeuoBuMzqgZ6cBIs0YIs,9571
|
||||
anyio/to_thread.py,sha256=WM2JQ2MbVsd5D5CM08bQiTwzZIvpsGjfH1Fy247KoDQ,2396
|
@ -1,5 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: setuptools (73.0.1)
|
||||
Generator: setuptools (75.1.0)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ._core._eventloop import current_time as current_time
|
||||
from ._core._eventloop import get_all_backends as get_all_backends
|
||||
from ._core._eventloop import get_cancelled_exc_class as get_cancelled_exc_class
|
||||
@ -69,8 +67,8 @@ from ._core._typedattr import TypedAttributeSet as TypedAttributeSet
|
||||
from ._core._typedattr import typed_attribute as typed_attribute
|
||||
|
||||
# Re-export imports so they look like they live directly in this package
|
||||
key: str
|
||||
value: Any
|
||||
for key, value in list(locals().items()):
|
||||
if getattr(value, "__module__", "").startswith("anyio."):
|
||||
value.__module__ = __name__
|
||||
for __value in list(locals().values()):
|
||||
if getattr(__value, "__module__", "").startswith("anyio."):
|
||||
__value.__module__ = __name__
|
||||
|
||||
del __value
|
||||
|
@ -4,6 +4,7 @@ import array
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import math
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
@ -19,7 +20,7 @@ from asyncio import (
|
||||
)
|
||||
from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined]
|
||||
from collections import OrderedDict, deque
|
||||
from collections.abc import AsyncIterator, Generator, Iterable
|
||||
from collections.abc import AsyncIterator, Iterable
|
||||
from concurrent.futures import Future
|
||||
from contextlib import suppress
|
||||
from contextvars import Context, copy_context
|
||||
@ -47,7 +48,6 @@ from typing import (
|
||||
Collection,
|
||||
ContextManager,
|
||||
Coroutine,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
@ -58,7 +58,13 @@ from weakref import WeakKeyDictionary
|
||||
|
||||
import sniffio
|
||||
|
||||
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
|
||||
from .. import (
|
||||
CapacityLimiterStatistics,
|
||||
EventStatistics,
|
||||
LockStatistics,
|
||||
TaskInfo,
|
||||
abc,
|
||||
)
|
||||
from .._core._eventloop import claim_worker_thread, threadlocals
|
||||
from .._core._exceptions import (
|
||||
BrokenResourceError,
|
||||
@ -66,12 +72,20 @@ from .._core._exceptions import (
|
||||
ClosedResourceError,
|
||||
EndOfStream,
|
||||
WouldBlock,
|
||||
iterate_exceptions,
|
||||
)
|
||||
from .._core._sockets import convert_ipv6_sockaddr
|
||||
from .._core._streams import create_memory_object_stream
|
||||
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
|
||||
from .._core._synchronization import (
|
||||
CapacityLimiter as BaseCapacityLimiter,
|
||||
)
|
||||
from .._core._synchronization import Event as BaseEvent
|
||||
from .._core._synchronization import ResourceGuard
|
||||
from .._core._synchronization import Lock as BaseLock
|
||||
from .._core._synchronization import (
|
||||
ResourceGuard,
|
||||
SemaphoreStatistics,
|
||||
)
|
||||
from .._core._synchronization import Semaphore as BaseSemaphore
|
||||
from .._core._tasks import CancelScope as BaseCancelScope
|
||||
from ..abc import (
|
||||
AsyncBackend,
|
||||
@ -80,6 +94,7 @@ from ..abc import (
|
||||
UDPPacketType,
|
||||
UNIXDatagramPacketType,
|
||||
)
|
||||
from ..abc._eventloop import StrOrBytesPath
|
||||
from ..lowlevel import RunVar
|
||||
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
@ -630,16 +645,6 @@ class _AsyncioTaskStatus(abc.TaskStatus):
|
||||
_task_states[task].parent_id = self._parent_id
|
||||
|
||||
|
||||
def iterate_exceptions(
|
||||
exception: BaseException,
|
||||
) -> Generator[BaseException, None, None]:
|
||||
if isinstance(exception, BaseExceptionGroup):
|
||||
for exc in exception.exceptions:
|
||||
yield from iterate_exceptions(exc)
|
||||
else:
|
||||
yield exception
|
||||
|
||||
|
||||
class TaskGroup(abc.TaskGroup):
|
||||
def __init__(self) -> None:
|
||||
self.cancel_scope: CancelScope = CancelScope()
|
||||
@ -925,7 +930,7 @@ class StreamReaderWrapper(abc.ByteReceiveStream):
|
||||
raise EndOfStream
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self._stream.feed_eof()
|
||||
self._stream.set_exception(ClosedResourceError())
|
||||
await AsyncIOBackend.checkpoint()
|
||||
|
||||
|
||||
@ -1073,7 +1078,8 @@ class StreamProtocol(asyncio.Protocol):
|
||||
self.write_event.set()
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self.read_queue.append(data)
|
||||
# ProactorEventloop sometimes sends bytearray instead of bytes
|
||||
self.read_queue.append(bytes(data))
|
||||
self.read_event.set()
|
||||
|
||||
def eof_received(self) -> bool | None:
|
||||
@ -1665,6 +1671,154 @@ class Event(BaseEvent):
|
||||
return EventStatistics(len(self._event._waiters))
|
||||
|
||||
|
||||
class Lock(BaseLock):
|
||||
def __new__(cls, *, fast_acquire: bool = False) -> Lock:
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(self, *, fast_acquire: bool = False) -> None:
|
||||
self._fast_acquire = fast_acquire
|
||||
self._owner_task: asyncio.Task | None = None
|
||||
self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if self._owner_task is None and not self._waiters:
|
||||
await AsyncIOBackend.checkpoint_if_cancelled()
|
||||
self._owner_task = current_task()
|
||||
|
||||
# Unless on the "fast path", yield control of the event loop so that other
|
||||
# tasks can run too
|
||||
if not self._fast_acquire:
|
||||
try:
|
||||
await AsyncIOBackend.cancel_shielded_checkpoint()
|
||||
except CancelledError:
|
||||
self.release()
|
||||
raise
|
||||
|
||||
return
|
||||
|
||||
task = cast(asyncio.Task, current_task())
|
||||
fut: asyncio.Future[None] = asyncio.Future()
|
||||
item = task, fut
|
||||
self._waiters.append(item)
|
||||
try:
|
||||
await fut
|
||||
except CancelledError:
|
||||
self._waiters.remove(item)
|
||||
if self._owner_task is task:
|
||||
self.release()
|
||||
|
||||
raise
|
||||
|
||||
self._waiters.remove(item)
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
if self._owner_task is None and not self._waiters:
|
||||
self._owner_task = current_task()
|
||||
return
|
||||
|
||||
raise WouldBlock
|
||||
|
||||
def locked(self) -> bool:
|
||||
return self._owner_task is not None
|
||||
|
||||
def release(self) -> None:
|
||||
if self._owner_task != current_task():
|
||||
raise RuntimeError("The current task is not holding this lock")
|
||||
|
||||
for task, fut in self._waiters:
|
||||
if not fut.cancelled():
|
||||
self._owner_task = task
|
||||
fut.set_result(None)
|
||||
return
|
||||
|
||||
self._owner_task = None
|
||||
|
||||
def statistics(self) -> LockStatistics:
|
||||
task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None
|
||||
return LockStatistics(self.locked(), task_info, len(self._waiters))
|
||||
|
||||
|
||||
class Semaphore(BaseSemaphore):
|
||||
def __new__(
|
||||
cls,
|
||||
initial_value: int,
|
||||
*,
|
||||
max_value: int | None = None,
|
||||
fast_acquire: bool = False,
|
||||
) -> Semaphore:
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_value: int,
|
||||
*,
|
||||
max_value: int | None = None,
|
||||
fast_acquire: bool = False,
|
||||
):
|
||||
super().__init__(initial_value, max_value=max_value)
|
||||
self._value = initial_value
|
||||
self._max_value = max_value
|
||||
self._fast_acquire = fast_acquire
|
||||
self._waiters: deque[asyncio.Future[None]] = deque()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if self._value > 0 and not self._waiters:
|
||||
await AsyncIOBackend.checkpoint_if_cancelled()
|
||||
self._value -= 1
|
||||
|
||||
# Unless on the "fast path", yield control of the event loop so that other
|
||||
# tasks can run too
|
||||
if not self._fast_acquire:
|
||||
try:
|
||||
await AsyncIOBackend.cancel_shielded_checkpoint()
|
||||
except CancelledError:
|
||||
self.release()
|
||||
raise
|
||||
|
||||
return
|
||||
|
||||
fut: asyncio.Future[None] = asyncio.Future()
|
||||
self._waiters.append(fut)
|
||||
try:
|
||||
await fut
|
||||
except CancelledError:
|
||||
try:
|
||||
self._waiters.remove(fut)
|
||||
except ValueError:
|
||||
self.release()
|
||||
|
||||
raise
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
if self._value == 0:
|
||||
raise WouldBlock
|
||||
|
||||
self._value -= 1
|
||||
|
||||
def release(self) -> None:
|
||||
if self._max_value is not None and self._value == self._max_value:
|
||||
raise ValueError("semaphore released too many times")
|
||||
|
||||
for fut in self._waiters:
|
||||
if not fut.cancelled():
|
||||
fut.set_result(None)
|
||||
self._waiters.remove(fut)
|
||||
return
|
||||
|
||||
self._value += 1
|
||||
|
||||
@property
|
||||
def value(self) -> int:
|
||||
return self._value
|
||||
|
||||
@property
|
||||
def max_value(self) -> int | None:
|
||||
return self._max_value
|
||||
|
||||
def statistics(self) -> SemaphoreStatistics:
|
||||
return SemaphoreStatistics(len(self._waiters))
|
||||
|
||||
|
||||
class CapacityLimiter(BaseCapacityLimiter):
|
||||
_total_tokens: float = 0
|
||||
|
||||
@ -1861,7 +2015,9 @@ class AsyncIOTaskInfo(TaskInfo):
|
||||
|
||||
if task_state := _task_states.get(task):
|
||||
if cancel_scope := task_state.cancel_scope:
|
||||
return cancel_scope.cancel_called or cancel_scope._parent_cancelled()
|
||||
return cancel_scope.cancel_called or (
|
||||
not cancel_scope.shield and cancel_scope._parent_cancelled()
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@ -1926,13 +2082,23 @@ class TestRunner(abc.TestRunner):
|
||||
tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
|
||||
],
|
||||
) -> None:
|
||||
from _pytest.outcomes import OutcomeException
|
||||
|
||||
with receive_stream, self._send_stream:
|
||||
async for coro, future in receive_stream:
|
||||
try:
|
||||
retval = await coro
|
||||
except CancelledError as exc:
|
||||
if not future.cancelled():
|
||||
future.cancel(*exc.args)
|
||||
|
||||
raise
|
||||
except BaseException as exc:
|
||||
if not future.cancelled():
|
||||
future.set_exception(exc)
|
||||
|
||||
if not isinstance(exc, (Exception, OutcomeException)):
|
||||
raise
|
||||
else:
|
||||
if not future.cancelled():
|
||||
future.set_result(retval)
|
||||
@ -2113,6 +2279,20 @@ class AsyncIOBackend(AsyncBackend):
|
||||
def create_event(cls) -> abc.Event:
|
||||
return Event()
|
||||
|
||||
@classmethod
|
||||
def create_lock(cls, *, fast_acquire: bool) -> abc.Lock:
|
||||
return Lock(fast_acquire=fast_acquire)
|
||||
|
||||
@classmethod
|
||||
def create_semaphore(
|
||||
cls,
|
||||
initial_value: int,
|
||||
*,
|
||||
max_value: int | None = None,
|
||||
fast_acquire: bool = False,
|
||||
) -> abc.Semaphore:
|
||||
return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
|
||||
|
||||
@classmethod
|
||||
def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
|
||||
return CapacityLimiter(total_tokens)
|
||||
@ -2245,26 +2425,24 @@ class AsyncIOBackend(AsyncBackend):
|
||||
@classmethod
|
||||
async def open_process(
|
||||
cls,
|
||||
command: str | bytes | Sequence[str | bytes],
|
||||
command: StrOrBytesPath | Sequence[StrOrBytesPath],
|
||||
*,
|
||||
shell: bool,
|
||||
stdin: int | IO[Any] | None,
|
||||
stdout: int | IO[Any] | None,
|
||||
stderr: int | IO[Any] | None,
|
||||
cwd: str | bytes | PathLike | None = None,
|
||||
env: Mapping[str, str] | None = None,
|
||||
start_new_session: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Process:
|
||||
await cls.checkpoint()
|
||||
if shell:
|
||||
if isinstance(command, PathLike):
|
||||
command = os.fspath(command)
|
||||
|
||||
if isinstance(command, (str, bytes)):
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cast("str | bytes", command),
|
||||
command,
|
||||
stdin=stdin,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
start_new_session=start_new_session,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
@ -2272,9 +2450,7 @@ class AsyncIOBackend(AsyncBackend):
|
||||
stdin=stdin,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
start_new_session=start_new_session,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
|
||||
@ -2289,7 +2465,7 @@ class AsyncIOBackend(AsyncBackend):
|
||||
name="AnyIO process pool shutdown task",
|
||||
)
|
||||
find_root_task().add_done_callback(
|
||||
partial(_forcibly_shutdown_process_pool_on_exit, workers)
|
||||
partial(_forcibly_shutdown_process_pool_on_exit, workers) # type:ignore[arg-type]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import array
|
||||
import math
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import types
|
||||
@ -25,7 +26,6 @@ from typing import (
|
||||
ContextManager,
|
||||
Coroutine,
|
||||
Generic,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
@ -45,7 +45,14 @@ from trio.lowlevel import (
|
||||
from trio.socket import SocketType as TrioSocketType
|
||||
from trio.to_thread import run_sync
|
||||
|
||||
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
|
||||
from .. import (
|
||||
CapacityLimiterStatistics,
|
||||
EventStatistics,
|
||||
LockStatistics,
|
||||
TaskInfo,
|
||||
WouldBlock,
|
||||
abc,
|
||||
)
|
||||
from .._core._eventloop import claim_worker_thread
|
||||
from .._core._exceptions import (
|
||||
BrokenResourceError,
|
||||
@ -55,12 +62,19 @@ from .._core._exceptions import (
|
||||
)
|
||||
from .._core._sockets import convert_ipv6_sockaddr
|
||||
from .._core._streams import create_memory_object_stream
|
||||
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
|
||||
from .._core._synchronization import (
|
||||
CapacityLimiter as BaseCapacityLimiter,
|
||||
)
|
||||
from .._core._synchronization import Event as BaseEvent
|
||||
from .._core._synchronization import ResourceGuard
|
||||
from .._core._synchronization import Lock as BaseLock
|
||||
from .._core._synchronization import (
|
||||
ResourceGuard,
|
||||
SemaphoreStatistics,
|
||||
)
|
||||
from .._core._synchronization import Semaphore as BaseSemaphore
|
||||
from .._core._tasks import CancelScope as BaseCancelScope
|
||||
from ..abc import IPSockAddrType, UDPPacketType, UNIXDatagramPacketType
|
||||
from ..abc._eventloop import AsyncBackend
|
||||
from ..abc._eventloop import AsyncBackend, StrOrBytesPath
|
||||
from ..streams.memory import MemoryObjectSendStream
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
@ -637,6 +651,100 @@ class Event(BaseEvent):
|
||||
self.__original.set()
|
||||
|
||||
|
||||
class Lock(BaseLock):
|
||||
def __new__(cls, *, fast_acquire: bool = False) -> Lock:
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(self, *, fast_acquire: bool = False) -> None:
|
||||
self._fast_acquire = fast_acquire
|
||||
self.__original = trio.Lock()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if not self._fast_acquire:
|
||||
await self.__original.acquire()
|
||||
return
|
||||
|
||||
# This is the "fast path" where we don't let other tasks run
|
||||
await trio.lowlevel.checkpoint_if_cancelled()
|
||||
try:
|
||||
self.__original.acquire_nowait()
|
||||
except trio.WouldBlock:
|
||||
await self.__original._lot.park()
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
try:
|
||||
self.__original.acquire_nowait()
|
||||
except trio.WouldBlock:
|
||||
raise WouldBlock from None
|
||||
|
||||
def locked(self) -> bool:
|
||||
return self.__original.locked()
|
||||
|
||||
def release(self) -> None:
|
||||
self.__original.release()
|
||||
|
||||
def statistics(self) -> LockStatistics:
|
||||
orig_statistics = self.__original.statistics()
|
||||
owner = TrioTaskInfo(orig_statistics.owner) if orig_statistics.owner else None
|
||||
return LockStatistics(
|
||||
orig_statistics.locked, owner, orig_statistics.tasks_waiting
|
||||
)
|
||||
|
||||
|
||||
class Semaphore(BaseSemaphore):
|
||||
def __new__(
|
||||
cls,
|
||||
initial_value: int,
|
||||
*,
|
||||
max_value: int | None = None,
|
||||
fast_acquire: bool = False,
|
||||
) -> Semaphore:
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_value: int,
|
||||
*,
|
||||
max_value: int | None = None,
|
||||
fast_acquire: bool = False,
|
||||
) -> None:
|
||||
super().__init__(initial_value, max_value=max_value, fast_acquire=fast_acquire)
|
||||
self.__original = trio.Semaphore(initial_value, max_value=max_value)
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if not self._fast_acquire:
|
||||
await self.__original.acquire()
|
||||
return
|
||||
|
||||
# This is the "fast path" where we don't let other tasks run
|
||||
await trio.lowlevel.checkpoint_if_cancelled()
|
||||
try:
|
||||
self.__original.acquire_nowait()
|
||||
except trio.WouldBlock:
|
||||
await self.__original._lot.park()
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
try:
|
||||
self.__original.acquire_nowait()
|
||||
except trio.WouldBlock:
|
||||
raise WouldBlock from None
|
||||
|
||||
@property
|
||||
def max_value(self) -> int | None:
|
||||
return self.__original.max_value
|
||||
|
||||
@property
|
||||
def value(self) -> int:
|
||||
return self.__original.value
|
||||
|
||||
def release(self) -> None:
|
||||
self.__original.release()
|
||||
|
||||
def statistics(self) -> SemaphoreStatistics:
|
||||
orig_statistics = self.__original.statistics()
|
||||
return SemaphoreStatistics(orig_statistics.tasks_waiting)
|
||||
|
||||
|
||||
class CapacityLimiter(BaseCapacityLimiter):
|
||||
def __new__(
|
||||
cls,
|
||||
@ -915,6 +1023,20 @@ class TrioBackend(AsyncBackend):
|
||||
def create_event(cls) -> abc.Event:
|
||||
return Event()
|
||||
|
||||
@classmethod
|
||||
def create_lock(cls, *, fast_acquire: bool) -> Lock:
|
||||
return Lock(fast_acquire=fast_acquire)
|
||||
|
||||
@classmethod
|
||||
def create_semaphore(
|
||||
cls,
|
||||
initial_value: int,
|
||||
*,
|
||||
max_value: int | None = None,
|
||||
fast_acquire: bool = False,
|
||||
) -> abc.Semaphore:
|
||||
return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
|
||||
|
||||
@classmethod
|
||||
def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter:
|
||||
return CapacityLimiter(total_tokens)
|
||||
@ -967,26 +1089,39 @@ class TrioBackend(AsyncBackend):
|
||||
@classmethod
|
||||
async def open_process(
|
||||
cls,
|
||||
command: str | bytes | Sequence[str | bytes],
|
||||
command: StrOrBytesPath | Sequence[StrOrBytesPath],
|
||||
*,
|
||||
shell: bool,
|
||||
stdin: int | IO[Any] | None,
|
||||
stdout: int | IO[Any] | None,
|
||||
stderr: int | IO[Any] | None,
|
||||
cwd: str | bytes | PathLike | None = None,
|
||||
env: Mapping[str, str] | None = None,
|
||||
start_new_session: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Process:
|
||||
process = await trio.lowlevel.open_process( # type: ignore[misc]
|
||||
command, # type: ignore[arg-type]
|
||||
stdin=stdin,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
shell=shell,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
start_new_session=start_new_session,
|
||||
)
|
||||
def convert_item(item: StrOrBytesPath) -> str:
|
||||
str_or_bytes = os.fspath(item)
|
||||
if isinstance(str_or_bytes, str):
|
||||
return str_or_bytes
|
||||
else:
|
||||
return os.fsdecode(str_or_bytes)
|
||||
|
||||
if isinstance(command, (str, bytes, PathLike)):
|
||||
process = await trio.lowlevel.open_process(
|
||||
convert_item(command),
|
||||
stdin=stdin,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
shell=True,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
process = await trio.lowlevel.open_process(
|
||||
[convert_item(item) for item in command],
|
||||
stdin=stdin,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
shell=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None
|
||||
stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None
|
||||
stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None
|
||||
|
@ -1,5 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Generator
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
|
||||
class BrokenResourceError(Exception):
|
||||
"""
|
||||
@ -71,3 +77,13 @@ class TypedAttributeLookupError(LookupError):
|
||||
|
||||
class WouldBlock(Exception):
|
||||
"""Raised by ``X_nowait`` functions if ``X()`` would block."""
|
||||
|
||||
|
||||
def iterate_exceptions(
|
||||
exception: BaseException,
|
||||
) -> Generator[BaseException, None, None]:
|
||||
if isinstance(exception, BaseExceptionGroup):
|
||||
for exc in exception.exceptions:
|
||||
yield from iterate_exceptions(exc)
|
||||
else:
|
||||
yield exception
|
||||
|
@ -358,8 +358,26 @@ class Path:
|
||||
def as_uri(self) -> str:
|
||||
return self._path.as_uri()
|
||||
|
||||
def match(self, path_pattern: str) -> bool:
|
||||
return self._path.match(path_pattern)
|
||||
if sys.version_info >= (3, 13):
|
||||
parser = pathlib.Path.parser
|
||||
|
||||
@classmethod
|
||||
def from_uri(cls, uri: str) -> Path:
|
||||
return Path(pathlib.Path.from_uri(uri))
|
||||
|
||||
def full_match(
|
||||
self, path_pattern: str, *, case_sensitive: bool | None = None
|
||||
) -> bool:
|
||||
return self._path.full_match(path_pattern, case_sensitive=case_sensitive)
|
||||
|
||||
def match(
|
||||
self, path_pattern: str, *, case_sensitive: bool | None = None
|
||||
) -> bool:
|
||||
return self._path.match(path_pattern, case_sensitive=case_sensitive)
|
||||
else:
|
||||
|
||||
def match(self, path_pattern: str) -> bool:
|
||||
return self._path.match(path_pattern)
|
||||
|
||||
def is_relative_to(self, other: str | PathLike[str]) -> bool:
|
||||
try:
|
||||
|
@ -680,19 +680,26 @@ async def setup_unix_local_socket(
|
||||
:param socktype: socket.SOCK_STREAM or socket.SOCK_DGRAM
|
||||
|
||||
"""
|
||||
path_str: str | bytes | None
|
||||
path_str: str | None
|
||||
if path is not None:
|
||||
path_str = os.fspath(path)
|
||||
path_str = os.fsdecode(path)
|
||||
|
||||
# Copied from pathlib...
|
||||
try:
|
||||
stat_result = os.stat(path)
|
||||
except OSError as e:
|
||||
if e.errno not in (errno.ENOENT, errno.ENOTDIR, errno.EBADF, errno.ELOOP):
|
||||
raise
|
||||
else:
|
||||
if stat.S_ISSOCK(stat_result.st_mode):
|
||||
os.unlink(path)
|
||||
# Linux abstract namespace sockets aren't backed by a concrete file so skip stat call
|
||||
if not path_str.startswith("\0"):
|
||||
# Copied from pathlib...
|
||||
try:
|
||||
stat_result = os.stat(path)
|
||||
except OSError as e:
|
||||
if e.errno not in (
|
||||
errno.ENOENT,
|
||||
errno.ENOTDIR,
|
||||
errno.EBADF,
|
||||
errno.ELOOP,
|
||||
):
|
||||
raise
|
||||
else:
|
||||
if stat.S_ISSOCK(stat_result.st_mode):
|
||||
os.unlink(path)
|
||||
else:
|
||||
path_str = None
|
||||
|
||||
|
@ -1,26 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterable, Mapping, Sequence
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Iterable, Mapping, Sequence
|
||||
from io import BytesIO
|
||||
from os import PathLike
|
||||
from subprocess import DEVNULL, PIPE, CalledProcessError, CompletedProcess
|
||||
from typing import IO, Any, cast
|
||||
from typing import IO, Any, Union, cast
|
||||
|
||||
from ..abc import Process
|
||||
from ._eventloop import get_async_backend
|
||||
from ._tasks import create_task_group
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
StrOrBytesPath: TypeAlias = Union[str, bytes, "PathLike[str]", "PathLike[bytes]"]
|
||||
|
||||
|
||||
async def run_process(
|
||||
command: str | bytes | Sequence[str | bytes],
|
||||
command: StrOrBytesPath | Sequence[StrOrBytesPath],
|
||||
*,
|
||||
input: bytes | None = None,
|
||||
stdout: int | IO[Any] | None = PIPE,
|
||||
stderr: int | IO[Any] | None = PIPE,
|
||||
check: bool = True,
|
||||
cwd: str | bytes | PathLike[str] | None = None,
|
||||
cwd: StrOrBytesPath | None = None,
|
||||
env: Mapping[str, str] | None = None,
|
||||
startupinfo: Any = None,
|
||||
creationflags: int = 0,
|
||||
start_new_session: bool = False,
|
||||
pass_fds: Sequence[int] = (),
|
||||
user: str | int | None = None,
|
||||
group: str | int | None = None,
|
||||
extra_groups: Iterable[str | int] | None = None,
|
||||
umask: int = -1,
|
||||
) -> CompletedProcess[bytes]:
|
||||
"""
|
||||
Run an external command in a subprocess and wait until it completes.
|
||||
@ -40,8 +55,20 @@ async def run_process(
|
||||
command
|
||||
:param env: if not ``None``, this mapping replaces the inherited environment
|
||||
variables from the parent process
|
||||
:param startupinfo: an instance of :class:`subprocess.STARTUPINFO` that can be used
|
||||
to specify process startup parameters (Windows only)
|
||||
:param creationflags: flags that can be used to control the creation of the
|
||||
subprocess (see :class:`subprocess.Popen` for the specifics)
|
||||
:param start_new_session: if ``true`` the setsid() system call will be made in the
|
||||
child process prior to the execution of the subprocess. (POSIX only)
|
||||
:param pass_fds: sequence of file descriptors to keep open between the parent and
|
||||
child processes. (POSIX only)
|
||||
:param user: effective user to run the process as (Python >= 3.9, POSIX only)
|
||||
:param group: effective group to run the process as (Python >= 3.9, POSIX only)
|
||||
:param extra_groups: supplementary groups to set in the subprocess (Python >= 3.9,
|
||||
POSIX only)
|
||||
:param umask: if not negative, this umask is applied in the child process before
|
||||
running the given command (Python >= 3.9, POSIX only)
|
||||
:return: an object representing the completed process
|
||||
:raises ~subprocess.CalledProcessError: if ``check`` is ``True`` and the process
|
||||
exits with a nonzero return code
|
||||
@ -62,7 +89,14 @@ async def run_process(
|
||||
stderr=stderr,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
startupinfo=startupinfo,
|
||||
creationflags=creationflags,
|
||||
start_new_session=start_new_session,
|
||||
pass_fds=pass_fds,
|
||||
user=user,
|
||||
group=group,
|
||||
extra_groups=extra_groups,
|
||||
umask=umask,
|
||||
) as process:
|
||||
stream_contents: list[bytes | None] = [None, None]
|
||||
async with create_task_group() as tg:
|
||||
@ -86,14 +120,21 @@ async def run_process(
|
||||
|
||||
|
||||
async def open_process(
|
||||
command: str | bytes | Sequence[str | bytes],
|
||||
command: StrOrBytesPath | Sequence[StrOrBytesPath],
|
||||
*,
|
||||
stdin: int | IO[Any] | None = PIPE,
|
||||
stdout: int | IO[Any] | None = PIPE,
|
||||
stderr: int | IO[Any] | None = PIPE,
|
||||
cwd: str | bytes | PathLike[str] | None = None,
|
||||
cwd: StrOrBytesPath | None = None,
|
||||
env: Mapping[str, str] | None = None,
|
||||
startupinfo: Any = None,
|
||||
creationflags: int = 0,
|
||||
start_new_session: bool = False,
|
||||
pass_fds: Sequence[int] = (),
|
||||
user: str | int | None = None,
|
||||
group: str | int | None = None,
|
||||
extra_groups: Iterable[str | int] | None = None,
|
||||
umask: int = -1,
|
||||
) -> Process:
|
||||
"""
|
||||
Start an external command in a subprocess.
|
||||
@ -111,30 +152,58 @@ async def open_process(
|
||||
:param cwd: If not ``None``, the working directory is changed before executing
|
||||
:param env: If env is not ``None``, it must be a mapping that defines the
|
||||
environment variables for the new process
|
||||
:param creationflags: flags that can be used to control the creation of the
|
||||
subprocess (see :class:`subprocess.Popen` for the specifics)
|
||||
:param startupinfo: an instance of :class:`subprocess.STARTUPINFO` that can be used
|
||||
to specify process startup parameters (Windows only)
|
||||
:param start_new_session: if ``true`` the setsid() system call will be made in the
|
||||
child process prior to the execution of the subprocess. (POSIX only)
|
||||
:param pass_fds: sequence of file descriptors to keep open between the parent and
|
||||
child processes. (POSIX only)
|
||||
:param user: effective user to run the process as (Python >= 3.9; POSIX only)
|
||||
:param group: effective group to run the process as (Python >= 3.9; POSIX only)
|
||||
:param extra_groups: supplementary groups to set in the subprocess (Python >= 3.9;
|
||||
POSIX only)
|
||||
:param umask: if not negative, this umask is applied in the child process before
|
||||
running the given command (Python >= 3.9; POSIX only)
|
||||
:return: an asynchronous process object
|
||||
|
||||
"""
|
||||
if isinstance(command, (str, bytes)):
|
||||
return await get_async_backend().open_process(
|
||||
command,
|
||||
shell=True,
|
||||
stdin=stdin,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
start_new_session=start_new_session,
|
||||
)
|
||||
else:
|
||||
return await get_async_backend().open_process(
|
||||
command,
|
||||
shell=False,
|
||||
stdin=stdin,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
start_new_session=start_new_session,
|
||||
)
|
||||
kwargs: dict[str, Any] = {}
|
||||
if user is not None:
|
||||
if sys.version_info < (3, 9):
|
||||
raise TypeError("the 'user' argument requires Python 3.9 or later")
|
||||
|
||||
kwargs["user"] = user
|
||||
|
||||
if group is not None:
|
||||
if sys.version_info < (3, 9):
|
||||
raise TypeError("the 'group' argument requires Python 3.9 or later")
|
||||
|
||||
kwargs["group"] = group
|
||||
|
||||
if extra_groups is not None:
|
||||
if sys.version_info < (3, 9):
|
||||
raise TypeError("the 'extra_groups' argument requires Python 3.9 or later")
|
||||
|
||||
kwargs["extra_groups"] = group
|
||||
|
||||
if umask >= 0:
|
||||
if sys.version_info < (3, 9):
|
||||
raise TypeError("the 'umask' argument requires Python 3.9 or later")
|
||||
|
||||
kwargs["umask"] = umask
|
||||
|
||||
return await get_async_backend().open_process(
|
||||
command,
|
||||
stdin=stdin,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
startupinfo=startupinfo,
|
||||
creationflags=creationflags,
|
||||
start_new_session=start_new_session,
|
||||
pass_fds=pass_fds,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -7,9 +7,9 @@ from types import TracebackType
|
||||
|
||||
from sniffio import AsyncLibraryNotFoundError
|
||||
|
||||
from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled
|
||||
from ..lowlevel import checkpoint
|
||||
from ._eventloop import get_async_backend
|
||||
from ._exceptions import BusyResourceError, WouldBlock
|
||||
from ._exceptions import BusyResourceError
|
||||
from ._tasks import CancelScope
|
||||
from ._testing import TaskInfo, get_current_task
|
||||
|
||||
@ -137,10 +137,11 @@ class EventAdapter(Event):
|
||||
|
||||
|
||||
class Lock:
|
||||
_owner_task: TaskInfo | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._waiters: deque[tuple[TaskInfo, Event]] = deque()
|
||||
def __new__(cls, *, fast_acquire: bool = False) -> Lock:
|
||||
try:
|
||||
return get_async_backend().create_lock(fast_acquire=fast_acquire)
|
||||
except AsyncLibraryNotFoundError:
|
||||
return LockAdapter(fast_acquire=fast_acquire)
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self.acquire()
|
||||
@ -155,31 +156,7 @@ class Lock:
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""Acquire the lock."""
|
||||
await checkpoint_if_cancelled()
|
||||
try:
|
||||
self.acquire_nowait()
|
||||
except WouldBlock:
|
||||
task = get_current_task()
|
||||
event = Event()
|
||||
token = task, event
|
||||
self._waiters.append(token)
|
||||
try:
|
||||
await event.wait()
|
||||
except BaseException:
|
||||
if not event.is_set():
|
||||
self._waiters.remove(token)
|
||||
elif self._owner_task == task:
|
||||
self.release()
|
||||
|
||||
raise
|
||||
|
||||
assert self._owner_task == task
|
||||
else:
|
||||
try:
|
||||
await cancel_shielded_checkpoint()
|
||||
except BaseException:
|
||||
self.release()
|
||||
raise
|
||||
raise NotImplementedError
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
"""
|
||||
@ -188,29 +165,15 @@ class Lock:
|
||||
:raises ~anyio.WouldBlock: if the operation would block
|
||||
|
||||
"""
|
||||
task = get_current_task()
|
||||
if self._owner_task == task:
|
||||
raise RuntimeError("Attempted to acquire an already held Lock")
|
||||
|
||||
if self._owner_task is not None:
|
||||
raise WouldBlock
|
||||
|
||||
self._owner_task = task
|
||||
raise NotImplementedError
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release the lock."""
|
||||
if self._owner_task != get_current_task():
|
||||
raise RuntimeError("The current task is not holding this lock")
|
||||
|
||||
if self._waiters:
|
||||
self._owner_task, event = self._waiters.popleft()
|
||||
event.set()
|
||||
else:
|
||||
del self._owner_task
|
||||
raise NotImplementedError
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""Return True if the lock is currently held."""
|
||||
return self._owner_task is not None
|
||||
raise NotImplementedError
|
||||
|
||||
def statistics(self) -> LockStatistics:
|
||||
"""
|
||||
@ -218,7 +181,71 @@ class Lock:
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return LockStatistics(self.locked(), self._owner_task, len(self._waiters))
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LockAdapter(Lock):
|
||||
_internal_lock: Lock | None = None
|
||||
|
||||
def __new__(cls, *, fast_acquire: bool = False) -> LockAdapter:
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(self, *, fast_acquire: bool = False):
|
||||
self._fast_acquire = fast_acquire
|
||||
|
||||
@property
|
||||
def _lock(self) -> Lock:
|
||||
if self._internal_lock is None:
|
||||
self._internal_lock = get_async_backend().create_lock(
|
||||
fast_acquire=self._fast_acquire
|
||||
)
|
||||
|
||||
return self._internal_lock
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self._lock.acquire()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
if self._internal_lock is not None:
|
||||
self._internal_lock.release()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""Acquire the lock."""
|
||||
await self._lock.acquire()
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
"""
|
||||
Acquire the lock, without blocking.
|
||||
|
||||
:raises ~anyio.WouldBlock: if the operation would block
|
||||
|
||||
"""
|
||||
self._lock.acquire_nowait()
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release the lock."""
|
||||
self._lock.release()
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""Return True if the lock is currently held."""
|
||||
return self._lock.locked()
|
||||
|
||||
def statistics(self) -> LockStatistics:
|
||||
"""
|
||||
Return statistics about the current state of this lock.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
|
||||
"""
|
||||
if self._internal_lock is None:
|
||||
return LockStatistics(False, None, 0)
|
||||
|
||||
return self._internal_lock.statistics()
|
||||
|
||||
|
||||
class Condition:
|
||||
@ -312,7 +339,27 @@ class Condition:
|
||||
|
||||
|
||||
class Semaphore:
|
||||
def __init__(self, initial_value: int, *, max_value: int | None = None):
|
||||
def __new__(
|
||||
cls,
|
||||
initial_value: int,
|
||||
*,
|
||||
max_value: int | None = None,
|
||||
fast_acquire: bool = False,
|
||||
) -> Semaphore:
|
||||
try:
|
||||
return get_async_backend().create_semaphore(
|
||||
initial_value, max_value=max_value, fast_acquire=fast_acquire
|
||||
)
|
||||
except AsyncLibraryNotFoundError:
|
||||
return SemaphoreAdapter(initial_value, max_value=max_value)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_value: int,
|
||||
*,
|
||||
max_value: int | None = None,
|
||||
fast_acquire: bool = False,
|
||||
):
|
||||
if not isinstance(initial_value, int):
|
||||
raise TypeError("initial_value must be an integer")
|
||||
if initial_value < 0:
|
||||
@ -325,9 +372,7 @@ class Semaphore:
|
||||
"max_value must be equal to or higher than initial_value"
|
||||
)
|
||||
|
||||
self._value = initial_value
|
||||
self._max_value = max_value
|
||||
self._waiters: deque[Event] = deque()
|
||||
self._fast_acquire = fast_acquire
|
||||
|
||||
async def __aenter__(self) -> Semaphore:
|
||||
await self.acquire()
|
||||
@ -343,27 +388,7 @@ class Semaphore:
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""Decrement the semaphore value, blocking if necessary."""
|
||||
await checkpoint_if_cancelled()
|
||||
try:
|
||||
self.acquire_nowait()
|
||||
except WouldBlock:
|
||||
event = Event()
|
||||
self._waiters.append(event)
|
||||
try:
|
||||
await event.wait()
|
||||
except BaseException:
|
||||
if not event.is_set():
|
||||
self._waiters.remove(event)
|
||||
else:
|
||||
self.release()
|
||||
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
await cancel_shielded_checkpoint()
|
||||
except BaseException:
|
||||
self.release()
|
||||
raise
|
||||
raise NotImplementedError
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
"""
|
||||
@ -372,30 +397,21 @@ class Semaphore:
|
||||
:raises ~anyio.WouldBlock: if the operation would block
|
||||
|
||||
"""
|
||||
if self._value == 0:
|
||||
raise WouldBlock
|
||||
|
||||
self._value -= 1
|
||||
raise NotImplementedError
|
||||
|
||||
def release(self) -> None:
|
||||
"""Increment the semaphore value."""
|
||||
if self._max_value is not None and self._value == self._max_value:
|
||||
raise ValueError("semaphore released too many times")
|
||||
|
||||
if self._waiters:
|
||||
self._waiters.popleft().set()
|
||||
else:
|
||||
self._value += 1
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def value(self) -> int:
|
||||
"""The current value of the semaphore."""
|
||||
return self._value
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def max_value(self) -> int | None:
|
||||
"""The maximum value of the semaphore."""
|
||||
return self._max_value
|
||||
raise NotImplementedError
|
||||
|
||||
def statistics(self) -> SemaphoreStatistics:
|
||||
"""
|
||||
@ -403,7 +419,66 @@ class Semaphore:
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return SemaphoreStatistics(len(self._waiters))
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SemaphoreAdapter(Semaphore):
|
||||
_internal_semaphore: Semaphore | None = None
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
initial_value: int,
|
||||
*,
|
||||
max_value: int | None = None,
|
||||
fast_acquire: bool = False,
|
||||
) -> SemaphoreAdapter:
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_value: int,
|
||||
*,
|
||||
max_value: int | None = None,
|
||||
fast_acquire: bool = False,
|
||||
) -> None:
|
||||
super().__init__(initial_value, max_value=max_value, fast_acquire=fast_acquire)
|
||||
self._initial_value = initial_value
|
||||
self._max_value = max_value
|
||||
|
||||
@property
|
||||
def _semaphore(self) -> Semaphore:
|
||||
if self._internal_semaphore is None:
|
||||
self._internal_semaphore = get_async_backend().create_semaphore(
|
||||
self._initial_value, max_value=self._max_value
|
||||
)
|
||||
|
||||
return self._internal_semaphore
|
||||
|
||||
async def acquire(self) -> None:
|
||||
await self._semaphore.acquire()
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
self._semaphore.acquire_nowait()
|
||||
|
||||
def release(self) -> None:
|
||||
self._semaphore.release()
|
||||
|
||||
@property
|
||||
def value(self) -> int:
|
||||
if self._internal_semaphore is None:
|
||||
return self._initial_value
|
||||
|
||||
return self._semaphore.value
|
||||
|
||||
@property
|
||||
def max_value(self) -> int | None:
|
||||
return self._max_value
|
||||
|
||||
def statistics(self) -> SemaphoreStatistics:
|
||||
if self._internal_semaphore is None:
|
||||
return SemaphoreStatistics(tasks_waiting=0)
|
||||
|
||||
return self._semaphore.statistics()
|
||||
|
||||
|
||||
class CapacityLimiter:
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import math
|
||||
import sys
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections.abc import AsyncIterator, Awaitable, Mapping
|
||||
from collections.abc import AsyncIterator, Awaitable
|
||||
from os import PathLike
|
||||
from signal import Signals
|
||||
from socket import AddressFamily, SocketKind, socket
|
||||
@ -15,6 +15,7 @@ from typing import (
|
||||
ContextManager,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
@ -23,10 +24,13 @@ if sys.version_info >= (3, 11):
|
||||
else:
|
||||
from typing_extensions import TypeVarTuple, Unpack
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .._core._synchronization import CapacityLimiter, Event
|
||||
if TYPE_CHECKING:
|
||||
from .._core._synchronization import CapacityLimiter, Event, Lock, Semaphore
|
||||
from .._core._tasks import CancelScope
|
||||
from .._core._testing import TaskInfo
|
||||
from ..from_thread import BlockingPortal
|
||||
@ -46,6 +50,7 @@ if TYPE_CHECKING:
|
||||
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
PosArgsT = TypeVarTuple("PosArgsT")
|
||||
StrOrBytesPath: TypeAlias = Union[str, bytes, "PathLike[str]", "PathLike[bytes]"]
|
||||
|
||||
|
||||
class AsyncBackend(metaclass=ABCMeta):
|
||||
@ -167,6 +172,22 @@ class AsyncBackend(metaclass=ABCMeta):
|
||||
def create_event(cls) -> Event:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create_lock(cls, *, fast_acquire: bool) -> Lock:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create_semaphore(
|
||||
cls,
|
||||
initial_value: int,
|
||||
*,
|
||||
max_value: int | None = None,
|
||||
fast_acquire: bool = False,
|
||||
) -> Semaphore:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter:
|
||||
@ -213,51 +234,16 @@ class AsyncBackend(metaclass=ABCMeta):
|
||||
def create_blocking_portal(cls) -> BlockingPortal:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@overload
|
||||
async def open_process(
|
||||
cls,
|
||||
command: str | bytes,
|
||||
*,
|
||||
shell: Literal[True],
|
||||
stdin: int | IO[Any] | None,
|
||||
stdout: int | IO[Any] | None,
|
||||
stderr: int | IO[Any] | None,
|
||||
cwd: str | bytes | PathLike[str] | None = None,
|
||||
env: Mapping[str, str] | None = None,
|
||||
start_new_session: bool = False,
|
||||
) -> Process:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@overload
|
||||
async def open_process(
|
||||
cls,
|
||||
command: Sequence[str | bytes],
|
||||
*,
|
||||
shell: Literal[False],
|
||||
stdin: int | IO[Any] | None,
|
||||
stdout: int | IO[Any] | None,
|
||||
stderr: int | IO[Any] | None,
|
||||
cwd: str | bytes | PathLike[str] | None = None,
|
||||
env: Mapping[str, str] | None = None,
|
||||
start_new_session: bool = False,
|
||||
) -> Process:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def open_process(
|
||||
cls,
|
||||
command: str | bytes | Sequence[str | bytes],
|
||||
command: StrOrBytesPath | Sequence[StrOrBytesPath],
|
||||
*,
|
||||
shell: bool,
|
||||
stdin: int | IO[Any] | None,
|
||||
stdout: int | IO[Any] | None,
|
||||
stderr: int | IO[Any] | None,
|
||||
cwd: str | bytes | PathLike[str] | None = None,
|
||||
env: Mapping[str, str] | None = None,
|
||||
start_new_session: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Process:
|
||||
pass
|
||||
|
||||
|
@ -1,19 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import threading
|
||||
from collections.abc import Awaitable, Callable, Generator
|
||||
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
|
||||
from concurrent.futures import Future
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from inspect import isawaitable
|
||||
from threading import Lock, Thread, get_ident
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
ContextManager,
|
||||
Generic,
|
||||
Iterable,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
@ -146,7 +145,7 @@ class BlockingPortal:
|
||||
return get_async_backend().create_blocking_portal()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._event_loop_thread_id: int | None = threading.get_ident()
|
||||
self._event_loop_thread_id: int | None = get_ident()
|
||||
self._stop_event = Event()
|
||||
self._task_group = create_task_group()
|
||||
self._cancelled_exc_class = get_cancelled_exc_class()
|
||||
@ -167,7 +166,7 @@ class BlockingPortal:
|
||||
def _check_running(self) -> None:
|
||||
if self._event_loop_thread_id is None:
|
||||
raise RuntimeError("This portal is not running")
|
||||
if self._event_loop_thread_id == threading.get_ident():
|
||||
if self._event_loop_thread_id == get_ident():
|
||||
raise RuntimeError(
|
||||
"This method cannot be called from the event loop thread"
|
||||
)
|
||||
@ -202,7 +201,7 @@ class BlockingPortal:
|
||||
def callback(f: Future[T_Retval]) -> None:
|
||||
if f.cancelled() and self._event_loop_thread_id not in (
|
||||
None,
|
||||
threading.get_ident(),
|
||||
get_ident(),
|
||||
):
|
||||
self.call(scope.cancel)
|
||||
|
||||
@ -411,7 +410,7 @@ class BlockingPortalProvider:
|
||||
|
||||
backend: str = "asyncio"
|
||||
backend_options: dict[str, Any] | None = None
|
||||
_lock: threading.Lock = field(init=False, default_factory=threading.Lock)
|
||||
_lock: Lock = field(init=False, default_factory=Lock)
|
||||
_leases: int = field(init=False, default=0)
|
||||
_portal: BlockingPortal = field(init=False)
|
||||
_portal_cm: AbstractContextManager[BlockingPortal] | None = field(
|
||||
@ -469,43 +468,37 @@ def start_blocking_portal(
|
||||
|
||||
async def run_portal() -> None:
|
||||
async with BlockingPortal() as portal_:
|
||||
if future.set_running_or_notify_cancel():
|
||||
future.set_result(portal_)
|
||||
await portal_.sleep_until_stopped()
|
||||
future.set_result(portal_)
|
||||
await portal_.sleep_until_stopped()
|
||||
|
||||
def run_blocking_portal() -> None:
|
||||
if future.set_running_or_notify_cancel():
|
||||
try:
|
||||
_eventloop.run(
|
||||
run_portal, backend=backend, backend_options=backend_options
|
||||
)
|
||||
except BaseException as exc:
|
||||
if not future.done():
|
||||
future.set_exception(exc)
|
||||
|
||||
future: Future[BlockingPortal] = Future()
|
||||
with ThreadPoolExecutor(1) as executor:
|
||||
run_future = executor.submit(
|
||||
_eventloop.run, # type: ignore[arg-type]
|
||||
run_portal,
|
||||
backend=backend,
|
||||
backend_options=backend_options,
|
||||
)
|
||||
thread = Thread(target=run_blocking_portal, daemon=True)
|
||||
thread.start()
|
||||
try:
|
||||
cancel_remaining_tasks = False
|
||||
portal = future.result()
|
||||
try:
|
||||
wait(
|
||||
cast(Iterable[Future], [run_future, future]),
|
||||
return_when=FIRST_COMPLETED,
|
||||
)
|
||||
yield portal
|
||||
except BaseException:
|
||||
future.cancel()
|
||||
run_future.cancel()
|
||||
cancel_remaining_tasks = True
|
||||
raise
|
||||
|
||||
if future.done():
|
||||
portal = future.result()
|
||||
cancel_remaining_tasks = False
|
||||
finally:
|
||||
try:
|
||||
yield portal
|
||||
except BaseException:
|
||||
cancel_remaining_tasks = True
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
portal.call(portal.stop, cancel_remaining_tasks)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
run_future.result()
|
||||
portal.call(portal.stop, cancel_remaining_tasks)
|
||||
except RuntimeError:
|
||||
pass
|
||||
finally:
|
||||
thread.join()
|
||||
|
||||
|
||||
def check_cancelled() -> None:
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from inspect import isasyncgenfunction, iscoroutinefunction
|
||||
@ -7,10 +8,15 @@ from typing import Any, Dict, Tuple, cast
|
||||
|
||||
import pytest
|
||||
import sniffio
|
||||
from _pytest.outcomes import Exit
|
||||
|
||||
from ._core._eventloop import get_all_backends, get_async_backend
|
||||
from ._core._exceptions import iterate_exceptions
|
||||
from .abc import TestRunner
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import ExceptionGroup
|
||||
|
||||
_current_runner: TestRunner | None = None
|
||||
_runner_stack: ExitStack | None = None
|
||||
_runner_leases = 0
|
||||
@ -121,7 +127,14 @@ def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None:
|
||||
funcargs = pyfuncitem.funcargs
|
||||
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
|
||||
with get_runner(backend_name, backend_options) as runner:
|
||||
runner.run_test(pyfuncitem.obj, testargs)
|
||||
try:
|
||||
runner.run_test(pyfuncitem.obj, testargs)
|
||||
except ExceptionGroup as excgrp:
|
||||
for exc in iterate_exceptions(excgrp):
|
||||
if isinstance(exc, (Exit, KeyboardInterrupt, SystemExit)):
|
||||
raise exc from excgrp
|
||||
|
||||
raise
|
||||
|
||||
return True
|
||||
|
||||
|
@ -38,6 +38,12 @@ class MemoryObjectItemReceiver(Generic[T_Item]):
|
||||
task_info: TaskInfo = field(init=False, default_factory=get_current_task)
|
||||
item: T_Item = field(init=False)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# When item is not defined, we get following error with default __repr__:
|
||||
# AttributeError: 'MemoryObjectItemReceiver' object has no attribute 'item'
|
||||
item = getattr(self, "item", None)
|
||||
return f"{self.__class__.__name__}(task_info={self.task_info}, item={item!r})"
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MemoryObjectStreamState(Generic[T_Item]):
|
||||
@ -175,7 +181,7 @@ class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
|
||||
def __del__(self) -> None:
|
||||
if not self._closed:
|
||||
warnings.warn(
|
||||
f"Unclosed <{self.__class__.__name__}>",
|
||||
f"Unclosed <{self.__class__.__name__} at {id(self):x}>",
|
||||
ResourceWarning,
|
||||
source=self,
|
||||
)
|
||||
@ -305,7 +311,7 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
|
||||
def __del__(self) -> None:
|
||||
if not self._closed:
|
||||
warnings.warn(
|
||||
f"Unclosed <{self.__class__.__name__}>",
|
||||
f"Unclosed <{self.__class__.__name__} at {id(self):x}>",
|
||||
ResourceWarning,
|
||||
source=self,
|
||||
)
|
||||
|
@ -223,7 +223,7 @@ def process_worker() -> None:
|
||||
main_module_path: str | None
|
||||
sys.path, main_module_path = args
|
||||
del sys.modules["__main__"]
|
||||
if main_module_path:
|
||||
if main_module_path and os.path.isfile(main_module_path):
|
||||
# Load the parent's main module but as __mp_main__ instead of
|
||||
# __main__ (like multiprocessing does) to avoid infinite recursion
|
||||
try:
|
||||
@ -234,7 +234,6 @@ def process_worker() -> None:
|
||||
sys.modules["__main__"] = main
|
||||
except BaseException as exc:
|
||||
exception = exc
|
||||
|
||||
try:
|
||||
if exception is not None:
|
||||
status = b"EXCEPTION"
|
||||
|
@ -4,7 +4,7 @@ charset_normalizer-3.3.2.dist-info/LICENSE,sha256=6zGgxaT7Cbik4yBV0lweX5w1iidS_v
|
||||
charset_normalizer-3.3.2.dist-info/METADATA,sha256=cfLhl5A6SI-F0oclm8w8ux9wshL1nipdeCdVnYb4AaA,33550
|
||||
charset_normalizer-3.3.2.dist-info/RECORD,,
|
||||
charset_normalizer-3.3.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
charset_normalizer-3.3.2.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
|
||||
charset_normalizer-3.3.2.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
||||
charset_normalizer-3.3.2.dist-info/entry_points.txt,sha256=ADSTKrkXZ3hhdOVFi6DcUEHQRS0xfxDIE_pEz4wLIXA,65
|
||||
charset_normalizer-3.3.2.dist-info/top_level.txt,sha256=7ASyzePr8_xuZWJsnqJjIBtyV8vhEo0wBCv1MPRRi3Q,19
|
||||
charset_normalizer/__init__.py,sha256=UzI3xC8PhmcLRMzSgPb6minTmRq0kWznnCBJ8ZCc2XI,1577
|
||||
|
@ -1,5 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: setuptools (73.0.1)
|
||||
Generator: setuptools (74.1.2)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
||||
|
@ -59,7 +59,7 @@ devchat/_service/route/__pycache__/workflows.cpython-38.pyc,,
|
||||
devchat/_service/route/logs.py,sha256=3on6cRIJ8P1N0dMImTqQHe7CUUNrSgKehRnFQIS1qjc,1290
|
||||
devchat/_service/route/message.py,sha256=Ex06NvmXEDVB4g1bMrZuAYvF_mkSSCJCQ3qRz86CrGs,3309
|
||||
devchat/_service/route/topics.py,sha256=ox40XH3scK0Ey58nTEJ3h6i0s1PEdmOImI4WZXSLIqc,2331
|
||||
devchat/_service/route/workflows.py,sha256=XaBwSX0QKuGP_WLCQGLBGmrDoj8bw8rEJuKe-qGkfxk,5252
|
||||
devchat/_service/route/workflows.py,sha256=cGmHfS7rm4sTEUK_7fTD2z2p0hWwhHJ-WHLK81pA-GE,5383
|
||||
devchat/_service/schema/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
devchat/_service/schema/__pycache__/__init__.cpython-38.pyc,,
|
||||
devchat/_service/schema/__pycache__/request.cpython-38.pyc,,
|
||||
@ -189,7 +189,7 @@ devchat/workflow/command/env.py,sha256=wXZc497GwSjWk8T37krTkxqyjUhWSAh0c0RCwmLzR
|
||||
devchat/workflow/command/list.py,sha256=MYX469XXazUzNkahXuTlIJYIf_r_t1RwGrKEfugY8HY,1228
|
||||
devchat/workflow/command/run.py,sha256=tSdDeVNBtNt7xwn6Q3hgjsMZUbL0NPt2g3-fsVAUg9A,214
|
||||
devchat/workflow/command/update.py,sha256=7hmqqtJHpKZ2I03zQtT-aGvw63L5V0Lsmi-qe8ThjL4,761
|
||||
devchat/workflow/env_manager.py,sha256=h5hPUfGrlG-8KMkfpVAuD_fvkS_myLi5s7jPf_uOyfs,8474
|
||||
devchat/workflow/env_manager.py,sha256=XJksiBCBvkfR1rLxFt80bDK2slQkraeFa7Og19wI_Ho,9326
|
||||
devchat/workflow/envs.py,sha256=-lVTLjWRMrb8RGVVlHgWKCiGZaojNdmycjHFT0ZKjEo,298
|
||||
devchat/workflow/namespace.py,sha256=CaotAAcSwGINH871zXcs2frBlzuf-eZ2DMEIx_8uRUo,3880
|
||||
devchat/workflow/path.py,sha256=6I3Fk-KTJRPPR9zKXP7pSfdrl8LcuWwOICSvWJiJNf0,1391
|
||||
|
@ -91,10 +91,12 @@ def update_custom_workflows():
|
||||
updated_any = True
|
||||
update_messages = []
|
||||
|
||||
for url in custom_git_urls:
|
||||
repo_name = url.split("/")[-1].replace(".git", "") # 提取repo名称
|
||||
for item in custom_git_urls:
|
||||
git_url = item["git_url"]
|
||||
branch = item["branch"]
|
||||
repo_name = git_url.split("/")[-1].replace(".git", "") # 提取repo名称
|
||||
repo_path: Path = base_path / repo_name # 拼接出clone路径
|
||||
candidates_git_urls = [(url, "main")]
|
||||
candidates_git_urls = [(git_url, branch)]
|
||||
|
||||
if repo_path.exists():
|
||||
logger.info(f"Repo path not empty {repo_path}, removing it.")
|
||||
@ -132,7 +134,7 @@ def update_custom_workflows():
|
||||
return response.UpdateWorkflows(updated=updated_any, message=message_summary)
|
||||
else:
|
||||
return response.UpdateWorkflows(
|
||||
False, "No custom_git_urls found in .chat/config.yaml"
|
||||
updated=False, message="No custom_git_urls found in .chat/config.yaml"
|
||||
)
|
||||
else:
|
||||
return response.UpdateWorkflows(False, "No .chat config found")
|
||||
return response.UpdateWorkflows(updated=False, message="No .chat config found")
|
||||
|
@ -7,16 +7,13 @@ from typing import Dict, Optional, Tuple
|
||||
from devchat.utils import get_logger, get_logging_file
|
||||
|
||||
from .envs import MAMBA_BIN_PATH
|
||||
from .path import ENV_CACHE_DIR, MAMBA_PY_ENVS, MAMBA_ROOT
|
||||
from .path import CHAT_CONFIG_FILENAME, CHAT_DIR, ENV_CACHE_DIR, MAMBA_PY_ENVS, MAMBA_ROOT
|
||||
from .schema import ExternalPyConf
|
||||
from .user_setting import USER_SETTINGS
|
||||
|
||||
# CONDA_FORGE = [
|
||||
# "https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/",
|
||||
# "conda-forge",
|
||||
# ]
|
||||
CONDA_FORGE_TUNA = "https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/"
|
||||
PYPI_TUNA = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
DEFAULT_CONDA_FORGE_URL = "https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/"
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -201,6 +198,9 @@ class PyEnvManager:
|
||||
if is_exist:
|
||||
return True, ""
|
||||
|
||||
# Get conda-forge URL from config file
|
||||
conda_forge_url = self._get_conda_forge_url()
|
||||
|
||||
# create the environment
|
||||
cmd = [
|
||||
self.mamba_bin,
|
||||
@ -208,7 +208,7 @@ class PyEnvManager:
|
||||
"-n",
|
||||
env_name,
|
||||
"-c",
|
||||
CONDA_FORGE_TUNA,
|
||||
conda_forge_url,
|
||||
"-r",
|
||||
self.mamba_root,
|
||||
f"python={py_version}",
|
||||
@ -264,3 +264,26 @@ class PyEnvManager:
|
||||
return env_path
|
||||
|
||||
return None
|
||||
|
||||
def _get_conda_forge_url(self) -> str:
|
||||
"""
|
||||
Read the conda-forge URL from the config file.
|
||||
If the config file does not exist or does not contain the conda-forge URL,
|
||||
use the default value.
|
||||
"""
|
||||
config_file = os.path.join(CHAT_DIR, CHAT_CONFIG_FILENAME)
|
||||
|
||||
try:
|
||||
if not os.path.exists(config_file):
|
||||
return DEFAULT_CONDA_FORGE_URL
|
||||
|
||||
import yaml
|
||||
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
return config.get("conda-forge-url", DEFAULT_CONDA_FORGE_URL)
|
||||
except Exception as e:
|
||||
# Log the exception if needed
|
||||
print(f"An error occurred when loading conda-forge-url from config file: {e}")
|
||||
return DEFAULT_CONDA_FORGE_URL
|
||||
|
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: idna
|
||||
Version: 3.8
|
||||
Version: 3.10
|
||||
Summary: Internationalized Domain Names in Applications (IDNA)
|
||||
Author-email: Kim Davies <kim+pypi@gumleaf.org>
|
||||
Requires-Python: >=3.6
|
||||
@ -26,9 +26,14 @@ Classifier: Programming Language :: Python :: Implementation :: PyPy
|
||||
Classifier: Topic :: Internet :: Name Service (DNS)
|
||||
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
||||
Classifier: Topic :: Utilities
|
||||
Requires-Dist: ruff >= 0.6.2 ; extra == "all"
|
||||
Requires-Dist: mypy >= 1.11.2 ; extra == "all"
|
||||
Requires-Dist: pytest >= 8.3.2 ; extra == "all"
|
||||
Requires-Dist: flake8 >= 7.1.1 ; extra == "all"
|
||||
Project-URL: Changelog, https://github.com/kjd/idna/blob/master/HISTORY.rst
|
||||
Project-URL: Issue tracker, https://github.com/kjd/idna/issues
|
||||
Project-URL: Source, https://github.com/kjd/idna
|
||||
Provides-Extra: all
|
||||
|
||||
Internationalized Domain Names in Applications (IDNA)
|
||||
=====================================================
|
22
site-packages/idna-3.10.dist-info/RECORD
Normal file
22
site-packages/idna-3.10.dist-info/RECORD
Normal file
@ -0,0 +1,22 @@
|
||||
idna-3.10.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
idna-3.10.dist-info/LICENSE.md,sha256=pZ8LDvNjWHQQmkRhykT_enDVBpboFHZ7-vch1Mmw2w8,1541
|
||||
idna-3.10.dist-info/METADATA,sha256=URR5ZyDfQ1PCEGhkYoojqfi2Ra0tau2--lhwG4XSfjI,10158
|
||||
idna-3.10.dist-info/RECORD,,
|
||||
idna-3.10.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
|
||||
idna/__init__.py,sha256=MPqNDLZbXqGaNdXxAFhiqFPKEQXju2jNQhCey6-5eJM,868
|
||||
idna/__pycache__/__init__.cpython-38.pyc,,
|
||||
idna/__pycache__/codec.cpython-38.pyc,,
|
||||
idna/__pycache__/compat.cpython-38.pyc,,
|
||||
idna/__pycache__/core.cpython-38.pyc,,
|
||||
idna/__pycache__/idnadata.cpython-38.pyc,,
|
||||
idna/__pycache__/intranges.cpython-38.pyc,,
|
||||
idna/__pycache__/package_data.cpython-38.pyc,,
|
||||
idna/__pycache__/uts46data.cpython-38.pyc,,
|
||||
idna/codec.py,sha256=PEew3ItwzjW4hymbasnty2N2OXvNcgHB-JjrBuxHPYY,3422
|
||||
idna/compat.py,sha256=RzLy6QQCdl9784aFhb2EX9EKGCJjg0P3PilGdeXXcx8,316
|
||||
idna/core.py,sha256=YJYyAMnwiQEPjVC4-Fqu_p4CJ6yKKuDGmppBNQNQpFs,13239
|
||||
idna/idnadata.py,sha256=W30GcIGvtOWYwAjZj4ZjuouUutC6ffgNuyjJy7fZ-lo,78306
|
||||
idna/intranges.py,sha256=amUtkdhYcQG8Zr-CoMM_kVRacxkivC1WgxN1b63KKdU,1898
|
||||
idna/package_data.py,sha256=q59S3OXsc5VI8j6vSD0sGBMyk6zZ4vWFREE88yCJYKs,21
|
||||
idna/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
idna/uts46data.py,sha256=rt90K9J40gUSwppDPCrhjgi5AA6pWM65dEGRSf6rIhM,239289
|
@ -1,22 +0,0 @@
|
||||
idna-3.8.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
idna-3.8.dist-info/LICENSE.md,sha256=pZ8LDvNjWHQQmkRhykT_enDVBpboFHZ7-vch1Mmw2w8,1541
|
||||
idna-3.8.dist-info/METADATA,sha256=t8baHZrBTPkJi3Lr8ZHm0pbRKnelgO5AU7EGIeTvEcg,9948
|
||||
idna-3.8.dist-info/RECORD,,
|
||||
idna-3.8.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
|
||||
idna/__init__.py,sha256=KJQN1eQBr8iIK5SKrJ47lXvxG0BJ7Lm38W4zT0v_8lk,849
|
||||
idna/__pycache__/__init__.cpython-38.pyc,,
|
||||
idna/__pycache__/codec.cpython-38.pyc,,
|
||||
idna/__pycache__/compat.cpython-38.pyc,,
|
||||
idna/__pycache__/core.cpython-38.pyc,,
|
||||
idna/__pycache__/idnadata.cpython-38.pyc,,
|
||||
idna/__pycache__/intranges.cpython-38.pyc,,
|
||||
idna/__pycache__/package_data.cpython-38.pyc,,
|
||||
idna/__pycache__/uts46data.cpython-38.pyc,,
|
||||
idna/codec.py,sha256=PS6m-XmdST7Wj7J7ulRMakPDt5EBJyYrT3CPtjh-7t4,3426
|
||||
idna/compat.py,sha256=0_sOEUMT4CVw9doD3vyRhX80X19PwqFoUBs7gWsFME4,321
|
||||
idna/core.py,sha256=OHDXwDVbb3R1gNXjHw7JWeeE2rn2u3a-QV-KCeznYcA,12884
|
||||
idna/idnadata.py,sha256=dqRwytzkjIHMBa2R1lYvHDwACenZPt8eGVu1Y8UBE-E,78320
|
||||
idna/intranges.py,sha256=YBr4fRYuWH7kTKS2tXlFjM24ZF1Pdvcir-aywniInqg,1881
|
||||
idna/package_data.py,sha256=DogtAD5vs_-I2Q0k3_ZA4egUq2YLJ4pBbbhI8APzOcY,21
|
||||
idna/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
idna/uts46data.py,sha256=1KuksWqLuccPXm2uyRVkhfiFLNIhM_H2m4azCcnOqEU,206503
|
@ -1,4 +1,3 @@
|
||||
from .package_data import __version__
|
||||
from .core import (
|
||||
IDNABidiError,
|
||||
IDNAError,
|
||||
@ -20,8 +19,10 @@ from .core import (
|
||||
valid_string_length,
|
||||
)
|
||||
from .intranges import intranges_contain
|
||||
from .package_data import __version__
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"IDNABidiError",
|
||||
"IDNAError",
|
||||
"InvalidCodepoint",
|
||||
|
@ -1,49 +1,51 @@
|
||||
from .core import encode, decode, alabel, ulabel, IDNAError
|
||||
import codecs
|
||||
import re
|
||||
from typing import Any, Tuple, Optional
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from .core import IDNAError, alabel, decode, encode, ulabel
|
||||
|
||||
_unicode_dots_re = re.compile("[\u002e\u3002\uff0e\uff61]")
|
||||
|
||||
_unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')
|
||||
|
||||
class Codec(codecs.Codec):
|
||||
|
||||
def encode(self, data: str, errors: str = 'strict') -> Tuple[bytes, int]:
|
||||
if errors != 'strict':
|
||||
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
|
||||
def encode(self, data: str, errors: str = "strict") -> Tuple[bytes, int]:
|
||||
if errors != "strict":
|
||||
raise IDNAError('Unsupported error handling "{}"'.format(errors))
|
||||
|
||||
if not data:
|
||||
return b"", 0
|
||||
|
||||
return encode(data), len(data)
|
||||
|
||||
def decode(self, data: bytes, errors: str = 'strict') -> Tuple[str, int]:
|
||||
if errors != 'strict':
|
||||
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
|
||||
def decode(self, data: bytes, errors: str = "strict") -> Tuple[str, int]:
|
||||
if errors != "strict":
|
||||
raise IDNAError('Unsupported error handling "{}"'.format(errors))
|
||||
|
||||
if not data:
|
||||
return '', 0
|
||||
return "", 0
|
||||
|
||||
return decode(data), len(data)
|
||||
|
||||
|
||||
class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
|
||||
def _buffer_encode(self, data: str, errors: str, final: bool) -> Tuple[bytes, int]:
|
||||
if errors != 'strict':
|
||||
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
|
||||
if errors != "strict":
|
||||
raise IDNAError('Unsupported error handling "{}"'.format(errors))
|
||||
|
||||
if not data:
|
||||
return b'', 0
|
||||
return b"", 0
|
||||
|
||||
labels = _unicode_dots_re.split(data)
|
||||
trailing_dot = b''
|
||||
trailing_dot = b""
|
||||
if labels:
|
||||
if not labels[-1]:
|
||||
trailing_dot = b'.'
|
||||
trailing_dot = b"."
|
||||
del labels[-1]
|
||||
elif not final:
|
||||
# Keep potentially unfinished label until the next call
|
||||
del labels[-1]
|
||||
if labels:
|
||||
trailing_dot = b'.'
|
||||
trailing_dot = b"."
|
||||
|
||||
result = []
|
||||
size = 0
|
||||
@ -54,32 +56,33 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
|
||||
size += len(label)
|
||||
|
||||
# Join with U+002E
|
||||
result_bytes = b'.'.join(result) + trailing_dot
|
||||
result_bytes = b".".join(result) + trailing_dot
|
||||
size += len(trailing_dot)
|
||||
return result_bytes, size
|
||||
|
||||
|
||||
class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
|
||||
def _buffer_decode(self, data: Any, errors: str, final: bool) -> Tuple[str, int]:
|
||||
if errors != 'strict':
|
||||
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
|
||||
if errors != "strict":
|
||||
raise IDNAError('Unsupported error handling "{}"'.format(errors))
|
||||
|
||||
if not data:
|
||||
return ('', 0)
|
||||
return ("", 0)
|
||||
|
||||
if not isinstance(data, str):
|
||||
data = str(data, 'ascii')
|
||||
data = str(data, "ascii")
|
||||
|
||||
labels = _unicode_dots_re.split(data)
|
||||
trailing_dot = ''
|
||||
trailing_dot = ""
|
||||
if labels:
|
||||
if not labels[-1]:
|
||||
trailing_dot = '.'
|
||||
trailing_dot = "."
|
||||
del labels[-1]
|
||||
elif not final:
|
||||
# Keep potentially unfinished label until the next call
|
||||
del labels[-1]
|
||||
if labels:
|
||||
trailing_dot = '.'
|
||||
trailing_dot = "."
|
||||
|
||||
result = []
|
||||
size = 0
|
||||
@ -89,7 +92,7 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
|
||||
size += 1
|
||||
size += len(label)
|
||||
|
||||
result_str = '.'.join(result) + trailing_dot
|
||||
result_str = ".".join(result) + trailing_dot
|
||||
size += len(trailing_dot)
|
||||
return (result_str, size)
|
||||
|
||||
@ -103,7 +106,7 @@ class StreamReader(Codec, codecs.StreamReader):
|
||||
|
||||
|
||||
def search_function(name: str) -> Optional[codecs.CodecInfo]:
|
||||
if name != 'idna2008':
|
||||
if name != "idna2008":
|
||||
return None
|
||||
return codecs.CodecInfo(
|
||||
name=name,
|
||||
@ -115,4 +118,5 @@ def search_function(name: str) -> Optional[codecs.CodecInfo]:
|
||||
streamreader=StreamReader,
|
||||
)
|
||||
|
||||
|
||||
codecs.register(search_function)
|
||||
|
@ -1,13 +1,15 @@
|
||||
from .core import *
|
||||
from .codec import *
|
||||
from typing import Any, Union
|
||||
|
||||
from .core import decode, encode
|
||||
|
||||
|
||||
def ToASCII(label: str) -> bytes:
|
||||
return encode(label)
|
||||
|
||||
|
||||
def ToUnicode(label: Union[bytes, bytearray]) -> str:
|
||||
return decode(label)
|
||||
|
||||
def nameprep(s: Any) -> None:
|
||||
raise NotImplementedError('IDNA 2008 does not utilise nameprep protocol')
|
||||
|
||||
def nameprep(s: Any) -> None:
|
||||
raise NotImplementedError("IDNA 2008 does not utilise nameprep protocol")
|
||||
|
@ -1,31 +1,37 @@
|
||||
from . import idnadata
|
||||
import bisect
|
||||
import unicodedata
|
||||
import re
|
||||
from typing import Union, Optional
|
||||
import unicodedata
|
||||
from typing import Optional, Union
|
||||
|
||||
from . import idnadata
|
||||
from .intranges import intranges_contain
|
||||
|
||||
_virama_combining_class = 9
|
||||
_alabel_prefix = b'xn--'
|
||||
_unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')
|
||||
_alabel_prefix = b"xn--"
|
||||
_unicode_dots_re = re.compile("[\u002e\u3002\uff0e\uff61]")
|
||||
|
||||
|
||||
class IDNAError(UnicodeError):
|
||||
""" Base exception for all IDNA-encoding related problems """
|
||||
"""Base exception for all IDNA-encoding related problems"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IDNABidiError(IDNAError):
|
||||
""" Exception when bidirectional requirements are not satisfied """
|
||||
"""Exception when bidirectional requirements are not satisfied"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidCodepoint(IDNAError):
|
||||
""" Exception when a disallowed or unallocated codepoint is used """
|
||||
"""Exception when a disallowed or unallocated codepoint is used"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidCodepointContext(IDNAError):
|
||||
""" Exception when the codepoint is not valid in the context it is used """
|
||||
"""Exception when the codepoint is not valid in the context it is used"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -33,17 +39,20 @@ def _combining_class(cp: int) -> int:
|
||||
v = unicodedata.combining(chr(cp))
|
||||
if v == 0:
|
||||
if not unicodedata.name(chr(cp)):
|
||||
raise ValueError('Unknown character in unicodedata')
|
||||
raise ValueError("Unknown character in unicodedata")
|
||||
return v
|
||||
|
||||
|
||||
def _is_script(cp: str, script: str) -> bool:
|
||||
return intranges_contain(ord(cp), idnadata.scripts[script])
|
||||
|
||||
|
||||
def _punycode(s: str) -> bytes:
|
||||
return s.encode('punycode')
|
||||
return s.encode("punycode")
|
||||
|
||||
|
||||
def _unot(s: int) -> str:
|
||||
return 'U+{:04X}'.format(s)
|
||||
return "U+{:04X}".format(s)
|
||||
|
||||
|
||||
def valid_label_length(label: Union[bytes, str]) -> bool:
|
||||
@ -61,96 +70,106 @@ def valid_string_length(label: Union[bytes, str], trailing_dot: bool) -> bool:
|
||||
def check_bidi(label: str, check_ltr: bool = False) -> bool:
|
||||
# Bidi rules should only be applied if string contains RTL characters
|
||||
bidi_label = False
|
||||
for (idx, cp) in enumerate(label, 1):
|
||||
for idx, cp in enumerate(label, 1):
|
||||
direction = unicodedata.bidirectional(cp)
|
||||
if direction == '':
|
||||
if direction == "":
|
||||
# String likely comes from a newer version of Unicode
|
||||
raise IDNABidiError('Unknown directionality in label {} at position {}'.format(repr(label), idx))
|
||||
if direction in ['R', 'AL', 'AN']:
|
||||
raise IDNABidiError("Unknown directionality in label {} at position {}".format(repr(label), idx))
|
||||
if direction in ["R", "AL", "AN"]:
|
||||
bidi_label = True
|
||||
if not bidi_label and not check_ltr:
|
||||
return True
|
||||
|
||||
# Bidi rule 1
|
||||
direction = unicodedata.bidirectional(label[0])
|
||||
if direction in ['R', 'AL']:
|
||||
if direction in ["R", "AL"]:
|
||||
rtl = True
|
||||
elif direction == 'L':
|
||||
elif direction == "L":
|
||||
rtl = False
|
||||
else:
|
||||
raise IDNABidiError('First codepoint in label {} must be directionality L, R or AL'.format(repr(label)))
|
||||
raise IDNABidiError("First codepoint in label {} must be directionality L, R or AL".format(repr(label)))
|
||||
|
||||
valid_ending = False
|
||||
number_type = None # type: Optional[str]
|
||||
for (idx, cp) in enumerate(label, 1):
|
||||
number_type: Optional[str] = None
|
||||
for idx, cp in enumerate(label, 1):
|
||||
direction = unicodedata.bidirectional(cp)
|
||||
|
||||
if rtl:
|
||||
# Bidi rule 2
|
||||
if not direction in ['R', 'AL', 'AN', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
|
||||
raise IDNABidiError('Invalid direction for codepoint at position {} in a right-to-left label'.format(idx))
|
||||
if direction not in [
|
||||
"R",
|
||||
"AL",
|
||||
"AN",
|
||||
"EN",
|
||||
"ES",
|
||||
"CS",
|
||||
"ET",
|
||||
"ON",
|
||||
"BN",
|
||||
"NSM",
|
||||
]:
|
||||
raise IDNABidiError("Invalid direction for codepoint at position {} in a right-to-left label".format(idx))
|
||||
# Bidi rule 3
|
||||
if direction in ['R', 'AL', 'EN', 'AN']:
|
||||
if direction in ["R", "AL", "EN", "AN"]:
|
||||
valid_ending = True
|
||||
elif direction != 'NSM':
|
||||
elif direction != "NSM":
|
||||
valid_ending = False
|
||||
# Bidi rule 4
|
||||
if direction in ['AN', 'EN']:
|
||||
if direction in ["AN", "EN"]:
|
||||
if not number_type:
|
||||
number_type = direction
|
||||
else:
|
||||
if number_type != direction:
|
||||
raise IDNABidiError('Can not mix numeral types in a right-to-left label')
|
||||
raise IDNABidiError("Can not mix numeral types in a right-to-left label")
|
||||
else:
|
||||
# Bidi rule 5
|
||||
if not direction in ['L', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
|
||||
raise IDNABidiError('Invalid direction for codepoint at position {} in a left-to-right label'.format(idx))
|
||||
if direction not in ["L", "EN", "ES", "CS", "ET", "ON", "BN", "NSM"]:
|
||||
raise IDNABidiError("Invalid direction for codepoint at position {} in a left-to-right label".format(idx))
|
||||
# Bidi rule 6
|
||||
if direction in ['L', 'EN']:
|
||||
if direction in ["L", "EN"]:
|
||||
valid_ending = True
|
||||
elif direction != 'NSM':
|
||||
elif direction != "NSM":
|
||||
valid_ending = False
|
||||
|
||||
if not valid_ending:
|
||||
raise IDNABidiError('Label ends with illegal codepoint directionality')
|
||||
raise IDNABidiError("Label ends with illegal codepoint directionality")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_initial_combiner(label: str) -> bool:
|
||||
if unicodedata.category(label[0])[0] == 'M':
|
||||
raise IDNAError('Label begins with an illegal combining character')
|
||||
if unicodedata.category(label[0])[0] == "M":
|
||||
raise IDNAError("Label begins with an illegal combining character")
|
||||
return True
|
||||
|
||||
|
||||
def check_hyphen_ok(label: str) -> bool:
|
||||
if label[2:4] == '--':
|
||||
raise IDNAError('Label has disallowed hyphens in 3rd and 4th position')
|
||||
if label[0] == '-' or label[-1] == '-':
|
||||
raise IDNAError('Label must not start or end with a hyphen')
|
||||
if label[2:4] == "--":
|
||||
raise IDNAError("Label has disallowed hyphens in 3rd and 4th position")
|
||||
if label[0] == "-" or label[-1] == "-":
|
||||
raise IDNAError("Label must not start or end with a hyphen")
|
||||
return True
|
||||
|
||||
|
||||
def check_nfc(label: str) -> None:
|
||||
if unicodedata.normalize('NFC', label) != label:
|
||||
raise IDNAError('Label must be in Normalization Form C')
|
||||
if unicodedata.normalize("NFC", label) != label:
|
||||
raise IDNAError("Label must be in Normalization Form C")
|
||||
|
||||
|
||||
def valid_contextj(label: str, pos: int) -> bool:
|
||||
cp_value = ord(label[pos])
|
||||
|
||||
if cp_value == 0x200c:
|
||||
|
||||
if cp_value == 0x200C:
|
||||
if pos > 0:
|
||||
if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
|
||||
return True
|
||||
|
||||
ok = False
|
||||
for i in range(pos-1, -1, -1):
|
||||
for i in range(pos - 1, -1, -1):
|
||||
joining_type = idnadata.joining_types.get(ord(label[i]))
|
||||
if joining_type == ord('T'):
|
||||
if joining_type == ord("T"):
|
||||
continue
|
||||
elif joining_type in [ord('L'), ord('D')]:
|
||||
elif joining_type in [ord("L"), ord("D")]:
|
||||
ok = True
|
||||
break
|
||||
else:
|
||||
@ -160,63 +179,61 @@ def valid_contextj(label: str, pos: int) -> bool:
|
||||
return False
|
||||
|
||||
ok = False
|
||||
for i in range(pos+1, len(label)):
|
||||
for i in range(pos + 1, len(label)):
|
||||
joining_type = idnadata.joining_types.get(ord(label[i]))
|
||||
if joining_type == ord('T'):
|
||||
if joining_type == ord("T"):
|
||||
continue
|
||||
elif joining_type in [ord('R'), ord('D')]:
|
||||
elif joining_type in [ord("R"), ord("D")]:
|
||||
ok = True
|
||||
break
|
||||
else:
|
||||
break
|
||||
return ok
|
||||
|
||||
if cp_value == 0x200d:
|
||||
|
||||
if cp_value == 0x200D:
|
||||
if pos > 0:
|
||||
if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
|
||||
return True
|
||||
return False
|
||||
|
||||
else:
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def valid_contexto(label: str, pos: int, exception: bool = False) -> bool:
|
||||
cp_value = ord(label[pos])
|
||||
|
||||
if cp_value == 0x00b7:
|
||||
if 0 < pos < len(label)-1:
|
||||
if ord(label[pos - 1]) == 0x006c and ord(label[pos + 1]) == 0x006c:
|
||||
if cp_value == 0x00B7:
|
||||
if 0 < pos < len(label) - 1:
|
||||
if ord(label[pos - 1]) == 0x006C and ord(label[pos + 1]) == 0x006C:
|
||||
return True
|
||||
return False
|
||||
|
||||
elif cp_value == 0x0375:
|
||||
if pos < len(label)-1 and len(label) > 1:
|
||||
return _is_script(label[pos + 1], 'Greek')
|
||||
if pos < len(label) - 1 and len(label) > 1:
|
||||
return _is_script(label[pos + 1], "Greek")
|
||||
return False
|
||||
|
||||
elif cp_value == 0x05f3 or cp_value == 0x05f4:
|
||||
elif cp_value == 0x05F3 or cp_value == 0x05F4:
|
||||
if pos > 0:
|
||||
return _is_script(label[pos - 1], 'Hebrew')
|
||||
return _is_script(label[pos - 1], "Hebrew")
|
||||
return False
|
||||
|
||||
elif cp_value == 0x30fb:
|
||||
elif cp_value == 0x30FB:
|
||||
for cp in label:
|
||||
if cp == '\u30fb':
|
||||
if cp == "\u30fb":
|
||||
continue
|
||||
if _is_script(cp, 'Hiragana') or _is_script(cp, 'Katakana') or _is_script(cp, 'Han'):
|
||||
if _is_script(cp, "Hiragana") or _is_script(cp, "Katakana") or _is_script(cp, "Han"):
|
||||
return True
|
||||
return False
|
||||
|
||||
elif 0x660 <= cp_value <= 0x669:
|
||||
for cp in label:
|
||||
if 0x6f0 <= ord(cp) <= 0x06f9:
|
||||
if 0x6F0 <= ord(cp) <= 0x06F9:
|
||||
return False
|
||||
return True
|
||||
|
||||
elif 0x6f0 <= cp_value <= 0x6f9:
|
||||
elif 0x6F0 <= cp_value <= 0x6F9:
|
||||
for cp in label:
|
||||
if 0x660 <= ord(cp) <= 0x0669:
|
||||
return False
|
||||
@ -227,41 +244,49 @@ def valid_contexto(label: str, pos: int, exception: bool = False) -> bool:
|
||||
|
||||
def check_label(label: Union[str, bytes, bytearray]) -> None:
|
||||
if isinstance(label, (bytes, bytearray)):
|
||||
label = label.decode('utf-8')
|
||||
label = label.decode("utf-8")
|
||||
if len(label) == 0:
|
||||
raise IDNAError('Empty Label')
|
||||
raise IDNAError("Empty Label")
|
||||
|
||||
check_nfc(label)
|
||||
check_hyphen_ok(label)
|
||||
check_initial_combiner(label)
|
||||
|
||||
for (pos, cp) in enumerate(label):
|
||||
for pos, cp in enumerate(label):
|
||||
cp_value = ord(cp)
|
||||
if intranges_contain(cp_value, idnadata.codepoint_classes['PVALID']):
|
||||
if intranges_contain(cp_value, idnadata.codepoint_classes["PVALID"]):
|
||||
continue
|
||||
elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTJ']):
|
||||
elif intranges_contain(cp_value, idnadata.codepoint_classes["CONTEXTJ"]):
|
||||
try:
|
||||
if not valid_contextj(label, pos):
|
||||
raise InvalidCodepointContext('Joiner {} not allowed at position {} in {}'.format(
|
||||
_unot(cp_value), pos+1, repr(label)))
|
||||
raise InvalidCodepointContext(
|
||||
"Joiner {} not allowed at position {} in {}".format(_unot(cp_value), pos + 1, repr(label))
|
||||
)
|
||||
except ValueError:
|
||||
raise IDNAError('Unknown codepoint adjacent to joiner {} at position {} in {}'.format(
|
||||
_unot(cp_value), pos+1, repr(label)))
|
||||
elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTO']):
|
||||
raise IDNAError(
|
||||
"Unknown codepoint adjacent to joiner {} at position {} in {}".format(
|
||||
_unot(cp_value), pos + 1, repr(label)
|
||||
)
|
||||
)
|
||||
elif intranges_contain(cp_value, idnadata.codepoint_classes["CONTEXTO"]):
|
||||
if not valid_contexto(label, pos):
|
||||
raise InvalidCodepointContext('Codepoint {} not allowed at position {} in {}'.format(_unot(cp_value), pos+1, repr(label)))
|
||||
raise InvalidCodepointContext(
|
||||
"Codepoint {} not allowed at position {} in {}".format(_unot(cp_value), pos + 1, repr(label))
|
||||
)
|
||||
else:
|
||||
raise InvalidCodepoint('Codepoint {} at position {} of {} not allowed'.format(_unot(cp_value), pos+1, repr(label)))
|
||||
raise InvalidCodepoint(
|
||||
"Codepoint {} at position {} of {} not allowed".format(_unot(cp_value), pos + 1, repr(label))
|
||||
)
|
||||
|
||||
check_bidi(label)
|
||||
|
||||
|
||||
def alabel(label: str) -> bytes:
|
||||
try:
|
||||
label_bytes = label.encode('ascii')
|
||||
label_bytes = label.encode("ascii")
|
||||
ulabel(label_bytes)
|
||||
if not valid_label_length(label_bytes):
|
||||
raise IDNAError('Label too long')
|
||||
raise IDNAError("Label too long")
|
||||
return label_bytes
|
||||
except UnicodeEncodeError:
|
||||
pass
|
||||
@ -270,7 +295,7 @@ def alabel(label: str) -> bytes:
|
||||
label_bytes = _alabel_prefix + _punycode(label)
|
||||
|
||||
if not valid_label_length(label_bytes):
|
||||
raise IDNAError('Label too long')
|
||||
raise IDNAError("Label too long")
|
||||
|
||||
return label_bytes
|
||||
|
||||
@ -278,7 +303,7 @@ def alabel(label: str) -> bytes:
|
||||
def ulabel(label: Union[str, bytes, bytearray]) -> str:
|
||||
if not isinstance(label, (bytes, bytearray)):
|
||||
try:
|
||||
label_bytes = label.encode('ascii')
|
||||
label_bytes = label.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
check_label(label)
|
||||
return label
|
||||
@ -287,19 +312,19 @@ def ulabel(label: Union[str, bytes, bytearray]) -> str:
|
||||
|
||||
label_bytes = label_bytes.lower()
|
||||
if label_bytes.startswith(_alabel_prefix):
|
||||
label_bytes = label_bytes[len(_alabel_prefix):]
|
||||
label_bytes = label_bytes[len(_alabel_prefix) :]
|
||||
if not label_bytes:
|
||||
raise IDNAError('Malformed A-label, no Punycode eligible content found')
|
||||
if label_bytes.decode('ascii')[-1] == '-':
|
||||
raise IDNAError('A-label must not end with a hyphen')
|
||||
raise IDNAError("Malformed A-label, no Punycode eligible content found")
|
||||
if label_bytes.decode("ascii")[-1] == "-":
|
||||
raise IDNAError("A-label must not end with a hyphen")
|
||||
else:
|
||||
check_label(label_bytes)
|
||||
return label_bytes.decode('ascii')
|
||||
return label_bytes.decode("ascii")
|
||||
|
||||
try:
|
||||
label = label_bytes.decode('punycode')
|
||||
label = label_bytes.decode("punycode")
|
||||
except UnicodeError:
|
||||
raise IDNAError('Invalid A-label')
|
||||
raise IDNAError("Invalid A-label")
|
||||
check_label(label)
|
||||
return label
|
||||
|
||||
@ -307,52 +332,60 @@ def ulabel(label: Union[str, bytes, bytearray]) -> str:
|
||||
def uts46_remap(domain: str, std3_rules: bool = True, transitional: bool = False) -> str:
|
||||
"""Re-map the characters in the string according to UTS46 processing."""
|
||||
from .uts46data import uts46data
|
||||
output = ''
|
||||
|
||||
output = ""
|
||||
|
||||
for pos, char in enumerate(domain):
|
||||
code_point = ord(char)
|
||||
try:
|
||||
uts46row = uts46data[code_point if code_point < 256 else
|
||||
bisect.bisect_left(uts46data, (code_point, 'Z')) - 1]
|
||||
uts46row = uts46data[code_point if code_point < 256 else bisect.bisect_left(uts46data, (code_point, "Z")) - 1]
|
||||
status = uts46row[1]
|
||||
replacement = None # type: Optional[str]
|
||||
replacement: Optional[str] = None
|
||||
if len(uts46row) == 3:
|
||||
replacement = uts46row[2]
|
||||
if (status == 'V' or
|
||||
(status == 'D' and not transitional) or
|
||||
(status == '3' and not std3_rules and replacement is None)):
|
||||
if (
|
||||
status == "V"
|
||||
or (status == "D" and not transitional)
|
||||
or (status == "3" and not std3_rules and replacement is None)
|
||||
):
|
||||
output += char
|
||||
elif replacement is not None and (status == 'M' or
|
||||
(status == '3' and not std3_rules) or
|
||||
(status == 'D' and transitional)):
|
||||
elif replacement is not None and (
|
||||
status == "M" or (status == "3" and not std3_rules) or (status == "D" and transitional)
|
||||
):
|
||||
output += replacement
|
||||
elif status != 'I':
|
||||
elif status != "I":
|
||||
raise IndexError()
|
||||
except IndexError:
|
||||
raise InvalidCodepoint(
|
||||
'Codepoint {} not allowed at position {} in {}'.format(
|
||||
_unot(code_point), pos + 1, repr(domain)))
|
||||
"Codepoint {} not allowed at position {} in {}".format(_unot(code_point), pos + 1, repr(domain))
|
||||
)
|
||||
|
||||
return unicodedata.normalize('NFC', output)
|
||||
return unicodedata.normalize("NFC", output)
|
||||
|
||||
|
||||
def encode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False, transitional: bool = False) -> bytes:
|
||||
def encode(
|
||||
s: Union[str, bytes, bytearray],
|
||||
strict: bool = False,
|
||||
uts46: bool = False,
|
||||
std3_rules: bool = False,
|
||||
transitional: bool = False,
|
||||
) -> bytes:
|
||||
if not isinstance(s, str):
|
||||
try:
|
||||
s = str(s, 'ascii')
|
||||
s = str(s, "ascii")
|
||||
except UnicodeDecodeError:
|
||||
raise IDNAError('should pass a unicode string to the function rather than a byte string.')
|
||||
raise IDNAError("should pass a unicode string to the function rather than a byte string.")
|
||||
if uts46:
|
||||
s = uts46_remap(s, std3_rules, transitional)
|
||||
trailing_dot = False
|
||||
result = []
|
||||
if strict:
|
||||
labels = s.split('.')
|
||||
labels = s.split(".")
|
||||
else:
|
||||
labels = _unicode_dots_re.split(s)
|
||||
if not labels or labels == ['']:
|
||||
raise IDNAError('Empty domain')
|
||||
if labels[-1] == '':
|
||||
if not labels or labels == [""]:
|
||||
raise IDNAError("Empty domain")
|
||||
if labels[-1] == "":
|
||||
del labels[-1]
|
||||
trailing_dot = True
|
||||
for label in labels:
|
||||
@ -360,21 +393,26 @@ def encode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool =
|
||||
if s:
|
||||
result.append(s)
|
||||
else:
|
||||
raise IDNAError('Empty label')
|
||||
raise IDNAError("Empty label")
|
||||
if trailing_dot:
|
||||
result.append(b'')
|
||||
s = b'.'.join(result)
|
||||
result.append(b"")
|
||||
s = b".".join(result)
|
||||
if not valid_string_length(s, trailing_dot):
|
||||
raise IDNAError('Domain too long')
|
||||
raise IDNAError("Domain too long")
|
||||
return s
|
||||
|
||||
|
||||
def decode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False) -> str:
|
||||
def decode(
|
||||
s: Union[str, bytes, bytearray],
|
||||
strict: bool = False,
|
||||
uts46: bool = False,
|
||||
std3_rules: bool = False,
|
||||
) -> str:
|
||||
try:
|
||||
if not isinstance(s, str):
|
||||
s = str(s, 'ascii')
|
||||
s = str(s, "ascii")
|
||||
except UnicodeDecodeError:
|
||||
raise IDNAError('Invalid ASCII in A-label')
|
||||
raise IDNAError("Invalid ASCII in A-label")
|
||||
if uts46:
|
||||
s = uts46_remap(s, std3_rules, False)
|
||||
trailing_dot = False
|
||||
@ -382,9 +420,9 @@ def decode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool =
|
||||
if not strict:
|
||||
labels = _unicode_dots_re.split(s)
|
||||
else:
|
||||
labels = s.split('.')
|
||||
if not labels or labels == ['']:
|
||||
raise IDNAError('Empty domain')
|
||||
labels = s.split(".")
|
||||
if not labels or labels == [""]:
|
||||
raise IDNAError("Empty domain")
|
||||
if not labels[-1]:
|
||||
del labels[-1]
|
||||
trailing_dot = True
|
||||
@ -393,7 +431,7 @@ def decode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool =
|
||||
if s:
|
||||
result.append(s)
|
||||
else:
|
||||
raise IDNAError('Empty label')
|
||||
raise IDNAError("Empty label")
|
||||
if trailing_dot:
|
||||
result.append('')
|
||||
return '.'.join(result)
|
||||
result.append("")
|
||||
return ".".join(result)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -8,6 +8,7 @@ in the original list?" in time O(log(# runs)).
|
||||
import bisect
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
def intranges_from_list(list_: List[int]) -> Tuple[int, ...]:
|
||||
"""Represent a list of integers as a sequence of ranges:
|
||||
((start_0, end_0), (start_1, end_1), ...), such that the original
|
||||
@ -20,18 +21,20 @@ def intranges_from_list(list_: List[int]) -> Tuple[int, ...]:
|
||||
ranges = []
|
||||
last_write = -1
|
||||
for i in range(len(sorted_list)):
|
||||
if i+1 < len(sorted_list):
|
||||
if sorted_list[i] == sorted_list[i+1]-1:
|
||||
if i + 1 < len(sorted_list):
|
||||
if sorted_list[i] == sorted_list[i + 1] - 1:
|
||||
continue
|
||||
current_range = sorted_list[last_write+1:i+1]
|
||||
current_range = sorted_list[last_write + 1 : i + 1]
|
||||
ranges.append(_encode_range(current_range[0], current_range[-1] + 1))
|
||||
last_write = i
|
||||
|
||||
return tuple(ranges)
|
||||
|
||||
|
||||
def _encode_range(start: int, end: int) -> int:
|
||||
return (start << 32) | end
|
||||
|
||||
|
||||
def _decode_range(r: int) -> Tuple[int, int]:
|
||||
return (r >> 32), (r & ((1 << 32) - 1))
|
||||
|
||||
@ -43,7 +46,7 @@ def intranges_contain(int_: int, ranges: Tuple[int, ...]) -> bool:
|
||||
# we could be immediately ahead of a tuple (start, end)
|
||||
# with start < int_ <= end
|
||||
if pos > 0:
|
||||
left, right = _decode_range(ranges[pos-1])
|
||||
left, right = _decode_range(ranges[pos - 1])
|
||||
if left <= int_ < right:
|
||||
return True
|
||||
# or we could be immediately behind a tuple (int_, end)
|
||||
|
@ -1,2 +1 @@
|
||||
__version__ = '3.8'
|
||||
|
||||
__version__ = "3.10"
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2,11 +2,20 @@
|
||||
__author__ = "Andrew Dunham"
|
||||
__license__ = "Apache"
|
||||
__copyright__ = "Copyright (c) 2012-2013, Andrew Dunham"
|
||||
__version__ = "0.0.9"
|
||||
__version__ = "0.0.10"
|
||||
|
||||
from .multipart import FormParser, MultipartParser, OctetStreamParser, QuerystringParser, create_form_parser, parse_form
|
||||
from .multipart import (
|
||||
BaseParser,
|
||||
FormParser,
|
||||
MultipartParser,
|
||||
OctetStreamParser,
|
||||
QuerystringParser,
|
||||
create_form_parser,
|
||||
parse_form,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"BaseParser",
|
||||
"FormParser",
|
||||
"MultipartParser",
|
||||
"OctetStreamParser",
|
||||
|
@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import binascii
|
||||
from io import BufferedWriter
|
||||
|
||||
from .exceptions import DecodeError
|
||||
|
||||
@ -33,11 +34,11 @@ class Base64Decoder:
|
||||
:param underlying: the underlying object to pass writes to
|
||||
"""
|
||||
|
||||
def __init__(self, underlying):
|
||||
def __init__(self, underlying: BufferedWriter):
|
||||
self.cache = bytearray()
|
||||
self.underlying = underlying
|
||||
|
||||
def write(self, data):
|
||||
def write(self, data: bytes) -> int:
|
||||
"""Takes any input data provided, decodes it as base64, and passes it
|
||||
on to the underlying object. If the data provided is invalid base64
|
||||
data, then this method will raise
|
||||
@ -73,14 +74,14 @@ class Base64Decoder:
|
||||
# Return the length of the data to indicate no error.
|
||||
return len(data)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close this decoder. If the underlying object has a `close()`
|
||||
method, this function will call it.
|
||||
"""
|
||||
if hasattr(self.underlying, "close"):
|
||||
self.underlying.close()
|
||||
|
||||
def finalize(self):
|
||||
def finalize(self) -> None:
|
||||
"""Finalize this object. This should be called when no more data
|
||||
should be written to the stream. This function can raise a
|
||||
:class:`multipart.exceptions.DecodeError` if there is some remaining
|
||||
@ -97,7 +98,7 @@ class Base64Decoder:
|
||||
if hasattr(self.underlying, "finalize"):
|
||||
self.underlying.finalize()
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(underlying={self.underlying!r})"
|
||||
|
||||
|
||||
@ -111,11 +112,11 @@ class QuotedPrintableDecoder:
|
||||
:param underlying: the underlying object to pass writes to
|
||||
"""
|
||||
|
||||
def __init__(self, underlying):
|
||||
def __init__(self, underlying: BufferedWriter) -> None:
|
||||
self.cache = b""
|
||||
self.underlying = underlying
|
||||
|
||||
def write(self, data):
|
||||
def write(self, data: bytes) -> int:
|
||||
"""Takes any input data provided, decodes it as quoted-printable, and
|
||||
passes it on to the underlying object.
|
||||
|
||||
@ -142,14 +143,14 @@ class QuotedPrintableDecoder:
|
||||
self.cache = rest
|
||||
return len(data)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close this decoder. If the underlying object has a `close()`
|
||||
method, this function will call it.
|
||||
"""
|
||||
if hasattr(self.underlying, "close"):
|
||||
self.underlying.close()
|
||||
|
||||
def finalize(self):
|
||||
def finalize(self) -> None:
|
||||
"""Finalize this object. This should be called when no more data
|
||||
should be written to the stream. This function will not raise any
|
||||
exceptions, but it may write more data to the underlying object if
|
||||
@ -167,5 +168,5 @@ class QuotedPrintableDecoder:
|
||||
if hasattr(self.underlying, "finalize"):
|
||||
self.underlying.finalize()
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(underlying={self.underlying!r})"
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -3,7 +3,7 @@ pydantic-1.10.14.dist-info/LICENSE,sha256=njlGaQrIi2tz6PABoFhq8TVovohS_VFOQ5Pzl2
|
||||
pydantic-1.10.14.dist-info/METADATA,sha256=lYLXr7lOF7BMRkokRnhTeJODH7dv_zlYVTKfw6M_rNg,150189
|
||||
pydantic-1.10.14.dist-info/RECORD,,
|
||||
pydantic-1.10.14.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
pydantic-1.10.14.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
|
||||
pydantic-1.10.14.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
||||
pydantic-1.10.14.dist-info/entry_points.txt,sha256=EquH5n3pilIXg-LLa1K4evpu5-6dnvxzi6vwvkoAMns,45
|
||||
pydantic-1.10.14.dist-info/top_level.txt,sha256=cmo_5n0F_YY5td5nPZBfdjBENkmGg_pE5ShWXYbXxTM,9
|
||||
pydantic/__init__.py,sha256=iTu8CwWWvn6zM_zYJtqhie24PImW25zokitz_06kDYw,2771
|
||||
|
@ -1,5 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: setuptools (73.0.1)
|
||||
Generator: setuptools (74.1.2)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
||||
|
49
site-packages/python_multipart-0.0.10.dist-info/METADATA
Normal file
49
site-packages/python_multipart-0.0.10.dist-info/METADATA
Normal file
@ -0,0 +1,49 @@
|
||||
Metadata-Version: 2.3
|
||||
Name: python-multipart
|
||||
Version: 0.0.10
|
||||
Summary: A streaming multipart parser for Python
|
||||
Project-URL: Homepage, https://github.com/Kludex/python-multipart
|
||||
Project-URL: Documentation, https://kludex.github.io/python-multipart/
|
||||
Project-URL: Changelog, https://github.com/Kludex/python-multipart/blob/master/CHANGELOG.md
|
||||
Project-URL: Source, https://github.com/Kludex/python-multipart
|
||||
Author-email: Andrew Dunham <andrew@du.nham.ca>, Marcelo Trylesinski <marcelotryle@gmail.com>
|
||||
License-Expression: Apache-2.0
|
||||
License-File: LICENSE.txt
|
||||
Classifier: Development Status :: 5 - Production/Stable
|
||||
Classifier: Environment :: Web Environment
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: License :: OSI Approved :: Apache Software License
|
||||
Classifier: Operating System :: OS Independent
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: Programming Language :: Python :: 3 :: Only
|
||||
Classifier: Programming Language :: Python :: 3.8
|
||||
Classifier: Programming Language :: Python :: 3.9
|
||||
Classifier: Programming Language :: Python :: 3.10
|
||||
Classifier: Programming Language :: Python :: 3.11
|
||||
Classifier: Programming Language :: Python :: 3.12
|
||||
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
||||
Requires-Python: >=3.8
|
||||
Description-Content-Type: text/markdown
|
||||
|
||||
# [Python-Multipart](https://kludex.github.io/python-multipart/)
|
||||
|
||||
[](https://pypi.python.org/pypi/python-multipart)
|
||||
[](https://pypi.org/project/python-multipart)
|
||||
|
||||
---
|
||||
|
||||
`python-multipart` is an Apache2-licensed streaming multipart parser for Python.
|
||||
Test coverage is currently 100%.
|
||||
|
||||
## Why?
|
||||
|
||||
Because streaming uploads are awesome for large files.
|
||||
|
||||
## How to Test
|
||||
|
||||
If you want to test:
|
||||
|
||||
```bash
|
||||
$ pip install '.[dev]'
|
||||
$ inv test
|
||||
```
|
13
site-packages/python_multipart-0.0.10.dist-info/RECORD
Normal file
13
site-packages/python_multipart-0.0.10.dist-info/RECORD
Normal file
@ -0,0 +1,13 @@
|
||||
multipart/__init__.py,sha256=q2w9JPTlzBOOBSIPOIKi0X7xEUByFKhbYjtPandI3c8,512
|
||||
multipart/__pycache__/__init__.cpython-38.pyc,,
|
||||
multipart/__pycache__/decoders.cpython-38.pyc,,
|
||||
multipart/__pycache__/exceptions.cpython-38.pyc,,
|
||||
multipart/__pycache__/multipart.cpython-38.pyc,,
|
||||
multipart/decoders.py,sha256=5y_47RFs3mPVJbePcNe1__nGdQche0PghI_UsIKz3po,6182
|
||||
multipart/exceptions.py,sha256=a9buSOv_eiHZoukEJhdWX9LJYSJ6t7XOK3ZEaWoQZlk,992
|
||||
multipart/multipart.py,sha256=2hXvVTSEJ7PZOXlTMnxSVhlhT535QNT6neZVmgt3Beg,73780
|
||||
python_multipart-0.0.10.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
python_multipart-0.0.10.dist-info/METADATA,sha256=pEXCzLDD0MPgQLffeich5yjgf87NxMSAkA5r_7FZ_nQ,1902
|
||||
python_multipart-0.0.10.dist-info/RECORD,,
|
||||
python_multipart-0.0.10.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
||||
python_multipart-0.0.10.dist-info/licenses/LICENSE.txt,sha256=qOgzF2zWF9rwC51tOfoVyo7evG0WQwec0vSJPAwom-I,556
|
@ -1,4 +1,4 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: hatchling 1.21.1
|
||||
Generator: hatchling 1.25.0
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
@ -1,70 +0,0 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: python-multipart
|
||||
Version: 0.0.9
|
||||
Summary: A streaming multipart parser for Python
|
||||
Project-URL: Homepage, https://github.com/andrew-d/python-multipart
|
||||
Project-URL: Documentation, https://andrew-d.github.io/python-multipart/
|
||||
Project-URL: Changelog, https://github.com/andrew-d/python-multipart/blob/master/CHANGELOG.md
|
||||
Project-URL: Source, https://github.com/andrew-d/python-multipart
|
||||
Author-email: Andrew Dunham <andrew@du.nham.ca>
|
||||
License-Expression: Apache-2.0
|
||||
License-File: LICENSE.txt
|
||||
Classifier: Development Status :: 5 - Production/Stable
|
||||
Classifier: Environment :: Web Environment
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: License :: OSI Approved :: Apache Software License
|
||||
Classifier: Operating System :: OS Independent
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: Programming Language :: Python :: 3 :: Only
|
||||
Classifier: Programming Language :: Python :: 3.8
|
||||
Classifier: Programming Language :: Python :: 3.9
|
||||
Classifier: Programming Language :: Python :: 3.10
|
||||
Classifier: Programming Language :: Python :: 3.11
|
||||
Classifier: Programming Language :: Python :: 3.12
|
||||
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
||||
Requires-Python: >=3.8
|
||||
Provides-Extra: dev
|
||||
Requires-Dist: atomicwrites==1.4.1; extra == 'dev'
|
||||
Requires-Dist: attrs==23.2.0; extra == 'dev'
|
||||
Requires-Dist: coverage==7.4.1; extra == 'dev'
|
||||
Requires-Dist: hatch; extra == 'dev'
|
||||
Requires-Dist: invoke==2.2.0; extra == 'dev'
|
||||
Requires-Dist: more-itertools==10.2.0; extra == 'dev'
|
||||
Requires-Dist: pbr==6.0.0; extra == 'dev'
|
||||
Requires-Dist: pluggy==1.4.0; extra == 'dev'
|
||||
Requires-Dist: py==1.11.0; extra == 'dev'
|
||||
Requires-Dist: pytest-cov==4.1.0; extra == 'dev'
|
||||
Requires-Dist: pytest-timeout==2.2.0; extra == 'dev'
|
||||
Requires-Dist: pytest==8.0.0; extra == 'dev'
|
||||
Requires-Dist: pyyaml==6.0.1; extra == 'dev'
|
||||
Requires-Dist: ruff==0.2.1; extra == 'dev'
|
||||
Description-Content-Type: text/x-rst
|
||||
|
||||
==================
|
||||
Python-Multipart
|
||||
==================
|
||||
|
||||
.. image:: https://github.com/andrew-d/python-multipart/actions/workflows/test.yaml/badge.svg
|
||||
:target: https://github.com/andrew-d/python-multipart/actions
|
||||
|
||||
|
||||
python-multipart is an Apache2 licensed streaming multipart parser for Python.
|
||||
Test coverage is currently 100%.
|
||||
Documentation is available `here`_.
|
||||
|
||||
.. _here: https://andrew-d.github.io/python-multipart/
|
||||
|
||||
Why?
|
||||
----
|
||||
|
||||
Because streaming uploads are awesome for large files.
|
||||
|
||||
How to Test
|
||||
-----------
|
||||
|
||||
If you want to test:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install '.[dev]'
|
||||
$ inv test
|
@ -1,13 +0,0 @@
|
||||
multipart/__init__.py,sha256=Z_EnZFoG_zmw22n7BomPmnVSCCl4XVNDYVWL8hry6Sc,448
|
||||
multipart/__pycache__/__init__.cpython-38.pyc,,
|
||||
multipart/__pycache__/decoders.cpython-38.pyc,,
|
||||
multipart/__pycache__/exceptions.cpython-38.pyc,,
|
||||
multipart/__pycache__/multipart.cpython-38.pyc,,
|
||||
multipart/decoders.py,sha256=A4SQHOwFRNzCfr5Fx0iOYpS8USTxE9ofYbL9kOJywHs,6038
|
||||
multipart/exceptions.py,sha256=a9buSOv_eiHZoukEJhdWX9LJYSJ6t7XOK3ZEaWoQZlk,992
|
||||
multipart/multipart.py,sha256=sThvJ7TSPQc1tjCfXMgqQduT3yKsmxhwzNZnmL5S7KY,72841
|
||||
python_multipart-0.0.9.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
python_multipart-0.0.9.dist-info/METADATA,sha256=5BDtS_h0qAxUX55VWiuCb7FZj6WZRd51l6YenNy4Its,2528
|
||||
python_multipart-0.0.9.dist-info/RECORD,,
|
||||
python_multipart-0.0.9.dist-info/WHEEL,sha256=TJPnKdtrSue7xZ_AVGkp9YXcvDrobsjBds1du3Nx6dc,87
|
||||
python_multipart-0.0.9.dist-info/licenses/LICENSE.txt,sha256=qOgzF2zWF9rwC51tOfoVyo7evG0WQwec0vSJPAwom-I,556
|
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: rich
|
||||
Version: 13.8.0
|
||||
Version: 13.8.1
|
||||
Summary: Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal
|
||||
Home-page: https://github.com/Textualize/rich
|
||||
License: MIT
|
||||
@ -22,6 +22,7 @@ Classifier: Programming Language :: Python :: 3.9
|
||||
Classifier: Programming Language :: Python :: 3.10
|
||||
Classifier: Programming Language :: Python :: 3.11
|
||||
Classifier: Programming Language :: Python :: 3.12
|
||||
Classifier: Programming Language :: Python :: 3.13
|
||||
Classifier: Typing :: Typed
|
||||
Provides-Extra: jupyter
|
||||
Requires-Dist: ipywidgets (>=7.5.1,<9) ; extra == "jupyter"
|
@ -1,8 +1,8 @@
|
||||
rich-13.8.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
rich-13.8.0.dist-info/LICENSE,sha256=3u18F6QxgVgZCj6iOcyHmlpQJxzruYrnAl9I--WNyhU,1056
|
||||
rich-13.8.0.dist-info/METADATA,sha256=Z0qmRupyrFfPHrIqUql-EF3Vf1EOJ1gKQH4PECmJRw0,18272
|
||||
rich-13.8.0.dist-info/RECORD,,
|
||||
rich-13.8.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
||||
rich-13.8.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
rich-13.8.1.dist-info/LICENSE,sha256=3u18F6QxgVgZCj6iOcyHmlpQJxzruYrnAl9I--WNyhU,1056
|
||||
rich-13.8.1.dist-info/METADATA,sha256=LUNl6wIGo9RlVR97ZMGhisWEApsSW5cOAtvI-zN39M4,18323
|
||||
rich-13.8.1.dist-info/RECORD,,
|
||||
rich-13.8.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
||||
rich/__init__.py,sha256=lh2WcoIOJp5M5_lbAsSUMGv8oiJeumROazHH_AYMS8I,6066
|
||||
rich/__main__.py,sha256=Wvh53rmOMyWeUeyqUHpn1PXsHlBc4TVcQnqrw46nf9Y,8333
|
||||
rich/__pycache__/__init__.cpython-38.pyc,,
|
||||
@ -155,7 +155,7 @@ rich/styled.py,sha256=wljVsVTXbABMMZvkzkO43ZEk_-irzEtvUiQ-sNnikQ8,1234
|
||||
rich/syntax.py,sha256=hVZcSKZ2RP3EErgY5MGK1WpWHd0x3ZtLCqod_EqgYTM,35302
|
||||
rich/table.py,sha256=Et3SB0mggJTIXGFOvShX8L9MDpehXJ2GexY6Lmo0Zaw,40041
|
||||
rich/terminal_theme.py,sha256=1j5-ufJfnvlAo5Qsi_ACZiXDmwMXzqgmFByObT9-yJY,3370
|
||||
rich/text.py,sha256=hldsjZVTBJZ40tvCsRaYs-EX_pkNNw08I4q_hctElQI,47351
|
||||
rich/text.py,sha256=0-7de80WdltOq58Kmz4nwMwjx4bbc87qe_qHbZXCH0Q,47365
|
||||
rich/theme.py,sha256=oNyhXhGagtDlbDye3tVu3esWOWk0vNkuxFw-_unlaK0,3771
|
||||
rich/themes.py,sha256=0xgTLozfabebYtcJtDdC5QkX5IVUEaviqDUJJh4YVFk,102
|
||||
rich/traceback.py,sha256=B76Q53tX9gef_DSF2Y0X1UJkEHlRGjuDElUpAe3LgIM,30027
|
@ -998,7 +998,7 @@ class Text(JupyterMixin):
|
||||
self._text.append(text.plain)
|
||||
self._spans.extend(
|
||||
_Span(start + text_length, end + text_length, style)
|
||||
for start, end, style in text._spans
|
||||
for start, end, style in text._spans.copy()
|
||||
)
|
||||
self._length += len(text)
|
||||
return self
|
||||
@ -1020,7 +1020,7 @@ class Text(JupyterMixin):
|
||||
self._text.append(text.plain)
|
||||
self._spans.extend(
|
||||
_Span(start + text_length, end + text_length, style)
|
||||
for start, end, style in text._spans
|
||||
for start, end, style in text._spans.copy()
|
||||
)
|
||||
self._length += len(text)
|
||||
return self
|
||||
|
@ -3,7 +3,7 @@ tiktoken-0.7.0.dist-info/LICENSE,sha256=QYy0mbQ2Eo1lPXmUEzOlQ3t74uqSE9zC8E0V1dLF
|
||||
tiktoken-0.7.0.dist-info/METADATA,sha256=Y8pRXptX5mYKr-PpPwd8Ppsp4MdArChHX6PqWLL2GFA,6595
|
||||
tiktoken-0.7.0.dist-info/RECORD,,
|
||||
tiktoken-0.7.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
tiktoken-0.7.0.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
||||
tiktoken-0.7.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
||||
tiktoken-0.7.0.dist-info/direct_url.json,sha256=A8eZHYYbfBMLJxQ_oI7XGcXTSJIWbjiNhps2PIzNXPE,138
|
||||
tiktoken-0.7.0.dist-info/top_level.txt,sha256=54G5MceQnuD7EXvp7jzGxDDapA1iOwsh77jhCN9WKkc,22
|
||||
tiktoken/__init__.py,sha256=FNmz8KgZfaG62vRgMMkTL9jj0a2AI7JGV1b-RZ29_tY,322
|
||||
|
@ -1,5 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: setuptools (74.1.2)
|
||||
Generator: setuptools (75.1.0)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: websockets
|
||||
Version: 13.0.1
|
||||
Version: 13.1
|
||||
Summary: An implementation of the WebSocket Protocol (RFC 6455 & 7692)
|
||||
Author-email: Aymeric Augustin <aymeric.augustin@m4x.org>
|
||||
License: BSD-3-Clause
|
||||
@ -24,6 +24,7 @@ Classifier: Programming Language :: Python :: 3.11
|
||||
Classifier: Programming Language :: Python :: 3.12
|
||||
Classifier: Programming Language :: Python :: 3.13
|
||||
Requires-Python: >=3.8
|
||||
Description-Content-Type: text/x-rst
|
||||
License-File: LICENSE
|
||||
|
||||
.. image:: logo/horizontal.svg
|
@ -1,10 +1,10 @@
|
||||
websockets-13.0.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
websockets-13.0.1.dist-info/LICENSE,sha256=PWoMBQ2L7FL6utUC5F-yW9ArytvXDeo01Ee2oP9Obag,1514
|
||||
websockets-13.0.1.dist-info/METADATA,sha256=0Q5-GuI6Wbmb45tJw8z0cVRJQM3QKB87Yb6XUoZg9NU,6742
|
||||
websockets-13.0.1.dist-info/RECORD,,
|
||||
websockets-13.0.1.dist-info/WHEEL,sha256=2-wOIdov8OETygfr6JjpwxZ3R_9CXxWhMBD1nM6bmy8,216
|
||||
websockets-13.0.1.dist-info/top_level.txt,sha256=CMpdKklxKsvZgCgyltxUWOHibZXZ1uYIVpca9xsQ8Hk,11
|
||||
websockets/__init__.py,sha256=EDW5qZk8Dt656-GB5qlnxH3jfSde-Nh8yNzmBY0MPH4,5664
|
||||
websockets-13.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
websockets-13.1.dist-info/LICENSE,sha256=PWoMBQ2L7FL6utUC5F-yW9ArytvXDeo01Ee2oP9Obag,1514
|
||||
websockets-13.1.dist-info/METADATA,sha256=dYWGin2sQircUMT-_HkCtQkc5_LTzdQ7BLwIGDTHbbI,6777
|
||||
websockets-13.1.dist-info/RECORD,,
|
||||
websockets-13.1.dist-info/WHEEL,sha256=kEv-8RuW3g3tDaovfFYwwvLeK_2lL8t8SExFktY7GL0,216
|
||||
websockets-13.1.dist-info/top_level.txt,sha256=CMpdKklxKsvZgCgyltxUWOHibZXZ1uYIVpca9xsQ8Hk,11
|
||||
websockets/__init__.py,sha256=UlYOZWjPPdgEtBFq4CP5t7Kd1Jjq-iMJT62Ya9ImDSo,5936
|
||||
websockets/__main__.py,sha256=q6tBA72COhz7NUkuP_VG9IVypJjOexx2Oi7qkKNxneg,4756
|
||||
websockets/__pycache__/__init__.cpython-38.pyc,,
|
||||
websockets/__pycache__/__main__.cpython-38.pyc,,
|
||||
@ -34,48 +34,50 @@ websockets/asyncio/__pycache__/connection.cpython-38.pyc,,
|
||||
websockets/asyncio/__pycache__/messages.cpython-38.pyc,,
|
||||
websockets/asyncio/__pycache__/server.cpython-38.pyc,,
|
||||
websockets/asyncio/async_timeout.py,sha256=N-6Mubyiaoh66PAXGvCzhgxCM-7V2XiRnH32Xi6J6TE,8971
|
||||
websockets/asyncio/client.py,sha256=XQy1HhHm26WYyZj3wXUeX889DnwqOcenjDON-tdt9AE,13547
|
||||
websockets/asyncio/client.py,sha256=Kx9L-AYQUlMRAyo0d2DjuebggcM-rogx3JB26rEebY4,21700
|
||||
websockets/asyncio/compatibility.py,sha256=gkenDDhzNbm6_iXV5Edvbvp6uHZYdrTvGNjt8P_JtyQ,786
|
||||
websockets/asyncio/connection.py,sha256=S1LNn_vnhw_Nz5yZ3IDgQ_SmQA1qg0hKDc6xByAkGfU,43514
|
||||
websockets/asyncio/messages.py,sha256=l3sV10tl0jzYqhrPBiQsml_PBFPHAmjjbSHql3lOyow,9784
|
||||
websockets/asyncio/server.py,sha256=yqh_tjtv-E9bFRG7Ns5OCmzWi8SeIJP0Yf_YpVZRvp8,31028
|
||||
websockets/asyncio/connection.py,sha256=sxX1WTz2iVxCsUvLJUoogJT9SdHsHU4ut2PuSIbxVs4,44475
|
||||
websockets/asyncio/messages.py,sha256=-sS9JCa4-aFVSv0sPJd_VtGcoADj8mE0sMxfsqW-rQw,9854
|
||||
websockets/asyncio/server.py,sha256=In45P1Ng2gznGMbnwuz3brlIAsZkSel0ScshrJZSMw8,36548
|
||||
websockets/auth.py,sha256=pCeunT3V2AdwRt_Tpq9TrkdGY7qUlDHIEqeggj5yQFk,262
|
||||
websockets/client.py,sha256=GXcwcgVCzxlvmgfiuVjbaddmATmnKntbmXeyBB2GQr4,12484
|
||||
websockets/connection.py,sha256=BQn7ws-u7hmsGhrPZmxLrbPqpoL3MrrcbdeUc93zZkE,288
|
||||
websockets/datastructures.py,sha256=oGbm3ZjVx3BwrYfgyPPcTtFO8qbgCsFVApMNGTRTh2c,5681
|
||||
websockets/exceptions.py,sha256=IraXIMjjSOvBuikGolls6AuOFOeYXRO4VryT091Dm8E,10248
|
||||
websockets/client.py,sha256=cc8y1I2Firs1JRXCfgD4j2JWnneYAuQSpGWNjrhkqFY,13541
|
||||
websockets/connection.py,sha256=OLiMVkNd25_86sB8Q7CrCwBoXy9nA0OCgdgLRA8WUR8,323
|
||||
websockets/datastructures.py,sha256=s5Rkipz4n15HSZsOrs64CoCs-_3oSBCgpe9uPvztDkY,5677
|
||||
websockets/exceptions.py,sha256=b2-QiL1pszljREQQCzbPE1Fv7-Xb-uwso2Zt6LLD10A,10594
|
||||
websockets/extensions/__init__.py,sha256=QkZsxaJVllVSp1uhdD5uPGibdbx_091GrVVfS5LXcpw,98
|
||||
websockets/extensions/__pycache__/__init__.cpython-38.pyc,,
|
||||
websockets/extensions/__pycache__/base.cpython-38.pyc,,
|
||||
websockets/extensions/__pycache__/permessage_deflate.cpython-38.pyc,,
|
||||
websockets/extensions/base.py,sha256=sMjmUfov0-woUmT4MiIuwBjj4DAJu0l3gfniPHn67Ec,2952
|
||||
websockets/extensions/permessage_deflate.py,sha256=Az-zpU9eJYlMi4mdA1Jx0qKs0FRzrjbp3MO9iXzKHIY,24716
|
||||
websockets/frames.py,sha256=-chQjYwtdTlTdosSfnvxRyXWSfqRfLQmL0Bo46x2Rok,14111
|
||||
websockets/headers.py,sha256=LkuZIpztb-llYblr9GrJK0SAA8fJz7JhErMj-aAzJMk,16109
|
||||
websockets/http.py,sha256=id8BzOG4AtqIeFwKHB4dzc6gxg45vcsdSdWUCvWMDvk,447
|
||||
websockets/http11.py,sha256=OI0U0yj8CDOT7h0IfHzRF1bIyY9zjeBTyDiicgfAasU,13255
|
||||
websockets/extensions/base.py,sha256=jsSJnO47L2VxYzx0cZ_LLQcAyUudSDgJEtKN247H-38,2890
|
||||
websockets/extensions/permessage_deflate.py,sha256=JR9s7pvAJv2zWRUfOLysOtAiO-ovgRMqnSUpb92gohI,24661
|
||||
websockets/frames.py,sha256=H-4ULsevYdna_CjalVASRPlh2Z54NoQat_vq8P4cVfc,12765
|
||||
websockets/headers.py,sha256=9OHHZvaj4hXrofi0HuJFNYJaE0yRoPmmrBYxMaDuCTs,15931
|
||||
websockets/http.py,sha256=eWitbqWAmHeqYK4OF3JLRC4lWwI1OIeft7oY3OobXvc,481
|
||||
websockets/http11.py,sha256=-TNxOVVLr0050-0Ac3jOlWt5G9HAfkHZrt8dqoto9bs,13376
|
||||
websockets/imports.py,sha256=TNONfYXO1UPExiwCVMgmg77fH5b4nyNAKcqtTg0gO2I,2768
|
||||
websockets/legacy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
websockets/legacy/__pycache__/__init__.cpython-38.pyc,,
|
||||
websockets/legacy/__pycache__/auth.cpython-38.pyc,,
|
||||
websockets/legacy/__pycache__/client.cpython-38.pyc,,
|
||||
websockets/legacy/__pycache__/exceptions.cpython-38.pyc,,
|
||||
websockets/legacy/__pycache__/framing.cpython-38.pyc,,
|
||||
websockets/legacy/__pycache__/handshake.cpython-38.pyc,,
|
||||
websockets/legacy/__pycache__/http.cpython-38.pyc,,
|
||||
websockets/legacy/__pycache__/protocol.cpython-38.pyc,,
|
||||
websockets/legacy/__pycache__/server.cpython-38.pyc,,
|
||||
websockets/legacy/auth.py,sha256=UdK0eZg1TjMGY6iEVRbBn51M9AjpSyRv2lJbvuuI6aA,6567
|
||||
websockets/legacy/client.py,sha256=qZcQgbBnSW2ha9IpzvRDGZNKwzXNJyVpB9ac86_y1Zk,26345
|
||||
websockets/legacy/framing.py,sha256=2itTWkMHXzjleywjyQb39P__75ThMX1AX3Ly_AU94lk,4972
|
||||
websockets/legacy/handshake.py,sha256=71rqsWT1-wZ8HaLorK1ac47atdmgv8q6YXV-qosjqxs,5445
|
||||
websockets/legacy/client.py,sha256=iuyFib2kX5ybK9vLVpqJRNJHa4BuA0u5MLyoNnartY4,26706
|
||||
websockets/legacy/exceptions.py,sha256=DbSHBKcDEoYoUeCxURo0cnH8PyCCKYzkTboP_tOtsxw,1967
|
||||
websockets/legacy/framing.py,sha256=ALEDiBNq17FUqNEe5LHxkPxWoY6tPwffgGFiHMdnnIs,6371
|
||||
websockets/legacy/handshake.py,sha256=2Nzr5AN2xvDC5EdNP-kB3lOcrAaUNlYuj_-hr_jv7pM,5285
|
||||
websockets/legacy/http.py,sha256=cOCQmDWhIKQmm8UWGXPW7CDZg03wjogCsb0LP9oetNQ,7061
|
||||
websockets/legacy/protocol.py,sha256=753BX2MRJPsMx4vgkCt0g9tRqFJHxi7Nx_CzuL05iKc,63689
|
||||
websockets/legacy/server.py,sha256=qI9M9cuP-7CzDjMC5odivJNesbvNOEhj2G6b3-wEHJw,45120
|
||||
websockets/protocol.py,sha256=zpfgfI82kfouh3JxONEjTrsEEknvNdKzG5Sq4bcNol0,24938
|
||||
websockets/legacy/protocol.py,sha256=Rbk88lnbghWpcEBT-TuTtAGqDy9OA7VsUFEMUcv95RM,63681
|
||||
websockets/legacy/server.py,sha256=lb26Vm_y7biVVYLyVzG9R1BiaLmuS3TrQh-LesjO4Ss,45318
|
||||
websockets/protocol.py,sha256=yl1j9ecLShF0iTRALOTzFfq0KmW5XO74Mtk0baVkvo0,25512
|
||||
websockets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
websockets/server.py,sha256=fQn2lrVEzrKoQZfcz-WZ7c_DswuqgdhBj7B1mxCIw20,21354
|
||||
websockets/speedups.c,sha256=ghPq-NF35VLVNkMv0uFDIruNpVISyW-qvoZgPpE65qw,5834
|
||||
websockets/speedups.cpython-38-x86_64-linux-gnu.so,sha256=UxyakwjXRx4kYsOleAxkc8vXBqt89fv5rhzSAlI13W8,34072
|
||||
websockets/server.py,sha256=BVoC433LZUgKVndtaYPrndB7uf_FTuG7MXrM9QHJEzk,21275
|
||||
websockets/speedups.c,sha256=j-damnT02MKRoYw8MtTT45qLGX6z6TnriqhTkyfcNZE,5767
|
||||
websockets/speedups.cpython-38-x86_64-linux-gnu.so,sha256=FRW0JugiQU4471Sd8Yergmr8u39ELoI5T9SIrQJ2Uqs,34072
|
||||
websockets/speedups.pyi,sha256=NikZ3sAxs9Z2uWH_ZvctvMJUBbsHeC2D1L954EVSwJc,55
|
||||
websockets/streams.py,sha256=3K3FcgTcXon-51P0sVyz0G4J-H51L82SVMS--W-gl6g,4038
|
||||
websockets/sync/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
@ -85,12 +87,12 @@ websockets/sync/__pycache__/connection.cpython-38.pyc,,
|
||||
websockets/sync/__pycache__/messages.cpython-38.pyc,,
|
||||
websockets/sync/__pycache__/server.cpython-38.pyc,,
|
||||
websockets/sync/__pycache__/utils.cpython-38.pyc,,
|
||||
websockets/sync/client.py,sha256=0Jw-jI7WvkuKzZ9_JNtrWBSu0eqt8JuIxfQjxzXpftY,11462
|
||||
websockets/sync/connection.py,sha256=jTHj0OKGxrk_aowooLRnGbqiytlnfb8bTXhjPK3yx9g,29759
|
||||
websockets/sync/messages.py,sha256=pDDKG7OHWcqXF9zzhm3TGYD92cfIaP1wNEX3U14b3ns,9739
|
||||
websockets/sync/server.py,sha256=-gowaOvtLzzt9uhKykcrgHgyfgoeo7945PxptTG5G9c,20531
|
||||
websockets/sync/client.py,sha256=QWs2wYU7S8--CZgFYWUliWjSYX5zrJEFQ6_gEcrW1sA,11372
|
||||
websockets/sync/connection.py,sha256=Ve2aW760xPz8nXU56TuL7M3qV1THYmZZcfoS_0Wwh0c,30684
|
||||
websockets/sync/messages.py,sha256=K-VHhUERUsS6bOaLgTox4kShnUKt8aPmWgOdqj_4E-Y,9809
|
||||
websockets/sync/server.py,sha256=WutnccxDQWJNfPsX2WthvDr0QeVn36fUpf0MKmbeXY0,25608
|
||||
websockets/sync/utils.py,sha256=TtW-ncYFvJmiSW2gO86ngE2BVsnnBdL-4H88kWNDYbg,1107
|
||||
websockets/typing.py,sha256=b9F78aYY-sDNnIgSbvV_ApVBicVJdduLGv5wU0PVB5c,2157
|
||||
websockets/uri.py,sha256=V7Qzs4ZvMUIYVqr0HDoIpsg7ah8iRGVmr1A8Vp91lPg,3159
|
||||
websockets/uri.py,sha256=1r8dXNEiLcdMrCrzXmsy7DwSHiF3gaOWlmAdoFexOOM,3125
|
||||
websockets/utils.py,sha256=ZpH3WJLsQS29Jf5R6lTacxf_hPd8E4zS2JmGyNpg4bA,1150
|
||||
websockets/version.py,sha256=mYdU1lDxfvmRjqzZIlYCJatBinMmAlaoeb-AXPnPyeo,3204
|
||||
websockets/version.py,sha256=M0HSppy6IqnAdAr0McbPGkyCuBlue4Uzigc78cOWHxs,3202
|
@ -1,5 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: setuptools (74.0.0)
|
||||
Generator: setuptools (75.1.0)
|
||||
Root-Is-Purelib: false
|
||||
Tag: cp38-cp38-manylinux_2_5_x86_64
|
||||
Tag: cp38-cp38-manylinux1_x86_64
|
@ -14,7 +14,7 @@ __all__ = [
|
||||
"HeadersLike",
|
||||
"MultipleValuesError",
|
||||
# .exceptions
|
||||
"AbortHandshake",
|
||||
"ConcurrencyError",
|
||||
"ConnectionClosed",
|
||||
"ConnectionClosedError",
|
||||
"ConnectionClosedOK",
|
||||
@ -23,19 +23,16 @@ __all__ = [
|
||||
"InvalidHeader",
|
||||
"InvalidHeaderFormat",
|
||||
"InvalidHeaderValue",
|
||||
"InvalidMessage",
|
||||
"InvalidOrigin",
|
||||
"InvalidParameterName",
|
||||
"InvalidParameterValue",
|
||||
"InvalidState",
|
||||
"InvalidStatus",
|
||||
"InvalidStatusCode",
|
||||
"InvalidUpgrade",
|
||||
"InvalidURI",
|
||||
"NegotiationError",
|
||||
"PayloadTooBig",
|
||||
"ProtocolError",
|
||||
"RedirectHandshake",
|
||||
"SecurityError",
|
||||
"WebSocketException",
|
||||
"WebSocketProtocolError",
|
||||
@ -46,6 +43,11 @@ __all__ = [
|
||||
"WebSocketClientProtocol",
|
||||
"connect",
|
||||
"unix_connect",
|
||||
# .legacy.exceptions
|
||||
"AbortHandshake",
|
||||
"InvalidMessage",
|
||||
"InvalidStatusCode",
|
||||
"RedirectHandshake",
|
||||
# .legacy.protocol
|
||||
"WebSocketCommonProtocol",
|
||||
# .legacy.server
|
||||
@ -71,7 +73,7 @@ if typing.TYPE_CHECKING:
|
||||
from .client import ClientProtocol
|
||||
from .datastructures import Headers, HeadersLike, MultipleValuesError
|
||||
from .exceptions import (
|
||||
AbortHandshake,
|
||||
ConcurrencyError,
|
||||
ConnectionClosed,
|
||||
ConnectionClosedError,
|
||||
ConnectionClosedOK,
|
||||
@ -80,19 +82,16 @@ if typing.TYPE_CHECKING:
|
||||
InvalidHeader,
|
||||
InvalidHeaderFormat,
|
||||
InvalidHeaderValue,
|
||||
InvalidMessage,
|
||||
InvalidOrigin,
|
||||
InvalidParameterName,
|
||||
InvalidParameterValue,
|
||||
InvalidState,
|
||||
InvalidStatus,
|
||||
InvalidStatusCode,
|
||||
InvalidUpgrade,
|
||||
InvalidURI,
|
||||
NegotiationError,
|
||||
PayloadTooBig,
|
||||
ProtocolError,
|
||||
RedirectHandshake,
|
||||
SecurityError,
|
||||
WebSocketException,
|
||||
WebSocketProtocolError,
|
||||
@ -102,6 +101,12 @@ if typing.TYPE_CHECKING:
|
||||
basic_auth_protocol_factory,
|
||||
)
|
||||
from .legacy.client import WebSocketClientProtocol, connect, unix_connect
|
||||
from .legacy.exceptions import (
|
||||
AbortHandshake,
|
||||
InvalidMessage,
|
||||
InvalidStatusCode,
|
||||
RedirectHandshake,
|
||||
)
|
||||
from .legacy.protocol import WebSocketCommonProtocol
|
||||
from .legacy.server import (
|
||||
WebSocketServer,
|
||||
@ -131,7 +136,7 @@ else:
|
||||
"HeadersLike": ".datastructures",
|
||||
"MultipleValuesError": ".datastructures",
|
||||
# .exceptions
|
||||
"AbortHandshake": ".exceptions",
|
||||
"ConcurrencyError": ".exceptions",
|
||||
"ConnectionClosed": ".exceptions",
|
||||
"ConnectionClosedError": ".exceptions",
|
||||
"ConnectionClosedOK": ".exceptions",
|
||||
@ -140,19 +145,16 @@ else:
|
||||
"InvalidHeader": ".exceptions",
|
||||
"InvalidHeaderFormat": ".exceptions",
|
||||
"InvalidHeaderValue": ".exceptions",
|
||||
"InvalidMessage": ".exceptions",
|
||||
"InvalidOrigin": ".exceptions",
|
||||
"InvalidParameterName": ".exceptions",
|
||||
"InvalidParameterValue": ".exceptions",
|
||||
"InvalidState": ".exceptions",
|
||||
"InvalidStatus": ".exceptions",
|
||||
"InvalidStatusCode": ".exceptions",
|
||||
"InvalidUpgrade": ".exceptions",
|
||||
"InvalidURI": ".exceptions",
|
||||
"NegotiationError": ".exceptions",
|
||||
"PayloadTooBig": ".exceptions",
|
||||
"ProtocolError": ".exceptions",
|
||||
"RedirectHandshake": ".exceptions",
|
||||
"SecurityError": ".exceptions",
|
||||
"WebSocketException": ".exceptions",
|
||||
"WebSocketProtocolError": ".exceptions",
|
||||
@ -163,6 +165,11 @@ else:
|
||||
"WebSocketClientProtocol": ".legacy.client",
|
||||
"connect": ".legacy.client",
|
||||
"unix_connect": ".legacy.client",
|
||||
# .legacy.exceptions
|
||||
"AbortHandshake": ".legacy.exceptions",
|
||||
"InvalidMessage": ".legacy.exceptions",
|
||||
"InvalidStatusCode": ".legacy.exceptions",
|
||||
"RedirectHandshake": ".legacy.exceptions",
|
||||
# .legacy.protocol
|
||||
"WebSocketCommonProtocol": ".legacy.protocol",
|
||||
# .legacy.server
|
||||
@ -179,10 +186,11 @@ else:
|
||||
"ExtensionParameter": ".typing",
|
||||
"LoggerLike": ".typing",
|
||||
"Origin": ".typing",
|
||||
"StatusLike": "typing",
|
||||
"StatusLike": ".typing",
|
||||
"Subprotocol": ".typing",
|
||||
},
|
||||
deprecated_aliases={
|
||||
# deprecated in 9.0 - 2021-09-01
|
||||
"framing": ".legacy",
|
||||
"handshake": ".legacy",
|
||||
"parse_uri": ".uri",
|
||||
|
@ -1,24 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import urllib.parse
|
||||
from types import TracebackType
|
||||
from typing import Any, Generator, Sequence
|
||||
from typing import Any, AsyncIterator, Callable, Generator, Sequence
|
||||
|
||||
from ..client import ClientProtocol
|
||||
from ..client import ClientProtocol, backoff
|
||||
from ..datastructures import HeadersLike
|
||||
from ..exceptions import InvalidStatus, SecurityError
|
||||
from ..extensions.base import ClientExtensionFactory
|
||||
from ..extensions.permessage_deflate import enable_client_permessage_deflate
|
||||
from ..headers import validate_subprotocols
|
||||
from ..http11 import USER_AGENT, Response
|
||||
from ..protocol import CONNECTING, Event
|
||||
from ..typing import LoggerLike, Origin, Subprotocol
|
||||
from ..uri import parse_uri
|
||||
from ..uri import WebSocketURI, parse_uri
|
||||
from .compatibility import TimeoutError, asyncio_timeout
|
||||
from .connection import Connection
|
||||
|
||||
|
||||
__all__ = ["connect", "unix_connect", "ClientConnection"]
|
||||
|
||||
MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
|
||||
|
||||
|
||||
class ClientConnection(Connection):
|
||||
"""
|
||||
@ -83,20 +89,17 @@ class ClientConnection(Connection):
|
||||
self.request.headers["User-Agent"] = user_agent_header
|
||||
self.protocol.send_request(self.request)
|
||||
|
||||
# May raise CancelledError if open_timeout is exceeded.
|
||||
await self.response_rcvd
|
||||
await asyncio.wait(
|
||||
[self.response_rcvd, self.connection_lost_waiter],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
if self.response is None:
|
||||
raise ConnectionError("connection closed during handshake")
|
||||
# self.protocol.handshake_exc is always set when the connection is lost
|
||||
# before receiving a response, when the response cannot be parsed, or
|
||||
# when the response fails the handshake.
|
||||
|
||||
if self.protocol.handshake_exc is None:
|
||||
self.start_keepalive()
|
||||
else:
|
||||
try:
|
||||
async with asyncio_timeout(self.close_timeout):
|
||||
await self.connection_lost_waiter
|
||||
finally:
|
||||
raise self.protocol.handshake_exc
|
||||
if self.protocol.handshake_exc is not None:
|
||||
raise self.protocol.handshake_exc
|
||||
|
||||
def process_event(self, event: Event) -> None:
|
||||
"""
|
||||
@ -112,13 +115,46 @@ class ClientConnection(Connection):
|
||||
else:
|
||||
super().process_event(event)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
try:
|
||||
super().connection_lost(exc)
|
||||
finally:
|
||||
# If the connection is closed during the handshake, unblock it.
|
||||
if not self.response_rcvd.done():
|
||||
self.response_rcvd.set_result(None)
|
||||
|
||||
def process_exception(exc: Exception) -> Exception | None:
|
||||
"""
|
||||
Determine whether a connection error is retryable or fatal.
|
||||
|
||||
When reconnecting automatically with ``async for ... in connect(...)``, if a
|
||||
connection attempt fails, :func:`process_exception` is called to determine
|
||||
whether to retry connecting or to raise the exception.
|
||||
|
||||
This function defines the default behavior, which is to retry on:
|
||||
|
||||
* :exc:`EOFError`, :exc:`OSError`, :exc:`asyncio.TimeoutError`: network
|
||||
errors;
|
||||
* :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
|
||||
502, 503, or 504: server or proxy errors.
|
||||
|
||||
All other exceptions are considered fatal.
|
||||
|
||||
You can change this behavior with the ``process_exception`` argument of
|
||||
:func:`connect`.
|
||||
|
||||
Return :obj:`None` if the exception is retryable i.e. when the error could
|
||||
be transient and trying to reconnect with the same parameters could succeed.
|
||||
The exception will be logged at the ``INFO`` level.
|
||||
|
||||
Return an exception, either ``exc`` or a new exception, if the exception is
|
||||
fatal i.e. when trying to reconnect will most likely produce the same error.
|
||||
That exception will be raised, breaking out of the retry loop.
|
||||
|
||||
"""
|
||||
if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)):
|
||||
return None
|
||||
if isinstance(exc, InvalidStatus) and exc.response.status_code in [
|
||||
500, # Internal Server Error
|
||||
502, # Bad Gateway
|
||||
503, # Service Unavailable
|
||||
504, # Gateway Timeout
|
||||
]:
|
||||
return None
|
||||
return exc
|
||||
|
||||
|
||||
# This is spelled in lower case because it's exposed as a callable in the API.
|
||||
@ -131,11 +167,28 @@ class connect:
|
||||
|
||||
:func:`connect` may be used as an asynchronous context manager::
|
||||
|
||||
async with websockets.asyncio.client.connect(...) as websocket:
|
||||
from websockets.asyncio.client import connect
|
||||
|
||||
async with connect(...) as websocket:
|
||||
...
|
||||
|
||||
The connection is closed automatically when exiting the context.
|
||||
|
||||
:func:`connect` can be used as an infinite asynchronous iterator to
|
||||
reconnect automatically on errors::
|
||||
|
||||
async for websocket in connect(...):
|
||||
try:
|
||||
...
|
||||
except websockets.ConnectionClosed:
|
||||
continue
|
||||
|
||||
If the connection fails with a transient error, it is retried with
|
||||
exponential backoff. If it fails with a fatal error, the exception is
|
||||
raised, breaking out of the loop.
|
||||
|
||||
The connection is closed automatically after each iteration of the loop.
|
||||
|
||||
Args:
|
||||
uri: URI of the WebSocket server.
|
||||
origin: Value of the ``Origin`` header, for servers that require it.
|
||||
@ -151,6 +204,9 @@ class connect:
|
||||
compression: The "permessage-deflate" extension is enabled by default.
|
||||
Set ``compression`` to :obj:`None` to disable it. See the
|
||||
:doc:`compression guide <../../topics/compression>` for details.
|
||||
process_exception: When reconnecting automatically, tell whether an
|
||||
error is transient or fatal. The default behavior is defined by
|
||||
:func:`process_exception`. Refer to its documentation for details.
|
||||
open_timeout: Timeout for opening the connection in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
ping_interval: Interval between keepalive pings in seconds.
|
||||
@ -217,6 +273,7 @@ class connect:
|
||||
additional_headers: HeadersLike | None = None,
|
||||
user_agent_header: str | None = USER_AGENT,
|
||||
compression: str | None = "deflate",
|
||||
process_exception: Callable[[Exception], Exception | None] = process_exception,
|
||||
# Timeouts
|
||||
open_timeout: float | None = 10,
|
||||
ping_interval: float | None = 20,
|
||||
@ -233,17 +290,7 @@ class connect:
|
||||
# Other keyword arguments are passed to loop.create_connection
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
wsuri = parse_uri(uri)
|
||||
|
||||
if wsuri.secure:
|
||||
kwargs.setdefault("ssl", True)
|
||||
kwargs.setdefault("server_hostname", wsuri.host)
|
||||
if kwargs.get("ssl") is None:
|
||||
raise TypeError("ssl=None is incompatible with a wss:// URI")
|
||||
else:
|
||||
if kwargs.get("ssl") is not None:
|
||||
raise TypeError("ssl argument is incompatible with a ws:// URI")
|
||||
self.uri = uri
|
||||
|
||||
if subprotocols is not None:
|
||||
validate_subprotocols(subprotocols)
|
||||
@ -253,10 +300,13 @@ class connect:
|
||||
elif compression is not None:
|
||||
raise ValueError(f"unsupported compression: {compression}")
|
||||
|
||||
if logger is None:
|
||||
logger = logging.getLogger("websockets.client")
|
||||
|
||||
if create_connection is None:
|
||||
create_connection = ClientConnection
|
||||
|
||||
def factory() -> ClientConnection:
|
||||
def protocol_factory(wsuri: WebSocketURI) -> ClientConnection:
|
||||
# This is a protocol in the Sans-I/O implementation of websockets.
|
||||
protocol = ClientProtocol(
|
||||
wsuri,
|
||||
@ -277,21 +327,154 @@ class connect:
|
||||
)
|
||||
return connection
|
||||
|
||||
self.protocol_factory = protocol_factory
|
||||
self.handshake_args = (
|
||||
additional_headers,
|
||||
user_agent_header,
|
||||
)
|
||||
self.process_exception = process_exception
|
||||
self.open_timeout = open_timeout
|
||||
self.logger = logger
|
||||
self.connection_kwargs = kwargs
|
||||
|
||||
async def create_connection(self) -> ClientConnection:
|
||||
"""Create TCP or Unix connection."""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
wsuri = parse_uri(self.uri)
|
||||
kwargs = self.connection_kwargs.copy()
|
||||
|
||||
def factory() -> ClientConnection:
|
||||
return self.protocol_factory(wsuri)
|
||||
|
||||
if wsuri.secure:
|
||||
kwargs.setdefault("ssl", True)
|
||||
kwargs.setdefault("server_hostname", wsuri.host)
|
||||
if kwargs.get("ssl") is None:
|
||||
raise TypeError("ssl=None is incompatible with a wss:// URI")
|
||||
else:
|
||||
if kwargs.get("ssl") is not None:
|
||||
raise TypeError("ssl argument is incompatible with a ws:// URI")
|
||||
|
||||
if kwargs.pop("unix", False):
|
||||
self._create_connection = loop.create_unix_connection(factory, **kwargs)
|
||||
_, connection = await loop.create_unix_connection(factory, **kwargs)
|
||||
else:
|
||||
if kwargs.get("sock") is None:
|
||||
kwargs.setdefault("host", wsuri.host)
|
||||
kwargs.setdefault("port", wsuri.port)
|
||||
self._create_connection = loop.create_connection(factory, **kwargs)
|
||||
_, connection = await loop.create_connection(factory, **kwargs)
|
||||
return connection
|
||||
|
||||
self._handshake_args = (
|
||||
additional_headers,
|
||||
user_agent_header,
|
||||
)
|
||||
def process_redirect(self, exc: Exception) -> Exception | str:
|
||||
"""
|
||||
Determine whether a connection error is a redirect that can be followed.
|
||||
|
||||
self._open_timeout = open_timeout
|
||||
Return the new URI if it's a valid redirect. Else, return an exception.
|
||||
|
||||
"""
|
||||
if not (
|
||||
isinstance(exc, InvalidStatus)
|
||||
and exc.response.status_code
|
||||
in [
|
||||
300, # Multiple Choices
|
||||
301, # Moved Permanently
|
||||
302, # Found
|
||||
303, # See Other
|
||||
307, # Temporary Redirect
|
||||
308, # Permanent Redirect
|
||||
]
|
||||
and "Location" in exc.response.headers
|
||||
):
|
||||
return exc
|
||||
|
||||
old_wsuri = parse_uri(self.uri)
|
||||
new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"])
|
||||
new_wsuri = parse_uri(new_uri)
|
||||
|
||||
# If connect() received a socket, it is closed and cannot be reused.
|
||||
if self.connection_kwargs.get("sock") is not None:
|
||||
return ValueError(
|
||||
f"cannot follow redirect to {new_uri} with a preexisting socket"
|
||||
)
|
||||
|
||||
# TLS downgrade is forbidden.
|
||||
if old_wsuri.secure and not new_wsuri.secure:
|
||||
return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}")
|
||||
|
||||
# Apply restrictions to cross-origin redirects.
|
||||
if (
|
||||
old_wsuri.secure != new_wsuri.secure
|
||||
or old_wsuri.host != new_wsuri.host
|
||||
or old_wsuri.port != new_wsuri.port
|
||||
):
|
||||
# Cross-origin redirects on Unix sockets don't quite make sense.
|
||||
if self.connection_kwargs.get("unix", False):
|
||||
return ValueError(
|
||||
f"cannot follow cross-origin redirect to {new_uri} "
|
||||
f"with a Unix socket"
|
||||
)
|
||||
|
||||
# Cross-origin redirects when host and port are overridden are ill-defined.
|
||||
if (
|
||||
self.connection_kwargs.get("host") is not None
|
||||
or self.connection_kwargs.get("port") is not None
|
||||
):
|
||||
return ValueError(
|
||||
f"cannot follow cross-origin redirect to {new_uri} "
|
||||
f"with an explicit host or port"
|
||||
)
|
||||
|
||||
return new_uri
|
||||
|
||||
# ... = await connect(...)
|
||||
|
||||
def __await__(self) -> Generator[Any, None, ClientConnection]:
|
||||
# Create a suitable iterator by calling __await__ on a coroutine.
|
||||
return self.__await_impl__().__await__()
|
||||
|
||||
async def __await_impl__(self) -> ClientConnection:
|
||||
try:
|
||||
async with asyncio_timeout(self.open_timeout):
|
||||
for _ in range(MAX_REDIRECTS):
|
||||
self.connection = await self.create_connection()
|
||||
try:
|
||||
await self.connection.handshake(*self.handshake_args)
|
||||
except asyncio.CancelledError:
|
||||
self.connection.close_transport()
|
||||
raise
|
||||
except Exception as exc:
|
||||
# Always close the connection even though keep-alive is
|
||||
# the default in HTTP/1.1 because create_connection ties
|
||||
# opening the network connection with initializing the
|
||||
# protocol. In the current design of connect(), there is
|
||||
# no easy way to reuse the network connection that works
|
||||
# in every case nor to reinitialize the protocol.
|
||||
self.connection.close_transport()
|
||||
|
||||
uri_or_exc = self.process_redirect(exc)
|
||||
# Response is a valid redirect; follow it.
|
||||
if isinstance(uri_or_exc, str):
|
||||
self.uri = uri_or_exc
|
||||
continue
|
||||
# Response isn't a valid redirect; raise the exception.
|
||||
if uri_or_exc is exc:
|
||||
raise
|
||||
else:
|
||||
raise uri_or_exc from exc
|
||||
|
||||
else:
|
||||
self.connection.start_keepalive()
|
||||
return self.connection
|
||||
else:
|
||||
raise SecurityError(f"more than {MAX_REDIRECTS} redirects")
|
||||
|
||||
except TimeoutError:
|
||||
# Re-raise exception with an informative error message.
|
||||
raise TimeoutError("timed out during handshake") from None
|
||||
|
||||
# ... = yield from connect(...) - remove when dropping Python < 3.10
|
||||
|
||||
__iter__ = __await__
|
||||
|
||||
# async with connect(...) as ...: ...
|
||||
|
||||
@ -306,30 +489,47 @@ class connect:
|
||||
) -> None:
|
||||
await self.connection.close()
|
||||
|
||||
# ... = await connect(...)
|
||||
# async for ... in connect(...):
|
||||
|
||||
def __await__(self) -> Generator[Any, None, ClientConnection]:
|
||||
# Create a suitable iterator by calling __await__ on a coroutine.
|
||||
return self.__await_impl__().__await__()
|
||||
|
||||
async def __await_impl__(self) -> ClientConnection:
|
||||
try:
|
||||
async with asyncio_timeout(self._open_timeout):
|
||||
_transport, self.connection = await self._create_connection
|
||||
async def __aiter__(self) -> AsyncIterator[ClientConnection]:
|
||||
delays: Generator[float, None, None] | None = None
|
||||
while True:
|
||||
try:
|
||||
async with self as protocol:
|
||||
yield protocol
|
||||
except Exception as exc:
|
||||
# Determine whether the exception is retryable or fatal.
|
||||
# The API of process_exception is "return an exception or None";
|
||||
# "raise an exception" is also supported because it's a frequent
|
||||
# mistake. It isn't documented in order to keep the API simple.
|
||||
try:
|
||||
await self.connection.handshake(*self._handshake_args)
|
||||
except (Exception, asyncio.CancelledError):
|
||||
self.connection.transport.close()
|
||||
new_exc = self.process_exception(exc)
|
||||
except Exception as raised_exc:
|
||||
new_exc = raised_exc
|
||||
|
||||
# The connection failed with a fatal error.
|
||||
# Raise the exception and exit the loop.
|
||||
if new_exc is exc:
|
||||
raise
|
||||
else:
|
||||
return self.connection
|
||||
except TimeoutError:
|
||||
# Re-raise exception with an informative error message.
|
||||
raise TimeoutError("timed out during handshake") from None
|
||||
if new_exc is not None:
|
||||
raise new_exc from exc
|
||||
|
||||
# ... = yield from connect(...) - remove when dropping Python < 3.10
|
||||
# The connection failed with a retryable error.
|
||||
# Start or continue backoff and reconnect.
|
||||
if delays is None:
|
||||
delays = backoff()
|
||||
delay = next(delays)
|
||||
self.logger.info(
|
||||
"! connect failed; reconnecting in %.1f seconds",
|
||||
delay,
|
||||
exc_info=True,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
__iter__ = __await__
|
||||
else:
|
||||
# The connection succeeded. Reset backoff.
|
||||
delays = None
|
||||
|
||||
|
||||
def unix_connect(
|
||||
|
@ -19,8 +19,13 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError
|
||||
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl
|
||||
from ..exceptions import (
|
||||
ConcurrencyError,
|
||||
ConnectionClosed,
|
||||
ConnectionClosedOK,
|
||||
ProtocolError,
|
||||
)
|
||||
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode
|
||||
from ..http11 import Request, Response
|
||||
from ..protocol import CLOSED, OPEN, Event, Protocol, State
|
||||
from ..typing import Data, LoggerLike, Subprotocol
|
||||
@ -262,16 +267,18 @@ class Connection(asyncio.Protocol):
|
||||
|
||||
Raises:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
RuntimeError: If two coroutines call :meth:`recv` or
|
||||
ConcurrencyError: If two coroutines call :meth:`recv` or
|
||||
:meth:`recv_streaming` concurrently.
|
||||
|
||||
"""
|
||||
try:
|
||||
return await self.recv_messages.get(decode)
|
||||
except EOFError:
|
||||
# Wait for the protocol state to be CLOSED before accessing close_exc.
|
||||
await asyncio.shield(self.connection_lost_waiter)
|
||||
raise self.protocol.close_exc from self.recv_exc
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
except ConcurrencyError:
|
||||
raise ConcurrencyError(
|
||||
"cannot call recv while another coroutine "
|
||||
"is already running recv or recv_streaming"
|
||||
) from None
|
||||
@ -283,8 +290,9 @@ class Connection(asyncio.Protocol):
|
||||
This method is designed for receiving fragmented messages. It returns an
|
||||
asynchronous iterator that yields each fragment as it is received. This
|
||||
iterator must be fully consumed. Else, future calls to :meth:`recv` or
|
||||
:meth:`recv_streaming` will raise :exc:`RuntimeError`, making the
|
||||
connection unusable.
|
||||
:meth:`recv_streaming` will raise
|
||||
:exc:`~websockets.exceptions.ConcurrencyError`, making the connection
|
||||
unusable.
|
||||
|
||||
:meth:`recv_streaming` raises the same exceptions as :meth:`recv`.
|
||||
|
||||
@ -315,7 +323,7 @@ class Connection(asyncio.Protocol):
|
||||
|
||||
Raises:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
RuntimeError: If two coroutines call :meth:`recv` or
|
||||
ConcurrencyError: If two coroutines call :meth:`recv` or
|
||||
:meth:`recv_streaming` concurrently.
|
||||
|
||||
"""
|
||||
@ -323,9 +331,11 @@ class Connection(asyncio.Protocol):
|
||||
async for frame in self.recv_messages.get_iter(decode):
|
||||
yield frame
|
||||
except EOFError:
|
||||
# Wait for the protocol state to be CLOSED before accessing close_exc.
|
||||
await asyncio.shield(self.connection_lost_waiter)
|
||||
raise self.protocol.close_exc from self.recv_exc
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
except ConcurrencyError:
|
||||
raise ConcurrencyError(
|
||||
"cannot call recv_streaming while another coroutine "
|
||||
"is already running recv or recv_streaming"
|
||||
) from None
|
||||
@ -593,17 +603,21 @@ class Connection(asyncio.Protocol):
|
||||
|
||||
Raises:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
RuntimeError: If another ping was sent with the same data and
|
||||
ConcurrencyError: If another ping was sent with the same data and
|
||||
the corresponding pong wasn't received yet.
|
||||
|
||||
"""
|
||||
if data is not None:
|
||||
data = prepare_ctrl(data)
|
||||
if isinstance(data, BytesLike):
|
||||
data = bytes(data)
|
||||
elif isinstance(data, str):
|
||||
data = data.encode()
|
||||
elif data is not None:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
async with self.send_context():
|
||||
# Protect against duplicates if a payload is explicitly set.
|
||||
if data in self.pong_waiters:
|
||||
raise RuntimeError("already waiting for a pong with the same data")
|
||||
raise ConcurrencyError("already waiting for a pong with the same data")
|
||||
|
||||
# Generate a unique random payload otherwise.
|
||||
while data is None or data in self.pong_waiters:
|
||||
@ -632,7 +646,12 @@ class Connection(asyncio.Protocol):
|
||||
ConnectionClosed: When the connection is closed.
|
||||
|
||||
"""
|
||||
data = prepare_ctrl(data)
|
||||
if isinstance(data, BytesLike):
|
||||
data = bytes(data)
|
||||
elif isinstance(data, str):
|
||||
data = data.encode()
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
async with self.send_context():
|
||||
self.protocol.send_pong(data)
|
||||
@ -784,7 +803,7 @@ class Connection(asyncio.Protocol):
|
||||
# Let the caller interact with the protocol.
|
||||
try:
|
||||
yield
|
||||
except (ProtocolError, RuntimeError):
|
||||
except (ProtocolError, ConcurrencyError):
|
||||
# The protocol state wasn't changed. Exit immediately.
|
||||
raise
|
||||
except Exception as exc:
|
||||
@ -849,6 +868,7 @@ class Connection(asyncio.Protocol):
|
||||
# raise an exception.
|
||||
if raise_close_exc:
|
||||
self.close_transport()
|
||||
# Wait for the protocol state to be CLOSED before accessing close_exc.
|
||||
await asyncio.shield(self.connection_lost_waiter)
|
||||
raise self.protocol.close_exc from original_exc
|
||||
|
||||
@ -911,11 +931,14 @@ class Connection(asyncio.Protocol):
|
||||
self.transport = transport
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.protocol.receive_eof() # receive_eof is idempotent
|
||||
# Calling protocol.receive_eof() is safe because it's idempotent.
|
||||
# This guarantees that the protocol state becomes CLOSED.
|
||||
self.protocol.receive_eof()
|
||||
assert self.protocol.state is CLOSED
|
||||
|
||||
self.set_recv_exc(exc)
|
||||
|
||||
# Abort recv() and pending pings with a ConnectionClosed exception.
|
||||
# Set recv_exc first to get proper exception reporting.
|
||||
self.set_recv_exc(exc)
|
||||
self.recv_messages.close()
|
||||
self.abort_pings()
|
||||
|
||||
@ -1083,15 +1106,17 @@ def broadcast(
|
||||
if raise_exceptions:
|
||||
if sys.version_info[:2] < (3, 11): # pragma: no cover
|
||||
raise ValueError("raise_exceptions requires at least Python 3.11")
|
||||
exceptions = []
|
||||
exceptions: list[Exception] = []
|
||||
|
||||
for connection in connections:
|
||||
exception: Exception
|
||||
|
||||
if connection.protocol.state is not OPEN:
|
||||
continue
|
||||
|
||||
if connection.fragmented_send_waiter is not None:
|
||||
if raise_exceptions:
|
||||
exception = RuntimeError("sending a fragmented message")
|
||||
exception = ConcurrencyError("sending a fragmented message")
|
||||
exceptions.append(exception)
|
||||
else:
|
||||
connection.logger.warning(
|
||||
|
@ -12,6 +12,7 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from ..exceptions import ConcurrencyError
|
||||
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
|
||||
from ..typing import Data
|
||||
|
||||
@ -49,7 +50,7 @@ class SimpleQueue(Generic[T]):
|
||||
"""Remove and return an item from the queue, waiting if necessary."""
|
||||
if not self.queue:
|
||||
if self.get_waiter is not None:
|
||||
raise RuntimeError("get is already running")
|
||||
raise ConcurrencyError("get is already running")
|
||||
self.get_waiter = self.loop.create_future()
|
||||
try:
|
||||
await self.get_waiter
|
||||
@ -135,15 +136,15 @@ class Assembler:
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream of frames has ended.
|
||||
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
|
||||
concurrently.
|
||||
ConcurrencyError: If two coroutines run :meth:`get` or
|
||||
:meth:`get_iter` concurrently.
|
||||
|
||||
"""
|
||||
if self.closed:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
if self.get_in_progress:
|
||||
raise RuntimeError("get() or get_iter() is already running")
|
||||
raise ConcurrencyError("get() or get_iter() is already running")
|
||||
|
||||
# Locking with get_in_progress ensures only one coroutine can get here.
|
||||
self.get_in_progress = True
|
||||
@ -190,7 +191,7 @@ class Assembler:
|
||||
:class:`str` or :class:`bytes` for each frame in the message.
|
||||
|
||||
The iterator must be fully consumed before calling :meth:`get_iter` or
|
||||
:meth:`get` again. Else, :exc:`RuntimeError` is raised.
|
||||
:meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
|
||||
|
||||
This method only makes sense for fragmented messages. If messages aren't
|
||||
fragmented, use :meth:`get` instead.
|
||||
@ -202,15 +203,15 @@ class Assembler:
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream of frames has ended.
|
||||
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
|
||||
concurrently.
|
||||
ConcurrencyError: If two coroutines run :meth:`get` or
|
||||
:meth:`get_iter` concurrently.
|
||||
|
||||
"""
|
||||
if self.closed:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
if self.get_in_progress:
|
||||
raise RuntimeError("get() or get_iter() is already running")
|
||||
raise ConcurrencyError("get() or get_iter() is already running")
|
||||
|
||||
# Locking with get_in_progress ensures only one coroutine can get here.
|
||||
self.get_in_progress = True
|
||||
@ -236,7 +237,7 @@ class Assembler:
|
||||
# We cannot handle asyncio.CancelledError because we don't buffer
|
||||
# previous fragments — we're streaming them. Canceling get_iter()
|
||||
# here will leave the assembler in a stuck state. Future calls to
|
||||
# get() or get_iter() will raise RuntimeError.
|
||||
# get() or get_iter() will raise ConcurrencyError.
|
||||
frame = await self.frames.get()
|
||||
self.maybe_resume()
|
||||
assert frame.opcode is OP_CONT
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hmac
|
||||
import http
|
||||
import logging
|
||||
import socket
|
||||
@ -13,22 +14,35 @@ from typing import (
|
||||
Generator,
|
||||
Iterable,
|
||||
Sequence,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
from websockets.frames import CloseCode
|
||||
|
||||
from ..exceptions import InvalidHeader
|
||||
from ..extensions.base import ServerExtensionFactory
|
||||
from ..extensions.permessage_deflate import enable_server_permessage_deflate
|
||||
from ..headers import validate_subprotocols
|
||||
from ..frames import CloseCode
|
||||
from ..headers import (
|
||||
build_www_authenticate_basic,
|
||||
parse_authorization_basic,
|
||||
validate_subprotocols,
|
||||
)
|
||||
from ..http11 import SERVER, Request, Response
|
||||
from ..protocol import CONNECTING, Event
|
||||
from ..protocol import CONNECTING, OPEN, Event
|
||||
from ..server import ServerProtocol
|
||||
from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
|
||||
from .compatibility import asyncio_timeout
|
||||
from .connection import Connection, broadcast
|
||||
|
||||
|
||||
__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "Server"]
|
||||
__all__ = [
|
||||
"broadcast",
|
||||
"serve",
|
||||
"unix_serve",
|
||||
"ServerConnection",
|
||||
"Server",
|
||||
"basic_auth",
|
||||
]
|
||||
|
||||
|
||||
class ServerConnection(Connection):
|
||||
@ -79,6 +93,7 @@ class ServerConnection(Connection):
|
||||
)
|
||||
self.server = server
|
||||
self.request_rcvd: asyncio.Future[None] = self.loop.create_future()
|
||||
self.username: str # see basic_auth()
|
||||
|
||||
def respond(self, status: StatusLike, text: str) -> Response:
|
||||
"""
|
||||
@ -123,78 +138,75 @@ class ServerConnection(Connection):
|
||||
Perform the opening handshake.
|
||||
|
||||
"""
|
||||
# May raise CancelledError if open_timeout is exceeded.
|
||||
await self.request_rcvd
|
||||
await asyncio.wait(
|
||||
[self.request_rcvd, self.connection_lost_waiter],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
if self.request is None:
|
||||
raise ConnectionError("connection closed during handshake")
|
||||
if self.request is not None:
|
||||
async with self.send_context(expected_state=CONNECTING):
|
||||
response = None
|
||||
|
||||
async with self.send_context(expected_state=CONNECTING):
|
||||
response = None
|
||||
if process_request is not None:
|
||||
try:
|
||||
response = process_request(self, self.request)
|
||||
if isinstance(response, Awaitable):
|
||||
response = await response
|
||||
except Exception as exc:
|
||||
self.protocol.handshake_exc = exc
|
||||
response = self.protocol.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
|
||||
if process_request is not None:
|
||||
try:
|
||||
response = process_request(self, self.request)
|
||||
if isinstance(response, Awaitable):
|
||||
response = await response
|
||||
except Exception as exc:
|
||||
self.protocol.handshake_exc = exc
|
||||
self.logger.error("opening handshake failed", exc_info=True)
|
||||
response = self.protocol.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
|
||||
if response is None:
|
||||
if self.server.is_serving():
|
||||
self.response = self.protocol.accept(self.request)
|
||||
if response is None:
|
||||
if self.server.is_serving():
|
||||
self.response = self.protocol.accept(self.request)
|
||||
else:
|
||||
self.response = self.protocol.reject(
|
||||
http.HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
"Server is shutting down.\n",
|
||||
)
|
||||
else:
|
||||
self.response = self.protocol.reject(
|
||||
http.HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
"Server is shutting down.\n",
|
||||
)
|
||||
else:
|
||||
assert isinstance(response, Response) # help mypy
|
||||
self.response = response
|
||||
assert isinstance(response, Response) # help mypy
|
||||
self.response = response
|
||||
|
||||
if server_header:
|
||||
self.response.headers["Server"] = server_header
|
||||
if server_header:
|
||||
self.response.headers["Server"] = server_header
|
||||
|
||||
response = None
|
||||
response = None
|
||||
|
||||
if process_response is not None:
|
||||
try:
|
||||
response = process_response(self, self.request, self.response)
|
||||
if isinstance(response, Awaitable):
|
||||
response = await response
|
||||
except Exception as exc:
|
||||
self.protocol.handshake_exc = exc
|
||||
self.logger.error("opening handshake failed", exc_info=True)
|
||||
response = self.protocol.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
if process_response is not None:
|
||||
try:
|
||||
response = process_response(self, self.request, self.response)
|
||||
if isinstance(response, Awaitable):
|
||||
response = await response
|
||||
except Exception as exc:
|
||||
self.protocol.handshake_exc = exc
|
||||
response = self.protocol.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
|
||||
if response is not None:
|
||||
assert isinstance(response, Response) # help mypy
|
||||
self.response = response
|
||||
if response is not None:
|
||||
assert isinstance(response, Response) # help mypy
|
||||
self.response = response
|
||||
|
||||
self.protocol.send_response(self.response)
|
||||
self.protocol.send_response(self.response)
|
||||
|
||||
if self.protocol.handshake_exc is None:
|
||||
self.start_keepalive()
|
||||
else:
|
||||
try:
|
||||
async with asyncio_timeout(self.close_timeout):
|
||||
await self.connection_lost_waiter
|
||||
finally:
|
||||
raise self.protocol.handshake_exc
|
||||
# self.protocol.handshake_exc is always set when the connection is lost
|
||||
# before receiving a request, when the request cannot be parsed, when
|
||||
# the handshake encounters an error, or when process_request or
|
||||
# process_response sends an HTTP response that rejects the handshake.
|
||||
|
||||
if self.protocol.handshake_exc is not None:
|
||||
raise self.protocol.handshake_exc
|
||||
|
||||
def process_event(self, event: Event) -> None:
|
||||
"""
|
||||
@ -214,14 +226,6 @@ class ServerConnection(Connection):
|
||||
super().connection_made(transport)
|
||||
self.server.start_connection_handler(self)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
try:
|
||||
super().connection_lost(exc)
|
||||
finally:
|
||||
# If the connection is closed during the handshake, unblock it.
|
||||
if not self.request_rcvd.done():
|
||||
self.request_rcvd.set_result(None)
|
||||
|
||||
|
||||
class Server:
|
||||
"""
|
||||
@ -298,6 +302,18 @@ class Server:
|
||||
# Completed when the server is closed and connections are terminated.
|
||||
self.closed_waiter: asyncio.Future[None] = self.loop.create_future()
|
||||
|
||||
@property
|
||||
def connections(self) -> set[ServerConnection]:
|
||||
"""
|
||||
Set of active connections.
|
||||
|
||||
This property contains all connections that completed the opening
|
||||
handshake successfully and didn't start the closing handshake yet.
|
||||
It can be useful in combination with :func:`~broadcast`.
|
||||
|
||||
"""
|
||||
return {connection for connection in self.handlers if connection.state is OPEN}
|
||||
|
||||
def wrap(self, server: asyncio.Server) -> None:
|
||||
"""
|
||||
Attach to a given :class:`asyncio.Server`.
|
||||
@ -337,25 +353,37 @@ class Server:
|
||||
|
||||
"""
|
||||
try:
|
||||
# On failure, handshake() closes the transport, raises an
|
||||
# exception, and logs it.
|
||||
async with asyncio_timeout(self.open_timeout):
|
||||
await connection.handshake(
|
||||
self.process_request,
|
||||
self.process_response,
|
||||
self.server_header,
|
||||
)
|
||||
try:
|
||||
await connection.handshake(
|
||||
self.process_request,
|
||||
self.process_response,
|
||||
self.server_header,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
connection.close_transport()
|
||||
raise
|
||||
except Exception:
|
||||
connection.logger.error("opening handshake failed", exc_info=True)
|
||||
connection.close_transport()
|
||||
return
|
||||
|
||||
assert connection.protocol.state is OPEN
|
||||
try:
|
||||
connection.start_keepalive()
|
||||
await self.handler(connection)
|
||||
except Exception:
|
||||
self.logger.error("connection handler failed", exc_info=True)
|
||||
connection.logger.error("connection handler failed", exc_info=True)
|
||||
await connection.close(CloseCode.INTERNAL_ERROR)
|
||||
else:
|
||||
await connection.close()
|
||||
|
||||
except Exception:
|
||||
# Don't leak connections on errors.
|
||||
except TimeoutError:
|
||||
# When the opening handshake times out, there's nothing to log.
|
||||
pass
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
# Don't leak connections on unexpected errors.
|
||||
connection.transport.abort()
|
||||
|
||||
finally:
|
||||
@ -548,19 +576,21 @@ class serve:
|
||||
:class:`asyncio.Server`. Treat it as an asynchronous context manager to
|
||||
ensure that the server will be closed::
|
||||
|
||||
from websockets.asyncio.server import serve
|
||||
|
||||
def handler(websocket):
|
||||
...
|
||||
|
||||
# set this future to exit the server
|
||||
stop = asyncio.get_running_loop().create_future()
|
||||
|
||||
async with websockets.asyncio.server.serve(handler, host, port):
|
||||
async with serve(handler, host, port):
|
||||
await stop
|
||||
|
||||
Alternatively, call :meth:`~Server.serve_forever` to serve requests and
|
||||
cancel it to stop the server::
|
||||
|
||||
server = await websockets.asyncio.server.serve(handler, host, port)
|
||||
server = await serve(handler, host, port)
|
||||
await server.serve_forever()
|
||||
|
||||
Args:
|
||||
@ -664,14 +694,14 @@ class serve:
|
||||
process_request: (
|
||||
Callable[
|
||||
[ServerConnection, Request],
|
||||
Response | None,
|
||||
Awaitable[Response | None] | Response | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
process_response: (
|
||||
Callable[
|
||||
[ServerConnection, Request, Response],
|
||||
Response | None,
|
||||
Awaitable[Response | None] | Response | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
@ -693,7 +723,6 @@ class serve:
|
||||
# Other keyword arguments are passed to loop.create_server
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
if subprotocols is not None:
|
||||
validate_subprotocols(subprotocols)
|
||||
|
||||
@ -767,10 +796,10 @@ class serve:
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
if kwargs.pop("unix", False):
|
||||
self._create_server = loop.create_unix_server(factory, **kwargs)
|
||||
self.create_server = loop.create_unix_server(factory, **kwargs)
|
||||
else:
|
||||
# mypy cannot tell that kwargs must provide sock when port is None.
|
||||
self._create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type]
|
||||
self.create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
# async with serve(...) as ...: ...
|
||||
|
||||
@ -793,7 +822,7 @@ class serve:
|
||||
return self.__await_impl__().__await__()
|
||||
|
||||
async def __await_impl__(self) -> Server:
|
||||
server = await self._create_server
|
||||
server = await self.create_server
|
||||
self.server.wrap(server)
|
||||
return self.server
|
||||
|
||||
@ -822,3 +851,123 @@ def unix_serve(
|
||||
|
||||
"""
|
||||
return serve(handler, unix=True, path=path, **kwargs)
|
||||
|
||||
|
||||
def is_credentials(credentials: Any) -> bool:
|
||||
try:
|
||||
username, password = credentials
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
else:
|
||||
return isinstance(username, str) and isinstance(password, str)
|
||||
|
||||
|
||||
def basic_auth(
|
||||
realm: str = "",
|
||||
credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None,
|
||||
check_credentials: Callable[[str, str], Awaitable[bool] | bool] | None = None,
|
||||
) -> Callable[[ServerConnection, Request], Awaitable[Response | None]]:
|
||||
"""
|
||||
Factory for ``process_request`` to enforce HTTP Basic Authentication.
|
||||
|
||||
:func:`basic_auth` is designed to integrate with :func:`serve` as follows::
|
||||
|
||||
from websockets.asyncio.server import basic_auth, serve
|
||||
|
||||
async with serve(
|
||||
...,
|
||||
process_request=basic_auth(
|
||||
realm="my dev server",
|
||||
credentials=("hello", "iloveyou"),
|
||||
),
|
||||
):
|
||||
|
||||
If authentication succeeds, the connection's ``username`` attribute is set.
|
||||
If it fails, the server responds with an HTTP 401 Unauthorized status.
|
||||
|
||||
One of ``credentials`` or ``check_credentials`` must be provided; not both.
|
||||
|
||||
Args:
|
||||
realm: Scope of protection. It should contain only ASCII characters
|
||||
because the encoding of non-ASCII characters is undefined. Refer to
|
||||
section 2.2 of :rfc:`7235` for details.
|
||||
credentials: Hard coded authorized credentials. It can be a
|
||||
``(username, password)`` pair or a list of such pairs.
|
||||
check_credentials: Function or coroutine that verifies credentials.
|
||||
It receives ``username`` and ``password`` arguments and returns
|
||||
whether they're valid.
|
||||
Raises:
|
||||
TypeError: If ``credentials`` or ``check_credentials`` is wrong.
|
||||
|
||||
"""
|
||||
if (credentials is None) == (check_credentials is None):
|
||||
raise TypeError("provide either credentials or check_credentials")
|
||||
|
||||
if credentials is not None:
|
||||
if is_credentials(credentials):
|
||||
credentials_list = [cast(Tuple[str, str], credentials)]
|
||||
elif isinstance(credentials, Iterable):
|
||||
credentials_list = list(cast(Iterable[Tuple[str, str]], credentials))
|
||||
if not all(is_credentials(item) for item in credentials_list):
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
else:
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
|
||||
credentials_dict = dict(credentials_list)
|
||||
|
||||
def check_credentials(username: str, password: str) -> bool:
|
||||
try:
|
||||
expected_password = credentials_dict[username]
|
||||
except KeyError:
|
||||
return False
|
||||
return hmac.compare_digest(expected_password, password)
|
||||
|
||||
assert check_credentials is not None # help mypy
|
||||
|
||||
async def process_request(
|
||||
connection: ServerConnection,
|
||||
request: Request,
|
||||
) -> Response | None:
|
||||
"""
|
||||
Perform HTTP Basic Authentication.
|
||||
|
||||
If it succeeds, set the connection's ``username`` attribute and return
|
||||
:obj:`None`. If it fails, return an HTTP 401 Unauthorized responss.
|
||||
|
||||
"""
|
||||
try:
|
||||
authorization = request.headers["Authorization"]
|
||||
except KeyError:
|
||||
response = connection.respond(
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
"Missing credentials\n",
|
||||
)
|
||||
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
|
||||
return response
|
||||
|
||||
try:
|
||||
username, password = parse_authorization_basic(authorization)
|
||||
except InvalidHeader:
|
||||
response = connection.respond(
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
"Unsupported credentials\n",
|
||||
)
|
||||
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
|
||||
return response
|
||||
|
||||
valid_credentials = check_credentials(username, password)
|
||||
if isinstance(valid_credentials, Awaitable):
|
||||
valid_credentials = await valid_credentials
|
||||
|
||||
if not valid_credentials:
|
||||
response = connection.respond(
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
"Invalid credentials\n",
|
||||
)
|
||||
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
|
||||
return response
|
||||
|
||||
connection.username = username
|
||||
return None
|
||||
|
||||
return process_request
|
||||
|
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from typing import Any, Generator, Sequence
|
||||
|
||||
@ -173,13 +175,10 @@ class ClientProtocol(Protocol):
|
||||
|
||||
try:
|
||||
s_w_accept = headers["Sec-WebSocket-Accept"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Accept",
|
||||
"more than one Sec-WebSocket-Accept header found",
|
||||
) from exc
|
||||
except KeyError:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept") from None
|
||||
except MultipleValuesError:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from None
|
||||
|
||||
if s_w_accept != accept_key(self.key):
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
|
||||
@ -225,7 +224,7 @@ class ClientProtocol(Protocol):
|
||||
|
||||
if extensions:
|
||||
if self.available_extensions is None:
|
||||
raise InvalidHandshake("no extensions supported")
|
||||
raise NegotiationError("no extensions supported")
|
||||
|
||||
parsed_extensions: list[ExtensionHeader] = sum(
|
||||
[parse_extension(header_value) for header_value in extensions], []
|
||||
@ -280,15 +279,17 @@ class ClientProtocol(Protocol):
|
||||
|
||||
if subprotocols:
|
||||
if self.available_subprotocols is None:
|
||||
raise InvalidHandshake("no subprotocols supported")
|
||||
raise NegotiationError("no subprotocols supported")
|
||||
|
||||
parsed_subprotocols: Sequence[Subprotocol] = sum(
|
||||
[parse_subprotocol(header_value) for header_value in subprotocols], []
|
||||
)
|
||||
|
||||
if len(parsed_subprotocols) > 1:
|
||||
subprotocols_display = ", ".join(parsed_subprotocols)
|
||||
raise InvalidHandshake(f"multiple subprotocols: {subprotocols_display}")
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Protocol",
|
||||
f"multiple values: {', '.join(parsed_subprotocols)}",
|
||||
)
|
||||
|
||||
subprotocol = parsed_subprotocols[0]
|
||||
|
||||
@ -322,6 +323,7 @@ class ClientProtocol(Protocol):
|
||||
)
|
||||
except Exception as exc:
|
||||
self.handshake_exc = exc
|
||||
self.send_eof()
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
yield
|
||||
@ -340,6 +342,7 @@ class ClientProtocol(Protocol):
|
||||
response._exception = exc
|
||||
self.events.append(response)
|
||||
self.handshake_exc = exc
|
||||
self.send_eof()
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
yield
|
||||
@ -353,8 +356,38 @@ class ClientProtocol(Protocol):
|
||||
|
||||
class ClientConnection(ClientProtocol):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
warnings.warn(
|
||||
warnings.warn( # deprecated in 11.0 - 2023-04-02
|
||||
"ClientConnection was renamed to ClientProtocol",
|
||||
DeprecationWarning,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
BACKOFF_INITIAL_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
|
||||
BACKOFF_MIN_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
|
||||
BACKOFF_MAX_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
|
||||
BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
|
||||
|
||||
|
||||
def backoff(
|
||||
initial_delay: float = BACKOFF_INITIAL_DELAY,
|
||||
min_delay: float = BACKOFF_MIN_DELAY,
|
||||
max_delay: float = BACKOFF_MAX_DELAY,
|
||||
factor: float = BACKOFF_FACTOR,
|
||||
) -> Generator[float, None, None]:
|
||||
"""
|
||||
Generate a series of backoff delays between reconnection attempts.
|
||||
|
||||
Yields:
|
||||
How many seconds to wait before retrying to connect.
|
||||
|
||||
"""
|
||||
# Add a random initial delay between 0 and 5 seconds.
|
||||
# See 7.2.3. Recovering from Abnormal Closure in RFC 6455.
|
||||
yield random.random() * initial_delay
|
||||
delay = min_delay
|
||||
while delay < max_delay:
|
||||
yield delay
|
||||
delay *= factor
|
||||
while True:
|
||||
yield max_delay
|
||||
|
@ -5,7 +5,7 @@ import warnings
|
||||
from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401
|
||||
|
||||
|
||||
warnings.warn(
|
||||
warnings.warn( # deprecated in 11.0 - 2023-04-02
|
||||
"websockets.connection was renamed to websockets.protocol "
|
||||
"and Connection was renamed to Protocol",
|
||||
DeprecationWarning,
|
||||
|
@ -17,7 +17,7 @@ __all__ = ["Headers", "HeadersLike", "MultipleValuesError"]
|
||||
|
||||
class MultipleValuesError(LookupError):
|
||||
"""
|
||||
Exception raised when :class:`Headers` has more than one value for a key.
|
||||
Exception raised when :class:`Headers` has multiple values for a key.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -1,67 +1,69 @@
|
||||
"""
|
||||
:mod:`websockets.exceptions` defines the following exception hierarchy:
|
||||
:mod:`websockets.exceptions` defines the following hierarchy of exceptions.
|
||||
|
||||
* :exc:`WebSocketException`
|
||||
* :exc:`ConnectionClosed`
|
||||
* :exc:`ConnectionClosedError`
|
||||
* :exc:`ConnectionClosedOK`
|
||||
* :exc:`ConnectionClosedError`
|
||||
* :exc:`InvalidURI`
|
||||
* :exc:`InvalidHandshake`
|
||||
* :exc:`SecurityError`
|
||||
* :exc:`InvalidMessage`
|
||||
* :exc:`InvalidMessage` (legacy)
|
||||
* :exc:`InvalidStatus`
|
||||
* :exc:`InvalidStatusCode` (legacy)
|
||||
* :exc:`InvalidHeader`
|
||||
* :exc:`InvalidHeaderFormat`
|
||||
* :exc:`InvalidHeaderValue`
|
||||
* :exc:`InvalidOrigin`
|
||||
* :exc:`InvalidUpgrade`
|
||||
* :exc:`InvalidStatus`
|
||||
* :exc:`InvalidStatusCode` (legacy)
|
||||
* :exc:`NegotiationError`
|
||||
* :exc:`DuplicateParameter`
|
||||
* :exc:`InvalidParameterName`
|
||||
* :exc:`InvalidParameterValue`
|
||||
* :exc:`AbortHandshake`
|
||||
* :exc:`RedirectHandshake`
|
||||
* :exc:`InvalidState`
|
||||
* :exc:`InvalidURI`
|
||||
* :exc:`PayloadTooBig`
|
||||
* :exc:`ProtocolError`
|
||||
* :exc:`AbortHandshake` (legacy)
|
||||
* :exc:`RedirectHandshake` (legacy)
|
||||
* :exc:`ProtocolError` (Sans-I/O)
|
||||
* :exc:`PayloadTooBig` (Sans-I/O)
|
||||
* :exc:`InvalidState` (Sans-I/O)
|
||||
* :exc:`ConcurrencyError`
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import http
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
from . import datastructures, frames, http11
|
||||
from .typing import StatusLike
|
||||
from .imports import lazy_import
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WebSocketException",
|
||||
"ConnectionClosed",
|
||||
"ConnectionClosedError",
|
||||
"ConnectionClosedOK",
|
||||
"ConnectionClosedError",
|
||||
"InvalidURI",
|
||||
"InvalidHandshake",
|
||||
"SecurityError",
|
||||
"InvalidMessage",
|
||||
"InvalidStatus",
|
||||
"InvalidStatusCode",
|
||||
"InvalidHeader",
|
||||
"InvalidHeaderFormat",
|
||||
"InvalidHeaderValue",
|
||||
"InvalidOrigin",
|
||||
"InvalidUpgrade",
|
||||
"InvalidStatus",
|
||||
"InvalidStatusCode",
|
||||
"NegotiationError",
|
||||
"DuplicateParameter",
|
||||
"InvalidParameterName",
|
||||
"InvalidParameterValue",
|
||||
"AbortHandshake",
|
||||
"RedirectHandshake",
|
||||
"InvalidState",
|
||||
"InvalidURI",
|
||||
"PayloadTooBig",
|
||||
"ProtocolError",
|
||||
"WebSocketProtocolError",
|
||||
"PayloadTooBig",
|
||||
"InvalidState",
|
||||
"ConcurrencyError",
|
||||
]
|
||||
|
||||
|
||||
@ -77,13 +79,13 @@ class ConnectionClosed(WebSocketException):
|
||||
Raised when trying to interact with a closed connection.
|
||||
|
||||
Attributes:
|
||||
rcvd (Close | None): if a close frame was received, its code and
|
||||
reason are available in ``rcvd.code`` and ``rcvd.reason``.
|
||||
sent (Close | None): if a close frame was sent, its code and reason
|
||||
are available in ``sent.code`` and ``sent.reason``.
|
||||
rcvd_then_sent (bool | None): if close frames were received and
|
||||
sent, this attribute tells in which order this happened, from the
|
||||
perspective of this side of the connection.
|
||||
rcvd: If a close frame was received, its code and reason are available
|
||||
in ``rcvd.code`` and ``rcvd.reason``.
|
||||
sent: If a close frame was sent, its code and reason are available
|
||||
in ``sent.code`` and ``sent.reason``.
|
||||
rcvd_then_sent: If close frames were received and sent, this attribute
|
||||
tells in which order this happened, from the perspective of this
|
||||
side of the connection.
|
||||
|
||||
"""
|
||||
|
||||
@ -96,21 +98,18 @@ class ConnectionClosed(WebSocketException):
|
||||
self.rcvd = rcvd
|
||||
self.sent = sent
|
||||
self.rcvd_then_sent = rcvd_then_sent
|
||||
assert (self.rcvd_then_sent is None) == (self.rcvd is None or self.sent is None)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.rcvd is None:
|
||||
if self.sent is None:
|
||||
assert self.rcvd_then_sent is None
|
||||
return "no close frame received or sent"
|
||||
else:
|
||||
assert self.rcvd_then_sent is None
|
||||
return f"sent {self.sent}; no close frame received"
|
||||
else:
|
||||
if self.sent is None:
|
||||
assert self.rcvd_then_sent is None
|
||||
return f"received {self.rcvd}; no close frame sent"
|
||||
else:
|
||||
assert self.rcvd_then_sent is not None
|
||||
if self.rcvd_then_sent:
|
||||
return f"received {self.rcvd}; then sent {self.sent}"
|
||||
else:
|
||||
@ -120,27 +119,27 @@ class ConnectionClosed(WebSocketException):
|
||||
|
||||
@property
|
||||
def code(self) -> int:
|
||||
warnings.warn( # deprecated in 13.1
|
||||
"ConnectionClosed.code is deprecated; "
|
||||
"use Protocol.close_code or ConnectionClosed.rcvd.code",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if self.rcvd is None:
|
||||
return frames.CloseCode.ABNORMAL_CLOSURE
|
||||
return self.rcvd.code
|
||||
|
||||
@property
|
||||
def reason(self) -> str:
|
||||
warnings.warn( # deprecated in 13.1
|
||||
"ConnectionClosed.reason is deprecated; "
|
||||
"use Protocol.close_reason or ConnectionClosed.rcvd.reason",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if self.rcvd is None:
|
||||
return ""
|
||||
return self.rcvd.reason
|
||||
|
||||
|
||||
class ConnectionClosedError(ConnectionClosed):
|
||||
"""
|
||||
Like :exc:`ConnectionClosed`, when the connection terminated with an error.
|
||||
|
||||
A close frame with a code other than 1000 (OK) or 1001 (going away) was
|
||||
received or sent, or the closing handshake didn't complete properly.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ConnectionClosedOK(ConnectionClosed):
|
||||
"""
|
||||
Like :exc:`ConnectionClosed`, when the connection terminated properly.
|
||||
@ -151,9 +150,33 @@ class ConnectionClosedOK(ConnectionClosed):
|
||||
"""
|
||||
|
||||
|
||||
class ConnectionClosedError(ConnectionClosed):
|
||||
"""
|
||||
Like :exc:`ConnectionClosed`, when the connection terminated with an error.
|
||||
|
||||
A close frame with a code other than 1000 (OK) or 1001 (going away) was
|
||||
received or sent, or the closing handshake didn't complete properly.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidURI(WebSocketException):
|
||||
"""
|
||||
Raised when connecting to a URI that isn't a valid WebSocket URI.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, uri: str, msg: str) -> None:
|
||||
self.uri = uri
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.uri} isn't a valid URI: {self.msg}"
|
||||
|
||||
|
||||
class InvalidHandshake(WebSocketException):
|
||||
"""
|
||||
Raised during the handshake when the WebSocket connection fails.
|
||||
Base class for exceptions raised when the opening handshake fails.
|
||||
|
||||
"""
|
||||
|
||||
@ -162,17 +185,27 @@ class SecurityError(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake request or response breaks a security rule.
|
||||
|
||||
Security limits are hard coded.
|
||||
Security limits can be configured with :doc:`environment variables
|
||||
<../reference/variables>`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidMessage(InvalidHandshake):
|
||||
class InvalidStatus(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake request or response is malformed.
|
||||
Raised when a handshake response rejects the WebSocket upgrade.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, response: http11.Response) -> None:
|
||||
self.response = response
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"server rejected WebSocket connection: "
|
||||
f"HTTP {self.response.status_code:d}"
|
||||
)
|
||||
|
||||
|
||||
class InvalidHeader(InvalidHandshake):
|
||||
"""
|
||||
@ -209,7 +242,7 @@ class InvalidHeaderValue(InvalidHeader):
|
||||
"""
|
||||
Raised when an HTTP header has a wrong value.
|
||||
|
||||
The format of the header is correct but a value isn't acceptable.
|
||||
The format of the header is correct but the value isn't acceptable.
|
||||
|
||||
"""
|
||||
|
||||
@ -231,39 +264,9 @@ class InvalidUpgrade(InvalidHeader):
|
||||
"""
|
||||
|
||||
|
||||
class InvalidStatus(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake response rejects the WebSocket upgrade.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, response: http11.Response) -> None:
|
||||
self.response = response
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"server rejected WebSocket connection: "
|
||||
f"HTTP {self.response.status_code:d}"
|
||||
)
|
||||
|
||||
|
||||
class InvalidStatusCode(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake response status code is invalid.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, headers: datastructures.Headers) -> None:
|
||||
self.status_code = status_code
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"server rejected WebSocket connection: HTTP {self.status_code}"
|
||||
|
||||
|
||||
class NegotiationError(InvalidHandshake):
|
||||
"""
|
||||
Raised when negotiating an extension fails.
|
||||
Raised when negotiating an extension or a subprotocol fails.
|
||||
|
||||
"""
|
||||
|
||||
@ -313,92 +316,77 @@ class InvalidParameterValue(NegotiationError):
|
||||
return f"invalid value for parameter {self.name}: {self.value}"
|
||||
|
||||
|
||||
class AbortHandshake(InvalidHandshake):
|
||||
class ProtocolError(WebSocketException):
|
||||
"""
|
||||
Raised to abort the handshake on purpose and return an HTTP response.
|
||||
Raised when receiving or sending a frame that breaks the protocol.
|
||||
|
||||
This exception is an implementation detail.
|
||||
The Sans-I/O implementation raises this exception when:
|
||||
|
||||
The public API is
|
||||
:meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`.
|
||||
|
||||
Attributes:
|
||||
status (~http.HTTPStatus): HTTP status code.
|
||||
headers (Headers): HTTP response headers.
|
||||
body (bytes): HTTP response body.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status: StatusLike,
|
||||
headers: datastructures.HeadersLike,
|
||||
body: bytes = b"",
|
||||
) -> None:
|
||||
# If a user passes an int instead of a HTTPStatus, fix it automatically.
|
||||
self.status = http.HTTPStatus(status)
|
||||
self.headers = datastructures.Headers(headers)
|
||||
self.body = body
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"HTTP {self.status:d}, "
|
||||
f"{len(self.headers)} headers, "
|
||||
f"{len(self.body)} bytes"
|
||||
)
|
||||
|
||||
|
||||
class RedirectHandshake(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake gets redirected.
|
||||
|
||||
This exception is an implementation detail.
|
||||
* receiving or sending a frame that contains invalid data;
|
||||
* receiving or sending an invalid sequence of frames.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, uri: str) -> None:
|
||||
self.uri = uri
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"redirect to {self.uri}"
|
||||
|
||||
|
||||
class InvalidState(WebSocketException, AssertionError):
|
||||
"""
|
||||
Raised when an operation is forbidden in the current state.
|
||||
|
||||
This exception is an implementation detail.
|
||||
|
||||
It should never be raised in normal circumstances.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidURI(WebSocketException):
|
||||
"""
|
||||
Raised when connecting to a URI that isn't a valid WebSocket URI.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, uri: str, msg: str) -> None:
|
||||
self.uri = uri
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.uri} isn't a valid URI: {self.msg}"
|
||||
|
||||
|
||||
class PayloadTooBig(WebSocketException):
|
||||
"""
|
||||
Raised when receiving a frame with a payload exceeding the maximum size.
|
||||
Raised when parsing a frame with a payload that exceeds the maximum size.
|
||||
|
||||
The Sans-I/O layer uses this exception internally. It doesn't bubble up to
|
||||
the I/O layer.
|
||||
|
||||
The :meth:`~websockets.extensions.Extension.decode` method of extensions
|
||||
must raise :exc:`PayloadTooBig` if decoding a frame would exceed the limit.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ProtocolError(WebSocketException):
|
||||
class InvalidState(WebSocketException, AssertionError):
|
||||
"""
|
||||
Raised when a frame breaks the protocol.
|
||||
Raised when sending a frame is forbidden in the current state.
|
||||
|
||||
Specifically, the Sans-I/O layer raises this exception when:
|
||||
|
||||
* sending a data frame to a connection in a state other
|
||||
:attr:`~websockets.protocol.State.OPEN`;
|
||||
* sending a control frame to a connection in a state other than
|
||||
:attr:`~websockets.protocol.State.OPEN` or
|
||||
:attr:`~websockets.protocol.State.CLOSING`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
WebSocketProtocolError = ProtocolError # for backwards compatibility
|
||||
class ConcurrencyError(WebSocketException, RuntimeError):
|
||||
"""
|
||||
Raised when receiving or sending messages concurrently.
|
||||
|
||||
WebSocket is a connection-oriented protocol. Reads must be serialized; so
|
||||
must be writes. However, reading and writing concurrently is possible.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# When type checking, import non-deprecated aliases eagerly. Else, import on demand.
|
||||
if typing.TYPE_CHECKING:
|
||||
from .legacy.exceptions import (
|
||||
AbortHandshake,
|
||||
InvalidMessage,
|
||||
InvalidStatusCode,
|
||||
RedirectHandshake,
|
||||
)
|
||||
|
||||
WebSocketProtocolError = ProtocolError
|
||||
else:
|
||||
lazy_import(
|
||||
globals(),
|
||||
aliases={
|
||||
"AbortHandshake": ".legacy.exceptions",
|
||||
"InvalidMessage": ".legacy.exceptions",
|
||||
"InvalidStatusCode": ".legacy.exceptions",
|
||||
"RedirectHandshake": ".legacy.exceptions",
|
||||
"WebSocketProtocolError": ".legacy.exceptions",
|
||||
},
|
||||
)
|
||||
|
||||
# At the bottom to break import cycles created by type annotations.
|
||||
from . import frames, http11 # noqa: E402
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from .. import frames
|
||||
from ..frames import Frame
|
||||
from ..typing import ExtensionName, ExtensionParameter
|
||||
|
||||
|
||||
@ -18,12 +18,7 @@ class Extension:
|
||||
name: ExtensionName
|
||||
"""Extension identifier."""
|
||||
|
||||
def decode(
|
||||
self,
|
||||
frame: frames.Frame,
|
||||
*,
|
||||
max_size: int | None = None,
|
||||
) -> frames.Frame:
|
||||
def decode(self, frame: Frame, *, max_size: int | None = None) -> Frame:
|
||||
"""
|
||||
Decode an incoming frame.
|
||||
|
||||
@ -40,7 +35,7 @@ class Extension:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(self, frame: frames.Frame) -> frames.Frame:
|
||||
def encode(self, frame: Frame) -> Frame:
|
||||
"""
|
||||
Encode an outgoing frame.
|
||||
|
||||
|
@ -4,7 +4,15 @@ import dataclasses
|
||||
import zlib
|
||||
from typing import Any, Sequence
|
||||
|
||||
from .. import exceptions, frames
|
||||
from .. import frames
|
||||
from ..exceptions import (
|
||||
DuplicateParameter,
|
||||
InvalidParameterName,
|
||||
InvalidParameterValue,
|
||||
NegotiationError,
|
||||
PayloadTooBig,
|
||||
ProtocolError,
|
||||
)
|
||||
from ..typing import ExtensionName, ExtensionParameter
|
||||
from .base import ClientExtensionFactory, Extension, ServerExtensionFactory
|
||||
|
||||
@ -129,9 +137,9 @@ class PerMessageDeflate(Extension):
|
||||
try:
|
||||
data = self.decoder.decompress(data, max_length)
|
||||
except zlib.error as exc:
|
||||
raise exceptions.ProtocolError("decompression failed") from exc
|
||||
raise ProtocolError("decompression failed") from exc
|
||||
if self.decoder.unconsumed_tail:
|
||||
raise exceptions.PayloadTooBig(f"over size limit (? > {max_size} bytes)")
|
||||
raise PayloadTooBig(f"over size limit (? > {max_size} bytes)")
|
||||
|
||||
# Allow garbage collection of the decoder if it won't be reused.
|
||||
if frame.fin and self.remote_no_context_takeover:
|
||||
@ -215,40 +223,40 @@ def _extract_parameters(
|
||||
for name, value in params:
|
||||
if name == "server_no_context_takeover":
|
||||
if server_no_context_takeover:
|
||||
raise exceptions.DuplicateParameter(name)
|
||||
raise DuplicateParameter(name)
|
||||
if value is None:
|
||||
server_no_context_takeover = True
|
||||
else:
|
||||
raise exceptions.InvalidParameterValue(name, value)
|
||||
raise InvalidParameterValue(name, value)
|
||||
|
||||
elif name == "client_no_context_takeover":
|
||||
if client_no_context_takeover:
|
||||
raise exceptions.DuplicateParameter(name)
|
||||
raise DuplicateParameter(name)
|
||||
if value is None:
|
||||
client_no_context_takeover = True
|
||||
else:
|
||||
raise exceptions.InvalidParameterValue(name, value)
|
||||
raise InvalidParameterValue(name, value)
|
||||
|
||||
elif name == "server_max_window_bits":
|
||||
if server_max_window_bits is not None:
|
||||
raise exceptions.DuplicateParameter(name)
|
||||
raise DuplicateParameter(name)
|
||||
if value in _MAX_WINDOW_BITS_VALUES:
|
||||
server_max_window_bits = int(value)
|
||||
else:
|
||||
raise exceptions.InvalidParameterValue(name, value)
|
||||
raise InvalidParameterValue(name, value)
|
||||
|
||||
elif name == "client_max_window_bits":
|
||||
if client_max_window_bits is not None:
|
||||
raise exceptions.DuplicateParameter(name)
|
||||
raise DuplicateParameter(name)
|
||||
if is_server and value is None: # only in handshake requests
|
||||
client_max_window_bits = True
|
||||
elif value in _MAX_WINDOW_BITS_VALUES:
|
||||
client_max_window_bits = int(value)
|
||||
else:
|
||||
raise exceptions.InvalidParameterValue(name, value)
|
||||
raise InvalidParameterValue(name, value)
|
||||
|
||||
else:
|
||||
raise exceptions.InvalidParameterName(name)
|
||||
raise InvalidParameterName(name)
|
||||
|
||||
return (
|
||||
server_no_context_takeover,
|
||||
@ -340,7 +348,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory):
|
||||
|
||||
"""
|
||||
if any(other.name == self.name for other in accepted_extensions):
|
||||
raise exceptions.NegotiationError(f"received duplicate {self.name}")
|
||||
raise NegotiationError(f"received duplicate {self.name}")
|
||||
|
||||
# Request parameters are available in instance variables.
|
||||
|
||||
@ -366,7 +374,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory):
|
||||
|
||||
if self.server_no_context_takeover:
|
||||
if not server_no_context_takeover:
|
||||
raise exceptions.NegotiationError("expected server_no_context_takeover")
|
||||
raise NegotiationError("expected server_no_context_takeover")
|
||||
|
||||
# client_no_context_takeover
|
||||
#
|
||||
@ -396,9 +404,9 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory):
|
||||
|
||||
else:
|
||||
if server_max_window_bits is None:
|
||||
raise exceptions.NegotiationError("expected server_max_window_bits")
|
||||
raise NegotiationError("expected server_max_window_bits")
|
||||
elif server_max_window_bits > self.server_max_window_bits:
|
||||
raise exceptions.NegotiationError("unsupported server_max_window_bits")
|
||||
raise NegotiationError("unsupported server_max_window_bits")
|
||||
|
||||
# client_max_window_bits
|
||||
|
||||
@ -414,7 +422,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory):
|
||||
|
||||
if self.client_max_window_bits is None:
|
||||
if client_max_window_bits is not None:
|
||||
raise exceptions.NegotiationError("unexpected client_max_window_bits")
|
||||
raise NegotiationError("unexpected client_max_window_bits")
|
||||
|
||||
elif self.client_max_window_bits is True:
|
||||
pass
|
||||
@ -423,7 +431,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory):
|
||||
if client_max_window_bits is None:
|
||||
client_max_window_bits = self.client_max_window_bits
|
||||
elif client_max_window_bits > self.client_max_window_bits:
|
||||
raise exceptions.NegotiationError("unsupported client_max_window_bits")
|
||||
raise NegotiationError("unsupported client_max_window_bits")
|
||||
|
||||
return PerMessageDeflate(
|
||||
server_no_context_takeover, # remote_no_context_takeover
|
||||
@ -534,7 +542,7 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory):
|
||||
|
||||
"""
|
||||
if any(other.name == self.name for other in accepted_extensions):
|
||||
raise exceptions.NegotiationError(f"skipped duplicate {self.name}")
|
||||
raise NegotiationError(f"skipped duplicate {self.name}")
|
||||
|
||||
# Load request parameters in local variables.
|
||||
(
|
||||
@ -613,7 +621,7 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory):
|
||||
else:
|
||||
if client_max_window_bits is None:
|
||||
if self.require_client_max_window_bits:
|
||||
raise exceptions.NegotiationError("required client_max_window_bits")
|
||||
raise NegotiationError("required client_max_window_bits")
|
||||
elif client_max_window_bits is True:
|
||||
client_max_window_bits = self.client_max_window_bits
|
||||
elif self.client_max_window_bits < client_max_window_bits:
|
||||
|
@ -8,8 +8,7 @@ import secrets
|
||||
import struct
|
||||
from typing import Callable, Generator, Sequence
|
||||
|
||||
from . import exceptions, extensions
|
||||
from .typing import Data
|
||||
from .exceptions import PayloadTooBig, ProtocolError
|
||||
|
||||
|
||||
try:
|
||||
@ -29,8 +28,6 @@ __all__ = [
|
||||
"DATA_OPCODES",
|
||||
"CTRL_OPCODES",
|
||||
"Frame",
|
||||
"prepare_data",
|
||||
"prepare_ctrl",
|
||||
"Close",
|
||||
]
|
||||
|
||||
@ -242,10 +239,10 @@ class Frame:
|
||||
try:
|
||||
opcode = Opcode(head1 & 0b00001111)
|
||||
except ValueError as exc:
|
||||
raise exceptions.ProtocolError("invalid opcode") from exc
|
||||
raise ProtocolError("invalid opcode") from exc
|
||||
|
||||
if (True if head2 & 0b10000000 else False) != mask:
|
||||
raise exceptions.ProtocolError("incorrect masking")
|
||||
raise ProtocolError("incorrect masking")
|
||||
|
||||
length = head2 & 0b01111111
|
||||
if length == 126:
|
||||
@ -255,9 +252,7 @@ class Frame:
|
||||
data = yield from read_exact(8)
|
||||
(length,) = struct.unpack("!Q", data)
|
||||
if max_size is not None and length > max_size:
|
||||
raise exceptions.PayloadTooBig(
|
||||
f"over size limit ({length} > {max_size} bytes)"
|
||||
)
|
||||
raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)")
|
||||
if mask:
|
||||
mask_bytes = yield from read_exact(4)
|
||||
|
||||
@ -345,60 +340,13 @@ class Frame:
|
||||
|
||||
"""
|
||||
if self.rsv1 or self.rsv2 or self.rsv3:
|
||||
raise exceptions.ProtocolError("reserved bits must be 0")
|
||||
raise ProtocolError("reserved bits must be 0")
|
||||
|
||||
if self.opcode in CTRL_OPCODES:
|
||||
if len(self.data) > 125:
|
||||
raise exceptions.ProtocolError("control frame too long")
|
||||
raise ProtocolError("control frame too long")
|
||||
if not self.fin:
|
||||
raise exceptions.ProtocolError("fragmented control frame")
|
||||
|
||||
|
||||
def prepare_data(data: Data) -> tuple[int, bytes]:
|
||||
"""
|
||||
Convert a string or byte-like object to an opcode and a bytes-like object.
|
||||
|
||||
This function is designed for data frames.
|
||||
|
||||
If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes`
|
||||
object encoding ``data`` in UTF-8.
|
||||
|
||||
If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like
|
||||
object.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``data`` doesn't have a supported type.
|
||||
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return OP_TEXT, data.encode()
|
||||
elif isinstance(data, BytesLike):
|
||||
return OP_BINARY, data
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
|
||||
def prepare_ctrl(data: Data) -> bytes:
|
||||
"""
|
||||
Convert a string or byte-like object to bytes.
|
||||
|
||||
This function is designed for ping and pong frames.
|
||||
|
||||
If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
|
||||
``data`` in UTF-8.
|
||||
|
||||
If ``data`` is a bytes-like object, return a :class:`bytes` object.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``data`` doesn't have a supported type.
|
||||
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return data.encode()
|
||||
elif isinstance(data, BytesLike):
|
||||
return bytes(data)
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
raise ProtocolError("fragmented control frame")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -455,7 +403,7 @@ class Close:
|
||||
elif len(data) == 0:
|
||||
return cls(CloseCode.NO_STATUS_RCVD, "")
|
||||
else:
|
||||
raise exceptions.ProtocolError("close frame too short")
|
||||
raise ProtocolError("close frame too short")
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""
|
||||
@ -474,4 +422,8 @@ class Close:
|
||||
|
||||
"""
|
||||
if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000):
|
||||
raise exceptions.ProtocolError("invalid status code")
|
||||
raise ProtocolError("invalid status code")
|
||||
|
||||
|
||||
# At the bottom to break import cycles created by type annotations.
|
||||
from . import extensions # noqa: E402
|
||||
|
@ -6,7 +6,7 @@ import ipaddress
|
||||
import re
|
||||
from typing import Callable, Sequence, TypeVar, cast
|
||||
|
||||
from . import exceptions
|
||||
from .exceptions import InvalidHeaderFormat, InvalidHeaderValue
|
||||
from .typing import (
|
||||
ConnectionOption,
|
||||
ExtensionHeader,
|
||||
@ -108,7 +108,7 @@ def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]:
|
||||
"""
|
||||
match = _token_re.match(header, pos)
|
||||
if match is None:
|
||||
raise exceptions.InvalidHeaderFormat(header_name, "expected token", header, pos)
|
||||
raise InvalidHeaderFormat(header_name, "expected token", header, pos)
|
||||
return match.group(), match.end()
|
||||
|
||||
|
||||
@ -132,9 +132,7 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, i
|
||||
"""
|
||||
match = _quoted_string_re.match(header, pos)
|
||||
if match is None:
|
||||
raise exceptions.InvalidHeaderFormat(
|
||||
header_name, "expected quoted string", header, pos
|
||||
)
|
||||
raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos)
|
||||
return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end()
|
||||
|
||||
|
||||
@ -206,9 +204,7 @@ def parse_list(
|
||||
if peek_ahead(header, pos) == ",":
|
||||
pos = parse_OWS(header, pos + 1)
|
||||
else:
|
||||
raise exceptions.InvalidHeaderFormat(
|
||||
header_name, "expected comma", header, pos
|
||||
)
|
||||
raise InvalidHeaderFormat(header_name, "expected comma", header, pos)
|
||||
|
||||
# Remove extra delimiters before the next item.
|
||||
while peek_ahead(header, pos) == ",":
|
||||
@ -276,9 +272,7 @@ def parse_upgrade_protocol(
|
||||
"""
|
||||
match = _protocol_re.match(header, pos)
|
||||
if match is None:
|
||||
raise exceptions.InvalidHeaderFormat(
|
||||
header_name, "expected protocol", header, pos
|
||||
)
|
||||
raise InvalidHeaderFormat(header_name, "expected protocol", header, pos)
|
||||
return cast(UpgradeProtocol, match.group()), match.end()
|
||||
|
||||
|
||||
@ -324,7 +318,7 @@ def parse_extension_item_param(
|
||||
# the value after quoted-string unescaping MUST conform to
|
||||
# the 'token' ABNF.
|
||||
if _token_re.fullmatch(value) is None:
|
||||
raise exceptions.InvalidHeaderFormat(
|
||||
raise InvalidHeaderFormat(
|
||||
header_name, "invalid quoted header content", header, pos_before
|
||||
)
|
||||
else:
|
||||
@ -510,9 +504,7 @@ def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]:
|
||||
"""
|
||||
match = _token68_re.match(header, pos)
|
||||
if match is None:
|
||||
raise exceptions.InvalidHeaderFormat(
|
||||
header_name, "expected token68", header, pos
|
||||
)
|
||||
raise InvalidHeaderFormat(header_name, "expected token68", header, pos)
|
||||
return match.group(), match.end()
|
||||
|
||||
|
||||
@ -522,7 +514,7 @@ def parse_end(header: str, pos: int, header_name: str) -> None:
|
||||
|
||||
"""
|
||||
if pos < len(header):
|
||||
raise exceptions.InvalidHeaderFormat(header_name, "trailing data", header, pos)
|
||||
raise InvalidHeaderFormat(header_name, "trailing data", header, pos)
|
||||
|
||||
|
||||
def parse_authorization_basic(header: str) -> tuple[str, str]:
|
||||
@ -543,12 +535,12 @@ def parse_authorization_basic(header: str) -> tuple[str, str]:
|
||||
# https://datatracker.ietf.org/doc/html/rfc7617#section-2
|
||||
scheme, pos = parse_token(header, 0, "Authorization")
|
||||
if scheme.lower() != "basic":
|
||||
raise exceptions.InvalidHeaderValue(
|
||||
raise InvalidHeaderValue(
|
||||
"Authorization",
|
||||
f"unsupported scheme: {scheme}",
|
||||
)
|
||||
if peek_ahead(header, pos) != " ":
|
||||
raise exceptions.InvalidHeaderFormat(
|
||||
raise InvalidHeaderFormat(
|
||||
"Authorization", "expected space after scheme", header, pos
|
||||
)
|
||||
pos += 1
|
||||
@ -558,14 +550,14 @@ def parse_authorization_basic(header: str) -> tuple[str, str]:
|
||||
try:
|
||||
user_pass = base64.b64decode(basic_credentials.encode()).decode()
|
||||
except binascii.Error:
|
||||
raise exceptions.InvalidHeaderValue(
|
||||
raise InvalidHeaderValue(
|
||||
"Authorization",
|
||||
"expected base64-encoded credentials",
|
||||
) from None
|
||||
try:
|
||||
username, password = user_pass.split(":", 1)
|
||||
except ValueError:
|
||||
raise exceptions.InvalidHeaderValue(
|
||||
raise InvalidHeaderValue(
|
||||
"Authorization",
|
||||
"expected username:password credentials",
|
||||
) from None
|
||||
|
@ -6,7 +6,7 @@ from .datastructures import Headers, MultipleValuesError # noqa: F401
|
||||
from .legacy.http import read_request, read_response # noqa: F401
|
||||
|
||||
|
||||
warnings.warn(
|
||||
warnings.warn( # deprecated in 9.0 - 2021-09-01
|
||||
"Headers and MultipleValuesError were moved "
|
||||
"from websockets.http to websockets.datastructures"
|
||||
"and read_request and read_response were moved "
|
||||
|
@ -7,7 +7,8 @@ import sys
|
||||
import warnings
|
||||
from typing import Callable, Generator
|
||||
|
||||
from . import datastructures, exceptions
|
||||
from .datastructures import Headers
|
||||
from .exceptions import SecurityError
|
||||
from .version import version as websockets_version
|
||||
|
||||
|
||||
@ -79,14 +80,14 @@ class Request:
|
||||
"""
|
||||
|
||||
path: str
|
||||
headers: datastructures.Headers
|
||||
headers: Headers
|
||||
# body isn't useful is the context of this library.
|
||||
|
||||
_exception: Exception | None = None
|
||||
|
||||
@property
|
||||
def exception(self) -> Exception | None: # pragma: no cover
|
||||
warnings.warn(
|
||||
warnings.warn( # deprecated in 10.3 - 2022-04-17
|
||||
"Request.exception is deprecated; "
|
||||
"use ServerProtocol.handshake_exc instead",
|
||||
DeprecationWarning,
|
||||
@ -134,14 +135,15 @@ class Request:
|
||||
raise EOFError("connection closed while reading HTTP request line") from exc
|
||||
|
||||
try:
|
||||
method, raw_path, version = request_line.split(b" ", 2)
|
||||
method, raw_path, protocol = request_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
|
||||
|
||||
if protocol != b"HTTP/1.1":
|
||||
raise ValueError(
|
||||
f"unsupported protocol; expected HTTP/1.1: {d(request_line)}"
|
||||
)
|
||||
if method != b"GET":
|
||||
raise ValueError(f"unsupported HTTP method: {d(method)}")
|
||||
if version != b"HTTP/1.1":
|
||||
raise ValueError(f"unsupported HTTP version: {d(version)}")
|
||||
raise ValueError(f"unsupported HTTP method; expected GET; got {d(method)}")
|
||||
path = raw_path.decode("ascii", "surrogateescape")
|
||||
|
||||
headers = yield from parse_headers(read_line)
|
||||
@ -183,14 +185,14 @@ class Response:
|
||||
|
||||
status_code: int
|
||||
reason_phrase: str
|
||||
headers: datastructures.Headers
|
||||
headers: Headers
|
||||
body: bytes | None = None
|
||||
|
||||
_exception: Exception | None = None
|
||||
|
||||
@property
|
||||
def exception(self) -> Exception | None: # pragma: no cover
|
||||
warnings.warn(
|
||||
warnings.warn( # deprecated in 10.3 - 2022-04-17
|
||||
"Response.exception is deprecated; "
|
||||
"use ClientProtocol.handshake_exc instead",
|
||||
DeprecationWarning,
|
||||
@ -235,23 +237,26 @@ class Response:
|
||||
raise EOFError("connection closed while reading HTTP status line") from exc
|
||||
|
||||
try:
|
||||
version, raw_status_code, raw_reason = status_line.split(b" ", 2)
|
||||
protocol, raw_status_code, raw_reason = status_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
|
||||
|
||||
if version != b"HTTP/1.1":
|
||||
raise ValueError(f"unsupported HTTP version: {d(version)}")
|
||||
if protocol != b"HTTP/1.1":
|
||||
raise ValueError(
|
||||
f"unsupported protocol; expected HTTP/1.1: {d(status_line)}"
|
||||
)
|
||||
try:
|
||||
status_code = int(raw_status_code)
|
||||
except ValueError: # invalid literal for int() with base 10
|
||||
raise ValueError(
|
||||
f"invalid HTTP status code: {d(raw_status_code)}"
|
||||
f"invalid status code; expected integer; got {d(raw_status_code)}"
|
||||
) from None
|
||||
if not 100 <= status_code < 1000:
|
||||
raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}")
|
||||
if not 100 <= status_code < 600:
|
||||
raise ValueError(
|
||||
f"invalid status code; expected 100–599; got {d(raw_status_code)}"
|
||||
)
|
||||
if not _value_re.fullmatch(raw_reason):
|
||||
raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
|
||||
reason = raw_reason.decode()
|
||||
reason = raw_reason.decode("ascii", "surrogateescape")
|
||||
|
||||
headers = yield from parse_headers(read_line)
|
||||
|
||||
@ -280,13 +285,9 @@ class Response:
|
||||
try:
|
||||
body = yield from read_to_eof(MAX_BODY_SIZE)
|
||||
except RuntimeError:
|
||||
raise exceptions.SecurityError(
|
||||
f"body too large: over {MAX_BODY_SIZE} bytes"
|
||||
)
|
||||
raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes")
|
||||
elif content_length > MAX_BODY_SIZE:
|
||||
raise exceptions.SecurityError(
|
||||
f"body too large: {content_length} bytes"
|
||||
)
|
||||
raise SecurityError(f"body too large: {content_length} bytes")
|
||||
else:
|
||||
body = yield from read_exact(content_length)
|
||||
|
||||
@ -308,7 +309,7 @@ class Response:
|
||||
|
||||
def parse_headers(
|
||||
read_line: Callable[[int], Generator[None, None, bytes]],
|
||||
) -> Generator[None, None, datastructures.Headers]:
|
||||
) -> Generator[None, None, Headers]:
|
||||
"""
|
||||
Parse HTTP headers.
|
||||
|
||||
@ -328,7 +329,7 @@ def parse_headers(
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
headers = datastructures.Headers()
|
||||
headers = Headers()
|
||||
for _ in range(MAX_NUM_HEADERS + 1):
|
||||
try:
|
||||
line = yield from parse_line(read_line)
|
||||
@ -352,7 +353,7 @@ def parse_headers(
|
||||
headers[name] = value
|
||||
|
||||
else:
|
||||
raise exceptions.SecurityError("too many HTTP headers")
|
||||
raise SecurityError("too many HTTP headers")
|
||||
|
||||
return headers
|
||||
|
||||
@ -377,7 +378,7 @@ def parse_line(
|
||||
try:
|
||||
line = yield from read_line(MAX_LINE_LENGTH)
|
||||
except RuntimeError:
|
||||
raise exceptions.SecurityError("line too long")
|
||||
raise SecurityError("line too long")
|
||||
# Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5
|
||||
if not line.endswith(b"\r\n"):
|
||||
raise EOFError("line without CRLF")
|
||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import urllib.parse
|
||||
import warnings
|
||||
@ -19,12 +20,9 @@ from typing import (
|
||||
from ..asyncio.compatibility import asyncio_timeout
|
||||
from ..datastructures import Headers, HeadersLike
|
||||
from ..exceptions import (
|
||||
InvalidHandshake,
|
||||
InvalidHeader,
|
||||
InvalidMessage,
|
||||
InvalidStatusCode,
|
||||
InvalidHeaderValue,
|
||||
NegotiationError,
|
||||
RedirectHandshake,
|
||||
SecurityError,
|
||||
)
|
||||
from ..extensions import ClientExtensionFactory, Extension
|
||||
@ -41,6 +39,7 @@ from ..headers import (
|
||||
from ..http11 import USER_AGENT
|
||||
from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol
|
||||
from ..uri import WebSocketURI, parse_uri
|
||||
from .exceptions import InvalidMessage, InvalidStatusCode, RedirectHandshake
|
||||
from .handshake import build_request, check_response
|
||||
from .http import read_response
|
||||
from .protocol import WebSocketCommonProtocol
|
||||
@ -182,7 +181,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol):
|
||||
|
||||
if header_values:
|
||||
if available_extensions is None:
|
||||
raise InvalidHandshake("no extensions supported")
|
||||
raise NegotiationError("no extensions supported")
|
||||
|
||||
parsed_header_values: list[ExtensionHeader] = sum(
|
||||
[parse_extension(header_value) for header_value in header_values], []
|
||||
@ -236,15 +235,17 @@ class WebSocketClientProtocol(WebSocketCommonProtocol):
|
||||
|
||||
if header_values:
|
||||
if available_subprotocols is None:
|
||||
raise InvalidHandshake("no subprotocols supported")
|
||||
raise NegotiationError("no subprotocols supported")
|
||||
|
||||
parsed_header_values: Sequence[Subprotocol] = sum(
|
||||
[parse_subprotocol(header_value) for header_value in header_values], []
|
||||
)
|
||||
|
||||
if len(parsed_header_values) > 1:
|
||||
subprotocols = ", ".join(parsed_header_values)
|
||||
raise InvalidHandshake(f"multiple subprotocols: {subprotocols}")
|
||||
raise InvalidHeaderValue(
|
||||
"Sec-WebSocket-Protocol",
|
||||
f"multiple values: {', '.join(parsed_header_values)}",
|
||||
)
|
||||
|
||||
subprotocol = parsed_header_values[0]
|
||||
|
||||
@ -417,7 +418,7 @@ class Connect:
|
||||
|
||||
"""
|
||||
|
||||
MAX_REDIRECTS_ALLOWED = 10
|
||||
MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -591,13 +592,13 @@ class Connect:
|
||||
|
||||
# async for ... in connect(...):
|
||||
|
||||
BACKOFF_MIN = 1.92
|
||||
BACKOFF_MAX = 60.0
|
||||
BACKOFF_FACTOR = 1.618
|
||||
BACKOFF_INITIAL = 5
|
||||
BACKOFF_INITIAL = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
|
||||
BACKOFF_MIN = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
|
||||
BACKOFF_MAX = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
|
||||
BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]:
|
||||
backoff_delay = self.BACKOFF_MIN
|
||||
backoff_delay = self.BACKOFF_MIN / self.BACKOFF_FACTOR
|
||||
while True:
|
||||
try:
|
||||
async with self as protocol:
|
||||
|
78
site-packages/websockets/legacy/exceptions.py
Normal file
78
site-packages/websockets/legacy/exceptions.py
Normal file
@ -0,0 +1,78 @@
|
||||
import http
|
||||
|
||||
from .. import datastructures
|
||||
from ..exceptions import (
|
||||
InvalidHandshake,
|
||||
ProtocolError as WebSocketProtocolError, # noqa: F401
|
||||
)
|
||||
from ..typing import StatusLike
|
||||
|
||||
|
||||
class InvalidMessage(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake request or response is malformed.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidStatusCode(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake response status code is invalid.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, headers: datastructures.Headers) -> None:
|
||||
self.status_code = status_code
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"server rejected WebSocket connection: HTTP {self.status_code}"
|
||||
|
||||
|
||||
class AbortHandshake(InvalidHandshake):
|
||||
"""
|
||||
Raised to abort the handshake on purpose and return an HTTP response.
|
||||
|
||||
This exception is an implementation detail.
|
||||
|
||||
The public API is
|
||||
:meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`.
|
||||
|
||||
Attributes:
|
||||
status (~http.HTTPStatus): HTTP status code.
|
||||
headers (Headers): HTTP response headers.
|
||||
body (bytes): HTTP response body.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status: StatusLike,
|
||||
headers: datastructures.HeadersLike,
|
||||
body: bytes = b"",
|
||||
) -> None:
|
||||
# If a user passes an int instead of a HTTPStatus, fix it automatically.
|
||||
self.status = http.HTTPStatus(status)
|
||||
self.headers = datastructures.Headers(headers)
|
||||
self.body = body
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"HTTP {self.status:d}, "
|
||||
f"{len(self.headers)} headers, "
|
||||
f"{len(self.body)} bytes"
|
||||
)
|
||||
|
||||
|
||||
class RedirectHandshake(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake gets redirected.
|
||||
|
||||
This exception is an implementation detail.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, uri: str) -> None:
|
||||
self.uri = uri
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"redirect to {self.uri}"
|
@ -5,6 +5,8 @@ from typing import Any, Awaitable, Callable, NamedTuple, Sequence
|
||||
|
||||
from .. import extensions, frames
|
||||
from ..exceptions import PayloadTooBig, ProtocolError
|
||||
from ..frames import BytesLike
|
||||
from ..typing import Data
|
||||
|
||||
|
||||
try:
|
||||
@ -144,12 +146,58 @@ class Frame(NamedTuple):
|
||||
write(self.new_frame.serialize(mask=mask, extensions=extensions))
|
||||
|
||||
|
||||
def prepare_data(data: Data) -> tuple[int, bytes]:
|
||||
"""
|
||||
Convert a string or byte-like object to an opcode and a bytes-like object.
|
||||
|
||||
This function is designed for data frames.
|
||||
|
||||
If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes`
|
||||
object encoding ``data`` in UTF-8.
|
||||
|
||||
If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like
|
||||
object.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``data`` doesn't have a supported type.
|
||||
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return frames.Opcode.TEXT, data.encode()
|
||||
elif isinstance(data, BytesLike):
|
||||
return frames.Opcode.BINARY, data
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
|
||||
def prepare_ctrl(data: Data) -> bytes:
|
||||
"""
|
||||
Convert a string or byte-like object to bytes.
|
||||
|
||||
This function is designed for ping and pong frames.
|
||||
|
||||
If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
|
||||
``data`` in UTF-8.
|
||||
|
||||
If ``data`` is a bytes-like object, return a :class:`bytes` object.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``data`` doesn't have a supported type.
|
||||
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return data.encode()
|
||||
elif isinstance(data, BytesLike):
|
||||
return bytes(data)
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
|
||||
# Backwards compatibility with previously documented public APIs
|
||||
from ..frames import ( # noqa: E402, F401, I001
|
||||
Close,
|
||||
prepare_ctrl as encode_data,
|
||||
prepare_data,
|
||||
)
|
||||
encode_data = prepare_ctrl
|
||||
|
||||
# Backwards compatibility with previously documented public APIs
|
||||
from ..frames import Close # noqa: E402 F401, I001
|
||||
|
||||
|
||||
def parse_close(data: bytes) -> tuple[int, str]:
|
||||
|
@ -76,9 +76,7 @@ def check_request(headers: Headers) -> str:
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Key") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found"
|
||||
) from exc
|
||||
raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc
|
||||
|
||||
try:
|
||||
raw_key = base64.b64decode(s_w_key.encode(), validate=True)
|
||||
@ -92,9 +90,7 @@ def check_request(headers: Headers) -> str:
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Version") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found"
|
||||
) from exc
|
||||
raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc
|
||||
|
||||
if s_w_version != "13":
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)
|
||||
@ -156,9 +152,7 @@ def check_response(headers: Headers, key: str) -> None:
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found"
|
||||
) from exc
|
||||
raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc
|
||||
|
||||
if s_w_accept != accept(key):
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
|
||||
|
@ -45,12 +45,10 @@ from ..frames import (
|
||||
Close,
|
||||
CloseCode,
|
||||
Opcode,
|
||||
prepare_ctrl,
|
||||
prepare_data,
|
||||
)
|
||||
from ..protocol import State
|
||||
from ..typing import Data, LoggerLike, Subprotocol
|
||||
from .framing import Frame
|
||||
from .framing import Frame, prepare_ctrl, prepare_data
|
||||
|
||||
|
||||
__all__ = ["WebSocketCommonProtocol"]
|
||||
|
@ -24,10 +24,8 @@ from typing import (
|
||||
from ..asyncio.compatibility import asyncio_timeout
|
||||
from ..datastructures import Headers, HeadersLike, MultipleValuesError
|
||||
from ..exceptions import (
|
||||
AbortHandshake,
|
||||
InvalidHandshake,
|
||||
InvalidHeader,
|
||||
InvalidMessage,
|
||||
InvalidOrigin,
|
||||
InvalidUpgrade,
|
||||
NegotiationError,
|
||||
@ -43,6 +41,7 @@ from ..headers import (
|
||||
from ..http11 import SERVER
|
||||
from ..protocol import State
|
||||
from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol
|
||||
from .exceptions import AbortHandshake, InvalidMessage
|
||||
from .handshake import build_response, check_request
|
||||
from .http import read_request
|
||||
from .protocol import WebSocketCommonProtocol, broadcast
|
||||
@ -101,9 +100,10 @@ class WebSocketServerProtocol(WebSocketCommonProtocol):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# The version that accepts the path in the second argument is deprecated.
|
||||
ws_handler: (
|
||||
Callable[[WebSocketServerProtocol], Awaitable[Any]]
|
||||
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated
|
||||
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]]
|
||||
),
|
||||
ws_server: WebSocketServer,
|
||||
*,
|
||||
@ -398,7 +398,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol):
|
||||
try:
|
||||
origin = headers.get("Origin")
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Origin", "more than one Origin header found") from exc
|
||||
raise InvalidHeader("Origin", "multiple values") from exc
|
||||
if origin is not None:
|
||||
origin = cast(Origin, origin)
|
||||
if origins is not None:
|
||||
@ -984,9 +984,10 @@ class Serve:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# The version that accepts the path in the second argument is deprecated.
|
||||
ws_handler: (
|
||||
Callable[[WebSocketServerProtocol], Awaitable[Any]]
|
||||
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated
|
||||
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]]
|
||||
),
|
||||
host: str | Sequence[str] | None = None,
|
||||
port: int | None = None,
|
||||
@ -1141,9 +1142,10 @@ serve = Serve
|
||||
|
||||
|
||||
def unix_serve(
|
||||
# The version that accepts the path in the second argument is deprecated.
|
||||
ws_handler: (
|
||||
Callable[[WebSocketServerProtocol], Awaitable[Any]]
|
||||
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated
|
||||
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]]
|
||||
),
|
||||
path: str | None = None,
|
||||
**kwargs: Any,
|
||||
@ -1170,7 +1172,7 @@ def remove_path_argument(
|
||||
ws_handler: (
|
||||
Callable[[WebSocketServerProtocol], Awaitable[Any]]
|
||||
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]]
|
||||
)
|
||||
),
|
||||
) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]:
|
||||
try:
|
||||
inspect.signature(ws_handler).bind(None)
|
||||
|
@ -378,9 +378,11 @@ class Protocol:
|
||||
else:
|
||||
close = Close(code, reason)
|
||||
data = close.serialize()
|
||||
# send_frame() guarantees that self.state is OPEN at this point.
|
||||
# 7.1.3. The WebSocket Closing Handshake is Started
|
||||
self.send_frame(Frame(OP_CLOSE, data))
|
||||
# Since the state is OPEN, no close frame was received yet.
|
||||
# As a consequence, self.close_rcvd_then_sent remains None.
|
||||
assert self.close_rcvd is None
|
||||
self.close_sent = close
|
||||
self.state = CLOSING
|
||||
|
||||
@ -441,6 +443,12 @@ class Protocol:
|
||||
data = close.serialize()
|
||||
self.send_frame(Frame(OP_CLOSE, data))
|
||||
self.close_sent = close
|
||||
# If recv_messages() raised an exception upon receiving a close
|
||||
# frame but before echoing it, then close_rcvd is not None even
|
||||
# though the state is OPEN. This happens when the connection is
|
||||
# closed while receiving a fragmented message.
|
||||
if self.close_rcvd is not None:
|
||||
self.close_rcvd_then_sent = True
|
||||
self.state = CLOSING
|
||||
|
||||
# When failing the connection, a server closes the TCP connection
|
||||
@ -602,18 +610,18 @@ class Protocol:
|
||||
- after sending a close frame, during an abnormal closure (7.1.7).
|
||||
|
||||
"""
|
||||
# The server close the TCP connection in the same circumstances where
|
||||
# discard() replaces parse(). The client closes the connection later,
|
||||
# after the server closes the connection or a timeout elapses.
|
||||
# (The latter case cannot be handled in this Sans-I/O layer.)
|
||||
assert (self.side is SERVER) == (self.eof_sent)
|
||||
# After the opening handshake completes, the server closes the TCP
|
||||
# connection in the same circumstances where discard() replaces parse().
|
||||
# The client closes it when it receives EOF from the server or times
|
||||
# out. (The latter case cannot be handled in this Sans-I/O layer.)
|
||||
assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent)
|
||||
while not (yield from self.reader.at_eof()):
|
||||
self.reader.discard()
|
||||
if self.debug:
|
||||
self.logger.debug("< EOF")
|
||||
# A server closes the TCP connection immediately, while a client
|
||||
# waits for the server to close the TCP connection.
|
||||
if self.side is CLIENT:
|
||||
if self.state != CONNECTING and self.side is CLIENT:
|
||||
self.send_eof()
|
||||
self.state = CLOSED
|
||||
# If discard() completes normally, execution ends here.
|
||||
|
@ -69,7 +69,7 @@ class ServerProtocol(Protocol):
|
||||
max_size: Maximum size of incoming messages in bytes;
|
||||
:obj:`None` disables the limit.
|
||||
logger: Logger for this connection;
|
||||
defaults to ``logging.getLogger("websockets.client")``;
|
||||
defaults to ``logging.getLogger("websockets.server")``;
|
||||
see the :doc:`logging guide <../../topics/logging>` for details.
|
||||
|
||||
"""
|
||||
@ -204,7 +204,6 @@ class ServerProtocol(Protocol):
|
||||
if protocol_header is not None:
|
||||
headers["Sec-WebSocket-Protocol"] = protocol_header
|
||||
|
||||
self.logger.info("connection open")
|
||||
return Response(101, "Switching Protocols", headers)
|
||||
|
||||
def process_request(
|
||||
@ -254,12 +253,10 @@ class ServerProtocol(Protocol):
|
||||
|
||||
try:
|
||||
key = headers["Sec-WebSocket-Key"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Key") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found"
|
||||
) from exc
|
||||
except KeyError:
|
||||
raise InvalidHeader("Sec-WebSocket-Key") from None
|
||||
except MultipleValuesError:
|
||||
raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from None
|
||||
|
||||
try:
|
||||
raw_key = base64.b64decode(key.encode(), validate=True)
|
||||
@ -270,13 +267,10 @@ class ServerProtocol(Protocol):
|
||||
|
||||
try:
|
||||
version = headers["Sec-WebSocket-Version"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Version") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Version",
|
||||
"more than one Sec-WebSocket-Version header found",
|
||||
) from exc
|
||||
except KeyError:
|
||||
raise InvalidHeader("Sec-WebSocket-Version") from None
|
||||
except MultipleValuesError:
|
||||
raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from None
|
||||
|
||||
if version != "13":
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Version", version)
|
||||
@ -314,8 +308,8 @@ class ServerProtocol(Protocol):
|
||||
# per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3.
|
||||
try:
|
||||
origin = headers.get("Origin")
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Origin", "more than one Origin header found") from exc
|
||||
except MultipleValuesError:
|
||||
raise InvalidHeader("Origin", "multiple values") from None
|
||||
if origin is not None:
|
||||
origin = cast(Origin, origin)
|
||||
if self.origins is not None:
|
||||
@ -509,7 +503,7 @@ class ServerProtocol(Protocol):
|
||||
HTTP response to send to the client.
|
||||
|
||||
"""
|
||||
# If a user passes an int instead of a HTTPStatus, fix it automatically.
|
||||
# If status is an int instead of an HTTPStatus, fix it automatically.
|
||||
status = http.HTTPStatus(status)
|
||||
body = text.encode()
|
||||
headers = Headers(
|
||||
@ -520,14 +514,7 @@ class ServerProtocol(Protocol):
|
||||
("Content-Type", "text/plain; charset=utf-8"),
|
||||
]
|
||||
)
|
||||
response = Response(status.value, status.phrase, headers, body)
|
||||
# When reject() is called from accept(), handshake_exc is already set.
|
||||
# If a user calls reject(), set handshake_exc to guarantee invariant:
|
||||
# "handshake_exc is None if and only if opening handshake succeeded."
|
||||
if self.handshake_exc is None:
|
||||
self.handshake_exc = InvalidStatus(response)
|
||||
self.logger.info("connection rejected (%d %s)", status.value, status.phrase)
|
||||
return response
|
||||
return Response(status.value, status.phrase, headers, body)
|
||||
|
||||
def send_response(self, response: Response) -> None:
|
||||
"""
|
||||
@ -550,7 +537,20 @@ class ServerProtocol(Protocol):
|
||||
if response.status_code == 101:
|
||||
assert self.state is CONNECTING
|
||||
self.state = OPEN
|
||||
self.logger.info("connection open")
|
||||
|
||||
else:
|
||||
# handshake_exc may be already set if accept() encountered an error.
|
||||
# If the connection isn't open, set handshake_exc to guarantee that
|
||||
# handshake_exc is None if and only if opening handshake succeeded.
|
||||
if self.handshake_exc is None:
|
||||
self.handshake_exc = InvalidStatus(response)
|
||||
self.logger.info(
|
||||
"connection rejected (%d %s)",
|
||||
response.status_code,
|
||||
response.reason_phrase,
|
||||
)
|
||||
|
||||
self.send_eof()
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
@ -580,7 +580,7 @@ class ServerProtocol(Protocol):
|
||||
|
||||
class ServerConnection(ServerProtocol):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
warnings.warn(
|
||||
warnings.warn( # deprecated in 11.0 - 2023-04-02
|
||||
"ServerConnection was renamed to ServerProtocol",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
@ -19,7 +19,6 @@ _PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ss
|
||||
{
|
||||
// This supports bytes, bytearrays, and memoryview objects,
|
||||
// which are common data structures for handling byte streams.
|
||||
// websockets.framing.prepare_data() returns only these types.
|
||||
// If *tmp isn't NULL, the caller gets a new reference.
|
||||
if (PyBytes_Check(obj))
|
||||
{
|
||||
|
Binary file not shown.
@ -12,7 +12,7 @@ from ..extensions.base import ClientExtensionFactory
|
||||
from ..extensions.permessage_deflate import enable_client_permessage_deflate
|
||||
from ..headers import validate_subprotocols
|
||||
from ..http11 import USER_AGENT, Response
|
||||
from ..protocol import CONNECTING, OPEN, Event
|
||||
from ..protocol import CONNECTING, Event
|
||||
from ..typing import LoggerLike, Origin, Subprotocol
|
||||
from ..uri import parse_uri
|
||||
from .connection import Connection
|
||||
@ -80,19 +80,11 @@ class ClientConnection(Connection):
|
||||
self.protocol.send_request(self.request)
|
||||
|
||||
if not self.response_rcvd.wait(timeout):
|
||||
self.close_socket()
|
||||
self.recv_events_thread.join()
|
||||
raise TimeoutError("timed out during handshake")
|
||||
|
||||
if self.response is None:
|
||||
self.close_socket()
|
||||
self.recv_events_thread.join()
|
||||
raise ConnectionError("connection closed during handshake")
|
||||
|
||||
if self.protocol.state is not OPEN:
|
||||
self.recv_events_thread.join(self.close_timeout)
|
||||
self.close_socket()
|
||||
self.recv_events_thread.join()
|
||||
# self.protocol.handshake_exc is always set when the connection is lost
|
||||
# before receiving a response, when the response cannot be parsed, or
|
||||
# when the response fails the handshake.
|
||||
|
||||
if self.protocol.handshake_exc is not None:
|
||||
raise self.protocol.handshake_exc
|
||||
@ -156,7 +148,9 @@ def connect(
|
||||
|
||||
:func:`connect` may be used as a context manager::
|
||||
|
||||
with websockets.sync.client.connect(...) as websocket:
|
||||
from websockets.sync.client import connect
|
||||
|
||||
with connect(...) as websocket:
|
||||
...
|
||||
|
||||
The connection is closed automatically when exiting the context.
|
||||
@ -210,7 +204,10 @@ def connect(
|
||||
# Backwards compatibility: ssl used to be called ssl_context.
|
||||
if ssl is None and "ssl_context" in kwargs:
|
||||
ssl = kwargs.pop("ssl_context")
|
||||
warnings.warn("ssl_context was renamed to ssl", DeprecationWarning)
|
||||
warnings.warn( # deprecated in 13.0 - 2024-08-20
|
||||
"ssl_context was renamed to ssl",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
wsuri = parse_uri(uri)
|
||||
if not wsuri.secure and ssl is not None:
|
||||
@ -290,16 +287,20 @@ def connect(
|
||||
protocol,
|
||||
close_timeout=close_timeout,
|
||||
)
|
||||
# On failure, handshake() closes the socket and raises an exception.
|
||||
except Exception:
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
raise
|
||||
|
||||
try:
|
||||
connection.handshake(
|
||||
additional_headers,
|
||||
user_agent_header,
|
||||
deadline.timeout(),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
connection.close_socket()
|
||||
connection.recv_events_thread.join()
|
||||
raise
|
||||
|
||||
return connection
|
||||
|
@ -10,8 +10,13 @@ import uuid
|
||||
from types import TracebackType
|
||||
from typing import Any, Iterable, Iterator, Mapping
|
||||
|
||||
from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError
|
||||
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl
|
||||
from ..exceptions import (
|
||||
ConcurrencyError,
|
||||
ConnectionClosed,
|
||||
ConnectionClosedOK,
|
||||
ProtocolError,
|
||||
)
|
||||
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode
|
||||
from ..http11 import Request, Response
|
||||
from ..protocol import CLOSED, OPEN, Event, Protocol, State
|
||||
from ..typing import Data, LoggerLike, Subprotocol
|
||||
@ -82,9 +87,9 @@ class Connection:
|
||||
# Mapping of ping IDs to pong waiters, in chronological order.
|
||||
self.ping_waiters: dict[bytes, threading.Event] = {}
|
||||
|
||||
# Receiving events from the socket. This thread explicitly is marked as
|
||||
# to support creating a connection in a non-daemon thread then using it
|
||||
# in a daemon thread; this shouldn't block the intpreter from exiting.
|
||||
# Receiving events from the socket. This thread is marked as daemon to
|
||||
# allow creating a connection in a non-daemon thread and using it in a
|
||||
# daemon thread. This mustn't prevent the interpreter from exiting.
|
||||
self.recv_events_thread = threading.Thread(
|
||||
target=self.recv_events,
|
||||
daemon=True,
|
||||
@ -194,16 +199,18 @@ class Connection:
|
||||
|
||||
Raises:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
RuntimeError: If two threads call :meth:`recv` or
|
||||
ConcurrencyError: If two threads call :meth:`recv` or
|
||||
:meth:`recv_streaming` concurrently.
|
||||
|
||||
"""
|
||||
try:
|
||||
return self.recv_messages.get(timeout)
|
||||
except EOFError:
|
||||
# Wait for the protocol state to be CLOSED before accessing close_exc.
|
||||
self.recv_events_thread.join()
|
||||
raise self.protocol.close_exc from self.recv_exc
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
except ConcurrencyError:
|
||||
raise ConcurrencyError(
|
||||
"cannot call recv while another thread "
|
||||
"is already running recv or recv_streaming"
|
||||
) from None
|
||||
@ -227,7 +234,7 @@ class Connection:
|
||||
|
||||
Raises:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
RuntimeError: If two threads call :meth:`recv` or
|
||||
ConcurrencyError: If two threads call :meth:`recv` or
|
||||
:meth:`recv_streaming` concurrently.
|
||||
|
||||
"""
|
||||
@ -235,9 +242,11 @@ class Connection:
|
||||
for frame in self.recv_messages.get_iter():
|
||||
yield frame
|
||||
except EOFError:
|
||||
# Wait for the protocol state to be CLOSED before accessing close_exc.
|
||||
self.recv_events_thread.join()
|
||||
raise self.protocol.close_exc from self.recv_exc
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
except ConcurrencyError:
|
||||
raise ConcurrencyError(
|
||||
"cannot call recv_streaming while another thread "
|
||||
"is already running recv or recv_streaming"
|
||||
) from None
|
||||
@ -277,7 +286,7 @@ class Connection:
|
||||
|
||||
Raises:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
RuntimeError: If the connection is sending a fragmented message.
|
||||
ConcurrencyError: If the connection is sending a fragmented message.
|
||||
TypeError: If ``message`` doesn't have a supported type.
|
||||
|
||||
"""
|
||||
@ -287,7 +296,7 @@ class Connection:
|
||||
if isinstance(message, str):
|
||||
with self.send_context():
|
||||
if self.send_in_progress:
|
||||
raise RuntimeError(
|
||||
raise ConcurrencyError(
|
||||
"cannot call send while another thread "
|
||||
"is already running send"
|
||||
)
|
||||
@ -296,7 +305,7 @@ class Connection:
|
||||
elif isinstance(message, BytesLike):
|
||||
with self.send_context():
|
||||
if self.send_in_progress:
|
||||
raise RuntimeError(
|
||||
raise ConcurrencyError(
|
||||
"cannot call send while another thread "
|
||||
"is already running send"
|
||||
)
|
||||
@ -322,7 +331,7 @@ class Connection:
|
||||
text = True
|
||||
with self.send_context():
|
||||
if self.send_in_progress:
|
||||
raise RuntimeError(
|
||||
raise ConcurrencyError(
|
||||
"cannot call send while another thread "
|
||||
"is already running send"
|
||||
)
|
||||
@ -335,7 +344,7 @@ class Connection:
|
||||
text = False
|
||||
with self.send_context():
|
||||
if self.send_in_progress:
|
||||
raise RuntimeError(
|
||||
raise ConcurrencyError(
|
||||
"cannot call send while another thread "
|
||||
"is already running send"
|
||||
)
|
||||
@ -371,7 +380,7 @@ class Connection:
|
||||
self.protocol.send_continuation(b"", fin=True)
|
||||
self.send_in_progress = False
|
||||
|
||||
except RuntimeError:
|
||||
except ConcurrencyError:
|
||||
# We didn't start sending a fragmented message.
|
||||
# The connection is still usable.
|
||||
raise
|
||||
@ -445,17 +454,21 @@ class Connection:
|
||||
|
||||
Raises:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
RuntimeError: If another ping was sent with the same data and
|
||||
ConcurrencyError: If another ping was sent with the same data and
|
||||
the corresponding pong wasn't received yet.
|
||||
|
||||
"""
|
||||
if data is not None:
|
||||
data = prepare_ctrl(data)
|
||||
if isinstance(data, BytesLike):
|
||||
data = bytes(data)
|
||||
elif isinstance(data, str):
|
||||
data = data.encode()
|
||||
elif data is not None:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
with self.send_context():
|
||||
# Protect against duplicates if a payload is explicitly set.
|
||||
if data in self.ping_waiters:
|
||||
raise RuntimeError("already waiting for a pong with the same data")
|
||||
raise ConcurrencyError("already waiting for a pong with the same data")
|
||||
|
||||
# Generate a unique random payload otherwise.
|
||||
while data is None or data in self.ping_waiters:
|
||||
@ -481,7 +494,12 @@ class Connection:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
|
||||
"""
|
||||
data = prepare_ctrl(data)
|
||||
if isinstance(data, BytesLike):
|
||||
data = bytes(data)
|
||||
elif isinstance(data, str):
|
||||
data = data.encode()
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
with self.send_context():
|
||||
self.protocol.send_pong(data)
|
||||
@ -615,8 +633,6 @@ class Connection:
|
||||
self.logger.error("unexpected internal error", exc_info=True)
|
||||
with self.protocol_mutex:
|
||||
self.set_recv_exc(exc)
|
||||
# We don't know where we crashed. Force protocol state to CLOSED.
|
||||
self.protocol.state = CLOSED
|
||||
finally:
|
||||
# This isn't expected to raise an exception.
|
||||
self.close_socket()
|
||||
@ -656,7 +672,7 @@ class Connection:
|
||||
# Let the caller interact with the protocol.
|
||||
try:
|
||||
yield
|
||||
except (ProtocolError, RuntimeError):
|
||||
except (ProtocolError, ConcurrencyError):
|
||||
# The protocol state wasn't changed. Exit immediately.
|
||||
raise
|
||||
except Exception as exc:
|
||||
@ -724,6 +740,7 @@ class Connection:
|
||||
# raise an exception.
|
||||
if raise_close_exc:
|
||||
self.close_socket()
|
||||
# Wait for the protocol state to be CLOSED before accessing close_exc.
|
||||
self.recv_events_thread.join()
|
||||
raise self.protocol.close_exc from original_exc
|
||||
|
||||
@ -774,4 +791,11 @@ class Connection:
|
||||
except OSError:
|
||||
pass # socket is already closed
|
||||
self.socket.close()
|
||||
|
||||
# Calling protocol.receive_eof() is safe because it's idempotent.
|
||||
# This guarantees that the protocol state becomes CLOSED.
|
||||
self.protocol.receive_eof()
|
||||
assert self.protocol.state is CLOSED
|
||||
|
||||
# Abort recv() with a ConnectionClosed exception.
|
||||
self.recv_messages.close()
|
||||
|
@ -5,6 +5,7 @@ import queue
|
||||
import threading
|
||||
from typing import Iterator, cast
|
||||
|
||||
from ..exceptions import ConcurrencyError
|
||||
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
|
||||
from ..typing import Data
|
||||
|
||||
@ -74,7 +75,7 @@ class Assembler:
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream of frames has ended.
|
||||
RuntimeError: If two threads run :meth:`get` or :meth:`get_iter`
|
||||
ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter`
|
||||
concurrently.
|
||||
TimeoutError: If a timeout is provided and elapses before a
|
||||
complete message is received.
|
||||
@ -85,7 +86,7 @@ class Assembler:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
if self.get_in_progress:
|
||||
raise RuntimeError("get() or get_iter() is already running")
|
||||
raise ConcurrencyError("get() or get_iter() is already running")
|
||||
|
||||
self.get_in_progress = True
|
||||
|
||||
@ -128,14 +129,14 @@ class Assembler:
|
||||
:class:`bytes` for each frame in the message.
|
||||
|
||||
The iterator must be fully consumed before calling :meth:`get_iter` or
|
||||
:meth:`get` again. Else, :exc:`RuntimeError` is raised.
|
||||
:meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
|
||||
|
||||
This method only makes sense for fragmented messages. If messages aren't
|
||||
fragmented, use :meth:`get` instead.
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream of frames has ended.
|
||||
RuntimeError: If two threads run :meth:`get` or :meth:`get_iter`
|
||||
ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter`
|
||||
concurrently.
|
||||
|
||||
"""
|
||||
@ -144,7 +145,7 @@ class Assembler:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
if self.get_in_progress:
|
||||
raise RuntimeError("get() or get_iter() is already running")
|
||||
raise ConcurrencyError("get() or get_iter() is already running")
|
||||
|
||||
chunks = self.chunks
|
||||
self.chunks = []
|
||||
@ -198,7 +199,7 @@ class Assembler:
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream of frames has ended.
|
||||
RuntimeError: If two threads run :meth:`put` concurrently.
|
||||
ConcurrencyError: If two threads run :meth:`put` concurrently.
|
||||
|
||||
"""
|
||||
with self.mutex:
|
||||
@ -206,7 +207,7 @@ class Assembler:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
if self.put_in_progress:
|
||||
raise RuntimeError("put is already running")
|
||||
raise ConcurrencyError("put is already running")
|
||||
|
||||
if frame.opcode is OP_TEXT:
|
||||
self.decoder = UTF8Decoder(errors="strict")
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
import http
|
||||
import logging
|
||||
import os
|
||||
@ -10,12 +11,17 @@ import sys
|
||||
import threading
|
||||
import warnings
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, Sequence
|
||||
from typing import Any, Callable, Iterable, Sequence, Tuple, cast
|
||||
|
||||
from ..exceptions import InvalidHeader
|
||||
from ..extensions.base import ServerExtensionFactory
|
||||
from ..extensions.permessage_deflate import enable_server_permessage_deflate
|
||||
from ..frames import CloseCode
|
||||
from ..headers import validate_subprotocols
|
||||
from ..headers import (
|
||||
build_www_authenticate_basic,
|
||||
parse_authorization_basic,
|
||||
validate_subprotocols,
|
||||
)
|
||||
from ..http11 import SERVER, Request, Response
|
||||
from ..protocol import CONNECTING, OPEN, Event
|
||||
from ..server import ServerProtocol
|
||||
@ -24,7 +30,7 @@ from .connection import Connection
|
||||
from .utils import Deadline
|
||||
|
||||
|
||||
__all__ = ["serve", "unix_serve", "ServerConnection", "Server"]
|
||||
__all__ = ["serve", "unix_serve", "ServerConnection", "Server", "basic_auth"]
|
||||
|
||||
|
||||
class ServerConnection(Connection):
|
||||
@ -65,6 +71,7 @@ class ServerConnection(Connection):
|
||||
protocol,
|
||||
close_timeout=close_timeout,
|
||||
)
|
||||
self.username: str # see basic_auth()
|
||||
|
||||
def respond(self, status: StatusLike, text: str) -> Response:
|
||||
"""
|
||||
@ -111,61 +118,57 @@ class ServerConnection(Connection):
|
||||
|
||||
"""
|
||||
if not self.request_rcvd.wait(timeout):
|
||||
self.close_socket()
|
||||
self.recv_events_thread.join()
|
||||
raise TimeoutError("timed out during handshake")
|
||||
|
||||
if self.request is None:
|
||||
self.close_socket()
|
||||
self.recv_events_thread.join()
|
||||
raise ConnectionError("connection closed during handshake")
|
||||
if self.request is not None:
|
||||
with self.send_context(expected_state=CONNECTING):
|
||||
response = None
|
||||
|
||||
with self.send_context(expected_state=CONNECTING):
|
||||
self.response = None
|
||||
if process_request is not None:
|
||||
try:
|
||||
response = process_request(self, self.request)
|
||||
except Exception as exc:
|
||||
self.protocol.handshake_exc = exc
|
||||
response = self.protocol.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
|
||||
if process_request is not None:
|
||||
try:
|
||||
self.response = process_request(self, self.request)
|
||||
except Exception as exc:
|
||||
self.protocol.handshake_exc = exc
|
||||
self.logger.error("opening handshake failed", exc_info=True)
|
||||
self.response = self.protocol.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
|
||||
if self.response is None:
|
||||
self.response = self.protocol.accept(self.request)
|
||||
|
||||
if server_header:
|
||||
self.response.headers["Server"] = server_header
|
||||
|
||||
if process_response is not None:
|
||||
try:
|
||||
response = process_response(self, self.request, self.response)
|
||||
except Exception as exc:
|
||||
self.protocol.handshake_exc = exc
|
||||
self.logger.error("opening handshake failed", exc_info=True)
|
||||
self.response = self.protocol.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
if response is None:
|
||||
self.response = self.protocol.accept(self.request)
|
||||
else:
|
||||
self.response = response
|
||||
|
||||
if server_header:
|
||||
self.response.headers["Server"] = server_header
|
||||
|
||||
response = None
|
||||
|
||||
if process_response is not None:
|
||||
try:
|
||||
response = process_response(self, self.request, self.response)
|
||||
except Exception as exc:
|
||||
self.protocol.handshake_exc = exc
|
||||
response = self.protocol.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
|
||||
if response is not None:
|
||||
self.response = response
|
||||
|
||||
self.protocol.send_response(self.response)
|
||||
self.protocol.send_response(self.response)
|
||||
|
||||
if self.protocol.state is not OPEN:
|
||||
self.recv_events_thread.join(self.close_timeout)
|
||||
self.close_socket()
|
||||
self.recv_events_thread.join()
|
||||
# self.protocol.handshake_exc is always set when the connection is lost
|
||||
# before receiving a request, when the request cannot be parsed, when
|
||||
# the handshake encounters an error, or when process_request or
|
||||
# process_response sends an HTTP response that rejects the handshake.
|
||||
|
||||
if self.protocol.handshake_exc is not None:
|
||||
raise self.protocol.handshake_exc
|
||||
@ -297,7 +300,7 @@ class Server:
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "WebSocketServer":
|
||||
warnings.warn(
|
||||
warnings.warn( # deprecated in 13.0 - 2024-08-20
|
||||
"WebSocketServer was renamed to Server",
|
||||
DeprecationWarning,
|
||||
)
|
||||
@ -368,10 +371,12 @@ def serve(
|
||||
that it will be closed and call :meth:`~Server.serve_forever` to serve
|
||||
requests::
|
||||
|
||||
from websockets.sync.server import serve
|
||||
|
||||
def handler(websocket):
|
||||
...
|
||||
|
||||
with websockets.sync.server.serve(handler, ...) as server:
|
||||
with serve(handler, ...) as server:
|
||||
server.serve_forever()
|
||||
|
||||
Args:
|
||||
@ -437,7 +442,10 @@ def serve(
|
||||
# Backwards compatibility: ssl used to be called ssl_context.
|
||||
if ssl is None and "ssl_context" in kwargs:
|
||||
ssl = kwargs.pop("ssl_context")
|
||||
warnings.warn("ssl_context was renamed to ssl", DeprecationWarning)
|
||||
warnings.warn( # deprecated in 13.0 - 2024-08-20
|
||||
"ssl_context was renamed to ssl",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if subprotocols is not None:
|
||||
validate_subprotocols(subprotocols)
|
||||
@ -540,26 +548,40 @@ def serve(
|
||||
protocol,
|
||||
close_timeout=close_timeout,
|
||||
)
|
||||
# On failure, handshake() closes the socket, raises an exception, and
|
||||
# logs it.
|
||||
connection.handshake(
|
||||
process_request,
|
||||
process_response,
|
||||
server_header,
|
||||
deadline.timeout(),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
sock.close()
|
||||
return
|
||||
|
||||
try:
|
||||
handler(connection)
|
||||
except Exception:
|
||||
protocol.logger.error("connection handler failed", exc_info=True)
|
||||
connection.close(CloseCode.INTERNAL_ERROR)
|
||||
else:
|
||||
connection.close()
|
||||
try:
|
||||
connection.handshake(
|
||||
process_request,
|
||||
process_response,
|
||||
server_header,
|
||||
deadline.timeout(),
|
||||
)
|
||||
except TimeoutError:
|
||||
connection.close_socket()
|
||||
connection.recv_events_thread.join()
|
||||
return
|
||||
except Exception:
|
||||
connection.logger.error("opening handshake failed", exc_info=True)
|
||||
connection.close_socket()
|
||||
connection.recv_events_thread.join()
|
||||
return
|
||||
|
||||
assert connection.protocol.state is OPEN
|
||||
try:
|
||||
handler(connection)
|
||||
except Exception:
|
||||
connection.logger.error("connection handler failed", exc_info=True)
|
||||
connection.close(CloseCode.INTERNAL_ERROR)
|
||||
else:
|
||||
connection.close()
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
# Don't leak sockets on unexpected errors.
|
||||
sock.close()
|
||||
|
||||
# Initialize server
|
||||
|
||||
@ -587,3 +609,119 @@ def unix_serve(
|
||||
|
||||
"""
|
||||
return serve(handler, unix=True, path=path, **kwargs)
|
||||
|
||||
|
||||
def is_credentials(credentials: Any) -> bool:
|
||||
try:
|
||||
username, password = credentials
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
else:
|
||||
return isinstance(username, str) and isinstance(password, str)
|
||||
|
||||
|
||||
def basic_auth(
|
||||
realm: str = "",
|
||||
credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None,
|
||||
check_credentials: Callable[[str, str], bool] | None = None,
|
||||
) -> Callable[[ServerConnection, Request], Response | None]:
|
||||
"""
|
||||
Factory for ``process_request`` to enforce HTTP Basic Authentication.
|
||||
|
||||
:func:`basic_auth` is designed to integrate with :func:`serve` as follows::
|
||||
|
||||
from websockets.sync.server import basic_auth, serve
|
||||
|
||||
with serve(
|
||||
...,
|
||||
process_request=basic_auth(
|
||||
realm="my dev server",
|
||||
credentials=("hello", "iloveyou"),
|
||||
),
|
||||
):
|
||||
|
||||
If authentication succeeds, the connection's ``username`` attribute is set.
|
||||
If it fails, the server responds with an HTTP 401 Unauthorized status.
|
||||
|
||||
One of ``credentials`` or ``check_credentials`` must be provided; not both.
|
||||
|
||||
Args:
|
||||
realm: Scope of protection. It should contain only ASCII characters
|
||||
because the encoding of non-ASCII characters is undefined. Refer to
|
||||
section 2.2 of :rfc:`7235` for details.
|
||||
credentials: Hard coded authorized credentials. It can be a
|
||||
``(username, password)`` pair or a list of such pairs.
|
||||
check_credentials: Function that verifies credentials.
|
||||
It receives ``username`` and ``password`` arguments and returns
|
||||
whether they're valid.
|
||||
Raises:
|
||||
TypeError: If ``credentials`` or ``check_credentials`` is wrong.
|
||||
|
||||
"""
|
||||
if (credentials is None) == (check_credentials is None):
|
||||
raise TypeError("provide either credentials or check_credentials")
|
||||
|
||||
if credentials is not None:
|
||||
if is_credentials(credentials):
|
||||
credentials_list = [cast(Tuple[str, str], credentials)]
|
||||
elif isinstance(credentials, Iterable):
|
||||
credentials_list = list(cast(Iterable[Tuple[str, str]], credentials))
|
||||
if not all(is_credentials(item) for item in credentials_list):
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
else:
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
|
||||
credentials_dict = dict(credentials_list)
|
||||
|
||||
def check_credentials(username: str, password: str) -> bool:
|
||||
try:
|
||||
expected_password = credentials_dict[username]
|
||||
except KeyError:
|
||||
return False
|
||||
return hmac.compare_digest(expected_password, password)
|
||||
|
||||
assert check_credentials is not None # help mypy
|
||||
|
||||
def process_request(
|
||||
connection: ServerConnection,
|
||||
request: Request,
|
||||
) -> Response | None:
|
||||
"""
|
||||
Perform HTTP Basic Authentication.
|
||||
|
||||
If it succeeds, set the connection's ``username`` attribute and return
|
||||
:obj:`None`. If it fails, return an HTTP 401 Unauthorized responss.
|
||||
|
||||
"""
|
||||
try:
|
||||
authorization = request.headers["Authorization"]
|
||||
except KeyError:
|
||||
response = connection.respond(
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
"Missing credentials\n",
|
||||
)
|
||||
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
|
||||
return response
|
||||
|
||||
try:
|
||||
username, password = parse_authorization_basic(authorization)
|
||||
except InvalidHeader:
|
||||
response = connection.respond(
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
"Unsupported credentials\n",
|
||||
)
|
||||
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
|
||||
return response
|
||||
|
||||
if not check_credentials(username, password):
|
||||
response = connection.respond(
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
"Invalid credentials\n",
|
||||
)
|
||||
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
|
||||
return response
|
||||
|
||||
connection.username = username
|
||||
return None
|
||||
|
||||
return process_request
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import urllib.parse
|
||||
|
||||
from . import exceptions
|
||||
from .exceptions import InvalidURI
|
||||
|
||||
|
||||
__all__ = ["parse_uri", "WebSocketURI"]
|
||||
@ -73,11 +73,11 @@ def parse_uri(uri: str) -> WebSocketURI:
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
if parsed.scheme not in ["ws", "wss"]:
|
||||
raise exceptions.InvalidURI(uri, "scheme isn't ws or wss")
|
||||
raise InvalidURI(uri, "scheme isn't ws or wss")
|
||||
if parsed.hostname is None:
|
||||
raise exceptions.InvalidURI(uri, "hostname isn't provided")
|
||||
raise InvalidURI(uri, "hostname isn't provided")
|
||||
if parsed.fragment != "":
|
||||
raise exceptions.InvalidURI(uri, "fragment identifier is meaningless")
|
||||
raise InvalidURI(uri, "fragment identifier is meaningless")
|
||||
|
||||
secure = parsed.scheme == "wss"
|
||||
host = parsed.hostname
|
||||
@ -89,7 +89,7 @@ def parse_uri(uri: str) -> WebSocketURI:
|
||||
# urllib.parse.urlparse accepts URLs with a username but without a
|
||||
# password. This doesn't make sense for HTTP Basic Auth credentials.
|
||||
if username is not None and password is None:
|
||||
raise exceptions.InvalidURI(uri, "username provided without password")
|
||||
raise InvalidURI(uri, "username provided without password")
|
||||
|
||||
try:
|
||||
uri.encode("ascii")
|
||||
|
@ -20,7 +20,7 @@ __all__ = ["tag", "version", "commit"]
|
||||
|
||||
released = True
|
||||
|
||||
tag = version = commit = "13.0.1"
|
||||
tag = version = commit = "13.1"
|
||||
|
||||
|
||||
if not released: # pragma: no cover
|
||||
|
@ -1,16 +0,0 @@
|
||||
zipp-3.20.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
zipp-3.20.1.dist-info/LICENSE,sha256=htoPAa6uRjSKPD1GUZXcHOzN55956HdppkuNoEsqR0E,1023
|
||||
zipp-3.20.1.dist-info/METADATA,sha256=J3sSw1smScyUJdU4tZp-6S5SoFY5hi8Wk1WebRFWcWA,3682
|
||||
zipp-3.20.1.dist-info/RECORD,,
|
||||
zipp-3.20.1.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
|
||||
zipp-3.20.1.dist-info/top_level.txt,sha256=iAbdoSHfaGqBfVb2XuR9JqSQHCoOsOtG6y9C_LSpqFw,5
|
||||
zipp/__init__.py,sha256=hFKawMr3tL33PV0OuYkg67V3cRvuvpOdj_zwvZCXmCk,11834
|
||||
zipp/__pycache__/__init__.cpython-38.pyc,,
|
||||
zipp/__pycache__/glob.cpython-38.pyc,,
|
||||
zipp/compat/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
zipp/compat/__pycache__/__init__.cpython-38.pyc,,
|
||||
zipp/compat/__pycache__/overlay.cpython-38.pyc,,
|
||||
zipp/compat/__pycache__/py310.cpython-38.pyc,,
|
||||
zipp/compat/overlay.py,sha256=B1lcbC9TZVA2terZQTjC8A1xfSCaah4oHNMejJZVwyU,295
|
||||
zipp/compat/py310.py,sha256=KS3sidGTSkoGh3biXiCqRzE6RMEGH0sbRQBevWU73dU,256
|
||||
zipp/glob.py,sha256=yPjGfHwcJxUn0fld7I-K-ZQSfTaJBBoimCIygU1SZQw,3315
|
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: zipp
|
||||
Version: 3.20.1
|
||||
Version: 3.20.2
|
||||
Summary: Backport of pathlib-compatible object wrapper for zip files
|
||||
Author-email: "Jason R. Coombs" <jaraco@jaraco.com>
|
||||
Project-URL: Source, https://github.com/jaraco/zipp
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user