Sync: devchat[main](ef593429) Merge pull request #414 from devchat-ai/fix-not-found-custom_git_urls

This commit is contained in:
Sync-Packages Action 2024-09-23 01:55:21 +00:00
parent 6ea48f76d1
commit f16f4ec2d5
105 changed files with 14696 additions and 13466 deletions

View File

@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: anyio
Version: 4.4.0
Version: 4.5.0
Summary: High level compatibility layer for multiple asynchronous event loop implementations
Author-email: Alex Grönholm <alex.gronholm@nextday.fi>
License: MIT
@ -20,6 +20,7 @@ Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Requires-Python: >=3.8
Description-Content-Type: text/x-rst
License-File: LICENSE
@ -29,7 +30,7 @@ Requires-Dist: exceptiongroup >=1.0.2 ; python_version < "3.11"
Requires-Dist: typing-extensions >=4.1 ; python_version < "3.11"
Provides-Extra: doc
Requires-Dist: packaging ; extra == 'doc'
Requires-Dist: Sphinx >=7 ; extra == 'doc'
Requires-Dist: Sphinx ~=7.4 ; extra == 'doc'
Requires-Dist: sphinx-rtd-theme ; extra == 'doc'
Requires-Dist: sphinx-autodoc-typehints >=1.2.0 ; extra == 'doc'
Provides-Extra: test
@ -41,9 +42,9 @@ Requires-Dist: psutil >=5.9 ; extra == 'test'
Requires-Dist: pytest >=7.0 ; extra == 'test'
Requires-Dist: pytest-mock >=3.6.1 ; extra == 'test'
Requires-Dist: trustme ; extra == 'test'
Requires-Dist: uvloop >=0.17 ; (platform_python_implementation == "CPython" and platform_system != "Windows") and extra == 'test'
Requires-Dist: uvloop >=0.21.0b1 ; (platform_python_implementation == "CPython" and platform_system != "Windows") and extra == 'test'
Provides-Extra: trio
Requires-Dist: trio >=0.23 ; extra == 'trio'
Requires-Dist: trio >=0.26.1 ; extra == 'trio'
.. image:: https://github.com/agronholm/anyio/actions/workflows/test.yml/badge.svg
:target: https://github.com/agronholm/anyio/actions/workflows/test.yml

View File

@ -1,11 +1,11 @@
anyio-4.4.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
anyio-4.4.0.dist-info/LICENSE,sha256=U2GsncWPLvX9LpsJxoKXwX8ElQkJu8gCO9uC6s8iwrA,1081
anyio-4.4.0.dist-info/METADATA,sha256=sbJaOJ_Ilka4D0U6yKtfOtVrYef7XRFzGjoBEZnRpes,4599
anyio-4.4.0.dist-info/RECORD,,
anyio-4.4.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
anyio-4.4.0.dist-info/entry_points.txt,sha256=_d6Yu6uiaZmNe0CydowirE9Cmg7zUL2g08tQpoS3Qvc,39
anyio-4.4.0.dist-info/top_level.txt,sha256=QglSMiWX8_5dpoVAEIHdEYzvqFMdSYWmCj6tYw2ITkQ,6
anyio/__init__.py,sha256=CxUxIHOIONI3KpsDLCg-dI6lQaDkW_4Zhtu5jWt1XO8,4344
anyio-4.5.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
anyio-4.5.0.dist-info/LICENSE,sha256=U2GsncWPLvX9LpsJxoKXwX8ElQkJu8gCO9uC6s8iwrA,1081
anyio-4.5.0.dist-info/METADATA,sha256=gveTB0gvT7MwEWHWcT9CUmYzS4qgywQPW4gweoJsESs,4658
anyio-4.5.0.dist-info/RECORD,,
anyio-4.5.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
anyio-4.5.0.dist-info/entry_points.txt,sha256=_d6Yu6uiaZmNe0CydowirE9Cmg7zUL2g08tQpoS3Qvc,39
anyio-4.5.0.dist-info/top_level.txt,sha256=QglSMiWX8_5dpoVAEIHdEYzvqFMdSYWmCj6tYw2ITkQ,6
anyio/__init__.py,sha256=myTIdg75VPwA-9L7BpislRQplJUPMeleUBHa4MyIruw,4315
anyio/__pycache__/__init__.cpython-38.pyc,,
anyio/__pycache__/from_thread.cpython-38.pyc,,
anyio/__pycache__/lowlevel.cpython-38.pyc,,
@ -16,8 +16,8 @@ anyio/_backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
anyio/_backends/__pycache__/__init__.cpython-38.pyc,,
anyio/_backends/__pycache__/_asyncio.cpython-38.pyc,,
anyio/_backends/__pycache__/_trio.cpython-38.pyc,,
anyio/_backends/_asyncio.py,sha256=CVy87WpTh1URmEjlE-AKTrBoPwsqH_nRxbGkLnTrfeg,83244
anyio/_backends/_trio.py,sha256=8gdA930WJFn4xrMNpMH6LrHC9MeyTdZBHsB_W-8HFBw,35909
anyio/_backends/_asyncio.py,sha256=eu1H_onDHe_Oa6_wiHxsQnOmEnxme0F9Dlmf7ykbkT8,88344
anyio/_backends/_trio.py,sha256=L2XiIaDruGztveS8SBPMJMjHHEnZDjXHhsbmu00II6Q,39735
anyio/_core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
anyio/_core/__pycache__/__init__.cpython-38.pyc,,
anyio/_core/__pycache__/_eventloop.cpython-38.pyc,,
@ -33,14 +33,14 @@ anyio/_core/__pycache__/_tasks.cpython-38.pyc,,
anyio/_core/__pycache__/_testing.cpython-38.pyc,,
anyio/_core/__pycache__/_typedattr.cpython-38.pyc,,
anyio/_core/_eventloop.py,sha256=t_tAwBFPjF8jrZGjlJ6bbYy6KA3bjsbZxV9mvh9t1i0,4695
anyio/_core/_exceptions.py,sha256=wUmhDu80qEB7z9EdCqUwVEhNUlNEok4_W2-rC6sCAUQ,2078
anyio/_core/_fileio.py,sha256=fC6H6DcueA-2AUaDkP91kmVeqoU7DlWj6O0CCAdnsdM,19456
anyio/_core/_exceptions.py,sha256=NPxECdXkG4nk3NOCUeFmBEAgPhmj7Bzs4vFAKaW_vqw,2481
anyio/_core/_fileio.py,sha256=a4Ab0y8UU0vkYYExM_7pLfBcjeqY0GM-4AyQ4EaRn6s,20067
anyio/_core/_resources.py,sha256=NbmU5O5UX3xEyACnkmYX28Fmwdl-f-ny0tHym26e0w0,435
anyio/_core/_signals.py,sha256=rDOVxtugZDgC5AhfW3lrwsre2n9Pj_adoRUidBiF6dA,878
anyio/_core/_sockets.py,sha256=2jOzi4bXQQYTLr9PSrCaeTgwaU_N7mt0yjHGmX4LvA8,24028
anyio/_core/_sockets.py,sha256=iM3UeMU68n0PlQjl2U9HyiOpV26rnjqV4KBr_Fo2z1I,24293
anyio/_core/_streams.py,sha256=Z8ZlTY6xom5EszrMsgCT3TphiT4JIlQG-y33CrD0NQY,1811
anyio/_core/_subprocesses.py,sha256=ZLLNXAtlRGfbyC4sOIltYB1k3NJa3tqk_x_Fsnbcs1M,5272
anyio/_core/_synchronization.py,sha256=h3o6dWWbzVrcNmi7i2mQjEgRtnIxkGtjmYK7KMpdlaE,18444
anyio/_core/_subprocesses.py,sha256=DysVq3ZEKayWtjKF3gtyteR_q-Q-K7CmpYTjhIEV_CQ,8314
anyio/_core/_synchronization.py,sha256=UDsbG5f8jWsWkRxYUOKp_WOBWCI9-vBO6wBrsR6WNjA,20121
anyio/_core/_tasks.py,sha256=pvVEX2Fw159sf0ypAPerukKsZgRRwvFFedVW52nR2Vk,4764
anyio/_core/_testing.py,sha256=YUGwA5cgFFbUTv4WFd7cv_BSVr4ryTtPp8owQA3JdWE,2118
anyio/_core/_typedattr.py,sha256=P4ozZikn3-DbpoYcvyghS_FOYAgbmUxeoU8-L_07pZM,2508
@ -53,17 +53,17 @@ anyio/abc/__pycache__/_streams.cpython-38.pyc,,
anyio/abc/__pycache__/_subprocesses.cpython-38.pyc,,
anyio/abc/__pycache__/_tasks.cpython-38.pyc,,
anyio/abc/__pycache__/_testing.cpython-38.pyc,,
anyio/abc/_eventloop.py,sha256=r9pldSu6p-ZsvO6D_brc0EIi1JgZRDbfgVLuy7Q7R6o,10085
anyio/abc/_eventloop.py,sha256=L0axMx7JdbH5efXSbkMaCYp0fMdk-zeWm-u4jx_nFW4,9593
anyio/abc/_resources.py,sha256=DrYvkNN1hH6Uvv5_5uKySvDsnknGVDe8FCKfko0VtN8,783
anyio/abc/_sockets.py,sha256=XdZ42TQ1omZN9Ec3HUfTMWG_i-21yMjXQ_FFslAZtzQ,6269
anyio/abc/_streams.py,sha256=GzST5Q2zQmxVzdrAqtbSyHNxkPlIC9AzeZJg_YyPAXw,6598
anyio/abc/_subprocesses.py,sha256=cumAPJTktOQtw63IqG0lDpyZqu_l1EElvQHMiwJgL08,2067
anyio/abc/_tasks.py,sha256=0Jc6oIwUjMIVReehF6knOZyAqlgwDt4TP1NQkx4IQGw,2731
anyio/abc/_testing.py,sha256=tBJUzkSfOXJw23fe8qSJ03kJlShOYjjaEyFB6k6MYT8,1821
anyio/from_thread.py,sha256=HtgJ7yZ6RLfRe0l0yyyhpg2mnoax0mqXXgKv8TORUlA,17700
anyio/from_thread.py,sha256=2tEb3LZeqVLl-WyLm9siIej4mKwcKK5zQIZEXyafC1A,17439
anyio/lowlevel.py,sha256=nkgmW--SdxGVp0cmLUYazjkigveRm5HY7-gW8Bpp9oY,4169
anyio/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
anyio/pytest_plugin.py,sha256=TBgRAfT-Oxy6efhO1Tziq54NND3Jy4dRmwkMmQXSvhI,5386
anyio/pytest_plugin.py,sha256=le89r6YzzM85skagv04dzDJiuPvSbxJVnwkoJZwaoY8,5852
anyio/streams/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
anyio/streams/__pycache__/__init__.cpython-38.pyc,,
anyio/streams/__pycache__/buffered.cpython-38.pyc,,
@ -74,9 +74,9 @@ anyio/streams/__pycache__/text.cpython-38.pyc,,
anyio/streams/__pycache__/tls.cpython-38.pyc,,
anyio/streams/buffered.py,sha256=UCldKC168YuLvT7n3HtNPnQ2iWAMSTYQWbZvzLwMwkM,4500
anyio/streams/file.py,sha256=6uoTNb5KbMoj-6gS3_xrrL8uZN8Q4iIvOS1WtGyFfKw,4383
anyio/streams/memory.py,sha256=Y286x16omNSSGONQx5CBLLNiB3vAJb_vVKt5vb3go-Q,10190
anyio/streams/memory.py,sha256=j8AyOExK4-UPaon_Xbhwax25Vqs0DwFg3ZXc-EIiHjY,10550
anyio/streams/stapled.py,sha256=U09pCrmOw9kkNhe6tKopsm1QIMT1lFTFvtb-A7SIe4k,4302
anyio/streams/text.py,sha256=6x8w8xlfCZKTUWQoJiMPoMhSSJFUBRKgoBNSBtbd9yg,5094
anyio/streams/tls.py,sha256=ev-6yNOGcIkziIkcIfKj8VmLqQJW-iDBJttaKgKDsF4,12752
anyio/to_process.py,sha256=lx_bt0CUJsS1eSlraw662OpCjRgGXowoyf1Q-i-kOxo,9535
anyio/to_process.py,sha256=cR4n7TssbbJowE_9cWme49zaeuoBuMzqgZ6cBIs0YIs,9571
anyio/to_thread.py,sha256=WM2JQ2MbVsd5D5CM08bQiTwzZIvpsGjfH1Fy247KoDQ,2396

View File

@ -1,5 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (73.0.1)
Generator: setuptools (75.1.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@ -1,7 +1,5 @@
from __future__ import annotations
from typing import Any
from ._core._eventloop import current_time as current_time
from ._core._eventloop import get_all_backends as get_all_backends
from ._core._eventloop import get_cancelled_exc_class as get_cancelled_exc_class
@ -69,8 +67,8 @@ from ._core._typedattr import TypedAttributeSet as TypedAttributeSet
from ._core._typedattr import typed_attribute as typed_attribute
# Re-export imports so they look like they live directly in this package
key: str
value: Any
for key, value in list(locals().items()):
if getattr(value, "__module__", "").startswith("anyio."):
value.__module__ = __name__
for __value in list(locals().values()):
if getattr(__value, "__module__", "").startswith("anyio."):
__value.__module__ = __name__
del __value

View File

@ -4,6 +4,7 @@ import array
import asyncio
import concurrent.futures
import math
import os
import socket
import sys
import threading
@ -19,7 +20,7 @@ from asyncio import (
)
from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined]
from collections import OrderedDict, deque
from collections.abc import AsyncIterator, Generator, Iterable
from collections.abc import AsyncIterator, Iterable
from concurrent.futures import Future
from contextlib import suppress
from contextvars import Context, copy_context
@ -47,7 +48,6 @@ from typing import (
Collection,
ContextManager,
Coroutine,
Mapping,
Optional,
Sequence,
Tuple,
@ -58,7 +58,13 @@ from weakref import WeakKeyDictionary
import sniffio
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
from .. import (
CapacityLimiterStatistics,
EventStatistics,
LockStatistics,
TaskInfo,
abc,
)
from .._core._eventloop import claim_worker_thread, threadlocals
from .._core._exceptions import (
BrokenResourceError,
@ -66,12 +72,20 @@ from .._core._exceptions import (
ClosedResourceError,
EndOfStream,
WouldBlock,
iterate_exceptions,
)
from .._core._sockets import convert_ipv6_sockaddr
from .._core._streams import create_memory_object_stream
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
from .._core._synchronization import (
CapacityLimiter as BaseCapacityLimiter,
)
from .._core._synchronization import Event as BaseEvent
from .._core._synchronization import ResourceGuard
from .._core._synchronization import Lock as BaseLock
from .._core._synchronization import (
ResourceGuard,
SemaphoreStatistics,
)
from .._core._synchronization import Semaphore as BaseSemaphore
from .._core._tasks import CancelScope as BaseCancelScope
from ..abc import (
AsyncBackend,
@ -80,6 +94,7 @@ from ..abc import (
UDPPacketType,
UNIXDatagramPacketType,
)
from ..abc._eventloop import StrOrBytesPath
from ..lowlevel import RunVar
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@ -630,16 +645,6 @@ class _AsyncioTaskStatus(abc.TaskStatus):
_task_states[task].parent_id = self._parent_id
def iterate_exceptions(
exception: BaseException,
) -> Generator[BaseException, None, None]:
if isinstance(exception, BaseExceptionGroup):
for exc in exception.exceptions:
yield from iterate_exceptions(exc)
else:
yield exception
class TaskGroup(abc.TaskGroup):
def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
@ -925,7 +930,7 @@ class StreamReaderWrapper(abc.ByteReceiveStream):
raise EndOfStream
async def aclose(self) -> None:
self._stream.feed_eof()
self._stream.set_exception(ClosedResourceError())
await AsyncIOBackend.checkpoint()
@ -1073,7 +1078,8 @@ class StreamProtocol(asyncio.Protocol):
self.write_event.set()
def data_received(self, data: bytes) -> None:
self.read_queue.append(data)
# ProactorEventloop sometimes sends bytearray instead of bytes
self.read_queue.append(bytes(data))
self.read_event.set()
def eof_received(self) -> bool | None:
@ -1665,6 +1671,154 @@ class Event(BaseEvent):
return EventStatistics(len(self._event._waiters))
class Lock(BaseLock):
def __new__(cls, *, fast_acquire: bool = False) -> Lock:
return object.__new__(cls)
def __init__(self, *, fast_acquire: bool = False) -> None:
self._fast_acquire = fast_acquire
self._owner_task: asyncio.Task | None = None
self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque()
async def acquire(self) -> None:
if self._owner_task is None and not self._waiters:
await AsyncIOBackend.checkpoint_if_cancelled()
self._owner_task = current_task()
# Unless on the "fast path", yield control of the event loop so that other
# tasks can run too
if not self._fast_acquire:
try:
await AsyncIOBackend.cancel_shielded_checkpoint()
except CancelledError:
self.release()
raise
return
task = cast(asyncio.Task, current_task())
fut: asyncio.Future[None] = asyncio.Future()
item = task, fut
self._waiters.append(item)
try:
await fut
except CancelledError:
self._waiters.remove(item)
if self._owner_task is task:
self.release()
raise
self._waiters.remove(item)
def acquire_nowait(self) -> None:
if self._owner_task is None and not self._waiters:
self._owner_task = current_task()
return
raise WouldBlock
def locked(self) -> bool:
return self._owner_task is not None
def release(self) -> None:
if self._owner_task != current_task():
raise RuntimeError("The current task is not holding this lock")
for task, fut in self._waiters:
if not fut.cancelled():
self._owner_task = task
fut.set_result(None)
return
self._owner_task = None
def statistics(self) -> LockStatistics:
task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None
return LockStatistics(self.locked(), task_info, len(self._waiters))
class Semaphore(BaseSemaphore):
def __new__(
cls,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> Semaphore:
return object.__new__(cls)
def __init__(
self,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
):
super().__init__(initial_value, max_value=max_value)
self._value = initial_value
self._max_value = max_value
self._fast_acquire = fast_acquire
self._waiters: deque[asyncio.Future[None]] = deque()
async def acquire(self) -> None:
if self._value > 0 and not self._waiters:
await AsyncIOBackend.checkpoint_if_cancelled()
self._value -= 1
# Unless on the "fast path", yield control of the event loop so that other
# tasks can run too
if not self._fast_acquire:
try:
await AsyncIOBackend.cancel_shielded_checkpoint()
except CancelledError:
self.release()
raise
return
fut: asyncio.Future[None] = asyncio.Future()
self._waiters.append(fut)
try:
await fut
except CancelledError:
try:
self._waiters.remove(fut)
except ValueError:
self.release()
raise
def acquire_nowait(self) -> None:
if self._value == 0:
raise WouldBlock
self._value -= 1
def release(self) -> None:
if self._max_value is not None and self._value == self._max_value:
raise ValueError("semaphore released too many times")
for fut in self._waiters:
if not fut.cancelled():
fut.set_result(None)
self._waiters.remove(fut)
return
self._value += 1
@property
def value(self) -> int:
return self._value
@property
def max_value(self) -> int | None:
return self._max_value
def statistics(self) -> SemaphoreStatistics:
return SemaphoreStatistics(len(self._waiters))
class CapacityLimiter(BaseCapacityLimiter):
_total_tokens: float = 0
@ -1861,7 +2015,9 @@ class AsyncIOTaskInfo(TaskInfo):
if task_state := _task_states.get(task):
if cancel_scope := task_state.cancel_scope:
return cancel_scope.cancel_called or cancel_scope._parent_cancelled()
return cancel_scope.cancel_called or (
not cancel_scope.shield and cancel_scope._parent_cancelled()
)
return False
@ -1926,13 +2082,23 @@ class TestRunner(abc.TestRunner):
tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
],
) -> None:
from _pytest.outcomes import OutcomeException
with receive_stream, self._send_stream:
async for coro, future in receive_stream:
try:
retval = await coro
except CancelledError as exc:
if not future.cancelled():
future.cancel(*exc.args)
raise
except BaseException as exc:
if not future.cancelled():
future.set_exception(exc)
if not isinstance(exc, (Exception, OutcomeException)):
raise
else:
if not future.cancelled():
future.set_result(retval)
@ -2113,6 +2279,20 @@ class AsyncIOBackend(AsyncBackend):
def create_event(cls) -> abc.Event:
return Event()
@classmethod
def create_lock(cls, *, fast_acquire: bool) -> abc.Lock:
return Lock(fast_acquire=fast_acquire)
@classmethod
def create_semaphore(
cls,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> abc.Semaphore:
return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
@classmethod
def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
return CapacityLimiter(total_tokens)
@ -2245,26 +2425,24 @@ class AsyncIOBackend(AsyncBackend):
@classmethod
async def open_process(
cls,
command: str | bytes | Sequence[str | bytes],
command: StrOrBytesPath | Sequence[StrOrBytesPath],
*,
shell: bool,
stdin: int | IO[Any] | None,
stdout: int | IO[Any] | None,
stderr: int | IO[Any] | None,
cwd: str | bytes | PathLike | None = None,
env: Mapping[str, str] | None = None,
start_new_session: bool = False,
**kwargs: Any,
) -> Process:
await cls.checkpoint()
if shell:
if isinstance(command, PathLike):
command = os.fspath(command)
if isinstance(command, (str, bytes)):
process = await asyncio.create_subprocess_shell(
cast("str | bytes", command),
command,
stdin=stdin,
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
start_new_session=start_new_session,
**kwargs,
)
else:
process = await asyncio.create_subprocess_exec(
@ -2272,9 +2450,7 @@ class AsyncIOBackend(AsyncBackend):
stdin=stdin,
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
start_new_session=start_new_session,
**kwargs,
)
stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
@ -2289,7 +2465,7 @@ class AsyncIOBackend(AsyncBackend):
name="AnyIO process pool shutdown task",
)
find_root_task().add_done_callback(
partial(_forcibly_shutdown_process_pool_on_exit, workers)
partial(_forcibly_shutdown_process_pool_on_exit, workers) # type:ignore[arg-type]
)
@classmethod

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import array
import math
import os
import socket
import sys
import types
@ -25,7 +26,6 @@ from typing import (
ContextManager,
Coroutine,
Generic,
Mapping,
NoReturn,
Sequence,
TypeVar,
@ -45,7 +45,14 @@ from trio.lowlevel import (
from trio.socket import SocketType as TrioSocketType
from trio.to_thread import run_sync
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
from .. import (
CapacityLimiterStatistics,
EventStatistics,
LockStatistics,
TaskInfo,
WouldBlock,
abc,
)
from .._core._eventloop import claim_worker_thread
from .._core._exceptions import (
BrokenResourceError,
@ -55,12 +62,19 @@ from .._core._exceptions import (
)
from .._core._sockets import convert_ipv6_sockaddr
from .._core._streams import create_memory_object_stream
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
from .._core._synchronization import (
CapacityLimiter as BaseCapacityLimiter,
)
from .._core._synchronization import Event as BaseEvent
from .._core._synchronization import ResourceGuard
from .._core._synchronization import Lock as BaseLock
from .._core._synchronization import (
ResourceGuard,
SemaphoreStatistics,
)
from .._core._synchronization import Semaphore as BaseSemaphore
from .._core._tasks import CancelScope as BaseCancelScope
from ..abc import IPSockAddrType, UDPPacketType, UNIXDatagramPacketType
from ..abc._eventloop import AsyncBackend
from ..abc._eventloop import AsyncBackend, StrOrBytesPath
from ..streams.memory import MemoryObjectSendStream
if sys.version_info >= (3, 10):
@ -637,6 +651,100 @@ class Event(BaseEvent):
self.__original.set()
class Lock(BaseLock):
def __new__(cls, *, fast_acquire: bool = False) -> Lock:
return object.__new__(cls)
def __init__(self, *, fast_acquire: bool = False) -> None:
self._fast_acquire = fast_acquire
self.__original = trio.Lock()
async def acquire(self) -> None:
if not self._fast_acquire:
await self.__original.acquire()
return
# This is the "fast path" where we don't let other tasks run
await trio.lowlevel.checkpoint_if_cancelled()
try:
self.__original.acquire_nowait()
except trio.WouldBlock:
await self.__original._lot.park()
def acquire_nowait(self) -> None:
try:
self.__original.acquire_nowait()
except trio.WouldBlock:
raise WouldBlock from None
def locked(self) -> bool:
return self.__original.locked()
def release(self) -> None:
self.__original.release()
def statistics(self) -> LockStatistics:
orig_statistics = self.__original.statistics()
owner = TrioTaskInfo(orig_statistics.owner) if orig_statistics.owner else None
return LockStatistics(
orig_statistics.locked, owner, orig_statistics.tasks_waiting
)
class Semaphore(BaseSemaphore):
def __new__(
cls,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> Semaphore:
return object.__new__(cls)
def __init__(
self,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> None:
super().__init__(initial_value, max_value=max_value, fast_acquire=fast_acquire)
self.__original = trio.Semaphore(initial_value, max_value=max_value)
async def acquire(self) -> None:
if not self._fast_acquire:
await self.__original.acquire()
return
# This is the "fast path" where we don't let other tasks run
await trio.lowlevel.checkpoint_if_cancelled()
try:
self.__original.acquire_nowait()
except trio.WouldBlock:
await self.__original._lot.park()
def acquire_nowait(self) -> None:
try:
self.__original.acquire_nowait()
except trio.WouldBlock:
raise WouldBlock from None
@property
def max_value(self) -> int | None:
return self.__original.max_value
@property
def value(self) -> int:
return self.__original.value
def release(self) -> None:
self.__original.release()
def statistics(self) -> SemaphoreStatistics:
orig_statistics = self.__original.statistics()
return SemaphoreStatistics(orig_statistics.tasks_waiting)
class CapacityLimiter(BaseCapacityLimiter):
def __new__(
cls,
@ -915,6 +1023,20 @@ class TrioBackend(AsyncBackend):
def create_event(cls) -> abc.Event:
return Event()
@classmethod
def create_lock(cls, *, fast_acquire: bool) -> Lock:
return Lock(fast_acquire=fast_acquire)
@classmethod
def create_semaphore(
cls,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> abc.Semaphore:
return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
@classmethod
def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter:
return CapacityLimiter(total_tokens)
@ -967,26 +1089,39 @@ class TrioBackend(AsyncBackend):
@classmethod
async def open_process(
cls,
command: str | bytes | Sequence[str | bytes],
command: StrOrBytesPath | Sequence[StrOrBytesPath],
*,
shell: bool,
stdin: int | IO[Any] | None,
stdout: int | IO[Any] | None,
stderr: int | IO[Any] | None,
cwd: str | bytes | PathLike | None = None,
env: Mapping[str, str] | None = None,
start_new_session: bool = False,
**kwargs: Any,
) -> Process:
process = await trio.lowlevel.open_process( # type: ignore[misc]
command, # type: ignore[arg-type]
stdin=stdin,
stdout=stdout,
stderr=stderr,
shell=shell,
cwd=cwd,
env=env,
start_new_session=start_new_session,
)
def convert_item(item: StrOrBytesPath) -> str:
str_or_bytes = os.fspath(item)
if isinstance(str_or_bytes, str):
return str_or_bytes
else:
return os.fsdecode(str_or_bytes)
if isinstance(command, (str, bytes, PathLike)):
process = await trio.lowlevel.open_process(
convert_item(command),
stdin=stdin,
stdout=stdout,
stderr=stderr,
shell=True,
**kwargs,
)
else:
process = await trio.lowlevel.open_process(
[convert_item(item) for item in command],
stdin=stdin,
stdout=stdout,
stderr=stderr,
shell=False,
**kwargs,
)
stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None
stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None
stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None

View File

@ -1,5 +1,11 @@
from __future__ import annotations
import sys
from collections.abc import Generator
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
class BrokenResourceError(Exception):
"""
@ -71,3 +77,13 @@ class TypedAttributeLookupError(LookupError):
class WouldBlock(Exception):
"""Raised by ``X_nowait`` functions if ``X()`` would block."""
def iterate_exceptions(
exception: BaseException,
) -> Generator[BaseException, None, None]:
if isinstance(exception, BaseExceptionGroup):
for exc in exception.exceptions:
yield from iterate_exceptions(exc)
else:
yield exception

View File

@ -358,8 +358,26 @@ class Path:
def as_uri(self) -> str:
return self._path.as_uri()
def match(self, path_pattern: str) -> bool:
return self._path.match(path_pattern)
if sys.version_info >= (3, 13):
parser = pathlib.Path.parser
@classmethod
def from_uri(cls, uri: str) -> Path:
return Path(pathlib.Path.from_uri(uri))
def full_match(
self, path_pattern: str, *, case_sensitive: bool | None = None
) -> bool:
return self._path.full_match(path_pattern, case_sensitive=case_sensitive)
def match(
self, path_pattern: str, *, case_sensitive: bool | None = None
) -> bool:
return self._path.match(path_pattern, case_sensitive=case_sensitive)
else:
def match(self, path_pattern: str) -> bool:
return self._path.match(path_pattern)
def is_relative_to(self, other: str | PathLike[str]) -> bool:
try:

View File

@ -680,19 +680,26 @@ async def setup_unix_local_socket(
:param socktype: socket.SOCK_STREAM or socket.SOCK_DGRAM
"""
path_str: str | bytes | None
path_str: str | None
if path is not None:
path_str = os.fspath(path)
path_str = os.fsdecode(path)
# Copied from pathlib...
try:
stat_result = os.stat(path)
except OSError as e:
if e.errno not in (errno.ENOENT, errno.ENOTDIR, errno.EBADF, errno.ELOOP):
raise
else:
if stat.S_ISSOCK(stat_result.st_mode):
os.unlink(path)
# Linux abstract namespace sockets aren't backed by a concrete file so skip stat call
if not path_str.startswith("\0"):
# Copied from pathlib...
try:
stat_result = os.stat(path)
except OSError as e:
if e.errno not in (
errno.ENOENT,
errno.ENOTDIR,
errno.EBADF,
errno.ELOOP,
):
raise
else:
if stat.S_ISSOCK(stat_result.st_mode):
os.unlink(path)
else:
path_str = None

View File

@ -1,26 +1,41 @@
from __future__ import annotations
from collections.abc import AsyncIterable, Mapping, Sequence
import sys
from collections.abc import AsyncIterable, Iterable, Mapping, Sequence
from io import BytesIO
from os import PathLike
from subprocess import DEVNULL, PIPE, CalledProcessError, CompletedProcess
from typing import IO, Any, cast
from typing import IO, Any, Union, cast
from ..abc import Process
from ._eventloop import get_async_backend
from ._tasks import create_task_group
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
StrOrBytesPath: TypeAlias = Union[str, bytes, "PathLike[str]", "PathLike[bytes]"]
async def run_process(
command: str | bytes | Sequence[str | bytes],
command: StrOrBytesPath | Sequence[StrOrBytesPath],
*,
input: bytes | None = None,
stdout: int | IO[Any] | None = PIPE,
stderr: int | IO[Any] | None = PIPE,
check: bool = True,
cwd: str | bytes | PathLike[str] | None = None,
cwd: StrOrBytesPath | None = None,
env: Mapping[str, str] | None = None,
startupinfo: Any = None,
creationflags: int = 0,
start_new_session: bool = False,
pass_fds: Sequence[int] = (),
user: str | int | None = None,
group: str | int | None = None,
extra_groups: Iterable[str | int] | None = None,
umask: int = -1,
) -> CompletedProcess[bytes]:
"""
Run an external command in a subprocess and wait until it completes.
@ -40,8 +55,20 @@ async def run_process(
command
:param env: if not ``None``, this mapping replaces the inherited environment
variables from the parent process
:param startupinfo: an instance of :class:`subprocess.STARTUPINFO` that can be used
to specify process startup parameters (Windows only)
:param creationflags: flags that can be used to control the creation of the
subprocess (see :class:`subprocess.Popen` for the specifics)
:param start_new_session: if ``true`` the setsid() system call will be made in the
child process prior to the execution of the subprocess. (POSIX only)
:param pass_fds: sequence of file descriptors to keep open between the parent and
child processes. (POSIX only)
:param user: effective user to run the process as (Python >= 3.9, POSIX only)
:param group: effective group to run the process as (Python >= 3.9, POSIX only)
:param extra_groups: supplementary groups to set in the subprocess (Python >= 3.9,
POSIX only)
:param umask: if not negative, this umask is applied in the child process before
running the given command (Python >= 3.9, POSIX only)
:return: an object representing the completed process
:raises ~subprocess.CalledProcessError: if ``check`` is ``True`` and the process
exits with a nonzero return code
@ -62,7 +89,14 @@ async def run_process(
stderr=stderr,
cwd=cwd,
env=env,
startupinfo=startupinfo,
creationflags=creationflags,
start_new_session=start_new_session,
pass_fds=pass_fds,
user=user,
group=group,
extra_groups=extra_groups,
umask=umask,
) as process:
stream_contents: list[bytes | None] = [None, None]
async with create_task_group() as tg:
@ -86,14 +120,21 @@ async def run_process(
async def open_process(
command: str | bytes | Sequence[str | bytes],
command: StrOrBytesPath | Sequence[StrOrBytesPath],
*,
stdin: int | IO[Any] | None = PIPE,
stdout: int | IO[Any] | None = PIPE,
stderr: int | IO[Any] | None = PIPE,
cwd: str | bytes | PathLike[str] | None = None,
cwd: StrOrBytesPath | None = None,
env: Mapping[str, str] | None = None,
startupinfo: Any = None,
creationflags: int = 0,
start_new_session: bool = False,
pass_fds: Sequence[int] = (),
user: str | int | None = None,
group: str | int | None = None,
extra_groups: Iterable[str | int] | None = None,
umask: int = -1,
) -> Process:
"""
Start an external command in a subprocess.
@ -111,30 +152,58 @@ async def open_process(
:param cwd: If not ``None``, the working directory is changed before executing
:param env: If env is not ``None``, it must be a mapping that defines the
environment variables for the new process
:param creationflags: flags that can be used to control the creation of the
subprocess (see :class:`subprocess.Popen` for the specifics)
:param startupinfo: an instance of :class:`subprocess.STARTUPINFO` that can be used
to specify process startup parameters (Windows only)
:param start_new_session: if ``true`` the setsid() system call will be made in the
child process prior to the execution of the subprocess. (POSIX only)
:param pass_fds: sequence of file descriptors to keep open between the parent and
child processes. (POSIX only)
:param user: effective user to run the process as (Python >= 3.9; POSIX only)
:param group: effective group to run the process as (Python >= 3.9; POSIX only)
:param extra_groups: supplementary groups to set in the subprocess (Python >= 3.9;
POSIX only)
:param umask: if not negative, this umask is applied in the child process before
running the given command (Python >= 3.9; POSIX only)
:return: an asynchronous process object
"""
if isinstance(command, (str, bytes)):
return await get_async_backend().open_process(
command,
shell=True,
stdin=stdin,
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
start_new_session=start_new_session,
)
else:
return await get_async_backend().open_process(
command,
shell=False,
stdin=stdin,
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
start_new_session=start_new_session,
)
kwargs: dict[str, Any] = {}
if user is not None:
if sys.version_info < (3, 9):
raise TypeError("the 'user' argument requires Python 3.9 or later")
kwargs["user"] = user
if group is not None:
if sys.version_info < (3, 9):
raise TypeError("the 'group' argument requires Python 3.9 or later")
kwargs["group"] = group
if extra_groups is not None:
if sys.version_info < (3, 9):
raise TypeError("the 'extra_groups' argument requires Python 3.9 or later")
kwargs["extra_groups"] = group
if umask >= 0:
if sys.version_info < (3, 9):
raise TypeError("the 'umask' argument requires Python 3.9 or later")
kwargs["umask"] = umask
return await get_async_backend().open_process(
command,
stdin=stdin,
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
startupinfo=startupinfo,
creationflags=creationflags,
start_new_session=start_new_session,
pass_fds=pass_fds,
**kwargs,
)

View File

@ -7,9 +7,9 @@ from types import TracebackType
from sniffio import AsyncLibraryNotFoundError
from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled
from ..lowlevel import checkpoint
from ._eventloop import get_async_backend
from ._exceptions import BusyResourceError, WouldBlock
from ._exceptions import BusyResourceError
from ._tasks import CancelScope
from ._testing import TaskInfo, get_current_task
@ -137,10 +137,11 @@ class EventAdapter(Event):
class Lock:
_owner_task: TaskInfo | None = None
def __init__(self) -> None:
self._waiters: deque[tuple[TaskInfo, Event]] = deque()
def __new__(cls, *, fast_acquire: bool = False) -> Lock:
try:
return get_async_backend().create_lock(fast_acquire=fast_acquire)
except AsyncLibraryNotFoundError:
return LockAdapter(fast_acquire=fast_acquire)
async def __aenter__(self) -> None:
await self.acquire()
@ -155,31 +156,7 @@ class Lock:
async def acquire(self) -> None:
"""Acquire the lock."""
await checkpoint_if_cancelled()
try:
self.acquire_nowait()
except WouldBlock:
task = get_current_task()
event = Event()
token = task, event
self._waiters.append(token)
try:
await event.wait()
except BaseException:
if not event.is_set():
self._waiters.remove(token)
elif self._owner_task == task:
self.release()
raise
assert self._owner_task == task
else:
try:
await cancel_shielded_checkpoint()
except BaseException:
self.release()
raise
raise NotImplementedError
def acquire_nowait(self) -> None:
"""
@ -188,29 +165,15 @@ class Lock:
:raises ~anyio.WouldBlock: if the operation would block
"""
task = get_current_task()
if self._owner_task == task:
raise RuntimeError("Attempted to acquire an already held Lock")
if self._owner_task is not None:
raise WouldBlock
self._owner_task = task
raise NotImplementedError
def release(self) -> None:
"""Release the lock."""
if self._owner_task != get_current_task():
raise RuntimeError("The current task is not holding this lock")
if self._waiters:
self._owner_task, event = self._waiters.popleft()
event.set()
else:
del self._owner_task
raise NotImplementedError
def locked(self) -> bool:
"""Return True if the lock is currently held."""
return self._owner_task is not None
raise NotImplementedError
def statistics(self) -> LockStatistics:
"""
@ -218,7 +181,71 @@ class Lock:
.. versionadded:: 3.0
"""
return LockStatistics(self.locked(), self._owner_task, len(self._waiters))
raise NotImplementedError
class LockAdapter(Lock):
_internal_lock: Lock | None = None
def __new__(cls, *, fast_acquire: bool = False) -> LockAdapter:
return object.__new__(cls)
def __init__(self, *, fast_acquire: bool = False):
self._fast_acquire = fast_acquire
@property
def _lock(self) -> Lock:
if self._internal_lock is None:
self._internal_lock = get_async_backend().create_lock(
fast_acquire=self._fast_acquire
)
return self._internal_lock
async def __aenter__(self) -> None:
await self._lock.acquire()
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self._internal_lock is not None:
self._internal_lock.release()
async def acquire(self) -> None:
"""Acquire the lock."""
await self._lock.acquire()
def acquire_nowait(self) -> None:
"""
Acquire the lock, without blocking.
:raises ~anyio.WouldBlock: if the operation would block
"""
self._lock.acquire_nowait()
def release(self) -> None:
"""Release the lock."""
self._lock.release()
def locked(self) -> bool:
"""Return True if the lock is currently held."""
return self._lock.locked()
def statistics(self) -> LockStatistics:
"""
Return statistics about the current state of this lock.
.. versionadded:: 3.0
"""
if self._internal_lock is None:
return LockStatistics(False, None, 0)
return self._internal_lock.statistics()
class Condition:
@ -312,7 +339,27 @@ class Condition:
class Semaphore:
def __init__(self, initial_value: int, *, max_value: int | None = None):
def __new__(
cls,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> Semaphore:
try:
return get_async_backend().create_semaphore(
initial_value, max_value=max_value, fast_acquire=fast_acquire
)
except AsyncLibraryNotFoundError:
return SemaphoreAdapter(initial_value, max_value=max_value)
def __init__(
self,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
):
if not isinstance(initial_value, int):
raise TypeError("initial_value must be an integer")
if initial_value < 0:
@ -325,9 +372,7 @@ class Semaphore:
"max_value must be equal to or higher than initial_value"
)
self._value = initial_value
self._max_value = max_value
self._waiters: deque[Event] = deque()
self._fast_acquire = fast_acquire
async def __aenter__(self) -> Semaphore:
await self.acquire()
@ -343,27 +388,7 @@ class Semaphore:
async def acquire(self) -> None:
"""Decrement the semaphore value, blocking if necessary."""
await checkpoint_if_cancelled()
try:
self.acquire_nowait()
except WouldBlock:
event = Event()
self._waiters.append(event)
try:
await event.wait()
except BaseException:
if not event.is_set():
self._waiters.remove(event)
else:
self.release()
raise
else:
try:
await cancel_shielded_checkpoint()
except BaseException:
self.release()
raise
raise NotImplementedError
def acquire_nowait(self) -> None:
"""
@ -372,30 +397,21 @@ class Semaphore:
:raises ~anyio.WouldBlock: if the operation would block
"""
if self._value == 0:
raise WouldBlock
self._value -= 1
raise NotImplementedError
def release(self) -> None:
"""Increment the semaphore value."""
if self._max_value is not None and self._value == self._max_value:
raise ValueError("semaphore released too many times")
if self._waiters:
self._waiters.popleft().set()
else:
self._value += 1
raise NotImplementedError
@property
def value(self) -> int:
"""The current value of the semaphore."""
return self._value
raise NotImplementedError
@property
def max_value(self) -> int | None:
"""The maximum value of the semaphore."""
return self._max_value
raise NotImplementedError
def statistics(self) -> SemaphoreStatistics:
"""
@ -403,7 +419,66 @@ class Semaphore:
.. versionadded:: 3.0
"""
return SemaphoreStatistics(len(self._waiters))
raise NotImplementedError
class SemaphoreAdapter(Semaphore):
_internal_semaphore: Semaphore | None = None
def __new__(
cls,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> SemaphoreAdapter:
return object.__new__(cls)
def __init__(
self,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> None:
super().__init__(initial_value, max_value=max_value, fast_acquire=fast_acquire)
self._initial_value = initial_value
self._max_value = max_value
@property
def _semaphore(self) -> Semaphore:
if self._internal_semaphore is None:
self._internal_semaphore = get_async_backend().create_semaphore(
self._initial_value, max_value=self._max_value
)
return self._internal_semaphore
async def acquire(self) -> None:
await self._semaphore.acquire()
def acquire_nowait(self) -> None:
self._semaphore.acquire_nowait()
def release(self) -> None:
self._semaphore.release()
@property
def value(self) -> int:
if self._internal_semaphore is None:
return self._initial_value
return self._semaphore.value
@property
def max_value(self) -> int | None:
return self._max_value
def statistics(self) -> SemaphoreStatistics:
if self._internal_semaphore is None:
return SemaphoreStatistics(tasks_waiting=0)
return self._semaphore.statistics()
class CapacityLimiter:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import math
import sys
from abc import ABCMeta, abstractmethod
from collections.abc import AsyncIterator, Awaitable, Mapping
from collections.abc import AsyncIterator, Awaitable
from os import PathLike
from signal import Signals
from socket import AddressFamily, SocketKind, socket
@ -15,6 +15,7 @@ from typing import (
ContextManager,
Sequence,
TypeVar,
Union,
overload,
)
@ -23,10 +24,13 @@ if sys.version_info >= (3, 11):
else:
from typing_extensions import TypeVarTuple, Unpack
if TYPE_CHECKING:
from typing import Literal
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from .._core._synchronization import CapacityLimiter, Event
if TYPE_CHECKING:
from .._core._synchronization import CapacityLimiter, Event, Lock, Semaphore
from .._core._tasks import CancelScope
from .._core._testing import TaskInfo
from ..from_thread import BlockingPortal
@ -46,6 +50,7 @@ if TYPE_CHECKING:
T_Retval = TypeVar("T_Retval")
PosArgsT = TypeVarTuple("PosArgsT")
StrOrBytesPath: TypeAlias = Union[str, bytes, "PathLike[str]", "PathLike[bytes]"]
class AsyncBackend(metaclass=ABCMeta):
@ -167,6 +172,22 @@ class AsyncBackend(metaclass=ABCMeta):
def create_event(cls) -> Event:
pass
@classmethod
@abstractmethod
def create_lock(cls, *, fast_acquire: bool) -> Lock:
pass
@classmethod
@abstractmethod
def create_semaphore(
cls,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> Semaphore:
pass
@classmethod
@abstractmethod
def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter:
@ -213,51 +234,16 @@ class AsyncBackend(metaclass=ABCMeta):
def create_blocking_portal(cls) -> BlockingPortal:
pass
@classmethod
@overload
async def open_process(
cls,
command: str | bytes,
*,
shell: Literal[True],
stdin: int | IO[Any] | None,
stdout: int | IO[Any] | None,
stderr: int | IO[Any] | None,
cwd: str | bytes | PathLike[str] | None = None,
env: Mapping[str, str] | None = None,
start_new_session: bool = False,
) -> Process:
pass
@classmethod
@overload
async def open_process(
cls,
command: Sequence[str | bytes],
*,
shell: Literal[False],
stdin: int | IO[Any] | None,
stdout: int | IO[Any] | None,
stderr: int | IO[Any] | None,
cwd: str | bytes | PathLike[str] | None = None,
env: Mapping[str, str] | None = None,
start_new_session: bool = False,
) -> Process:
pass
@classmethod
@abstractmethod
async def open_process(
cls,
command: str | bytes | Sequence[str | bytes],
command: StrOrBytesPath | Sequence[StrOrBytesPath],
*,
shell: bool,
stdin: int | IO[Any] | None,
stdout: int | IO[Any] | None,
stderr: int | IO[Any] | None,
cwd: str | bytes | PathLike[str] | None = None,
env: Mapping[str, str] | None = None,
start_new_session: bool = False,
**kwargs: Any,
) -> Process:
pass

View File

@ -1,19 +1,18 @@
from __future__ import annotations
import sys
import threading
from collections.abc import Awaitable, Callable, Generator
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
from concurrent.futures import Future
from contextlib import AbstractContextManager, contextmanager
from dataclasses import dataclass, field
from inspect import isawaitable
from threading import Lock, Thread, get_ident
from types import TracebackType
from typing import (
Any,
AsyncContextManager,
ContextManager,
Generic,
Iterable,
TypeVar,
cast,
overload,
@ -146,7 +145,7 @@ class BlockingPortal:
return get_async_backend().create_blocking_portal()
def __init__(self) -> None:
self._event_loop_thread_id: int | None = threading.get_ident()
self._event_loop_thread_id: int | None = get_ident()
self._stop_event = Event()
self._task_group = create_task_group()
self._cancelled_exc_class = get_cancelled_exc_class()
@ -167,7 +166,7 @@ class BlockingPortal:
def _check_running(self) -> None:
if self._event_loop_thread_id is None:
raise RuntimeError("This portal is not running")
if self._event_loop_thread_id == threading.get_ident():
if self._event_loop_thread_id == get_ident():
raise RuntimeError(
"This method cannot be called from the event loop thread"
)
@ -202,7 +201,7 @@ class BlockingPortal:
def callback(f: Future[T_Retval]) -> None:
if f.cancelled() and self._event_loop_thread_id not in (
None,
threading.get_ident(),
get_ident(),
):
self.call(scope.cancel)
@ -411,7 +410,7 @@ class BlockingPortalProvider:
backend: str = "asyncio"
backend_options: dict[str, Any] | None = None
_lock: threading.Lock = field(init=False, default_factory=threading.Lock)
_lock: Lock = field(init=False, default_factory=Lock)
_leases: int = field(init=False, default=0)
_portal: BlockingPortal = field(init=False)
_portal_cm: AbstractContextManager[BlockingPortal] | None = field(
@ -469,43 +468,37 @@ def start_blocking_portal(
async def run_portal() -> None:
async with BlockingPortal() as portal_:
if future.set_running_or_notify_cancel():
future.set_result(portal_)
await portal_.sleep_until_stopped()
future.set_result(portal_)
await portal_.sleep_until_stopped()
def run_blocking_portal() -> None:
if future.set_running_or_notify_cancel():
try:
_eventloop.run(
run_portal, backend=backend, backend_options=backend_options
)
except BaseException as exc:
if not future.done():
future.set_exception(exc)
future: Future[BlockingPortal] = Future()
with ThreadPoolExecutor(1) as executor:
run_future = executor.submit(
_eventloop.run, # type: ignore[arg-type]
run_portal,
backend=backend,
backend_options=backend_options,
)
thread = Thread(target=run_blocking_portal, daemon=True)
thread.start()
try:
cancel_remaining_tasks = False
portal = future.result()
try:
wait(
cast(Iterable[Future], [run_future, future]),
return_when=FIRST_COMPLETED,
)
yield portal
except BaseException:
future.cancel()
run_future.cancel()
cancel_remaining_tasks = True
raise
if future.done():
portal = future.result()
cancel_remaining_tasks = False
finally:
try:
yield portal
except BaseException:
cancel_remaining_tasks = True
raise
finally:
try:
portal.call(portal.stop, cancel_remaining_tasks)
except RuntimeError:
pass
run_future.result()
portal.call(portal.stop, cancel_remaining_tasks)
except RuntimeError:
pass
finally:
thread.join()
def check_cancelled() -> None:

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import sys
from collections.abc import Iterator
from contextlib import ExitStack, contextmanager
from inspect import isasyncgenfunction, iscoroutinefunction
@ -7,10 +8,15 @@ from typing import Any, Dict, Tuple, cast
import pytest
import sniffio
from _pytest.outcomes import Exit
from ._core._eventloop import get_all_backends, get_async_backend
from ._core._exceptions import iterate_exceptions
from .abc import TestRunner
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
_current_runner: TestRunner | None = None
_runner_stack: ExitStack | None = None
_runner_leases = 0
@ -121,7 +127,14 @@ def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None:
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
with get_runner(backend_name, backend_options) as runner:
runner.run_test(pyfuncitem.obj, testargs)
try:
runner.run_test(pyfuncitem.obj, testargs)
except ExceptionGroup as excgrp:
for exc in iterate_exceptions(excgrp):
if isinstance(exc, (Exit, KeyboardInterrupt, SystemExit)):
raise exc from excgrp
raise
return True

View File

@ -38,6 +38,12 @@ class MemoryObjectItemReceiver(Generic[T_Item]):
task_info: TaskInfo = field(init=False, default_factory=get_current_task)
item: T_Item = field(init=False)
def __repr__(self) -> str:
# When item is not defined, we get following error with default __repr__:
# AttributeError: 'MemoryObjectItemReceiver' object has no attribute 'item'
item = getattr(self, "item", None)
return f"{self.__class__.__name__}(task_info={self.task_info}, item={item!r})"
@dataclass(eq=False)
class MemoryObjectStreamState(Generic[T_Item]):
@ -175,7 +181,7 @@ class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
def __del__(self) -> None:
if not self._closed:
warnings.warn(
f"Unclosed <{self.__class__.__name__}>",
f"Unclosed <{self.__class__.__name__} at {id(self):x}>",
ResourceWarning,
source=self,
)
@ -305,7 +311,7 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
def __del__(self) -> None:
if not self._closed:
warnings.warn(
f"Unclosed <{self.__class__.__name__}>",
f"Unclosed <{self.__class__.__name__} at {id(self):x}>",
ResourceWarning,
source=self,
)

View File

@ -223,7 +223,7 @@ def process_worker() -> None:
main_module_path: str | None
sys.path, main_module_path = args
del sys.modules["__main__"]
if main_module_path:
if main_module_path and os.path.isfile(main_module_path):
# Load the parent's main module but as __mp_main__ instead of
# __main__ (like multiprocessing does) to avoid infinite recursion
try:
@ -234,7 +234,6 @@ def process_worker() -> None:
sys.modules["__main__"] = main
except BaseException as exc:
exception = exc
try:
if exception is not None:
status = b"EXCEPTION"

View File

@ -4,7 +4,7 @@ charset_normalizer-3.3.2.dist-info/LICENSE,sha256=6zGgxaT7Cbik4yBV0lweX5w1iidS_v
charset_normalizer-3.3.2.dist-info/METADATA,sha256=cfLhl5A6SI-F0oclm8w8ux9wshL1nipdeCdVnYb4AaA,33550
charset_normalizer-3.3.2.dist-info/RECORD,,
charset_normalizer-3.3.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
charset_normalizer-3.3.2.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
charset_normalizer-3.3.2.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
charset_normalizer-3.3.2.dist-info/entry_points.txt,sha256=ADSTKrkXZ3hhdOVFi6DcUEHQRS0xfxDIE_pEz4wLIXA,65
charset_normalizer-3.3.2.dist-info/top_level.txt,sha256=7ASyzePr8_xuZWJsnqJjIBtyV8vhEo0wBCv1MPRRi3Q,19
charset_normalizer/__init__.py,sha256=UzI3xC8PhmcLRMzSgPb6minTmRq0kWznnCBJ8ZCc2XI,1577

View File

@ -1,5 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (73.0.1)
Generator: setuptools (74.1.2)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@ -59,7 +59,7 @@ devchat/_service/route/__pycache__/workflows.cpython-38.pyc,,
devchat/_service/route/logs.py,sha256=3on6cRIJ8P1N0dMImTqQHe7CUUNrSgKehRnFQIS1qjc,1290
devchat/_service/route/message.py,sha256=Ex06NvmXEDVB4g1bMrZuAYvF_mkSSCJCQ3qRz86CrGs,3309
devchat/_service/route/topics.py,sha256=ox40XH3scK0Ey58nTEJ3h6i0s1PEdmOImI4WZXSLIqc,2331
devchat/_service/route/workflows.py,sha256=XaBwSX0QKuGP_WLCQGLBGmrDoj8bw8rEJuKe-qGkfxk,5252
devchat/_service/route/workflows.py,sha256=cGmHfS7rm4sTEUK_7fTD2z2p0hWwhHJ-WHLK81pA-GE,5383
devchat/_service/schema/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
devchat/_service/schema/__pycache__/__init__.cpython-38.pyc,,
devchat/_service/schema/__pycache__/request.cpython-38.pyc,,
@ -189,7 +189,7 @@ devchat/workflow/command/env.py,sha256=wXZc497GwSjWk8T37krTkxqyjUhWSAh0c0RCwmLzR
devchat/workflow/command/list.py,sha256=MYX469XXazUzNkahXuTlIJYIf_r_t1RwGrKEfugY8HY,1228
devchat/workflow/command/run.py,sha256=tSdDeVNBtNt7xwn6Q3hgjsMZUbL0NPt2g3-fsVAUg9A,214
devchat/workflow/command/update.py,sha256=7hmqqtJHpKZ2I03zQtT-aGvw63L5V0Lsmi-qe8ThjL4,761
devchat/workflow/env_manager.py,sha256=h5hPUfGrlG-8KMkfpVAuD_fvkS_myLi5s7jPf_uOyfs,8474
devchat/workflow/env_manager.py,sha256=XJksiBCBvkfR1rLxFt80bDK2slQkraeFa7Og19wI_Ho,9326
devchat/workflow/envs.py,sha256=-lVTLjWRMrb8RGVVlHgWKCiGZaojNdmycjHFT0ZKjEo,298
devchat/workflow/namespace.py,sha256=CaotAAcSwGINH871zXcs2frBlzuf-eZ2DMEIx_8uRUo,3880
devchat/workflow/path.py,sha256=6I3Fk-KTJRPPR9zKXP7pSfdrl8LcuWwOICSvWJiJNf0,1391

View File

@ -91,10 +91,12 @@ def update_custom_workflows():
updated_any = True
update_messages = []
for url in custom_git_urls:
repo_name = url.split("/")[-1].replace(".git", "") # 提取repo名称
for item in custom_git_urls:
git_url = item["git_url"]
branch = item["branch"]
repo_name = git_url.split("/")[-1].replace(".git", "") # 提取repo名称
repo_path: Path = base_path / repo_name # 拼接出clone路径
candidates_git_urls = [(url, "main")]
candidates_git_urls = [(git_url, branch)]
if repo_path.exists():
logger.info(f"Repo path not empty {repo_path}, removing it.")
@ -132,7 +134,7 @@ def update_custom_workflows():
return response.UpdateWorkflows(updated=updated_any, message=message_summary)
else:
return response.UpdateWorkflows(
False, "No custom_git_urls found in .chat/config.yaml"
updated=False, message="No custom_git_urls found in .chat/config.yaml"
)
else:
return response.UpdateWorkflows(False, "No .chat config found")
return response.UpdateWorkflows(updated=False, message="No .chat config found")

View File

@ -7,16 +7,13 @@ from typing import Dict, Optional, Tuple
from devchat.utils import get_logger, get_logging_file
from .envs import MAMBA_BIN_PATH
from .path import ENV_CACHE_DIR, MAMBA_PY_ENVS, MAMBA_ROOT
from .path import CHAT_CONFIG_FILENAME, CHAT_DIR, ENV_CACHE_DIR, MAMBA_PY_ENVS, MAMBA_ROOT
from .schema import ExternalPyConf
from .user_setting import USER_SETTINGS
# CONDA_FORGE = [
# "https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/",
# "conda-forge",
# ]
CONDA_FORGE_TUNA = "https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/"
PYPI_TUNA = "https://pypi.tuna.tsinghua.edu.cn/simple"
DEFAULT_CONDA_FORGE_URL = "https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/"
logger = get_logger(__name__)
@ -201,6 +198,9 @@ class PyEnvManager:
if is_exist:
return True, ""
# Get conda-forge URL from config file
conda_forge_url = self._get_conda_forge_url()
# create the environment
cmd = [
self.mamba_bin,
@ -208,7 +208,7 @@ class PyEnvManager:
"-n",
env_name,
"-c",
CONDA_FORGE_TUNA,
conda_forge_url,
"-r",
self.mamba_root,
f"python={py_version}",
@ -264,3 +264,26 @@ class PyEnvManager:
return env_path
return None
def _get_conda_forge_url(self) -> str:
"""
Read the conda-forge URL from the config file.
If the config file does not exist or does not contain the conda-forge URL,
use the default value.
"""
config_file = os.path.join(CHAT_DIR, CHAT_CONFIG_FILENAME)
try:
if not os.path.exists(config_file):
return DEFAULT_CONDA_FORGE_URL
import yaml
with open(config_file, "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
return config.get("conda-forge-url", DEFAULT_CONDA_FORGE_URL)
except Exception as e:
# Log the exception if needed
print(f"An error occurred when loading conda-forge-url from config file: {e}")
return DEFAULT_CONDA_FORGE_URL

View File

@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: idna
Version: 3.8
Version: 3.10
Summary: Internationalized Domain Names in Applications (IDNA)
Author-email: Kim Davies <kim+pypi@gumleaf.org>
Requires-Python: >=3.6
@ -26,9 +26,14 @@ Classifier: Programming Language :: Python :: Implementation :: PyPy
Classifier: Topic :: Internet :: Name Service (DNS)
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Classifier: Topic :: Utilities
Requires-Dist: ruff >= 0.6.2 ; extra == "all"
Requires-Dist: mypy >= 1.11.2 ; extra == "all"
Requires-Dist: pytest >= 8.3.2 ; extra == "all"
Requires-Dist: flake8 >= 7.1.1 ; extra == "all"
Project-URL: Changelog, https://github.com/kjd/idna/blob/master/HISTORY.rst
Project-URL: Issue tracker, https://github.com/kjd/idna/issues
Project-URL: Source, https://github.com/kjd/idna
Provides-Extra: all
Internationalized Domain Names in Applications (IDNA)
=====================================================

View File

@ -0,0 +1,22 @@
idna-3.10.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
idna-3.10.dist-info/LICENSE.md,sha256=pZ8LDvNjWHQQmkRhykT_enDVBpboFHZ7-vch1Mmw2w8,1541
idna-3.10.dist-info/METADATA,sha256=URR5ZyDfQ1PCEGhkYoojqfi2Ra0tau2--lhwG4XSfjI,10158
idna-3.10.dist-info/RECORD,,
idna-3.10.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
idna/__init__.py,sha256=MPqNDLZbXqGaNdXxAFhiqFPKEQXju2jNQhCey6-5eJM,868
idna/__pycache__/__init__.cpython-38.pyc,,
idna/__pycache__/codec.cpython-38.pyc,,
idna/__pycache__/compat.cpython-38.pyc,,
idna/__pycache__/core.cpython-38.pyc,,
idna/__pycache__/idnadata.cpython-38.pyc,,
idna/__pycache__/intranges.cpython-38.pyc,,
idna/__pycache__/package_data.cpython-38.pyc,,
idna/__pycache__/uts46data.cpython-38.pyc,,
idna/codec.py,sha256=PEew3ItwzjW4hymbasnty2N2OXvNcgHB-JjrBuxHPYY,3422
idna/compat.py,sha256=RzLy6QQCdl9784aFhb2EX9EKGCJjg0P3PilGdeXXcx8,316
idna/core.py,sha256=YJYyAMnwiQEPjVC4-Fqu_p4CJ6yKKuDGmppBNQNQpFs,13239
idna/idnadata.py,sha256=W30GcIGvtOWYwAjZj4ZjuouUutC6ffgNuyjJy7fZ-lo,78306
idna/intranges.py,sha256=amUtkdhYcQG8Zr-CoMM_kVRacxkivC1WgxN1b63KKdU,1898
idna/package_data.py,sha256=q59S3OXsc5VI8j6vSD0sGBMyk6zZ4vWFREE88yCJYKs,21
idna/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
idna/uts46data.py,sha256=rt90K9J40gUSwppDPCrhjgi5AA6pWM65dEGRSf6rIhM,239289

View File

@ -1,22 +0,0 @@
idna-3.8.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
idna-3.8.dist-info/LICENSE.md,sha256=pZ8LDvNjWHQQmkRhykT_enDVBpboFHZ7-vch1Mmw2w8,1541
idna-3.8.dist-info/METADATA,sha256=t8baHZrBTPkJi3Lr8ZHm0pbRKnelgO5AU7EGIeTvEcg,9948
idna-3.8.dist-info/RECORD,,
idna-3.8.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
idna/__init__.py,sha256=KJQN1eQBr8iIK5SKrJ47lXvxG0BJ7Lm38W4zT0v_8lk,849
idna/__pycache__/__init__.cpython-38.pyc,,
idna/__pycache__/codec.cpython-38.pyc,,
idna/__pycache__/compat.cpython-38.pyc,,
idna/__pycache__/core.cpython-38.pyc,,
idna/__pycache__/idnadata.cpython-38.pyc,,
idna/__pycache__/intranges.cpython-38.pyc,,
idna/__pycache__/package_data.cpython-38.pyc,,
idna/__pycache__/uts46data.cpython-38.pyc,,
idna/codec.py,sha256=PS6m-XmdST7Wj7J7ulRMakPDt5EBJyYrT3CPtjh-7t4,3426
idna/compat.py,sha256=0_sOEUMT4CVw9doD3vyRhX80X19PwqFoUBs7gWsFME4,321
idna/core.py,sha256=OHDXwDVbb3R1gNXjHw7JWeeE2rn2u3a-QV-KCeznYcA,12884
idna/idnadata.py,sha256=dqRwytzkjIHMBa2R1lYvHDwACenZPt8eGVu1Y8UBE-E,78320
idna/intranges.py,sha256=YBr4fRYuWH7kTKS2tXlFjM24ZF1Pdvcir-aywniInqg,1881
idna/package_data.py,sha256=DogtAD5vs_-I2Q0k3_ZA4egUq2YLJ4pBbbhI8APzOcY,21
idna/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
idna/uts46data.py,sha256=1KuksWqLuccPXm2uyRVkhfiFLNIhM_H2m4azCcnOqEU,206503

View File

@ -1,4 +1,3 @@
from .package_data import __version__
from .core import (
IDNABidiError,
IDNAError,
@ -20,8 +19,10 @@ from .core import (
valid_string_length,
)
from .intranges import intranges_contain
from .package_data import __version__
__all__ = [
"__version__",
"IDNABidiError",
"IDNAError",
"InvalidCodepoint",

View File

@ -1,49 +1,51 @@
from .core import encode, decode, alabel, ulabel, IDNAError
import codecs
import re
from typing import Any, Tuple, Optional
from typing import Any, Optional, Tuple
from .core import IDNAError, alabel, decode, encode, ulabel
_unicode_dots_re = re.compile("[\u002e\u3002\uff0e\uff61]")
_unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')
class Codec(codecs.Codec):
def encode(self, data: str, errors: str = 'strict') -> Tuple[bytes, int]:
if errors != 'strict':
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
def encode(self, data: str, errors: str = "strict") -> Tuple[bytes, int]:
if errors != "strict":
raise IDNAError('Unsupported error handling "{}"'.format(errors))
if not data:
return b"", 0
return encode(data), len(data)
def decode(self, data: bytes, errors: str = 'strict') -> Tuple[str, int]:
if errors != 'strict':
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
def decode(self, data: bytes, errors: str = "strict") -> Tuple[str, int]:
if errors != "strict":
raise IDNAError('Unsupported error handling "{}"'.format(errors))
if not data:
return '', 0
return "", 0
return decode(data), len(data)
class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
def _buffer_encode(self, data: str, errors: str, final: bool) -> Tuple[bytes, int]:
if errors != 'strict':
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
if errors != "strict":
raise IDNAError('Unsupported error handling "{}"'.format(errors))
if not data:
return b'', 0
return b"", 0
labels = _unicode_dots_re.split(data)
trailing_dot = b''
trailing_dot = b""
if labels:
if not labels[-1]:
trailing_dot = b'.'
trailing_dot = b"."
del labels[-1]
elif not final:
# Keep potentially unfinished label until the next call
del labels[-1]
if labels:
trailing_dot = b'.'
trailing_dot = b"."
result = []
size = 0
@ -54,32 +56,33 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
size += len(label)
# Join with U+002E
result_bytes = b'.'.join(result) + trailing_dot
result_bytes = b".".join(result) + trailing_dot
size += len(trailing_dot)
return result_bytes, size
class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
def _buffer_decode(self, data: Any, errors: str, final: bool) -> Tuple[str, int]:
if errors != 'strict':
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
if errors != "strict":
raise IDNAError('Unsupported error handling "{}"'.format(errors))
if not data:
return ('', 0)
return ("", 0)
if not isinstance(data, str):
data = str(data, 'ascii')
data = str(data, "ascii")
labels = _unicode_dots_re.split(data)
trailing_dot = ''
trailing_dot = ""
if labels:
if not labels[-1]:
trailing_dot = '.'
trailing_dot = "."
del labels[-1]
elif not final:
# Keep potentially unfinished label until the next call
del labels[-1]
if labels:
trailing_dot = '.'
trailing_dot = "."
result = []
size = 0
@ -89,7 +92,7 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
size += 1
size += len(label)
result_str = '.'.join(result) + trailing_dot
result_str = ".".join(result) + trailing_dot
size += len(trailing_dot)
return (result_str, size)
@ -103,7 +106,7 @@ class StreamReader(Codec, codecs.StreamReader):
def search_function(name: str) -> Optional[codecs.CodecInfo]:
if name != 'idna2008':
if name != "idna2008":
return None
return codecs.CodecInfo(
name=name,
@ -115,4 +118,5 @@ def search_function(name: str) -> Optional[codecs.CodecInfo]:
streamreader=StreamReader,
)
codecs.register(search_function)

View File

@ -1,13 +1,15 @@
from .core import *
from .codec import *
from typing import Any, Union
from .core import decode, encode
def ToASCII(label: str) -> bytes:
return encode(label)
def ToUnicode(label: Union[bytes, bytearray]) -> str:
return decode(label)
def nameprep(s: Any) -> None:
raise NotImplementedError('IDNA 2008 does not utilise nameprep protocol')
def nameprep(s: Any) -> None:
raise NotImplementedError("IDNA 2008 does not utilise nameprep protocol")

View File

@ -1,31 +1,37 @@
from . import idnadata
import bisect
import unicodedata
import re
from typing import Union, Optional
import unicodedata
from typing import Optional, Union
from . import idnadata
from .intranges import intranges_contain
_virama_combining_class = 9
_alabel_prefix = b'xn--'
_unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')
_alabel_prefix = b"xn--"
_unicode_dots_re = re.compile("[\u002e\u3002\uff0e\uff61]")
class IDNAError(UnicodeError):
""" Base exception for all IDNA-encoding related problems """
"""Base exception for all IDNA-encoding related problems"""
pass
class IDNABidiError(IDNAError):
""" Exception when bidirectional requirements are not satisfied """
"""Exception when bidirectional requirements are not satisfied"""
pass
class InvalidCodepoint(IDNAError):
""" Exception when a disallowed or unallocated codepoint is used """
"""Exception when a disallowed or unallocated codepoint is used"""
pass
class InvalidCodepointContext(IDNAError):
""" Exception when the codepoint is not valid in the context it is used """
"""Exception when the codepoint is not valid in the context it is used"""
pass
@ -33,17 +39,20 @@ def _combining_class(cp: int) -> int:
v = unicodedata.combining(chr(cp))
if v == 0:
if not unicodedata.name(chr(cp)):
raise ValueError('Unknown character in unicodedata')
raise ValueError("Unknown character in unicodedata")
return v
def _is_script(cp: str, script: str) -> bool:
return intranges_contain(ord(cp), idnadata.scripts[script])
def _punycode(s: str) -> bytes:
return s.encode('punycode')
return s.encode("punycode")
def _unot(s: int) -> str:
return 'U+{:04X}'.format(s)
return "U+{:04X}".format(s)
def valid_label_length(label: Union[bytes, str]) -> bool:
@ -61,96 +70,106 @@ def valid_string_length(label: Union[bytes, str], trailing_dot: bool) -> bool:
def check_bidi(label: str, check_ltr: bool = False) -> bool:
# Bidi rules should only be applied if string contains RTL characters
bidi_label = False
for (idx, cp) in enumerate(label, 1):
for idx, cp in enumerate(label, 1):
direction = unicodedata.bidirectional(cp)
if direction == '':
if direction == "":
# String likely comes from a newer version of Unicode
raise IDNABidiError('Unknown directionality in label {} at position {}'.format(repr(label), idx))
if direction in ['R', 'AL', 'AN']:
raise IDNABidiError("Unknown directionality in label {} at position {}".format(repr(label), idx))
if direction in ["R", "AL", "AN"]:
bidi_label = True
if not bidi_label and not check_ltr:
return True
# Bidi rule 1
direction = unicodedata.bidirectional(label[0])
if direction in ['R', 'AL']:
if direction in ["R", "AL"]:
rtl = True
elif direction == 'L':
elif direction == "L":
rtl = False
else:
raise IDNABidiError('First codepoint in label {} must be directionality L, R or AL'.format(repr(label)))
raise IDNABidiError("First codepoint in label {} must be directionality L, R or AL".format(repr(label)))
valid_ending = False
number_type = None # type: Optional[str]
for (idx, cp) in enumerate(label, 1):
number_type: Optional[str] = None
for idx, cp in enumerate(label, 1):
direction = unicodedata.bidirectional(cp)
if rtl:
# Bidi rule 2
if not direction in ['R', 'AL', 'AN', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
raise IDNABidiError('Invalid direction for codepoint at position {} in a right-to-left label'.format(idx))
if direction not in [
"R",
"AL",
"AN",
"EN",
"ES",
"CS",
"ET",
"ON",
"BN",
"NSM",
]:
raise IDNABidiError("Invalid direction for codepoint at position {} in a right-to-left label".format(idx))
# Bidi rule 3
if direction in ['R', 'AL', 'EN', 'AN']:
if direction in ["R", "AL", "EN", "AN"]:
valid_ending = True
elif direction != 'NSM':
elif direction != "NSM":
valid_ending = False
# Bidi rule 4
if direction in ['AN', 'EN']:
if direction in ["AN", "EN"]:
if not number_type:
number_type = direction
else:
if number_type != direction:
raise IDNABidiError('Can not mix numeral types in a right-to-left label')
raise IDNABidiError("Can not mix numeral types in a right-to-left label")
else:
# Bidi rule 5
if not direction in ['L', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
raise IDNABidiError('Invalid direction for codepoint at position {} in a left-to-right label'.format(idx))
if direction not in ["L", "EN", "ES", "CS", "ET", "ON", "BN", "NSM"]:
raise IDNABidiError("Invalid direction for codepoint at position {} in a left-to-right label".format(idx))
# Bidi rule 6
if direction in ['L', 'EN']:
if direction in ["L", "EN"]:
valid_ending = True
elif direction != 'NSM':
elif direction != "NSM":
valid_ending = False
if not valid_ending:
raise IDNABidiError('Label ends with illegal codepoint directionality')
raise IDNABidiError("Label ends with illegal codepoint directionality")
return True
def check_initial_combiner(label: str) -> bool:
if unicodedata.category(label[0])[0] == 'M':
raise IDNAError('Label begins with an illegal combining character')
if unicodedata.category(label[0])[0] == "M":
raise IDNAError("Label begins with an illegal combining character")
return True
def check_hyphen_ok(label: str) -> bool:
if label[2:4] == '--':
raise IDNAError('Label has disallowed hyphens in 3rd and 4th position')
if label[0] == '-' or label[-1] == '-':
raise IDNAError('Label must not start or end with a hyphen')
if label[2:4] == "--":
raise IDNAError("Label has disallowed hyphens in 3rd and 4th position")
if label[0] == "-" or label[-1] == "-":
raise IDNAError("Label must not start or end with a hyphen")
return True
def check_nfc(label: str) -> None:
if unicodedata.normalize('NFC', label) != label:
raise IDNAError('Label must be in Normalization Form C')
if unicodedata.normalize("NFC", label) != label:
raise IDNAError("Label must be in Normalization Form C")
def valid_contextj(label: str, pos: int) -> bool:
cp_value = ord(label[pos])
if cp_value == 0x200c:
if cp_value == 0x200C:
if pos > 0:
if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
return True
ok = False
for i in range(pos-1, -1, -1):
for i in range(pos - 1, -1, -1):
joining_type = idnadata.joining_types.get(ord(label[i]))
if joining_type == ord('T'):
if joining_type == ord("T"):
continue
elif joining_type in [ord('L'), ord('D')]:
elif joining_type in [ord("L"), ord("D")]:
ok = True
break
else:
@ -160,63 +179,61 @@ def valid_contextj(label: str, pos: int) -> bool:
return False
ok = False
for i in range(pos+1, len(label)):
for i in range(pos + 1, len(label)):
joining_type = idnadata.joining_types.get(ord(label[i]))
if joining_type == ord('T'):
if joining_type == ord("T"):
continue
elif joining_type in [ord('R'), ord('D')]:
elif joining_type in [ord("R"), ord("D")]:
ok = True
break
else:
break
return ok
if cp_value == 0x200d:
if cp_value == 0x200D:
if pos > 0:
if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
return True
return False
else:
return False
def valid_contexto(label: str, pos: int, exception: bool = False) -> bool:
cp_value = ord(label[pos])
if cp_value == 0x00b7:
if 0 < pos < len(label)-1:
if ord(label[pos - 1]) == 0x006c and ord(label[pos + 1]) == 0x006c:
if cp_value == 0x00B7:
if 0 < pos < len(label) - 1:
if ord(label[pos - 1]) == 0x006C and ord(label[pos + 1]) == 0x006C:
return True
return False
elif cp_value == 0x0375:
if pos < len(label)-1 and len(label) > 1:
return _is_script(label[pos + 1], 'Greek')
if pos < len(label) - 1 and len(label) > 1:
return _is_script(label[pos + 1], "Greek")
return False
elif cp_value == 0x05f3 or cp_value == 0x05f4:
elif cp_value == 0x05F3 or cp_value == 0x05F4:
if pos > 0:
return _is_script(label[pos - 1], 'Hebrew')
return _is_script(label[pos - 1], "Hebrew")
return False
elif cp_value == 0x30fb:
elif cp_value == 0x30FB:
for cp in label:
if cp == '\u30fb':
if cp == "\u30fb":
continue
if _is_script(cp, 'Hiragana') or _is_script(cp, 'Katakana') or _is_script(cp, 'Han'):
if _is_script(cp, "Hiragana") or _is_script(cp, "Katakana") or _is_script(cp, "Han"):
return True
return False
elif 0x660 <= cp_value <= 0x669:
for cp in label:
if 0x6f0 <= ord(cp) <= 0x06f9:
if 0x6F0 <= ord(cp) <= 0x06F9:
return False
return True
elif 0x6f0 <= cp_value <= 0x6f9:
elif 0x6F0 <= cp_value <= 0x6F9:
for cp in label:
if 0x660 <= ord(cp) <= 0x0669:
return False
@ -227,41 +244,49 @@ def valid_contexto(label: str, pos: int, exception: bool = False) -> bool:
def check_label(label: Union[str, bytes, bytearray]) -> None:
if isinstance(label, (bytes, bytearray)):
label = label.decode('utf-8')
label = label.decode("utf-8")
if len(label) == 0:
raise IDNAError('Empty Label')
raise IDNAError("Empty Label")
check_nfc(label)
check_hyphen_ok(label)
check_initial_combiner(label)
for (pos, cp) in enumerate(label):
for pos, cp in enumerate(label):
cp_value = ord(cp)
if intranges_contain(cp_value, idnadata.codepoint_classes['PVALID']):
if intranges_contain(cp_value, idnadata.codepoint_classes["PVALID"]):
continue
elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTJ']):
elif intranges_contain(cp_value, idnadata.codepoint_classes["CONTEXTJ"]):
try:
if not valid_contextj(label, pos):
raise InvalidCodepointContext('Joiner {} not allowed at position {} in {}'.format(
_unot(cp_value), pos+1, repr(label)))
raise InvalidCodepointContext(
"Joiner {} not allowed at position {} in {}".format(_unot(cp_value), pos + 1, repr(label))
)
except ValueError:
raise IDNAError('Unknown codepoint adjacent to joiner {} at position {} in {}'.format(
_unot(cp_value), pos+1, repr(label)))
elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTO']):
raise IDNAError(
"Unknown codepoint adjacent to joiner {} at position {} in {}".format(
_unot(cp_value), pos + 1, repr(label)
)
)
elif intranges_contain(cp_value, idnadata.codepoint_classes["CONTEXTO"]):
if not valid_contexto(label, pos):
raise InvalidCodepointContext('Codepoint {} not allowed at position {} in {}'.format(_unot(cp_value), pos+1, repr(label)))
raise InvalidCodepointContext(
"Codepoint {} not allowed at position {} in {}".format(_unot(cp_value), pos + 1, repr(label))
)
else:
raise InvalidCodepoint('Codepoint {} at position {} of {} not allowed'.format(_unot(cp_value), pos+1, repr(label)))
raise InvalidCodepoint(
"Codepoint {} at position {} of {} not allowed".format(_unot(cp_value), pos + 1, repr(label))
)
check_bidi(label)
def alabel(label: str) -> bytes:
try:
label_bytes = label.encode('ascii')
label_bytes = label.encode("ascii")
ulabel(label_bytes)
if not valid_label_length(label_bytes):
raise IDNAError('Label too long')
raise IDNAError("Label too long")
return label_bytes
except UnicodeEncodeError:
pass
@ -270,7 +295,7 @@ def alabel(label: str) -> bytes:
label_bytes = _alabel_prefix + _punycode(label)
if not valid_label_length(label_bytes):
raise IDNAError('Label too long')
raise IDNAError("Label too long")
return label_bytes
@ -278,7 +303,7 @@ def alabel(label: str) -> bytes:
def ulabel(label: Union[str, bytes, bytearray]) -> str:
if not isinstance(label, (bytes, bytearray)):
try:
label_bytes = label.encode('ascii')
label_bytes = label.encode("ascii")
except UnicodeEncodeError:
check_label(label)
return label
@ -287,19 +312,19 @@ def ulabel(label: Union[str, bytes, bytearray]) -> str:
label_bytes = label_bytes.lower()
if label_bytes.startswith(_alabel_prefix):
label_bytes = label_bytes[len(_alabel_prefix):]
label_bytes = label_bytes[len(_alabel_prefix) :]
if not label_bytes:
raise IDNAError('Malformed A-label, no Punycode eligible content found')
if label_bytes.decode('ascii')[-1] == '-':
raise IDNAError('A-label must not end with a hyphen')
raise IDNAError("Malformed A-label, no Punycode eligible content found")
if label_bytes.decode("ascii")[-1] == "-":
raise IDNAError("A-label must not end with a hyphen")
else:
check_label(label_bytes)
return label_bytes.decode('ascii')
return label_bytes.decode("ascii")
try:
label = label_bytes.decode('punycode')
label = label_bytes.decode("punycode")
except UnicodeError:
raise IDNAError('Invalid A-label')
raise IDNAError("Invalid A-label")
check_label(label)
return label
@ -307,52 +332,60 @@ def ulabel(label: Union[str, bytes, bytearray]) -> str:
def uts46_remap(domain: str, std3_rules: bool = True, transitional: bool = False) -> str:
"""Re-map the characters in the string according to UTS46 processing."""
from .uts46data import uts46data
output = ''
output = ""
for pos, char in enumerate(domain):
code_point = ord(char)
try:
uts46row = uts46data[code_point if code_point < 256 else
bisect.bisect_left(uts46data, (code_point, 'Z')) - 1]
uts46row = uts46data[code_point if code_point < 256 else bisect.bisect_left(uts46data, (code_point, "Z")) - 1]
status = uts46row[1]
replacement = None # type: Optional[str]
replacement: Optional[str] = None
if len(uts46row) == 3:
replacement = uts46row[2]
if (status == 'V' or
(status == 'D' and not transitional) or
(status == '3' and not std3_rules and replacement is None)):
if (
status == "V"
or (status == "D" and not transitional)
or (status == "3" and not std3_rules and replacement is None)
):
output += char
elif replacement is not None and (status == 'M' or
(status == '3' and not std3_rules) or
(status == 'D' and transitional)):
elif replacement is not None and (
status == "M" or (status == "3" and not std3_rules) or (status == "D" and transitional)
):
output += replacement
elif status != 'I':
elif status != "I":
raise IndexError()
except IndexError:
raise InvalidCodepoint(
'Codepoint {} not allowed at position {} in {}'.format(
_unot(code_point), pos + 1, repr(domain)))
"Codepoint {} not allowed at position {} in {}".format(_unot(code_point), pos + 1, repr(domain))
)
return unicodedata.normalize('NFC', output)
return unicodedata.normalize("NFC", output)
def encode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False, transitional: bool = False) -> bytes:
def encode(
s: Union[str, bytes, bytearray],
strict: bool = False,
uts46: bool = False,
std3_rules: bool = False,
transitional: bool = False,
) -> bytes:
if not isinstance(s, str):
try:
s = str(s, 'ascii')
s = str(s, "ascii")
except UnicodeDecodeError:
raise IDNAError('should pass a unicode string to the function rather than a byte string.')
raise IDNAError("should pass a unicode string to the function rather than a byte string.")
if uts46:
s = uts46_remap(s, std3_rules, transitional)
trailing_dot = False
result = []
if strict:
labels = s.split('.')
labels = s.split(".")
else:
labels = _unicode_dots_re.split(s)
if not labels or labels == ['']:
raise IDNAError('Empty domain')
if labels[-1] == '':
if not labels or labels == [""]:
raise IDNAError("Empty domain")
if labels[-1] == "":
del labels[-1]
trailing_dot = True
for label in labels:
@ -360,21 +393,26 @@ def encode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool =
if s:
result.append(s)
else:
raise IDNAError('Empty label')
raise IDNAError("Empty label")
if trailing_dot:
result.append(b'')
s = b'.'.join(result)
result.append(b"")
s = b".".join(result)
if not valid_string_length(s, trailing_dot):
raise IDNAError('Domain too long')
raise IDNAError("Domain too long")
return s
def decode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False) -> str:
def decode(
s: Union[str, bytes, bytearray],
strict: bool = False,
uts46: bool = False,
std3_rules: bool = False,
) -> str:
try:
if not isinstance(s, str):
s = str(s, 'ascii')
s = str(s, "ascii")
except UnicodeDecodeError:
raise IDNAError('Invalid ASCII in A-label')
raise IDNAError("Invalid ASCII in A-label")
if uts46:
s = uts46_remap(s, std3_rules, False)
trailing_dot = False
@ -382,9 +420,9 @@ def decode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool =
if not strict:
labels = _unicode_dots_re.split(s)
else:
labels = s.split('.')
if not labels or labels == ['']:
raise IDNAError('Empty domain')
labels = s.split(".")
if not labels or labels == [""]:
raise IDNAError("Empty domain")
if not labels[-1]:
del labels[-1]
trailing_dot = True
@ -393,7 +431,7 @@ def decode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool =
if s:
result.append(s)
else:
raise IDNAError('Empty label')
raise IDNAError("Empty label")
if trailing_dot:
result.append('')
return '.'.join(result)
result.append("")
return ".".join(result)

File diff suppressed because it is too large Load Diff

View File

@ -8,6 +8,7 @@ in the original list?" in time O(log(# runs)).
import bisect
from typing import List, Tuple
def intranges_from_list(list_: List[int]) -> Tuple[int, ...]:
"""Represent a list of integers as a sequence of ranges:
((start_0, end_0), (start_1, end_1), ...), such that the original
@ -20,18 +21,20 @@ def intranges_from_list(list_: List[int]) -> Tuple[int, ...]:
ranges = []
last_write = -1
for i in range(len(sorted_list)):
if i+1 < len(sorted_list):
if sorted_list[i] == sorted_list[i+1]-1:
if i + 1 < len(sorted_list):
if sorted_list[i] == sorted_list[i + 1] - 1:
continue
current_range = sorted_list[last_write+1:i+1]
current_range = sorted_list[last_write + 1 : i + 1]
ranges.append(_encode_range(current_range[0], current_range[-1] + 1))
last_write = i
return tuple(ranges)
def _encode_range(start: int, end: int) -> int:
return (start << 32) | end
def _decode_range(r: int) -> Tuple[int, int]:
return (r >> 32), (r & ((1 << 32) - 1))
@ -43,7 +46,7 @@ def intranges_contain(int_: int, ranges: Tuple[int, ...]) -> bool:
# we could be immediately ahead of a tuple (start, end)
# with start < int_ <= end
if pos > 0:
left, right = _decode_range(ranges[pos-1])
left, right = _decode_range(ranges[pos - 1])
if left <= int_ < right:
return True
# or we could be immediately behind a tuple (int_, end)

View File

@ -1,2 +1 @@
__version__ = '3.8'
__version__ = "3.10"

File diff suppressed because it is too large Load Diff

View File

@ -2,11 +2,20 @@
__author__ = "Andrew Dunham"
__license__ = "Apache"
__copyright__ = "Copyright (c) 2012-2013, Andrew Dunham"
__version__ = "0.0.9"
__version__ = "0.0.10"
from .multipart import FormParser, MultipartParser, OctetStreamParser, QuerystringParser, create_form_parser, parse_form
from .multipart import (
BaseParser,
FormParser,
MultipartParser,
OctetStreamParser,
QuerystringParser,
create_form_parser,
parse_form,
)
__all__ = (
"BaseParser",
"FormParser",
"MultipartParser",
"OctetStreamParser",

View File

@ -1,5 +1,6 @@
import base64
import binascii
from io import BufferedWriter
from .exceptions import DecodeError
@ -33,11 +34,11 @@ class Base64Decoder:
:param underlying: the underlying object to pass writes to
"""
def __init__(self, underlying):
def __init__(self, underlying: BufferedWriter):
self.cache = bytearray()
self.underlying = underlying
def write(self, data):
def write(self, data: bytes) -> int:
"""Takes any input data provided, decodes it as base64, and passes it
on to the underlying object. If the data provided is invalid base64
data, then this method will raise
@ -73,14 +74,14 @@ class Base64Decoder:
# Return the length of the data to indicate no error.
return len(data)
def close(self):
def close(self) -> None:
"""Close this decoder. If the underlying object has a `close()`
method, this function will call it.
"""
if hasattr(self.underlying, "close"):
self.underlying.close()
def finalize(self):
def finalize(self) -> None:
"""Finalize this object. This should be called when no more data
should be written to the stream. This function can raise a
:class:`multipart.exceptions.DecodeError` if there is some remaining
@ -97,7 +98,7 @@ class Base64Decoder:
if hasattr(self.underlying, "finalize"):
self.underlying.finalize()
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}(underlying={self.underlying!r})"
@ -111,11 +112,11 @@ class QuotedPrintableDecoder:
:param underlying: the underlying object to pass writes to
"""
def __init__(self, underlying):
def __init__(self, underlying: BufferedWriter) -> None:
self.cache = b""
self.underlying = underlying
def write(self, data):
def write(self, data: bytes) -> int:
"""Takes any input data provided, decodes it as quoted-printable, and
passes it on to the underlying object.
@ -142,14 +143,14 @@ class QuotedPrintableDecoder:
self.cache = rest
return len(data)
def close(self):
def close(self) -> None:
"""Close this decoder. If the underlying object has a `close()`
method, this function will call it.
"""
if hasattr(self.underlying, "close"):
self.underlying.close()
def finalize(self):
def finalize(self) -> None:
"""Finalize this object. This should be called when no more data
should be written to the stream. This function will not raise any
exceptions, but it may write more data to the underlying object if
@ -167,5 +168,5 @@ class QuotedPrintableDecoder:
if hasattr(self.underlying, "finalize"):
self.underlying.finalize()
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}(underlying={self.underlying!r})"

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,7 @@ pydantic-1.10.14.dist-info/LICENSE,sha256=njlGaQrIi2tz6PABoFhq8TVovohS_VFOQ5Pzl2
pydantic-1.10.14.dist-info/METADATA,sha256=lYLXr7lOF7BMRkokRnhTeJODH7dv_zlYVTKfw6M_rNg,150189
pydantic-1.10.14.dist-info/RECORD,,
pydantic-1.10.14.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
pydantic-1.10.14.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
pydantic-1.10.14.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
pydantic-1.10.14.dist-info/entry_points.txt,sha256=EquH5n3pilIXg-LLa1K4evpu5-6dnvxzi6vwvkoAMns,45
pydantic-1.10.14.dist-info/top_level.txt,sha256=cmo_5n0F_YY5td5nPZBfdjBENkmGg_pE5ShWXYbXxTM,9
pydantic/__init__.py,sha256=iTu8CwWWvn6zM_zYJtqhie24PImW25zokitz_06kDYw,2771

View File

@ -1,5 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (73.0.1)
Generator: setuptools (74.1.2)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@ -0,0 +1,49 @@
Metadata-Version: 2.3
Name: python-multipart
Version: 0.0.10
Summary: A streaming multipart parser for Python
Project-URL: Homepage, https://github.com/Kludex/python-multipart
Project-URL: Documentation, https://kludex.github.io/python-multipart/
Project-URL: Changelog, https://github.com/Kludex/python-multipart/blob/master/CHANGELOG.md
Project-URL: Source, https://github.com/Kludex/python-multipart
Author-email: Andrew Dunham <andrew@du.nham.ca>, Marcelo Trylesinski <marcelotryle@gmail.com>
License-Expression: Apache-2.0
License-File: LICENSE.txt
Classifier: Development Status :: 5 - Production/Stable
Classifier: Environment :: Web Environment
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.8
Description-Content-Type: text/markdown
# [Python-Multipart](https://kludex.github.io/python-multipart/)
[![Package version](https://badge.fury.io/py/python-multipart.svg)](https://pypi.python.org/pypi/python-multipart)
[![Supported Python Version](https://img.shields.io/pypi/pyversions/python-multipart.svg?color=%2334D058)](https://pypi.org/project/python-multipart)
---
`python-multipart` is an Apache2-licensed streaming multipart parser for Python.
Test coverage is currently 100%.
## Why?
Because streaming uploads are awesome for large files.
## How to Test
If you want to test:
```bash
$ pip install '.[dev]'
$ inv test
```

View File

@ -0,0 +1,13 @@
multipart/__init__.py,sha256=q2w9JPTlzBOOBSIPOIKi0X7xEUByFKhbYjtPandI3c8,512
multipart/__pycache__/__init__.cpython-38.pyc,,
multipart/__pycache__/decoders.cpython-38.pyc,,
multipart/__pycache__/exceptions.cpython-38.pyc,,
multipart/__pycache__/multipart.cpython-38.pyc,,
multipart/decoders.py,sha256=5y_47RFs3mPVJbePcNe1__nGdQche0PghI_UsIKz3po,6182
multipart/exceptions.py,sha256=a9buSOv_eiHZoukEJhdWX9LJYSJ6t7XOK3ZEaWoQZlk,992
multipart/multipart.py,sha256=2hXvVTSEJ7PZOXlTMnxSVhlhT535QNT6neZVmgt3Beg,73780
python_multipart-0.0.10.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
python_multipart-0.0.10.dist-info/METADATA,sha256=pEXCzLDD0MPgQLffeich5yjgf87NxMSAkA5r_7FZ_nQ,1902
python_multipart-0.0.10.dist-info/RECORD,,
python_multipart-0.0.10.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
python_multipart-0.0.10.dist-info/licenses/LICENSE.txt,sha256=qOgzF2zWF9rwC51tOfoVyo7evG0WQwec0vSJPAwom-I,556

View File

@ -1,4 +1,4 @@
Wheel-Version: 1.0
Generator: hatchling 1.21.1
Generator: hatchling 1.25.0
Root-Is-Purelib: true
Tag: py3-none-any

View File

@ -1,70 +0,0 @@
Metadata-Version: 2.1
Name: python-multipart
Version: 0.0.9
Summary: A streaming multipart parser for Python
Project-URL: Homepage, https://github.com/andrew-d/python-multipart
Project-URL: Documentation, https://andrew-d.github.io/python-multipart/
Project-URL: Changelog, https://github.com/andrew-d/python-multipart/blob/master/CHANGELOG.md
Project-URL: Source, https://github.com/andrew-d/python-multipart
Author-email: Andrew Dunham <andrew@du.nham.ca>
License-Expression: Apache-2.0
License-File: LICENSE.txt
Classifier: Development Status :: 5 - Production/Stable
Classifier: Environment :: Web Environment
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.8
Provides-Extra: dev
Requires-Dist: atomicwrites==1.4.1; extra == 'dev'
Requires-Dist: attrs==23.2.0; extra == 'dev'
Requires-Dist: coverage==7.4.1; extra == 'dev'
Requires-Dist: hatch; extra == 'dev'
Requires-Dist: invoke==2.2.0; extra == 'dev'
Requires-Dist: more-itertools==10.2.0; extra == 'dev'
Requires-Dist: pbr==6.0.0; extra == 'dev'
Requires-Dist: pluggy==1.4.0; extra == 'dev'
Requires-Dist: py==1.11.0; extra == 'dev'
Requires-Dist: pytest-cov==4.1.0; extra == 'dev'
Requires-Dist: pytest-timeout==2.2.0; extra == 'dev'
Requires-Dist: pytest==8.0.0; extra == 'dev'
Requires-Dist: pyyaml==6.0.1; extra == 'dev'
Requires-Dist: ruff==0.2.1; extra == 'dev'
Description-Content-Type: text/x-rst
==================
Python-Multipart
==================
.. image:: https://github.com/andrew-d/python-multipart/actions/workflows/test.yaml/badge.svg
:target: https://github.com/andrew-d/python-multipart/actions
python-multipart is an Apache2 licensed streaming multipart parser for Python.
Test coverage is currently 100%.
Documentation is available `here`_.
.. _here: https://andrew-d.github.io/python-multipart/
Why?
----
Because streaming uploads are awesome for large files.
How to Test
-----------
If you want to test:
.. code-block:: bash
$ pip install '.[dev]'
$ inv test

View File

@ -1,13 +0,0 @@
multipart/__init__.py,sha256=Z_EnZFoG_zmw22n7BomPmnVSCCl4XVNDYVWL8hry6Sc,448
multipart/__pycache__/__init__.cpython-38.pyc,,
multipart/__pycache__/decoders.cpython-38.pyc,,
multipart/__pycache__/exceptions.cpython-38.pyc,,
multipart/__pycache__/multipart.cpython-38.pyc,,
multipart/decoders.py,sha256=A4SQHOwFRNzCfr5Fx0iOYpS8USTxE9ofYbL9kOJywHs,6038
multipart/exceptions.py,sha256=a9buSOv_eiHZoukEJhdWX9LJYSJ6t7XOK3ZEaWoQZlk,992
multipart/multipart.py,sha256=sThvJ7TSPQc1tjCfXMgqQduT3yKsmxhwzNZnmL5S7KY,72841
python_multipart-0.0.9.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
python_multipart-0.0.9.dist-info/METADATA,sha256=5BDtS_h0qAxUX55VWiuCb7FZj6WZRd51l6YenNy4Its,2528
python_multipart-0.0.9.dist-info/RECORD,,
python_multipart-0.0.9.dist-info/WHEEL,sha256=TJPnKdtrSue7xZ_AVGkp9YXcvDrobsjBds1du3Nx6dc,87
python_multipart-0.0.9.dist-info/licenses/LICENSE.txt,sha256=qOgzF2zWF9rwC51tOfoVyo7evG0WQwec0vSJPAwom-I,556

View File

@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: rich
Version: 13.8.0
Version: 13.8.1
Summary: Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal
Home-page: https://github.com/Textualize/rich
License: MIT
@ -22,6 +22,7 @@ Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Typing :: Typed
Provides-Extra: jupyter
Requires-Dist: ipywidgets (>=7.5.1,<9) ; extra == "jupyter"

View File

@ -1,8 +1,8 @@
rich-13.8.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
rich-13.8.0.dist-info/LICENSE,sha256=3u18F6QxgVgZCj6iOcyHmlpQJxzruYrnAl9I--WNyhU,1056
rich-13.8.0.dist-info/METADATA,sha256=Z0qmRupyrFfPHrIqUql-EF3Vf1EOJ1gKQH4PECmJRw0,18272
rich-13.8.0.dist-info/RECORD,,
rich-13.8.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
rich-13.8.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
rich-13.8.1.dist-info/LICENSE,sha256=3u18F6QxgVgZCj6iOcyHmlpQJxzruYrnAl9I--WNyhU,1056
rich-13.8.1.dist-info/METADATA,sha256=LUNl6wIGo9RlVR97ZMGhisWEApsSW5cOAtvI-zN39M4,18323
rich-13.8.1.dist-info/RECORD,,
rich-13.8.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
rich/__init__.py,sha256=lh2WcoIOJp5M5_lbAsSUMGv8oiJeumROazHH_AYMS8I,6066
rich/__main__.py,sha256=Wvh53rmOMyWeUeyqUHpn1PXsHlBc4TVcQnqrw46nf9Y,8333
rich/__pycache__/__init__.cpython-38.pyc,,
@ -155,7 +155,7 @@ rich/styled.py,sha256=wljVsVTXbABMMZvkzkO43ZEk_-irzEtvUiQ-sNnikQ8,1234
rich/syntax.py,sha256=hVZcSKZ2RP3EErgY5MGK1WpWHd0x3ZtLCqod_EqgYTM,35302
rich/table.py,sha256=Et3SB0mggJTIXGFOvShX8L9MDpehXJ2GexY6Lmo0Zaw,40041
rich/terminal_theme.py,sha256=1j5-ufJfnvlAo5Qsi_ACZiXDmwMXzqgmFByObT9-yJY,3370
rich/text.py,sha256=hldsjZVTBJZ40tvCsRaYs-EX_pkNNw08I4q_hctElQI,47351
rich/text.py,sha256=0-7de80WdltOq58Kmz4nwMwjx4bbc87qe_qHbZXCH0Q,47365
rich/theme.py,sha256=oNyhXhGagtDlbDye3tVu3esWOWk0vNkuxFw-_unlaK0,3771
rich/themes.py,sha256=0xgTLozfabebYtcJtDdC5QkX5IVUEaviqDUJJh4YVFk,102
rich/traceback.py,sha256=B76Q53tX9gef_DSF2Y0X1UJkEHlRGjuDElUpAe3LgIM,30027

View File

@ -998,7 +998,7 @@ class Text(JupyterMixin):
self._text.append(text.plain)
self._spans.extend(
_Span(start + text_length, end + text_length, style)
for start, end, style in text._spans
for start, end, style in text._spans.copy()
)
self._length += len(text)
return self
@ -1020,7 +1020,7 @@ class Text(JupyterMixin):
self._text.append(text.plain)
self._spans.extend(
_Span(start + text_length, end + text_length, style)
for start, end, style in text._spans
for start, end, style in text._spans.copy()
)
self._length += len(text)
return self

View File

@ -3,7 +3,7 @@ tiktoken-0.7.0.dist-info/LICENSE,sha256=QYy0mbQ2Eo1lPXmUEzOlQ3t74uqSE9zC8E0V1dLF
tiktoken-0.7.0.dist-info/METADATA,sha256=Y8pRXptX5mYKr-PpPwd8Ppsp4MdArChHX6PqWLL2GFA,6595
tiktoken-0.7.0.dist-info/RECORD,,
tiktoken-0.7.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
tiktoken-0.7.0.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
tiktoken-0.7.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
tiktoken-0.7.0.dist-info/direct_url.json,sha256=A8eZHYYbfBMLJxQ_oI7XGcXTSJIWbjiNhps2PIzNXPE,138
tiktoken-0.7.0.dist-info/top_level.txt,sha256=54G5MceQnuD7EXvp7jzGxDDapA1iOwsh77jhCN9WKkc,22
tiktoken/__init__.py,sha256=FNmz8KgZfaG62vRgMMkTL9jj0a2AI7JGV1b-RZ29_tY,322

View File

@ -1,5 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (74.1.2)
Generator: setuptools (75.1.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: websockets
Version: 13.0.1
Version: 13.1
Summary: An implementation of the WebSocket Protocol (RFC 6455 & 7692)
Author-email: Aymeric Augustin <aymeric.augustin@m4x.org>
License: BSD-3-Clause
@ -24,6 +24,7 @@ Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Requires-Python: >=3.8
Description-Content-Type: text/x-rst
License-File: LICENSE
.. image:: logo/horizontal.svg

View File

@ -1,10 +1,10 @@
websockets-13.0.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
websockets-13.0.1.dist-info/LICENSE,sha256=PWoMBQ2L7FL6utUC5F-yW9ArytvXDeo01Ee2oP9Obag,1514
websockets-13.0.1.dist-info/METADATA,sha256=0Q5-GuI6Wbmb45tJw8z0cVRJQM3QKB87Yb6XUoZg9NU,6742
websockets-13.0.1.dist-info/RECORD,,
websockets-13.0.1.dist-info/WHEEL,sha256=2-wOIdov8OETygfr6JjpwxZ3R_9CXxWhMBD1nM6bmy8,216
websockets-13.0.1.dist-info/top_level.txt,sha256=CMpdKklxKsvZgCgyltxUWOHibZXZ1uYIVpca9xsQ8Hk,11
websockets/__init__.py,sha256=EDW5qZk8Dt656-GB5qlnxH3jfSde-Nh8yNzmBY0MPH4,5664
websockets-13.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
websockets-13.1.dist-info/LICENSE,sha256=PWoMBQ2L7FL6utUC5F-yW9ArytvXDeo01Ee2oP9Obag,1514
websockets-13.1.dist-info/METADATA,sha256=dYWGin2sQircUMT-_HkCtQkc5_LTzdQ7BLwIGDTHbbI,6777
websockets-13.1.dist-info/RECORD,,
websockets-13.1.dist-info/WHEEL,sha256=kEv-8RuW3g3tDaovfFYwwvLeK_2lL8t8SExFktY7GL0,216
websockets-13.1.dist-info/top_level.txt,sha256=CMpdKklxKsvZgCgyltxUWOHibZXZ1uYIVpca9xsQ8Hk,11
websockets/__init__.py,sha256=UlYOZWjPPdgEtBFq4CP5t7Kd1Jjq-iMJT62Ya9ImDSo,5936
websockets/__main__.py,sha256=q6tBA72COhz7NUkuP_VG9IVypJjOexx2Oi7qkKNxneg,4756
websockets/__pycache__/__init__.cpython-38.pyc,,
websockets/__pycache__/__main__.cpython-38.pyc,,
@ -34,48 +34,50 @@ websockets/asyncio/__pycache__/connection.cpython-38.pyc,,
websockets/asyncio/__pycache__/messages.cpython-38.pyc,,
websockets/asyncio/__pycache__/server.cpython-38.pyc,,
websockets/asyncio/async_timeout.py,sha256=N-6Mubyiaoh66PAXGvCzhgxCM-7V2XiRnH32Xi6J6TE,8971
websockets/asyncio/client.py,sha256=XQy1HhHm26WYyZj3wXUeX889DnwqOcenjDON-tdt9AE,13547
websockets/asyncio/client.py,sha256=Kx9L-AYQUlMRAyo0d2DjuebggcM-rogx3JB26rEebY4,21700
websockets/asyncio/compatibility.py,sha256=gkenDDhzNbm6_iXV5Edvbvp6uHZYdrTvGNjt8P_JtyQ,786
websockets/asyncio/connection.py,sha256=S1LNn_vnhw_Nz5yZ3IDgQ_SmQA1qg0hKDc6xByAkGfU,43514
websockets/asyncio/messages.py,sha256=l3sV10tl0jzYqhrPBiQsml_PBFPHAmjjbSHql3lOyow,9784
websockets/asyncio/server.py,sha256=yqh_tjtv-E9bFRG7Ns5OCmzWi8SeIJP0Yf_YpVZRvp8,31028
websockets/asyncio/connection.py,sha256=sxX1WTz2iVxCsUvLJUoogJT9SdHsHU4ut2PuSIbxVs4,44475
websockets/asyncio/messages.py,sha256=-sS9JCa4-aFVSv0sPJd_VtGcoADj8mE0sMxfsqW-rQw,9854
websockets/asyncio/server.py,sha256=In45P1Ng2gznGMbnwuz3brlIAsZkSel0ScshrJZSMw8,36548
websockets/auth.py,sha256=pCeunT3V2AdwRt_Tpq9TrkdGY7qUlDHIEqeggj5yQFk,262
websockets/client.py,sha256=GXcwcgVCzxlvmgfiuVjbaddmATmnKntbmXeyBB2GQr4,12484
websockets/connection.py,sha256=BQn7ws-u7hmsGhrPZmxLrbPqpoL3MrrcbdeUc93zZkE,288
websockets/datastructures.py,sha256=oGbm3ZjVx3BwrYfgyPPcTtFO8qbgCsFVApMNGTRTh2c,5681
websockets/exceptions.py,sha256=IraXIMjjSOvBuikGolls6AuOFOeYXRO4VryT091Dm8E,10248
websockets/client.py,sha256=cc8y1I2Firs1JRXCfgD4j2JWnneYAuQSpGWNjrhkqFY,13541
websockets/connection.py,sha256=OLiMVkNd25_86sB8Q7CrCwBoXy9nA0OCgdgLRA8WUR8,323
websockets/datastructures.py,sha256=s5Rkipz4n15HSZsOrs64CoCs-_3oSBCgpe9uPvztDkY,5677
websockets/exceptions.py,sha256=b2-QiL1pszljREQQCzbPE1Fv7-Xb-uwso2Zt6LLD10A,10594
websockets/extensions/__init__.py,sha256=QkZsxaJVllVSp1uhdD5uPGibdbx_091GrVVfS5LXcpw,98
websockets/extensions/__pycache__/__init__.cpython-38.pyc,,
websockets/extensions/__pycache__/base.cpython-38.pyc,,
websockets/extensions/__pycache__/permessage_deflate.cpython-38.pyc,,
websockets/extensions/base.py,sha256=sMjmUfov0-woUmT4MiIuwBjj4DAJu0l3gfniPHn67Ec,2952
websockets/extensions/permessage_deflate.py,sha256=Az-zpU9eJYlMi4mdA1Jx0qKs0FRzrjbp3MO9iXzKHIY,24716
websockets/frames.py,sha256=-chQjYwtdTlTdosSfnvxRyXWSfqRfLQmL0Bo46x2Rok,14111
websockets/headers.py,sha256=LkuZIpztb-llYblr9GrJK0SAA8fJz7JhErMj-aAzJMk,16109
websockets/http.py,sha256=id8BzOG4AtqIeFwKHB4dzc6gxg45vcsdSdWUCvWMDvk,447
websockets/http11.py,sha256=OI0U0yj8CDOT7h0IfHzRF1bIyY9zjeBTyDiicgfAasU,13255
websockets/extensions/base.py,sha256=jsSJnO47L2VxYzx0cZ_LLQcAyUudSDgJEtKN247H-38,2890
websockets/extensions/permessage_deflate.py,sha256=JR9s7pvAJv2zWRUfOLysOtAiO-ovgRMqnSUpb92gohI,24661
websockets/frames.py,sha256=H-4ULsevYdna_CjalVASRPlh2Z54NoQat_vq8P4cVfc,12765
websockets/headers.py,sha256=9OHHZvaj4hXrofi0HuJFNYJaE0yRoPmmrBYxMaDuCTs,15931
websockets/http.py,sha256=eWitbqWAmHeqYK4OF3JLRC4lWwI1OIeft7oY3OobXvc,481
websockets/http11.py,sha256=-TNxOVVLr0050-0Ac3jOlWt5G9HAfkHZrt8dqoto9bs,13376
websockets/imports.py,sha256=TNONfYXO1UPExiwCVMgmg77fH5b4nyNAKcqtTg0gO2I,2768
websockets/legacy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
websockets/legacy/__pycache__/__init__.cpython-38.pyc,,
websockets/legacy/__pycache__/auth.cpython-38.pyc,,
websockets/legacy/__pycache__/client.cpython-38.pyc,,
websockets/legacy/__pycache__/exceptions.cpython-38.pyc,,
websockets/legacy/__pycache__/framing.cpython-38.pyc,,
websockets/legacy/__pycache__/handshake.cpython-38.pyc,,
websockets/legacy/__pycache__/http.cpython-38.pyc,,
websockets/legacy/__pycache__/protocol.cpython-38.pyc,,
websockets/legacy/__pycache__/server.cpython-38.pyc,,
websockets/legacy/auth.py,sha256=UdK0eZg1TjMGY6iEVRbBn51M9AjpSyRv2lJbvuuI6aA,6567
websockets/legacy/client.py,sha256=qZcQgbBnSW2ha9IpzvRDGZNKwzXNJyVpB9ac86_y1Zk,26345
websockets/legacy/framing.py,sha256=2itTWkMHXzjleywjyQb39P__75ThMX1AX3Ly_AU94lk,4972
websockets/legacy/handshake.py,sha256=71rqsWT1-wZ8HaLorK1ac47atdmgv8q6YXV-qosjqxs,5445
websockets/legacy/client.py,sha256=iuyFib2kX5ybK9vLVpqJRNJHa4BuA0u5MLyoNnartY4,26706
websockets/legacy/exceptions.py,sha256=DbSHBKcDEoYoUeCxURo0cnH8PyCCKYzkTboP_tOtsxw,1967
websockets/legacy/framing.py,sha256=ALEDiBNq17FUqNEe5LHxkPxWoY6tPwffgGFiHMdnnIs,6371
websockets/legacy/handshake.py,sha256=2Nzr5AN2xvDC5EdNP-kB3lOcrAaUNlYuj_-hr_jv7pM,5285
websockets/legacy/http.py,sha256=cOCQmDWhIKQmm8UWGXPW7CDZg03wjogCsb0LP9oetNQ,7061
websockets/legacy/protocol.py,sha256=753BX2MRJPsMx4vgkCt0g9tRqFJHxi7Nx_CzuL05iKc,63689
websockets/legacy/server.py,sha256=qI9M9cuP-7CzDjMC5odivJNesbvNOEhj2G6b3-wEHJw,45120
websockets/protocol.py,sha256=zpfgfI82kfouh3JxONEjTrsEEknvNdKzG5Sq4bcNol0,24938
websockets/legacy/protocol.py,sha256=Rbk88lnbghWpcEBT-TuTtAGqDy9OA7VsUFEMUcv95RM,63681
websockets/legacy/server.py,sha256=lb26Vm_y7biVVYLyVzG9R1BiaLmuS3TrQh-LesjO4Ss,45318
websockets/protocol.py,sha256=yl1j9ecLShF0iTRALOTzFfq0KmW5XO74Mtk0baVkvo0,25512
websockets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
websockets/server.py,sha256=fQn2lrVEzrKoQZfcz-WZ7c_DswuqgdhBj7B1mxCIw20,21354
websockets/speedups.c,sha256=ghPq-NF35VLVNkMv0uFDIruNpVISyW-qvoZgPpE65qw,5834
websockets/speedups.cpython-38-x86_64-linux-gnu.so,sha256=UxyakwjXRx4kYsOleAxkc8vXBqt89fv5rhzSAlI13W8,34072
websockets/server.py,sha256=BVoC433LZUgKVndtaYPrndB7uf_FTuG7MXrM9QHJEzk,21275
websockets/speedups.c,sha256=j-damnT02MKRoYw8MtTT45qLGX6z6TnriqhTkyfcNZE,5767
websockets/speedups.cpython-38-x86_64-linux-gnu.so,sha256=FRW0JugiQU4471Sd8Yergmr8u39ELoI5T9SIrQJ2Uqs,34072
websockets/speedups.pyi,sha256=NikZ3sAxs9Z2uWH_ZvctvMJUBbsHeC2D1L954EVSwJc,55
websockets/streams.py,sha256=3K3FcgTcXon-51P0sVyz0G4J-H51L82SVMS--W-gl6g,4038
websockets/sync/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@ -85,12 +87,12 @@ websockets/sync/__pycache__/connection.cpython-38.pyc,,
websockets/sync/__pycache__/messages.cpython-38.pyc,,
websockets/sync/__pycache__/server.cpython-38.pyc,,
websockets/sync/__pycache__/utils.cpython-38.pyc,,
websockets/sync/client.py,sha256=0Jw-jI7WvkuKzZ9_JNtrWBSu0eqt8JuIxfQjxzXpftY,11462
websockets/sync/connection.py,sha256=jTHj0OKGxrk_aowooLRnGbqiytlnfb8bTXhjPK3yx9g,29759
websockets/sync/messages.py,sha256=pDDKG7OHWcqXF9zzhm3TGYD92cfIaP1wNEX3U14b3ns,9739
websockets/sync/server.py,sha256=-gowaOvtLzzt9uhKykcrgHgyfgoeo7945PxptTG5G9c,20531
websockets/sync/client.py,sha256=QWs2wYU7S8--CZgFYWUliWjSYX5zrJEFQ6_gEcrW1sA,11372
websockets/sync/connection.py,sha256=Ve2aW760xPz8nXU56TuL7M3qV1THYmZZcfoS_0Wwh0c,30684
websockets/sync/messages.py,sha256=K-VHhUERUsS6bOaLgTox4kShnUKt8aPmWgOdqj_4E-Y,9809
websockets/sync/server.py,sha256=WutnccxDQWJNfPsX2WthvDr0QeVn36fUpf0MKmbeXY0,25608
websockets/sync/utils.py,sha256=TtW-ncYFvJmiSW2gO86ngE2BVsnnBdL-4H88kWNDYbg,1107
websockets/typing.py,sha256=b9F78aYY-sDNnIgSbvV_ApVBicVJdduLGv5wU0PVB5c,2157
websockets/uri.py,sha256=V7Qzs4ZvMUIYVqr0HDoIpsg7ah8iRGVmr1A8Vp91lPg,3159
websockets/uri.py,sha256=1r8dXNEiLcdMrCrzXmsy7DwSHiF3gaOWlmAdoFexOOM,3125
websockets/utils.py,sha256=ZpH3WJLsQS29Jf5R6lTacxf_hPd8E4zS2JmGyNpg4bA,1150
websockets/version.py,sha256=mYdU1lDxfvmRjqzZIlYCJatBinMmAlaoeb-AXPnPyeo,3204
websockets/version.py,sha256=M0HSppy6IqnAdAr0McbPGkyCuBlue4Uzigc78cOWHxs,3202

View File

@ -1,5 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (74.0.0)
Generator: setuptools (75.1.0)
Root-Is-Purelib: false
Tag: cp38-cp38-manylinux_2_5_x86_64
Tag: cp38-cp38-manylinux1_x86_64

View File

@ -14,7 +14,7 @@ __all__ = [
"HeadersLike",
"MultipleValuesError",
# .exceptions
"AbortHandshake",
"ConcurrencyError",
"ConnectionClosed",
"ConnectionClosedError",
"ConnectionClosedOK",
@ -23,19 +23,16 @@ __all__ = [
"InvalidHeader",
"InvalidHeaderFormat",
"InvalidHeaderValue",
"InvalidMessage",
"InvalidOrigin",
"InvalidParameterName",
"InvalidParameterValue",
"InvalidState",
"InvalidStatus",
"InvalidStatusCode",
"InvalidUpgrade",
"InvalidURI",
"NegotiationError",
"PayloadTooBig",
"ProtocolError",
"RedirectHandshake",
"SecurityError",
"WebSocketException",
"WebSocketProtocolError",
@ -46,6 +43,11 @@ __all__ = [
"WebSocketClientProtocol",
"connect",
"unix_connect",
# .legacy.exceptions
"AbortHandshake",
"InvalidMessage",
"InvalidStatusCode",
"RedirectHandshake",
# .legacy.protocol
"WebSocketCommonProtocol",
# .legacy.server
@ -71,7 +73,7 @@ if typing.TYPE_CHECKING:
from .client import ClientProtocol
from .datastructures import Headers, HeadersLike, MultipleValuesError
from .exceptions import (
AbortHandshake,
ConcurrencyError,
ConnectionClosed,
ConnectionClosedError,
ConnectionClosedOK,
@ -80,19 +82,16 @@ if typing.TYPE_CHECKING:
InvalidHeader,
InvalidHeaderFormat,
InvalidHeaderValue,
InvalidMessage,
InvalidOrigin,
InvalidParameterName,
InvalidParameterValue,
InvalidState,
InvalidStatus,
InvalidStatusCode,
InvalidUpgrade,
InvalidURI,
NegotiationError,
PayloadTooBig,
ProtocolError,
RedirectHandshake,
SecurityError,
WebSocketException,
WebSocketProtocolError,
@ -102,6 +101,12 @@ if typing.TYPE_CHECKING:
basic_auth_protocol_factory,
)
from .legacy.client import WebSocketClientProtocol, connect, unix_connect
from .legacy.exceptions import (
AbortHandshake,
InvalidMessage,
InvalidStatusCode,
RedirectHandshake,
)
from .legacy.protocol import WebSocketCommonProtocol
from .legacy.server import (
WebSocketServer,
@ -131,7 +136,7 @@ else:
"HeadersLike": ".datastructures",
"MultipleValuesError": ".datastructures",
# .exceptions
"AbortHandshake": ".exceptions",
"ConcurrencyError": ".exceptions",
"ConnectionClosed": ".exceptions",
"ConnectionClosedError": ".exceptions",
"ConnectionClosedOK": ".exceptions",
@ -140,19 +145,16 @@ else:
"InvalidHeader": ".exceptions",
"InvalidHeaderFormat": ".exceptions",
"InvalidHeaderValue": ".exceptions",
"InvalidMessage": ".exceptions",
"InvalidOrigin": ".exceptions",
"InvalidParameterName": ".exceptions",
"InvalidParameterValue": ".exceptions",
"InvalidState": ".exceptions",
"InvalidStatus": ".exceptions",
"InvalidStatusCode": ".exceptions",
"InvalidUpgrade": ".exceptions",
"InvalidURI": ".exceptions",
"NegotiationError": ".exceptions",
"PayloadTooBig": ".exceptions",
"ProtocolError": ".exceptions",
"RedirectHandshake": ".exceptions",
"SecurityError": ".exceptions",
"WebSocketException": ".exceptions",
"WebSocketProtocolError": ".exceptions",
@ -163,6 +165,11 @@ else:
"WebSocketClientProtocol": ".legacy.client",
"connect": ".legacy.client",
"unix_connect": ".legacy.client",
# .legacy.exceptions
"AbortHandshake": ".legacy.exceptions",
"InvalidMessage": ".legacy.exceptions",
"InvalidStatusCode": ".legacy.exceptions",
"RedirectHandshake": ".legacy.exceptions",
# .legacy.protocol
"WebSocketCommonProtocol": ".legacy.protocol",
# .legacy.server
@ -179,10 +186,11 @@ else:
"ExtensionParameter": ".typing",
"LoggerLike": ".typing",
"Origin": ".typing",
"StatusLike": "typing",
"StatusLike": ".typing",
"Subprotocol": ".typing",
},
deprecated_aliases={
# deprecated in 9.0 - 2021-09-01
"framing": ".legacy",
"handshake": ".legacy",
"parse_uri": ".uri",

View File

@ -1,24 +1,30 @@
from __future__ import annotations
import asyncio
import logging
import os
import urllib.parse
from types import TracebackType
from typing import Any, Generator, Sequence
from typing import Any, AsyncIterator, Callable, Generator, Sequence
from ..client import ClientProtocol
from ..client import ClientProtocol, backoff
from ..datastructures import HeadersLike
from ..exceptions import InvalidStatus, SecurityError
from ..extensions.base import ClientExtensionFactory
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import validate_subprotocols
from ..http11 import USER_AGENT, Response
from ..protocol import CONNECTING, Event
from ..typing import LoggerLike, Origin, Subprotocol
from ..uri import parse_uri
from ..uri import WebSocketURI, parse_uri
from .compatibility import TimeoutError, asyncio_timeout
from .connection import Connection
__all__ = ["connect", "unix_connect", "ClientConnection"]
MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
class ClientConnection(Connection):
"""
@ -83,20 +89,17 @@ class ClientConnection(Connection):
self.request.headers["User-Agent"] = user_agent_header
self.protocol.send_request(self.request)
# May raise CancelledError if open_timeout is exceeded.
await self.response_rcvd
await asyncio.wait(
[self.response_rcvd, self.connection_lost_waiter],
return_when=asyncio.FIRST_COMPLETED,
)
if self.response is None:
raise ConnectionError("connection closed during handshake")
# self.protocol.handshake_exc is always set when the connection is lost
# before receiving a response, when the response cannot be parsed, or
# when the response fails the handshake.
if self.protocol.handshake_exc is None:
self.start_keepalive()
else:
try:
async with asyncio_timeout(self.close_timeout):
await self.connection_lost_waiter
finally:
raise self.protocol.handshake_exc
if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
def process_event(self, event: Event) -> None:
"""
@ -112,13 +115,46 @@ class ClientConnection(Connection):
else:
super().process_event(event)
def connection_lost(self, exc: Exception | None) -> None:
try:
super().connection_lost(exc)
finally:
# If the connection is closed during the handshake, unblock it.
if not self.response_rcvd.done():
self.response_rcvd.set_result(None)
def process_exception(exc: Exception) -> Exception | None:
"""
Determine whether a connection error is retryable or fatal.
When reconnecting automatically with ``async for ... in connect(...)``, if a
connection attempt fails, :func:`process_exception` is called to determine
whether to retry connecting or to raise the exception.
This function defines the default behavior, which is to retry on:
* :exc:`EOFError`, :exc:`OSError`, :exc:`asyncio.TimeoutError`: network
errors;
* :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
502, 503, or 504: server or proxy errors.
All other exceptions are considered fatal.
You can change this behavior with the ``process_exception`` argument of
:func:`connect`.
Return :obj:`None` if the exception is retryable i.e. when the error could
be transient and trying to reconnect with the same parameters could succeed.
The exception will be logged at the ``INFO`` level.
Return an exception, either ``exc`` or a new exception, if the exception is
fatal i.e. when trying to reconnect will most likely produce the same error.
That exception will be raised, breaking out of the retry loop.
"""
if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)):
return None
if isinstance(exc, InvalidStatus) and exc.response.status_code in [
500, # Internal Server Error
502, # Bad Gateway
503, # Service Unavailable
504, # Gateway Timeout
]:
return None
return exc
# This is spelled in lower case because it's exposed as a callable in the API.
@ -131,11 +167,28 @@ class connect:
:func:`connect` may be used as an asynchronous context manager::
async with websockets.asyncio.client.connect(...) as websocket:
from websockets.asyncio.client import connect
async with connect(...) as websocket:
...
The connection is closed automatically when exiting the context.
:func:`connect` can be used as an infinite asynchronous iterator to
reconnect automatically on errors::
async for websocket in connect(...):
try:
...
except websockets.ConnectionClosed:
continue
If the connection fails with a transient error, it is retried with
exponential backoff. If it fails with a fatal error, the exception is
raised, breaking out of the loop.
The connection is closed automatically after each iteration of the loop.
Args:
uri: URI of the WebSocket server.
origin: Value of the ``Origin`` header, for servers that require it.
@ -151,6 +204,9 @@ class connect:
compression: The "permessage-deflate" extension is enabled by default.
Set ``compression`` to :obj:`None` to disable it. See the
:doc:`compression guide <../../topics/compression>` for details.
process_exception: When reconnecting automatically, tell whether an
error is transient or fatal. The default behavior is defined by
:func:`process_exception`. Refer to its documentation for details.
open_timeout: Timeout for opening the connection in seconds.
:obj:`None` disables the timeout.
ping_interval: Interval between keepalive pings in seconds.
@ -217,6 +273,7 @@ class connect:
additional_headers: HeadersLike | None = None,
user_agent_header: str | None = USER_AGENT,
compression: str | None = "deflate",
process_exception: Callable[[Exception], Exception | None] = process_exception,
# Timeouts
open_timeout: float | None = 10,
ping_interval: float | None = 20,
@ -233,17 +290,7 @@ class connect:
# Other keyword arguments are passed to loop.create_connection
**kwargs: Any,
) -> None:
wsuri = parse_uri(uri)
if wsuri.secure:
kwargs.setdefault("ssl", True)
kwargs.setdefault("server_hostname", wsuri.host)
if kwargs.get("ssl") is None:
raise TypeError("ssl=None is incompatible with a wss:// URI")
else:
if kwargs.get("ssl") is not None:
raise TypeError("ssl argument is incompatible with a ws:// URI")
self.uri = uri
if subprotocols is not None:
validate_subprotocols(subprotocols)
@ -253,10 +300,13 @@ class connect:
elif compression is not None:
raise ValueError(f"unsupported compression: {compression}")
if logger is None:
logger = logging.getLogger("websockets.client")
if create_connection is None:
create_connection = ClientConnection
def factory() -> ClientConnection:
def protocol_factory(wsuri: WebSocketURI) -> ClientConnection:
# This is a protocol in the Sans-I/O implementation of websockets.
protocol = ClientProtocol(
wsuri,
@ -277,21 +327,154 @@ class connect:
)
return connection
self.protocol_factory = protocol_factory
self.handshake_args = (
additional_headers,
user_agent_header,
)
self.process_exception = process_exception
self.open_timeout = open_timeout
self.logger = logger
self.connection_kwargs = kwargs
async def create_connection(self) -> ClientConnection:
"""Create TCP or Unix connection."""
loop = asyncio.get_running_loop()
wsuri = parse_uri(self.uri)
kwargs = self.connection_kwargs.copy()
def factory() -> ClientConnection:
return self.protocol_factory(wsuri)
if wsuri.secure:
kwargs.setdefault("ssl", True)
kwargs.setdefault("server_hostname", wsuri.host)
if kwargs.get("ssl") is None:
raise TypeError("ssl=None is incompatible with a wss:// URI")
else:
if kwargs.get("ssl") is not None:
raise TypeError("ssl argument is incompatible with a ws:// URI")
if kwargs.pop("unix", False):
self._create_connection = loop.create_unix_connection(factory, **kwargs)
_, connection = await loop.create_unix_connection(factory, **kwargs)
else:
if kwargs.get("sock") is None:
kwargs.setdefault("host", wsuri.host)
kwargs.setdefault("port", wsuri.port)
self._create_connection = loop.create_connection(factory, **kwargs)
_, connection = await loop.create_connection(factory, **kwargs)
return connection
self._handshake_args = (
additional_headers,
user_agent_header,
)
def process_redirect(self, exc: Exception) -> Exception | str:
"""
Determine whether a connection error is a redirect that can be followed.
self._open_timeout = open_timeout
Return the new URI if it's a valid redirect. Else, return an exception.
"""
if not (
isinstance(exc, InvalidStatus)
and exc.response.status_code
in [
300, # Multiple Choices
301, # Moved Permanently
302, # Found
303, # See Other
307, # Temporary Redirect
308, # Permanent Redirect
]
and "Location" in exc.response.headers
):
return exc
old_wsuri = parse_uri(self.uri)
new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"])
new_wsuri = parse_uri(new_uri)
# If connect() received a socket, it is closed and cannot be reused.
if self.connection_kwargs.get("sock") is not None:
return ValueError(
f"cannot follow redirect to {new_uri} with a preexisting socket"
)
# TLS downgrade is forbidden.
if old_wsuri.secure and not new_wsuri.secure:
return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}")
# Apply restrictions to cross-origin redirects.
if (
old_wsuri.secure != new_wsuri.secure
or old_wsuri.host != new_wsuri.host
or old_wsuri.port != new_wsuri.port
):
# Cross-origin redirects on Unix sockets don't quite make sense.
if self.connection_kwargs.get("unix", False):
return ValueError(
f"cannot follow cross-origin redirect to {new_uri} "
f"with a Unix socket"
)
# Cross-origin redirects when host and port are overridden are ill-defined.
if (
self.connection_kwargs.get("host") is not None
or self.connection_kwargs.get("port") is not None
):
return ValueError(
f"cannot follow cross-origin redirect to {new_uri} "
f"with an explicit host or port"
)
return new_uri
# ... = await connect(...)
def __await__(self) -> Generator[Any, None, ClientConnection]:
# Create a suitable iterator by calling __await__ on a coroutine.
return self.__await_impl__().__await__()
async def __await_impl__(self) -> ClientConnection:
try:
async with asyncio_timeout(self.open_timeout):
for _ in range(MAX_REDIRECTS):
self.connection = await self.create_connection()
try:
await self.connection.handshake(*self.handshake_args)
except asyncio.CancelledError:
self.connection.close_transport()
raise
except Exception as exc:
# Always close the connection even though keep-alive is
# the default in HTTP/1.1 because create_connection ties
# opening the network connection with initializing the
# protocol. In the current design of connect(), there is
# no easy way to reuse the network connection that works
# in every case nor to reinitialize the protocol.
self.connection.close_transport()
uri_or_exc = self.process_redirect(exc)
# Response is a valid redirect; follow it.
if isinstance(uri_or_exc, str):
self.uri = uri_or_exc
continue
# Response isn't a valid redirect; raise the exception.
if uri_or_exc is exc:
raise
else:
raise uri_or_exc from exc
else:
self.connection.start_keepalive()
return self.connection
else:
raise SecurityError(f"more than {MAX_REDIRECTS} redirects")
except TimeoutError:
# Re-raise exception with an informative error message.
raise TimeoutError("timed out during handshake") from None
# ... = yield from connect(...) - remove when dropping Python < 3.10
__iter__ = __await__
# async with connect(...) as ...: ...
@ -306,30 +489,47 @@ class connect:
) -> None:
await self.connection.close()
# ... = await connect(...)
# async for ... in connect(...):
def __await__(self) -> Generator[Any, None, ClientConnection]:
# Create a suitable iterator by calling __await__ on a coroutine.
return self.__await_impl__().__await__()
async def __await_impl__(self) -> ClientConnection:
try:
async with asyncio_timeout(self._open_timeout):
_transport, self.connection = await self._create_connection
async def __aiter__(self) -> AsyncIterator[ClientConnection]:
delays: Generator[float, None, None] | None = None
while True:
try:
async with self as protocol:
yield protocol
except Exception as exc:
# Determine whether the exception is retryable or fatal.
# The API of process_exception is "return an exception or None";
# "raise an exception" is also supported because it's a frequent
# mistake. It isn't documented in order to keep the API simple.
try:
await self.connection.handshake(*self._handshake_args)
except (Exception, asyncio.CancelledError):
self.connection.transport.close()
new_exc = self.process_exception(exc)
except Exception as raised_exc:
new_exc = raised_exc
# The connection failed with a fatal error.
# Raise the exception and exit the loop.
if new_exc is exc:
raise
else:
return self.connection
except TimeoutError:
# Re-raise exception with an informative error message.
raise TimeoutError("timed out during handshake") from None
if new_exc is not None:
raise new_exc from exc
# ... = yield from connect(...) - remove when dropping Python < 3.10
# The connection failed with a retryable error.
# Start or continue backoff and reconnect.
if delays is None:
delays = backoff()
delay = next(delays)
self.logger.info(
"! connect failed; reconnecting in %.1f seconds",
delay,
exc_info=True,
)
await asyncio.sleep(delay)
continue
__iter__ = __await__
else:
# The connection succeeded. Reset backoff.
delays = None
def unix_connect(

View File

@ -19,8 +19,13 @@ from typing import (
cast,
)
from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl
from ..exceptions import (
ConcurrencyError,
ConnectionClosed,
ConnectionClosedOK,
ProtocolError,
)
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode
from ..http11 import Request, Response
from ..protocol import CLOSED, OPEN, Event, Protocol, State
from ..typing import Data, LoggerLike, Subprotocol
@ -262,16 +267,18 @@ class Connection(asyncio.Protocol):
Raises:
ConnectionClosed: When the connection is closed.
RuntimeError: If two coroutines call :meth:`recv` or
ConcurrencyError: If two coroutines call :meth:`recv` or
:meth:`recv_streaming` concurrently.
"""
try:
return await self.recv_messages.get(decode)
except EOFError:
# Wait for the protocol state to be CLOSED before accessing close_exc.
await asyncio.shield(self.connection_lost_waiter)
raise self.protocol.close_exc from self.recv_exc
except RuntimeError:
raise RuntimeError(
except ConcurrencyError:
raise ConcurrencyError(
"cannot call recv while another coroutine "
"is already running recv or recv_streaming"
) from None
@ -283,8 +290,9 @@ class Connection(asyncio.Protocol):
This method is designed for receiving fragmented messages. It returns an
asynchronous iterator that yields each fragment as it is received. This
iterator must be fully consumed. Else, future calls to :meth:`recv` or
:meth:`recv_streaming` will raise :exc:`RuntimeError`, making the
connection unusable.
:meth:`recv_streaming` will raise
:exc:`~websockets.exceptions.ConcurrencyError`, making the connection
unusable.
:meth:`recv_streaming` raises the same exceptions as :meth:`recv`.
@ -315,7 +323,7 @@ class Connection(asyncio.Protocol):
Raises:
ConnectionClosed: When the connection is closed.
RuntimeError: If two coroutines call :meth:`recv` or
ConcurrencyError: If two coroutines call :meth:`recv` or
:meth:`recv_streaming` concurrently.
"""
@ -323,9 +331,11 @@ class Connection(asyncio.Protocol):
async for frame in self.recv_messages.get_iter(decode):
yield frame
except EOFError:
# Wait for the protocol state to be CLOSED before accessing close_exc.
await asyncio.shield(self.connection_lost_waiter)
raise self.protocol.close_exc from self.recv_exc
except RuntimeError:
raise RuntimeError(
except ConcurrencyError:
raise ConcurrencyError(
"cannot call recv_streaming while another coroutine "
"is already running recv or recv_streaming"
) from None
@ -593,17 +603,21 @@ class Connection(asyncio.Protocol):
Raises:
ConnectionClosed: When the connection is closed.
RuntimeError: If another ping was sent with the same data and
ConcurrencyError: If another ping was sent with the same data and
the corresponding pong wasn't received yet.
"""
if data is not None:
data = prepare_ctrl(data)
if isinstance(data, BytesLike):
data = bytes(data)
elif isinstance(data, str):
data = data.encode()
elif data is not None:
raise TypeError("data must be str or bytes-like")
async with self.send_context():
# Protect against duplicates if a payload is explicitly set.
if data in self.pong_waiters:
raise RuntimeError("already waiting for a pong with the same data")
raise ConcurrencyError("already waiting for a pong with the same data")
# Generate a unique random payload otherwise.
while data is None or data in self.pong_waiters:
@ -632,7 +646,12 @@ class Connection(asyncio.Protocol):
ConnectionClosed: When the connection is closed.
"""
data = prepare_ctrl(data)
if isinstance(data, BytesLike):
data = bytes(data)
elif isinstance(data, str):
data = data.encode()
else:
raise TypeError("data must be str or bytes-like")
async with self.send_context():
self.protocol.send_pong(data)
@ -784,7 +803,7 @@ class Connection(asyncio.Protocol):
# Let the caller interact with the protocol.
try:
yield
except (ProtocolError, RuntimeError):
except (ProtocolError, ConcurrencyError):
# The protocol state wasn't changed. Exit immediately.
raise
except Exception as exc:
@ -849,6 +868,7 @@ class Connection(asyncio.Protocol):
# raise an exception.
if raise_close_exc:
self.close_transport()
# Wait for the protocol state to be CLOSED before accessing close_exc.
await asyncio.shield(self.connection_lost_waiter)
raise self.protocol.close_exc from original_exc
@ -911,11 +931,14 @@ class Connection(asyncio.Protocol):
self.transport = transport
def connection_lost(self, exc: Exception | None) -> None:
self.protocol.receive_eof() # receive_eof is idempotent
# Calling protocol.receive_eof() is safe because it's idempotent.
# This guarantees that the protocol state becomes CLOSED.
self.protocol.receive_eof()
assert self.protocol.state is CLOSED
self.set_recv_exc(exc)
# Abort recv() and pending pings with a ConnectionClosed exception.
# Set recv_exc first to get proper exception reporting.
self.set_recv_exc(exc)
self.recv_messages.close()
self.abort_pings()
@ -1083,15 +1106,17 @@ def broadcast(
if raise_exceptions:
if sys.version_info[:2] < (3, 11): # pragma: no cover
raise ValueError("raise_exceptions requires at least Python 3.11")
exceptions = []
exceptions: list[Exception] = []
for connection in connections:
exception: Exception
if connection.protocol.state is not OPEN:
continue
if connection.fragmented_send_waiter is not None:
if raise_exceptions:
exception = RuntimeError("sending a fragmented message")
exception = ConcurrencyError("sending a fragmented message")
exceptions.append(exception)
else:
connection.logger.warning(

View File

@ -12,6 +12,7 @@ from typing import (
TypeVar,
)
from ..exceptions import ConcurrencyError
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
from ..typing import Data
@ -49,7 +50,7 @@ class SimpleQueue(Generic[T]):
"""Remove and return an item from the queue, waiting if necessary."""
if not self.queue:
if self.get_waiter is not None:
raise RuntimeError("get is already running")
raise ConcurrencyError("get is already running")
self.get_waiter = self.loop.create_future()
try:
await self.get_waiter
@ -135,15 +136,15 @@ class Assembler:
Raises:
EOFError: If the stream of frames has ended.
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
concurrently.
ConcurrencyError: If two coroutines run :meth:`get` or
:meth:`get_iter` concurrently.
"""
if self.closed:
raise EOFError("stream of frames ended")
if self.get_in_progress:
raise RuntimeError("get() or get_iter() is already running")
raise ConcurrencyError("get() or get_iter() is already running")
# Locking with get_in_progress ensures only one coroutine can get here.
self.get_in_progress = True
@ -190,7 +191,7 @@ class Assembler:
:class:`str` or :class:`bytes` for each frame in the message.
The iterator must be fully consumed before calling :meth:`get_iter` or
:meth:`get` again. Else, :exc:`RuntimeError` is raised.
:meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
This method only makes sense for fragmented messages. If messages aren't
fragmented, use :meth:`get` instead.
@ -202,15 +203,15 @@ class Assembler:
Raises:
EOFError: If the stream of frames has ended.
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
concurrently.
ConcurrencyError: If two coroutines run :meth:`get` or
:meth:`get_iter` concurrently.
"""
if self.closed:
raise EOFError("stream of frames ended")
if self.get_in_progress:
raise RuntimeError("get() or get_iter() is already running")
raise ConcurrencyError("get() or get_iter() is already running")
# Locking with get_in_progress ensures only one coroutine can get here.
self.get_in_progress = True
@ -236,7 +237,7 @@ class Assembler:
# We cannot handle asyncio.CancelledError because we don't buffer
# previous fragments — we're streaming them. Canceling get_iter()
# here will leave the assembler in a stuck state. Future calls to
# get() or get_iter() will raise RuntimeError.
# get() or get_iter() will raise ConcurrencyError.
frame = await self.frames.get()
self.maybe_resume()
assert frame.opcode is OP_CONT

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import hmac
import http
import logging
import socket
@ -13,22 +14,35 @@ from typing import (
Generator,
Iterable,
Sequence,
Tuple,
cast,
)
from websockets.frames import CloseCode
from ..exceptions import InvalidHeader
from ..extensions.base import ServerExtensionFactory
from ..extensions.permessage_deflate import enable_server_permessage_deflate
from ..headers import validate_subprotocols
from ..frames import CloseCode
from ..headers import (
build_www_authenticate_basic,
parse_authorization_basic,
validate_subprotocols,
)
from ..http11 import SERVER, Request, Response
from ..protocol import CONNECTING, Event
from ..protocol import CONNECTING, OPEN, Event
from ..server import ServerProtocol
from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
from .compatibility import asyncio_timeout
from .connection import Connection, broadcast
__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "Server"]
__all__ = [
"broadcast",
"serve",
"unix_serve",
"ServerConnection",
"Server",
"basic_auth",
]
class ServerConnection(Connection):
@ -79,6 +93,7 @@ class ServerConnection(Connection):
)
self.server = server
self.request_rcvd: asyncio.Future[None] = self.loop.create_future()
self.username: str # see basic_auth()
def respond(self, status: StatusLike, text: str) -> Response:
"""
@ -123,78 +138,75 @@ class ServerConnection(Connection):
Perform the opening handshake.
"""
# May raise CancelledError if open_timeout is exceeded.
await self.request_rcvd
await asyncio.wait(
[self.request_rcvd, self.connection_lost_waiter],
return_when=asyncio.FIRST_COMPLETED,
)
if self.request is None:
raise ConnectionError("connection closed during handshake")
if self.request is not None:
async with self.send_context(expected_state=CONNECTING):
response = None
async with self.send_context(expected_state=CONNECTING):
response = None
if process_request is not None:
try:
response = process_request(self, self.request)
if isinstance(response, Awaitable):
response = await response
except Exception as exc:
self.protocol.handshake_exc = exc
response = self.protocol.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
(
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
),
)
if process_request is not None:
try:
response = process_request(self, self.request)
if isinstance(response, Awaitable):
response = await response
except Exception as exc:
self.protocol.handshake_exc = exc
self.logger.error("opening handshake failed", exc_info=True)
response = self.protocol.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
(
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
),
)
if response is None:
if self.server.is_serving():
self.response = self.protocol.accept(self.request)
if response is None:
if self.server.is_serving():
self.response = self.protocol.accept(self.request)
else:
self.response = self.protocol.reject(
http.HTTPStatus.SERVICE_UNAVAILABLE,
"Server is shutting down.\n",
)
else:
self.response = self.protocol.reject(
http.HTTPStatus.SERVICE_UNAVAILABLE,
"Server is shutting down.\n",
)
else:
assert isinstance(response, Response) # help mypy
self.response = response
assert isinstance(response, Response) # help mypy
self.response = response
if server_header:
self.response.headers["Server"] = server_header
if server_header:
self.response.headers["Server"] = server_header
response = None
response = None
if process_response is not None:
try:
response = process_response(self, self.request, self.response)
if isinstance(response, Awaitable):
response = await response
except Exception as exc:
self.protocol.handshake_exc = exc
self.logger.error("opening handshake failed", exc_info=True)
response = self.protocol.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
(
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
),
)
if process_response is not None:
try:
response = process_response(self, self.request, self.response)
if isinstance(response, Awaitable):
response = await response
except Exception as exc:
self.protocol.handshake_exc = exc
response = self.protocol.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
(
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
),
)
if response is not None:
assert isinstance(response, Response) # help mypy
self.response = response
if response is not None:
assert isinstance(response, Response) # help mypy
self.response = response
self.protocol.send_response(self.response)
self.protocol.send_response(self.response)
if self.protocol.handshake_exc is None:
self.start_keepalive()
else:
try:
async with asyncio_timeout(self.close_timeout):
await self.connection_lost_waiter
finally:
raise self.protocol.handshake_exc
# self.protocol.handshake_exc is always set when the connection is lost
# before receiving a request, when the request cannot be parsed, when
# the handshake encounters an error, or when process_request or
# process_response sends an HTTP response that rejects the handshake.
if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
def process_event(self, event: Event) -> None:
"""
@ -214,14 +226,6 @@ class ServerConnection(Connection):
super().connection_made(transport)
self.server.start_connection_handler(self)
def connection_lost(self, exc: Exception | None) -> None:
try:
super().connection_lost(exc)
finally:
# If the connection is closed during the handshake, unblock it.
if not self.request_rcvd.done():
self.request_rcvd.set_result(None)
class Server:
"""
@ -298,6 +302,18 @@ class Server:
# Completed when the server is closed and connections are terminated.
self.closed_waiter: asyncio.Future[None] = self.loop.create_future()
@property
def connections(self) -> set[ServerConnection]:
"""
Set of active connections.
This property contains all connections that completed the opening
handshake successfully and didn't start the closing handshake yet.
It can be useful in combination with :func:`~broadcast`.
"""
return {connection for connection in self.handlers if connection.state is OPEN}
def wrap(self, server: asyncio.Server) -> None:
"""
Attach to a given :class:`asyncio.Server`.
@ -337,25 +353,37 @@ class Server:
"""
try:
# On failure, handshake() closes the transport, raises an
# exception, and logs it.
async with asyncio_timeout(self.open_timeout):
await connection.handshake(
self.process_request,
self.process_response,
self.server_header,
)
try:
await connection.handshake(
self.process_request,
self.process_response,
self.server_header,
)
except asyncio.CancelledError:
connection.close_transport()
raise
except Exception:
connection.logger.error("opening handshake failed", exc_info=True)
connection.close_transport()
return
assert connection.protocol.state is OPEN
try:
connection.start_keepalive()
await self.handler(connection)
except Exception:
self.logger.error("connection handler failed", exc_info=True)
connection.logger.error("connection handler failed", exc_info=True)
await connection.close(CloseCode.INTERNAL_ERROR)
else:
await connection.close()
except Exception:
# Don't leak connections on errors.
except TimeoutError:
# When the opening handshake times out, there's nothing to log.
pass
except Exception: # pragma: no cover
# Don't leak connections on unexpected errors.
connection.transport.abort()
finally:
@ -548,19 +576,21 @@ class serve:
:class:`asyncio.Server`. Treat it as an asynchronous context manager to
ensure that the server will be closed::
from websockets.asyncio.server import serve
def handler(websocket):
...
# set this future to exit the server
stop = asyncio.get_running_loop().create_future()
async with websockets.asyncio.server.serve(handler, host, port):
async with serve(handler, host, port):
await stop
Alternatively, call :meth:`~Server.serve_forever` to serve requests and
cancel it to stop the server::
server = await websockets.asyncio.server.serve(handler, host, port)
server = await serve(handler, host, port)
await server.serve_forever()
Args:
@ -664,14 +694,14 @@ class serve:
process_request: (
Callable[
[ServerConnection, Request],
Response | None,
Awaitable[Response | None] | Response | None,
]
| None
) = None,
process_response: (
Callable[
[ServerConnection, Request, Response],
Response | None,
Awaitable[Response | None] | Response | None,
]
| None
) = None,
@ -693,7 +723,6 @@ class serve:
# Other keyword arguments are passed to loop.create_server
**kwargs: Any,
) -> None:
if subprotocols is not None:
validate_subprotocols(subprotocols)
@ -767,10 +796,10 @@ class serve:
loop = asyncio.get_running_loop()
if kwargs.pop("unix", False):
self._create_server = loop.create_unix_server(factory, **kwargs)
self.create_server = loop.create_unix_server(factory, **kwargs)
else:
# mypy cannot tell that kwargs must provide sock when port is None.
self._create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type]
self.create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type]
# async with serve(...) as ...: ...
@ -793,7 +822,7 @@ class serve:
return self.__await_impl__().__await__()
async def __await_impl__(self) -> Server:
server = await self._create_server
server = await self.create_server
self.server.wrap(server)
return self.server
@ -822,3 +851,123 @@ def unix_serve(
"""
return serve(handler, unix=True, path=path, **kwargs)
def is_credentials(credentials: Any) -> bool:
try:
username, password = credentials
except (TypeError, ValueError):
return False
else:
return isinstance(username, str) and isinstance(password, str)
def basic_auth(
realm: str = "",
credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None,
check_credentials: Callable[[str, str], Awaitable[bool] | bool] | None = None,
) -> Callable[[ServerConnection, Request], Awaitable[Response | None]]:
"""
Factory for ``process_request`` to enforce HTTP Basic Authentication.
:func:`basic_auth` is designed to integrate with :func:`serve` as follows::
from websockets.asyncio.server import basic_auth, serve
async with serve(
...,
process_request=basic_auth(
realm="my dev server",
credentials=("hello", "iloveyou"),
),
):
If authentication succeeds, the connection's ``username`` attribute is set.
If it fails, the server responds with an HTTP 401 Unauthorized status.
One of ``credentials`` or ``check_credentials`` must be provided; not both.
Args:
realm: Scope of protection. It should contain only ASCII characters
because the encoding of non-ASCII characters is undefined. Refer to
section 2.2 of :rfc:`7235` for details.
credentials: Hard coded authorized credentials. It can be a
``(username, password)`` pair or a list of such pairs.
check_credentials: Function or coroutine that verifies credentials.
It receives ``username`` and ``password`` arguments and returns
whether they're valid.
Raises:
TypeError: If ``credentials`` or ``check_credentials`` is wrong.
"""
if (credentials is None) == (check_credentials is None):
raise TypeError("provide either credentials or check_credentials")
if credentials is not None:
if is_credentials(credentials):
credentials_list = [cast(Tuple[str, str], credentials)]
elif isinstance(credentials, Iterable):
credentials_list = list(cast(Iterable[Tuple[str, str]], credentials))
if not all(is_credentials(item) for item in credentials_list):
raise TypeError(f"invalid credentials argument: {credentials}")
else:
raise TypeError(f"invalid credentials argument: {credentials}")
credentials_dict = dict(credentials_list)
def check_credentials(username: str, password: str) -> bool:
try:
expected_password = credentials_dict[username]
except KeyError:
return False
return hmac.compare_digest(expected_password, password)
assert check_credentials is not None # help mypy
async def process_request(
connection: ServerConnection,
request: Request,
) -> Response | None:
"""
Perform HTTP Basic Authentication.
If it succeeds, set the connection's ``username`` attribute and return
:obj:`None`. If it fails, return an HTTP 401 Unauthorized responss.
"""
try:
authorization = request.headers["Authorization"]
except KeyError:
response = connection.respond(
http.HTTPStatus.UNAUTHORIZED,
"Missing credentials\n",
)
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
return response
try:
username, password = parse_authorization_basic(authorization)
except InvalidHeader:
response = connection.respond(
http.HTTPStatus.UNAUTHORIZED,
"Unsupported credentials\n",
)
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
return response
valid_credentials = check_credentials(username, password)
if isinstance(valid_credentials, Awaitable):
valid_credentials = await valid_credentials
if not valid_credentials:
response = connection.respond(
http.HTTPStatus.UNAUTHORIZED,
"Invalid credentials\n",
)
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
return response
connection.username = username
return None
return process_request

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import os
import random
import warnings
from typing import Any, Generator, Sequence
@ -173,13 +175,10 @@ class ClientProtocol(Protocol):
try:
s_w_accept = headers["Sec-WebSocket-Accept"]
except KeyError as exc:
raise InvalidHeader("Sec-WebSocket-Accept") from exc
except MultipleValuesError as exc:
raise InvalidHeader(
"Sec-WebSocket-Accept",
"more than one Sec-WebSocket-Accept header found",
) from exc
except KeyError:
raise InvalidHeader("Sec-WebSocket-Accept") from None
except MultipleValuesError:
raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from None
if s_w_accept != accept_key(self.key):
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
@ -225,7 +224,7 @@ class ClientProtocol(Protocol):
if extensions:
if self.available_extensions is None:
raise InvalidHandshake("no extensions supported")
raise NegotiationError("no extensions supported")
parsed_extensions: list[ExtensionHeader] = sum(
[parse_extension(header_value) for header_value in extensions], []
@ -280,15 +279,17 @@ class ClientProtocol(Protocol):
if subprotocols:
if self.available_subprotocols is None:
raise InvalidHandshake("no subprotocols supported")
raise NegotiationError("no subprotocols supported")
parsed_subprotocols: Sequence[Subprotocol] = sum(
[parse_subprotocol(header_value) for header_value in subprotocols], []
)
if len(parsed_subprotocols) > 1:
subprotocols_display = ", ".join(parsed_subprotocols)
raise InvalidHandshake(f"multiple subprotocols: {subprotocols_display}")
raise InvalidHeader(
"Sec-WebSocket-Protocol",
f"multiple values: {', '.join(parsed_subprotocols)}",
)
subprotocol = parsed_subprotocols[0]
@ -322,6 +323,7 @@ class ClientProtocol(Protocol):
)
except Exception as exc:
self.handshake_exc = exc
self.send_eof()
self.parser = self.discard()
next(self.parser) # start coroutine
yield
@ -340,6 +342,7 @@ class ClientProtocol(Protocol):
response._exception = exc
self.events.append(response)
self.handshake_exc = exc
self.send_eof()
self.parser = self.discard()
next(self.parser) # start coroutine
yield
@ -353,8 +356,38 @@ class ClientProtocol(Protocol):
class ClientConnection(ClientProtocol):
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
warnings.warn( # deprecated in 11.0 - 2023-04-02
"ClientConnection was renamed to ClientProtocol",
DeprecationWarning,
)
super().__init__(*args, **kwargs)
BACKOFF_INITIAL_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
BACKOFF_MIN_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
BACKOFF_MAX_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
def backoff(
initial_delay: float = BACKOFF_INITIAL_DELAY,
min_delay: float = BACKOFF_MIN_DELAY,
max_delay: float = BACKOFF_MAX_DELAY,
factor: float = BACKOFF_FACTOR,
) -> Generator[float, None, None]:
"""
Generate a series of backoff delays between reconnection attempts.
Yields:
How many seconds to wait before retrying to connect.
"""
# Add a random initial delay between 0 and 5 seconds.
# See 7.2.3. Recovering from Abnormal Closure in RFC 6455.
yield random.random() * initial_delay
delay = min_delay
while delay < max_delay:
yield delay
delay *= factor
while True:
yield max_delay

View File

@ -5,7 +5,7 @@ import warnings
from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401
warnings.warn(
warnings.warn( # deprecated in 11.0 - 2023-04-02
"websockets.connection was renamed to websockets.protocol "
"and Connection was renamed to Protocol",
DeprecationWarning,

View File

@ -17,7 +17,7 @@ __all__ = ["Headers", "HeadersLike", "MultipleValuesError"]
class MultipleValuesError(LookupError):
"""
Exception raised when :class:`Headers` has more than one value for a key.
Exception raised when :class:`Headers` has multiple values for a key.
"""

View File

@ -1,67 +1,69 @@
"""
:mod:`websockets.exceptions` defines the following exception hierarchy:
:mod:`websockets.exceptions` defines the following hierarchy of exceptions.
* :exc:`WebSocketException`
* :exc:`ConnectionClosed`
* :exc:`ConnectionClosedError`
* :exc:`ConnectionClosedOK`
* :exc:`ConnectionClosedError`
* :exc:`InvalidURI`
* :exc:`InvalidHandshake`
* :exc:`SecurityError`
* :exc:`InvalidMessage`
* :exc:`InvalidMessage` (legacy)
* :exc:`InvalidStatus`
* :exc:`InvalidStatusCode` (legacy)
* :exc:`InvalidHeader`
* :exc:`InvalidHeaderFormat`
* :exc:`InvalidHeaderValue`
* :exc:`InvalidOrigin`
* :exc:`InvalidUpgrade`
* :exc:`InvalidStatus`
* :exc:`InvalidStatusCode` (legacy)
* :exc:`NegotiationError`
* :exc:`DuplicateParameter`
* :exc:`InvalidParameterName`
* :exc:`InvalidParameterValue`
* :exc:`AbortHandshake`
* :exc:`RedirectHandshake`
* :exc:`InvalidState`
* :exc:`InvalidURI`
* :exc:`PayloadTooBig`
* :exc:`ProtocolError`
* :exc:`AbortHandshake` (legacy)
* :exc:`RedirectHandshake` (legacy)
* :exc:`ProtocolError` (Sans-I/O)
* :exc:`PayloadTooBig` (Sans-I/O)
* :exc:`InvalidState` (Sans-I/O)
* :exc:`ConcurrencyError`
"""
from __future__ import annotations
import http
import typing
import warnings
from . import datastructures, frames, http11
from .typing import StatusLike
from .imports import lazy_import
__all__ = [
"WebSocketException",
"ConnectionClosed",
"ConnectionClosedError",
"ConnectionClosedOK",
"ConnectionClosedError",
"InvalidURI",
"InvalidHandshake",
"SecurityError",
"InvalidMessage",
"InvalidStatus",
"InvalidStatusCode",
"InvalidHeader",
"InvalidHeaderFormat",
"InvalidHeaderValue",
"InvalidOrigin",
"InvalidUpgrade",
"InvalidStatus",
"InvalidStatusCode",
"NegotiationError",
"DuplicateParameter",
"InvalidParameterName",
"InvalidParameterValue",
"AbortHandshake",
"RedirectHandshake",
"InvalidState",
"InvalidURI",
"PayloadTooBig",
"ProtocolError",
"WebSocketProtocolError",
"PayloadTooBig",
"InvalidState",
"ConcurrencyError",
]
@ -77,13 +79,13 @@ class ConnectionClosed(WebSocketException):
Raised when trying to interact with a closed connection.
Attributes:
rcvd (Close | None): if a close frame was received, its code and
reason are available in ``rcvd.code`` and ``rcvd.reason``.
sent (Close | None): if a close frame was sent, its code and reason
are available in ``sent.code`` and ``sent.reason``.
rcvd_then_sent (bool | None): if close frames were received and
sent, this attribute tells in which order this happened, from the
perspective of this side of the connection.
rcvd: If a close frame was received, its code and reason are available
in ``rcvd.code`` and ``rcvd.reason``.
sent: If a close frame was sent, its code and reason are available
in ``sent.code`` and ``sent.reason``.
rcvd_then_sent: If close frames were received and sent, this attribute
tells in which order this happened, from the perspective of this
side of the connection.
"""
@ -96,21 +98,18 @@ class ConnectionClosed(WebSocketException):
self.rcvd = rcvd
self.sent = sent
self.rcvd_then_sent = rcvd_then_sent
assert (self.rcvd_then_sent is None) == (self.rcvd is None or self.sent is None)
def __str__(self) -> str:
if self.rcvd is None:
if self.sent is None:
assert self.rcvd_then_sent is None
return "no close frame received or sent"
else:
assert self.rcvd_then_sent is None
return f"sent {self.sent}; no close frame received"
else:
if self.sent is None:
assert self.rcvd_then_sent is None
return f"received {self.rcvd}; no close frame sent"
else:
assert self.rcvd_then_sent is not None
if self.rcvd_then_sent:
return f"received {self.rcvd}; then sent {self.sent}"
else:
@ -120,27 +119,27 @@ class ConnectionClosed(WebSocketException):
@property
def code(self) -> int:
warnings.warn( # deprecated in 13.1
"ConnectionClosed.code is deprecated; "
"use Protocol.close_code or ConnectionClosed.rcvd.code",
DeprecationWarning,
)
if self.rcvd is None:
return frames.CloseCode.ABNORMAL_CLOSURE
return self.rcvd.code
@property
def reason(self) -> str:
warnings.warn( # deprecated in 13.1
"ConnectionClosed.reason is deprecated; "
"use Protocol.close_reason or ConnectionClosed.rcvd.reason",
DeprecationWarning,
)
if self.rcvd is None:
return ""
return self.rcvd.reason
class ConnectionClosedError(ConnectionClosed):
"""
Like :exc:`ConnectionClosed`, when the connection terminated with an error.
A close frame with a code other than 1000 (OK) or 1001 (going away) was
received or sent, or the closing handshake didn't complete properly.
"""
class ConnectionClosedOK(ConnectionClosed):
"""
Like :exc:`ConnectionClosed`, when the connection terminated properly.
@ -151,9 +150,33 @@ class ConnectionClosedOK(ConnectionClosed):
"""
class ConnectionClosedError(ConnectionClosed):
"""
Like :exc:`ConnectionClosed`, when the connection terminated with an error.
A close frame with a code other than 1000 (OK) or 1001 (going away) was
received or sent, or the closing handshake didn't complete properly.
"""
class InvalidURI(WebSocketException):
"""
Raised when connecting to a URI that isn't a valid WebSocket URI.
"""
def __init__(self, uri: str, msg: str) -> None:
self.uri = uri
self.msg = msg
def __str__(self) -> str:
return f"{self.uri} isn't a valid URI: {self.msg}"
class InvalidHandshake(WebSocketException):
"""
Raised during the handshake when the WebSocket connection fails.
Base class for exceptions raised when the opening handshake fails.
"""
@ -162,17 +185,27 @@ class SecurityError(InvalidHandshake):
"""
Raised when a handshake request or response breaks a security rule.
Security limits are hard coded.
Security limits can be configured with :doc:`environment variables
<../reference/variables>`.
"""
class InvalidMessage(InvalidHandshake):
class InvalidStatus(InvalidHandshake):
"""
Raised when a handshake request or response is malformed.
Raised when a handshake response rejects the WebSocket upgrade.
"""
def __init__(self, response: http11.Response) -> None:
self.response = response
def __str__(self) -> str:
return (
"server rejected WebSocket connection: "
f"HTTP {self.response.status_code:d}"
)
class InvalidHeader(InvalidHandshake):
"""
@ -209,7 +242,7 @@ class InvalidHeaderValue(InvalidHeader):
"""
Raised when an HTTP header has a wrong value.
The format of the header is correct but a value isn't acceptable.
The format of the header is correct but the value isn't acceptable.
"""
@ -231,39 +264,9 @@ class InvalidUpgrade(InvalidHeader):
"""
class InvalidStatus(InvalidHandshake):
"""
Raised when a handshake response rejects the WebSocket upgrade.
"""
def __init__(self, response: http11.Response) -> None:
self.response = response
def __str__(self) -> str:
return (
"server rejected WebSocket connection: "
f"HTTP {self.response.status_code:d}"
)
class InvalidStatusCode(InvalidHandshake):
"""
Raised when a handshake response status code is invalid.
"""
def __init__(self, status_code: int, headers: datastructures.Headers) -> None:
self.status_code = status_code
self.headers = headers
def __str__(self) -> str:
return f"server rejected WebSocket connection: HTTP {self.status_code}"
class NegotiationError(InvalidHandshake):
"""
Raised when negotiating an extension fails.
Raised when negotiating an extension or a subprotocol fails.
"""
@ -313,92 +316,77 @@ class InvalidParameterValue(NegotiationError):
return f"invalid value for parameter {self.name}: {self.value}"
class AbortHandshake(InvalidHandshake):
class ProtocolError(WebSocketException):
"""
Raised to abort the handshake on purpose and return an HTTP response.
Raised when receiving or sending a frame that breaks the protocol.
This exception is an implementation detail.
The Sans-I/O implementation raises this exception when:
The public API is
:meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`.
Attributes:
status (~http.HTTPStatus): HTTP status code.
headers (Headers): HTTP response headers.
body (bytes): HTTP response body.
"""
def __init__(
self,
status: StatusLike,
headers: datastructures.HeadersLike,
body: bytes = b"",
) -> None:
# If a user passes an int instead of a HTTPStatus, fix it automatically.
self.status = http.HTTPStatus(status)
self.headers = datastructures.Headers(headers)
self.body = body
def __str__(self) -> str:
return (
f"HTTP {self.status:d}, "
f"{len(self.headers)} headers, "
f"{len(self.body)} bytes"
)
class RedirectHandshake(InvalidHandshake):
"""
Raised when a handshake gets redirected.
This exception is an implementation detail.
* receiving or sending a frame that contains invalid data;
* receiving or sending an invalid sequence of frames.
"""
def __init__(self, uri: str) -> None:
self.uri = uri
def __str__(self) -> str:
return f"redirect to {self.uri}"
class InvalidState(WebSocketException, AssertionError):
"""
Raised when an operation is forbidden in the current state.
This exception is an implementation detail.
It should never be raised in normal circumstances.
"""
class InvalidURI(WebSocketException):
"""
Raised when connecting to a URI that isn't a valid WebSocket URI.
"""
def __init__(self, uri: str, msg: str) -> None:
self.uri = uri
self.msg = msg
def __str__(self) -> str:
return f"{self.uri} isn't a valid URI: {self.msg}"
class PayloadTooBig(WebSocketException):
"""
Raised when receiving a frame with a payload exceeding the maximum size.
Raised when parsing a frame with a payload that exceeds the maximum size.
The Sans-I/O layer uses this exception internally. It doesn't bubble up to
the I/O layer.
The :meth:`~websockets.extensions.Extension.decode` method of extensions
must raise :exc:`PayloadTooBig` if decoding a frame would exceed the limit.
"""
class ProtocolError(WebSocketException):
class InvalidState(WebSocketException, AssertionError):
"""
Raised when a frame breaks the protocol.
Raised when sending a frame is forbidden in the current state.
Specifically, the Sans-I/O layer raises this exception when:
* sending a data frame to a connection in a state other
:attr:`~websockets.protocol.State.OPEN`;
* sending a control frame to a connection in a state other than
:attr:`~websockets.protocol.State.OPEN` or
:attr:`~websockets.protocol.State.CLOSING`.
"""
WebSocketProtocolError = ProtocolError # for backwards compatibility
class ConcurrencyError(WebSocketException, RuntimeError):
"""
Raised when receiving or sending messages concurrently.
WebSocket is a connection-oriented protocol. Reads must be serialized; so
must be writes. However, reading and writing concurrently is possible.
"""
# When type checking, import non-deprecated aliases eagerly. Else, import on demand.
if typing.TYPE_CHECKING:
from .legacy.exceptions import (
AbortHandshake,
InvalidMessage,
InvalidStatusCode,
RedirectHandshake,
)
WebSocketProtocolError = ProtocolError
else:
lazy_import(
globals(),
aliases={
"AbortHandshake": ".legacy.exceptions",
"InvalidMessage": ".legacy.exceptions",
"InvalidStatusCode": ".legacy.exceptions",
"RedirectHandshake": ".legacy.exceptions",
"WebSocketProtocolError": ".legacy.exceptions",
},
)
# At the bottom to break import cycles created by type annotations.
from . import frames, http11 # noqa: E402

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Sequence
from .. import frames
from ..frames import Frame
from ..typing import ExtensionName, ExtensionParameter
@ -18,12 +18,7 @@ class Extension:
name: ExtensionName
"""Extension identifier."""
def decode(
self,
frame: frames.Frame,
*,
max_size: int | None = None,
) -> frames.Frame:
def decode(self, frame: Frame, *, max_size: int | None = None) -> Frame:
"""
Decode an incoming frame.
@ -40,7 +35,7 @@ class Extension:
"""
raise NotImplementedError
def encode(self, frame: frames.Frame) -> frames.Frame:
def encode(self, frame: Frame) -> Frame:
"""
Encode an outgoing frame.

View File

@ -4,7 +4,15 @@ import dataclasses
import zlib
from typing import Any, Sequence
from .. import exceptions, frames
from .. import frames
from ..exceptions import (
DuplicateParameter,
InvalidParameterName,
InvalidParameterValue,
NegotiationError,
PayloadTooBig,
ProtocolError,
)
from ..typing import ExtensionName, ExtensionParameter
from .base import ClientExtensionFactory, Extension, ServerExtensionFactory
@ -129,9 +137,9 @@ class PerMessageDeflate(Extension):
try:
data = self.decoder.decompress(data, max_length)
except zlib.error as exc:
raise exceptions.ProtocolError("decompression failed") from exc
raise ProtocolError("decompression failed") from exc
if self.decoder.unconsumed_tail:
raise exceptions.PayloadTooBig(f"over size limit (? > {max_size} bytes)")
raise PayloadTooBig(f"over size limit (? > {max_size} bytes)")
# Allow garbage collection of the decoder if it won't be reused.
if frame.fin and self.remote_no_context_takeover:
@ -215,40 +223,40 @@ def _extract_parameters(
for name, value in params:
if name == "server_no_context_takeover":
if server_no_context_takeover:
raise exceptions.DuplicateParameter(name)
raise DuplicateParameter(name)
if value is None:
server_no_context_takeover = True
else:
raise exceptions.InvalidParameterValue(name, value)
raise InvalidParameterValue(name, value)
elif name == "client_no_context_takeover":
if client_no_context_takeover:
raise exceptions.DuplicateParameter(name)
raise DuplicateParameter(name)
if value is None:
client_no_context_takeover = True
else:
raise exceptions.InvalidParameterValue(name, value)
raise InvalidParameterValue(name, value)
elif name == "server_max_window_bits":
if server_max_window_bits is not None:
raise exceptions.DuplicateParameter(name)
raise DuplicateParameter(name)
if value in _MAX_WINDOW_BITS_VALUES:
server_max_window_bits = int(value)
else:
raise exceptions.InvalidParameterValue(name, value)
raise InvalidParameterValue(name, value)
elif name == "client_max_window_bits":
if client_max_window_bits is not None:
raise exceptions.DuplicateParameter(name)
raise DuplicateParameter(name)
if is_server and value is None: # only in handshake requests
client_max_window_bits = True
elif value in _MAX_WINDOW_BITS_VALUES:
client_max_window_bits = int(value)
else:
raise exceptions.InvalidParameterValue(name, value)
raise InvalidParameterValue(name, value)
else:
raise exceptions.InvalidParameterName(name)
raise InvalidParameterName(name)
return (
server_no_context_takeover,
@ -340,7 +348,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory):
"""
if any(other.name == self.name for other in accepted_extensions):
raise exceptions.NegotiationError(f"received duplicate {self.name}")
raise NegotiationError(f"received duplicate {self.name}")
# Request parameters are available in instance variables.
@ -366,7 +374,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory):
if self.server_no_context_takeover:
if not server_no_context_takeover:
raise exceptions.NegotiationError("expected server_no_context_takeover")
raise NegotiationError("expected server_no_context_takeover")
# client_no_context_takeover
#
@ -396,9 +404,9 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory):
else:
if server_max_window_bits is None:
raise exceptions.NegotiationError("expected server_max_window_bits")
raise NegotiationError("expected server_max_window_bits")
elif server_max_window_bits > self.server_max_window_bits:
raise exceptions.NegotiationError("unsupported server_max_window_bits")
raise NegotiationError("unsupported server_max_window_bits")
# client_max_window_bits
@ -414,7 +422,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory):
if self.client_max_window_bits is None:
if client_max_window_bits is not None:
raise exceptions.NegotiationError("unexpected client_max_window_bits")
raise NegotiationError("unexpected client_max_window_bits")
elif self.client_max_window_bits is True:
pass
@ -423,7 +431,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory):
if client_max_window_bits is None:
client_max_window_bits = self.client_max_window_bits
elif client_max_window_bits > self.client_max_window_bits:
raise exceptions.NegotiationError("unsupported client_max_window_bits")
raise NegotiationError("unsupported client_max_window_bits")
return PerMessageDeflate(
server_no_context_takeover, # remote_no_context_takeover
@ -534,7 +542,7 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory):
"""
if any(other.name == self.name for other in accepted_extensions):
raise exceptions.NegotiationError(f"skipped duplicate {self.name}")
raise NegotiationError(f"skipped duplicate {self.name}")
# Load request parameters in local variables.
(
@ -613,7 +621,7 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory):
else:
if client_max_window_bits is None:
if self.require_client_max_window_bits:
raise exceptions.NegotiationError("required client_max_window_bits")
raise NegotiationError("required client_max_window_bits")
elif client_max_window_bits is True:
client_max_window_bits = self.client_max_window_bits
elif self.client_max_window_bits < client_max_window_bits:

View File

@ -8,8 +8,7 @@ import secrets
import struct
from typing import Callable, Generator, Sequence
from . import exceptions, extensions
from .typing import Data
from .exceptions import PayloadTooBig, ProtocolError
try:
@ -29,8 +28,6 @@ __all__ = [
"DATA_OPCODES",
"CTRL_OPCODES",
"Frame",
"prepare_data",
"prepare_ctrl",
"Close",
]
@ -242,10 +239,10 @@ class Frame:
try:
opcode = Opcode(head1 & 0b00001111)
except ValueError as exc:
raise exceptions.ProtocolError("invalid opcode") from exc
raise ProtocolError("invalid opcode") from exc
if (True if head2 & 0b10000000 else False) != mask:
raise exceptions.ProtocolError("incorrect masking")
raise ProtocolError("incorrect masking")
length = head2 & 0b01111111
if length == 126:
@ -255,9 +252,7 @@ class Frame:
data = yield from read_exact(8)
(length,) = struct.unpack("!Q", data)
if max_size is not None and length > max_size:
raise exceptions.PayloadTooBig(
f"over size limit ({length} > {max_size} bytes)"
)
raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)")
if mask:
mask_bytes = yield from read_exact(4)
@ -345,60 +340,13 @@ class Frame:
"""
if self.rsv1 or self.rsv2 or self.rsv3:
raise exceptions.ProtocolError("reserved bits must be 0")
raise ProtocolError("reserved bits must be 0")
if self.opcode in CTRL_OPCODES:
if len(self.data) > 125:
raise exceptions.ProtocolError("control frame too long")
raise ProtocolError("control frame too long")
if not self.fin:
raise exceptions.ProtocolError("fragmented control frame")
def prepare_data(data: Data) -> tuple[int, bytes]:
"""
Convert a string or byte-like object to an opcode and a bytes-like object.
This function is designed for data frames.
If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes`
object encoding ``data`` in UTF-8.
If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like
object.
Raises:
TypeError: If ``data`` doesn't have a supported type.
"""
if isinstance(data, str):
return OP_TEXT, data.encode()
elif isinstance(data, BytesLike):
return OP_BINARY, data
else:
raise TypeError("data must be str or bytes-like")
def prepare_ctrl(data: Data) -> bytes:
"""
Convert a string or byte-like object to bytes.
This function is designed for ping and pong frames.
If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
``data`` in UTF-8.
If ``data`` is a bytes-like object, return a :class:`bytes` object.
Raises:
TypeError: If ``data`` doesn't have a supported type.
"""
if isinstance(data, str):
return data.encode()
elif isinstance(data, BytesLike):
return bytes(data)
else:
raise TypeError("data must be str or bytes-like")
raise ProtocolError("fragmented control frame")
@dataclasses.dataclass
@ -455,7 +403,7 @@ class Close:
elif len(data) == 0:
return cls(CloseCode.NO_STATUS_RCVD, "")
else:
raise exceptions.ProtocolError("close frame too short")
raise ProtocolError("close frame too short")
def serialize(self) -> bytes:
"""
@ -474,4 +422,8 @@ class Close:
"""
if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000):
raise exceptions.ProtocolError("invalid status code")
raise ProtocolError("invalid status code")
# At the bottom to break import cycles created by type annotations.
from . import extensions # noqa: E402

View File

@ -6,7 +6,7 @@ import ipaddress
import re
from typing import Callable, Sequence, TypeVar, cast
from . import exceptions
from .exceptions import InvalidHeaderFormat, InvalidHeaderValue
from .typing import (
ConnectionOption,
ExtensionHeader,
@ -108,7 +108,7 @@ def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]:
"""
match = _token_re.match(header, pos)
if match is None:
raise exceptions.InvalidHeaderFormat(header_name, "expected token", header, pos)
raise InvalidHeaderFormat(header_name, "expected token", header, pos)
return match.group(), match.end()
@ -132,9 +132,7 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, i
"""
match = _quoted_string_re.match(header, pos)
if match is None:
raise exceptions.InvalidHeaderFormat(
header_name, "expected quoted string", header, pos
)
raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos)
return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end()
@ -206,9 +204,7 @@ def parse_list(
if peek_ahead(header, pos) == ",":
pos = parse_OWS(header, pos + 1)
else:
raise exceptions.InvalidHeaderFormat(
header_name, "expected comma", header, pos
)
raise InvalidHeaderFormat(header_name, "expected comma", header, pos)
# Remove extra delimiters before the next item.
while peek_ahead(header, pos) == ",":
@ -276,9 +272,7 @@ def parse_upgrade_protocol(
"""
match = _protocol_re.match(header, pos)
if match is None:
raise exceptions.InvalidHeaderFormat(
header_name, "expected protocol", header, pos
)
raise InvalidHeaderFormat(header_name, "expected protocol", header, pos)
return cast(UpgradeProtocol, match.group()), match.end()
@ -324,7 +318,7 @@ def parse_extension_item_param(
# the value after quoted-string unescaping MUST conform to
# the 'token' ABNF.
if _token_re.fullmatch(value) is None:
raise exceptions.InvalidHeaderFormat(
raise InvalidHeaderFormat(
header_name, "invalid quoted header content", header, pos_before
)
else:
@ -510,9 +504,7 @@ def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]:
"""
match = _token68_re.match(header, pos)
if match is None:
raise exceptions.InvalidHeaderFormat(
header_name, "expected token68", header, pos
)
raise InvalidHeaderFormat(header_name, "expected token68", header, pos)
return match.group(), match.end()
@ -522,7 +514,7 @@ def parse_end(header: str, pos: int, header_name: str) -> None:
"""
if pos < len(header):
raise exceptions.InvalidHeaderFormat(header_name, "trailing data", header, pos)
raise InvalidHeaderFormat(header_name, "trailing data", header, pos)
def parse_authorization_basic(header: str) -> tuple[str, str]:
@ -543,12 +535,12 @@ def parse_authorization_basic(header: str) -> tuple[str, str]:
# https://datatracker.ietf.org/doc/html/rfc7617#section-2
scheme, pos = parse_token(header, 0, "Authorization")
if scheme.lower() != "basic":
raise exceptions.InvalidHeaderValue(
raise InvalidHeaderValue(
"Authorization",
f"unsupported scheme: {scheme}",
)
if peek_ahead(header, pos) != " ":
raise exceptions.InvalidHeaderFormat(
raise InvalidHeaderFormat(
"Authorization", "expected space after scheme", header, pos
)
pos += 1
@ -558,14 +550,14 @@ def parse_authorization_basic(header: str) -> tuple[str, str]:
try:
user_pass = base64.b64decode(basic_credentials.encode()).decode()
except binascii.Error:
raise exceptions.InvalidHeaderValue(
raise InvalidHeaderValue(
"Authorization",
"expected base64-encoded credentials",
) from None
try:
username, password = user_pass.split(":", 1)
except ValueError:
raise exceptions.InvalidHeaderValue(
raise InvalidHeaderValue(
"Authorization",
"expected username:password credentials",
) from None

View File

@ -6,7 +6,7 @@ from .datastructures import Headers, MultipleValuesError # noqa: F401
from .legacy.http import read_request, read_response # noqa: F401
warnings.warn(
warnings.warn( # deprecated in 9.0 - 2021-09-01
"Headers and MultipleValuesError were moved "
"from websockets.http to websockets.datastructures"
"and read_request and read_response were moved "

View File

@ -7,7 +7,8 @@ import sys
import warnings
from typing import Callable, Generator
from . import datastructures, exceptions
from .datastructures import Headers
from .exceptions import SecurityError
from .version import version as websockets_version
@ -79,14 +80,14 @@ class Request:
"""
path: str
headers: datastructures.Headers
headers: Headers
# body isn't useful is the context of this library.
_exception: Exception | None = None
@property
def exception(self) -> Exception | None: # pragma: no cover
warnings.warn(
warnings.warn( # deprecated in 10.3 - 2022-04-17
"Request.exception is deprecated; "
"use ServerProtocol.handshake_exc instead",
DeprecationWarning,
@ -134,14 +135,15 @@ class Request:
raise EOFError("connection closed while reading HTTP request line") from exc
try:
method, raw_path, version = request_line.split(b" ", 2)
method, raw_path, protocol = request_line.split(b" ", 2)
except ValueError: # not enough values to unpack (expected 3, got 1-2)
raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
if protocol != b"HTTP/1.1":
raise ValueError(
f"unsupported protocol; expected HTTP/1.1: {d(request_line)}"
)
if method != b"GET":
raise ValueError(f"unsupported HTTP method: {d(method)}")
if version != b"HTTP/1.1":
raise ValueError(f"unsupported HTTP version: {d(version)}")
raise ValueError(f"unsupported HTTP method; expected GET; got {d(method)}")
path = raw_path.decode("ascii", "surrogateescape")
headers = yield from parse_headers(read_line)
@ -183,14 +185,14 @@ class Response:
status_code: int
reason_phrase: str
headers: datastructures.Headers
headers: Headers
body: bytes | None = None
_exception: Exception | None = None
@property
def exception(self) -> Exception | None: # pragma: no cover
warnings.warn(
warnings.warn( # deprecated in 10.3 - 2022-04-17
"Response.exception is deprecated; "
"use ClientProtocol.handshake_exc instead",
DeprecationWarning,
@ -235,23 +237,26 @@ class Response:
raise EOFError("connection closed while reading HTTP status line") from exc
try:
version, raw_status_code, raw_reason = status_line.split(b" ", 2)
protocol, raw_status_code, raw_reason = status_line.split(b" ", 2)
except ValueError: # not enough values to unpack (expected 3, got 1-2)
raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
if version != b"HTTP/1.1":
raise ValueError(f"unsupported HTTP version: {d(version)}")
if protocol != b"HTTP/1.1":
raise ValueError(
f"unsupported protocol; expected HTTP/1.1: {d(status_line)}"
)
try:
status_code = int(raw_status_code)
except ValueError: # invalid literal for int() with base 10
raise ValueError(
f"invalid HTTP status code: {d(raw_status_code)}"
f"invalid status code; expected integer; got {d(raw_status_code)}"
) from None
if not 100 <= status_code < 1000:
raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}")
if not 100 <= status_code < 600:
raise ValueError(
f"invalid status code; expected 100599; got {d(raw_status_code)}"
)
if not _value_re.fullmatch(raw_reason):
raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
reason = raw_reason.decode()
reason = raw_reason.decode("ascii", "surrogateescape")
headers = yield from parse_headers(read_line)
@ -280,13 +285,9 @@ class Response:
try:
body = yield from read_to_eof(MAX_BODY_SIZE)
except RuntimeError:
raise exceptions.SecurityError(
f"body too large: over {MAX_BODY_SIZE} bytes"
)
raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes")
elif content_length > MAX_BODY_SIZE:
raise exceptions.SecurityError(
f"body too large: {content_length} bytes"
)
raise SecurityError(f"body too large: {content_length} bytes")
else:
body = yield from read_exact(content_length)
@ -308,7 +309,7 @@ class Response:
def parse_headers(
read_line: Callable[[int], Generator[None, None, bytes]],
) -> Generator[None, None, datastructures.Headers]:
) -> Generator[None, None, Headers]:
"""
Parse HTTP headers.
@ -328,7 +329,7 @@ def parse_headers(
# We don't attempt to support obsolete line folding.
headers = datastructures.Headers()
headers = Headers()
for _ in range(MAX_NUM_HEADERS + 1):
try:
line = yield from parse_line(read_line)
@ -352,7 +353,7 @@ def parse_headers(
headers[name] = value
else:
raise exceptions.SecurityError("too many HTTP headers")
raise SecurityError("too many HTTP headers")
return headers
@ -377,7 +378,7 @@ def parse_line(
try:
line = yield from read_line(MAX_LINE_LENGTH)
except RuntimeError:
raise exceptions.SecurityError("line too long")
raise SecurityError("line too long")
# Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5
if not line.endswith(b"\r\n"):
raise EOFError("line without CRLF")

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import functools
import logging
import os
import random
import urllib.parse
import warnings
@ -19,12 +20,9 @@ from typing import (
from ..asyncio.compatibility import asyncio_timeout
from ..datastructures import Headers, HeadersLike
from ..exceptions import (
InvalidHandshake,
InvalidHeader,
InvalidMessage,
InvalidStatusCode,
InvalidHeaderValue,
NegotiationError,
RedirectHandshake,
SecurityError,
)
from ..extensions import ClientExtensionFactory, Extension
@ -41,6 +39,7 @@ from ..headers import (
from ..http11 import USER_AGENT
from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol
from ..uri import WebSocketURI, parse_uri
from .exceptions import InvalidMessage, InvalidStatusCode, RedirectHandshake
from .handshake import build_request, check_response
from .http import read_response
from .protocol import WebSocketCommonProtocol
@ -182,7 +181,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol):
if header_values:
if available_extensions is None:
raise InvalidHandshake("no extensions supported")
raise NegotiationError("no extensions supported")
parsed_header_values: list[ExtensionHeader] = sum(
[parse_extension(header_value) for header_value in header_values], []
@ -236,15 +235,17 @@ class WebSocketClientProtocol(WebSocketCommonProtocol):
if header_values:
if available_subprotocols is None:
raise InvalidHandshake("no subprotocols supported")
raise NegotiationError("no subprotocols supported")
parsed_header_values: Sequence[Subprotocol] = sum(
[parse_subprotocol(header_value) for header_value in header_values], []
)
if len(parsed_header_values) > 1:
subprotocols = ", ".join(parsed_header_values)
raise InvalidHandshake(f"multiple subprotocols: {subprotocols}")
raise InvalidHeaderValue(
"Sec-WebSocket-Protocol",
f"multiple values: {', '.join(parsed_header_values)}",
)
subprotocol = parsed_header_values[0]
@ -417,7 +418,7 @@ class Connect:
"""
MAX_REDIRECTS_ALLOWED = 10
MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
def __init__(
self,
@ -591,13 +592,13 @@ class Connect:
# async for ... in connect(...):
BACKOFF_MIN = 1.92
BACKOFF_MAX = 60.0
BACKOFF_FACTOR = 1.618
BACKOFF_INITIAL = 5
BACKOFF_INITIAL = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
BACKOFF_MIN = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
BACKOFF_MAX = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]:
backoff_delay = self.BACKOFF_MIN
backoff_delay = self.BACKOFF_MIN / self.BACKOFF_FACTOR
while True:
try:
async with self as protocol:

View File

@ -0,0 +1,78 @@
import http
from .. import datastructures
from ..exceptions import (
InvalidHandshake,
ProtocolError as WebSocketProtocolError, # noqa: F401
)
from ..typing import StatusLike
class InvalidMessage(InvalidHandshake):
"""
Raised when a handshake request or response is malformed.
"""
class InvalidStatusCode(InvalidHandshake):
"""
Raised when a handshake response status code is invalid.
"""
def __init__(self, status_code: int, headers: datastructures.Headers) -> None:
self.status_code = status_code
self.headers = headers
def __str__(self) -> str:
return f"server rejected WebSocket connection: HTTP {self.status_code}"
class AbortHandshake(InvalidHandshake):
"""
Raised to abort the handshake on purpose and return an HTTP response.
This exception is an implementation detail.
The public API is
:meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`.
Attributes:
status (~http.HTTPStatus): HTTP status code.
headers (Headers): HTTP response headers.
body (bytes): HTTP response body.
"""
def __init__(
self,
status: StatusLike,
headers: datastructures.HeadersLike,
body: bytes = b"",
) -> None:
# If a user passes an int instead of a HTTPStatus, fix it automatically.
self.status = http.HTTPStatus(status)
self.headers = datastructures.Headers(headers)
self.body = body
def __str__(self) -> str:
return (
f"HTTP {self.status:d}, "
f"{len(self.headers)} headers, "
f"{len(self.body)} bytes"
)
class RedirectHandshake(InvalidHandshake):
"""
Raised when a handshake gets redirected.
This exception is an implementation detail.
"""
def __init__(self, uri: str) -> None:
self.uri = uri
def __str__(self) -> str:
return f"redirect to {self.uri}"

View File

@ -5,6 +5,8 @@ from typing import Any, Awaitable, Callable, NamedTuple, Sequence
from .. import extensions, frames
from ..exceptions import PayloadTooBig, ProtocolError
from ..frames import BytesLike
from ..typing import Data
try:
@ -144,12 +146,58 @@ class Frame(NamedTuple):
write(self.new_frame.serialize(mask=mask, extensions=extensions))
def prepare_data(data: Data) -> tuple[int, bytes]:
"""
Convert a string or byte-like object to an opcode and a bytes-like object.
This function is designed for data frames.
If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes`
object encoding ``data`` in UTF-8.
If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like
object.
Raises:
TypeError: If ``data`` doesn't have a supported type.
"""
if isinstance(data, str):
return frames.Opcode.TEXT, data.encode()
elif isinstance(data, BytesLike):
return frames.Opcode.BINARY, data
else:
raise TypeError("data must be str or bytes-like")
def prepare_ctrl(data: Data) -> bytes:
"""
Convert a string or byte-like object to bytes.
This function is designed for ping and pong frames.
If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
``data`` in UTF-8.
If ``data`` is a bytes-like object, return a :class:`bytes` object.
Raises:
TypeError: If ``data`` doesn't have a supported type.
"""
if isinstance(data, str):
return data.encode()
elif isinstance(data, BytesLike):
return bytes(data)
else:
raise TypeError("data must be str or bytes-like")
# Backwards compatibility with previously documented public APIs
from ..frames import ( # noqa: E402, F401, I001
Close,
prepare_ctrl as encode_data,
prepare_data,
)
encode_data = prepare_ctrl
# Backwards compatibility with previously documented public APIs
from ..frames import Close # noqa: E402 F401, I001
def parse_close(data: bytes) -> tuple[int, str]:

View File

@ -76,9 +76,7 @@ def check_request(headers: Headers) -> str:
except KeyError as exc:
raise InvalidHeader("Sec-WebSocket-Key") from exc
except MultipleValuesError as exc:
raise InvalidHeader(
"Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found"
) from exc
raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc
try:
raw_key = base64.b64decode(s_w_key.encode(), validate=True)
@ -92,9 +90,7 @@ def check_request(headers: Headers) -> str:
except KeyError as exc:
raise InvalidHeader("Sec-WebSocket-Version") from exc
except MultipleValuesError as exc:
raise InvalidHeader(
"Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found"
) from exc
raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc
if s_w_version != "13":
raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)
@ -156,9 +152,7 @@ def check_response(headers: Headers, key: str) -> None:
except KeyError as exc:
raise InvalidHeader("Sec-WebSocket-Accept") from exc
except MultipleValuesError as exc:
raise InvalidHeader(
"Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found"
) from exc
raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc
if s_w_accept != accept(key):
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)

View File

@ -45,12 +45,10 @@ from ..frames import (
Close,
CloseCode,
Opcode,
prepare_ctrl,
prepare_data,
)
from ..protocol import State
from ..typing import Data, LoggerLike, Subprotocol
from .framing import Frame
from .framing import Frame, prepare_ctrl, prepare_data
__all__ = ["WebSocketCommonProtocol"]

View File

@ -24,10 +24,8 @@ from typing import (
from ..asyncio.compatibility import asyncio_timeout
from ..datastructures import Headers, HeadersLike, MultipleValuesError
from ..exceptions import (
AbortHandshake,
InvalidHandshake,
InvalidHeader,
InvalidMessage,
InvalidOrigin,
InvalidUpgrade,
NegotiationError,
@ -43,6 +41,7 @@ from ..headers import (
from ..http11 import SERVER
from ..protocol import State
from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol
from .exceptions import AbortHandshake, InvalidMessage
from .handshake import build_response, check_request
from .http import read_request
from .protocol import WebSocketCommonProtocol, broadcast
@ -101,9 +100,10 @@ class WebSocketServerProtocol(WebSocketCommonProtocol):
def __init__(
self,
# The version that accepts the path in the second argument is deprecated.
ws_handler: (
Callable[[WebSocketServerProtocol], Awaitable[Any]]
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]]
),
ws_server: WebSocketServer,
*,
@ -398,7 +398,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol):
try:
origin = headers.get("Origin")
except MultipleValuesError as exc:
raise InvalidHeader("Origin", "more than one Origin header found") from exc
raise InvalidHeader("Origin", "multiple values") from exc
if origin is not None:
origin = cast(Origin, origin)
if origins is not None:
@ -984,9 +984,10 @@ class Serve:
def __init__(
self,
# The version that accepts the path in the second argument is deprecated.
ws_handler: (
Callable[[WebSocketServerProtocol], Awaitable[Any]]
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]]
),
host: str | Sequence[str] | None = None,
port: int | None = None,
@ -1141,9 +1142,10 @@ serve = Serve
def unix_serve(
# The version that accepts the path in the second argument is deprecated.
ws_handler: (
Callable[[WebSocketServerProtocol], Awaitable[Any]]
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]]
),
path: str | None = None,
**kwargs: Any,
@ -1170,7 +1172,7 @@ def remove_path_argument(
ws_handler: (
Callable[[WebSocketServerProtocol], Awaitable[Any]]
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]]
)
),
) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]:
try:
inspect.signature(ws_handler).bind(None)

View File

@ -378,9 +378,11 @@ class Protocol:
else:
close = Close(code, reason)
data = close.serialize()
# send_frame() guarantees that self.state is OPEN at this point.
# 7.1.3. The WebSocket Closing Handshake is Started
self.send_frame(Frame(OP_CLOSE, data))
# Since the state is OPEN, no close frame was received yet.
# As a consequence, self.close_rcvd_then_sent remains None.
assert self.close_rcvd is None
self.close_sent = close
self.state = CLOSING
@ -441,6 +443,12 @@ class Protocol:
data = close.serialize()
self.send_frame(Frame(OP_CLOSE, data))
self.close_sent = close
# If recv_messages() raised an exception upon receiving a close
# frame but before echoing it, then close_rcvd is not None even
# though the state is OPEN. This happens when the connection is
# closed while receiving a fragmented message.
if self.close_rcvd is not None:
self.close_rcvd_then_sent = True
self.state = CLOSING
# When failing the connection, a server closes the TCP connection
@ -602,18 +610,18 @@ class Protocol:
- after sending a close frame, during an abnormal closure (7.1.7).
"""
# The server close the TCP connection in the same circumstances where
# discard() replaces parse(). The client closes the connection later,
# after the server closes the connection or a timeout elapses.
# (The latter case cannot be handled in this Sans-I/O layer.)
assert (self.side is SERVER) == (self.eof_sent)
# After the opening handshake completes, the server closes the TCP
# connection in the same circumstances where discard() replaces parse().
# The client closes it when it receives EOF from the server or times
# out. (The latter case cannot be handled in this Sans-I/O layer.)
assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent)
while not (yield from self.reader.at_eof()):
self.reader.discard()
if self.debug:
self.logger.debug("< EOF")
# A server closes the TCP connection immediately, while a client
# waits for the server to close the TCP connection.
if self.side is CLIENT:
if self.state != CONNECTING and self.side is CLIENT:
self.send_eof()
self.state = CLOSED
# If discard() completes normally, execution ends here.

View File

@ -69,7 +69,7 @@ class ServerProtocol(Protocol):
max_size: Maximum size of incoming messages in bytes;
:obj:`None` disables the limit.
logger: Logger for this connection;
defaults to ``logging.getLogger("websockets.client")``;
defaults to ``logging.getLogger("websockets.server")``;
see the :doc:`logging guide <../../topics/logging>` for details.
"""
@ -204,7 +204,6 @@ class ServerProtocol(Protocol):
if protocol_header is not None:
headers["Sec-WebSocket-Protocol"] = protocol_header
self.logger.info("connection open")
return Response(101, "Switching Protocols", headers)
def process_request(
@ -254,12 +253,10 @@ class ServerProtocol(Protocol):
try:
key = headers["Sec-WebSocket-Key"]
except KeyError as exc:
raise InvalidHeader("Sec-WebSocket-Key") from exc
except MultipleValuesError as exc:
raise InvalidHeader(
"Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found"
) from exc
except KeyError:
raise InvalidHeader("Sec-WebSocket-Key") from None
except MultipleValuesError:
raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from None
try:
raw_key = base64.b64decode(key.encode(), validate=True)
@ -270,13 +267,10 @@ class ServerProtocol(Protocol):
try:
version = headers["Sec-WebSocket-Version"]
except KeyError as exc:
raise InvalidHeader("Sec-WebSocket-Version") from exc
except MultipleValuesError as exc:
raise InvalidHeader(
"Sec-WebSocket-Version",
"more than one Sec-WebSocket-Version header found",
) from exc
except KeyError:
raise InvalidHeader("Sec-WebSocket-Version") from None
except MultipleValuesError:
raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from None
if version != "13":
raise InvalidHeaderValue("Sec-WebSocket-Version", version)
@ -314,8 +308,8 @@ class ServerProtocol(Protocol):
# per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3.
try:
origin = headers.get("Origin")
except MultipleValuesError as exc:
raise InvalidHeader("Origin", "more than one Origin header found") from exc
except MultipleValuesError:
raise InvalidHeader("Origin", "multiple values") from None
if origin is not None:
origin = cast(Origin, origin)
if self.origins is not None:
@ -509,7 +503,7 @@ class ServerProtocol(Protocol):
HTTP response to send to the client.
"""
# If a user passes an int instead of a HTTPStatus, fix it automatically.
# If status is an int instead of an HTTPStatus, fix it automatically.
status = http.HTTPStatus(status)
body = text.encode()
headers = Headers(
@ -520,14 +514,7 @@ class ServerProtocol(Protocol):
("Content-Type", "text/plain; charset=utf-8"),
]
)
response = Response(status.value, status.phrase, headers, body)
# When reject() is called from accept(), handshake_exc is already set.
# If a user calls reject(), set handshake_exc to guarantee invariant:
# "handshake_exc is None if and only if opening handshake succeeded."
if self.handshake_exc is None:
self.handshake_exc = InvalidStatus(response)
self.logger.info("connection rejected (%d %s)", status.value, status.phrase)
return response
return Response(status.value, status.phrase, headers, body)
def send_response(self, response: Response) -> None:
"""
@ -550,7 +537,20 @@ class ServerProtocol(Protocol):
if response.status_code == 101:
assert self.state is CONNECTING
self.state = OPEN
self.logger.info("connection open")
else:
# handshake_exc may be already set if accept() encountered an error.
# If the connection isn't open, set handshake_exc to guarantee that
# handshake_exc is None if and only if opening handshake succeeded.
if self.handshake_exc is None:
self.handshake_exc = InvalidStatus(response)
self.logger.info(
"connection rejected (%d %s)",
response.status_code,
response.reason_phrase,
)
self.send_eof()
self.parser = self.discard()
next(self.parser) # start coroutine
@ -580,7 +580,7 @@ class ServerProtocol(Protocol):
class ServerConnection(ServerProtocol):
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
warnings.warn( # deprecated in 11.0 - 2023-04-02
"ServerConnection was renamed to ServerProtocol",
DeprecationWarning,
)

View File

@ -19,7 +19,6 @@ _PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ss
{
// This supports bytes, bytearrays, and memoryview objects,
// which are common data structures for handling byte streams.
// websockets.framing.prepare_data() returns only these types.
// If *tmp isn't NULL, the caller gets a new reference.
if (PyBytes_Check(obj))
{

View File

@ -12,7 +12,7 @@ from ..extensions.base import ClientExtensionFactory
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import validate_subprotocols
from ..http11 import USER_AGENT, Response
from ..protocol import CONNECTING, OPEN, Event
from ..protocol import CONNECTING, Event
from ..typing import LoggerLike, Origin, Subprotocol
from ..uri import parse_uri
from .connection import Connection
@ -80,19 +80,11 @@ class ClientConnection(Connection):
self.protocol.send_request(self.request)
if not self.response_rcvd.wait(timeout):
self.close_socket()
self.recv_events_thread.join()
raise TimeoutError("timed out during handshake")
if self.response is None:
self.close_socket()
self.recv_events_thread.join()
raise ConnectionError("connection closed during handshake")
if self.protocol.state is not OPEN:
self.recv_events_thread.join(self.close_timeout)
self.close_socket()
self.recv_events_thread.join()
# self.protocol.handshake_exc is always set when the connection is lost
# before receiving a response, when the response cannot be parsed, or
# when the response fails the handshake.
if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
@ -156,7 +148,9 @@ def connect(
:func:`connect` may be used as a context manager::
with websockets.sync.client.connect(...) as websocket:
from websockets.sync.client import connect
with connect(...) as websocket:
...
The connection is closed automatically when exiting the context.
@ -210,7 +204,10 @@ def connect(
# Backwards compatibility: ssl used to be called ssl_context.
if ssl is None and "ssl_context" in kwargs:
ssl = kwargs.pop("ssl_context")
warnings.warn("ssl_context was renamed to ssl", DeprecationWarning)
warnings.warn( # deprecated in 13.0 - 2024-08-20
"ssl_context was renamed to ssl",
DeprecationWarning,
)
wsuri = parse_uri(uri)
if not wsuri.secure and ssl is not None:
@ -290,16 +287,20 @@ def connect(
protocol,
close_timeout=close_timeout,
)
# On failure, handshake() closes the socket and raises an exception.
except Exception:
if sock is not None:
sock.close()
raise
try:
connection.handshake(
additional_headers,
user_agent_header,
deadline.timeout(),
)
except Exception:
if sock is not None:
sock.close()
connection.close_socket()
connection.recv_events_thread.join()
raise
return connection

View File

@ -10,8 +10,13 @@ import uuid
from types import TracebackType
from typing import Any, Iterable, Iterator, Mapping
from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl
from ..exceptions import (
ConcurrencyError,
ConnectionClosed,
ConnectionClosedOK,
ProtocolError,
)
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode
from ..http11 import Request, Response
from ..protocol import CLOSED, OPEN, Event, Protocol, State
from ..typing import Data, LoggerLike, Subprotocol
@ -82,9 +87,9 @@ class Connection:
# Mapping of ping IDs to pong waiters, in chronological order.
self.ping_waiters: dict[bytes, threading.Event] = {}
# Receiving events from the socket. This thread explicitly is marked as
# to support creating a connection in a non-daemon thread then using it
# in a daemon thread; this shouldn't block the intpreter from exiting.
# Receiving events from the socket. This thread is marked as daemon to
# allow creating a connection in a non-daemon thread and using it in a
# daemon thread. This mustn't prevent the interpreter from exiting.
self.recv_events_thread = threading.Thread(
target=self.recv_events,
daemon=True,
@ -194,16 +199,18 @@ class Connection:
Raises:
ConnectionClosed: When the connection is closed.
RuntimeError: If two threads call :meth:`recv` or
ConcurrencyError: If two threads call :meth:`recv` or
:meth:`recv_streaming` concurrently.
"""
try:
return self.recv_messages.get(timeout)
except EOFError:
# Wait for the protocol state to be CLOSED before accessing close_exc.
self.recv_events_thread.join()
raise self.protocol.close_exc from self.recv_exc
except RuntimeError:
raise RuntimeError(
except ConcurrencyError:
raise ConcurrencyError(
"cannot call recv while another thread "
"is already running recv or recv_streaming"
) from None
@ -227,7 +234,7 @@ class Connection:
Raises:
ConnectionClosed: When the connection is closed.
RuntimeError: If two threads call :meth:`recv` or
ConcurrencyError: If two threads call :meth:`recv` or
:meth:`recv_streaming` concurrently.
"""
@ -235,9 +242,11 @@ class Connection:
for frame in self.recv_messages.get_iter():
yield frame
except EOFError:
# Wait for the protocol state to be CLOSED before accessing close_exc.
self.recv_events_thread.join()
raise self.protocol.close_exc from self.recv_exc
except RuntimeError:
raise RuntimeError(
except ConcurrencyError:
raise ConcurrencyError(
"cannot call recv_streaming while another thread "
"is already running recv or recv_streaming"
) from None
@ -277,7 +286,7 @@ class Connection:
Raises:
ConnectionClosed: When the connection is closed.
RuntimeError: If the connection is sending a fragmented message.
ConcurrencyError: If the connection is sending a fragmented message.
TypeError: If ``message`` doesn't have a supported type.
"""
@ -287,7 +296,7 @@ class Connection:
if isinstance(message, str):
with self.send_context():
if self.send_in_progress:
raise RuntimeError(
raise ConcurrencyError(
"cannot call send while another thread "
"is already running send"
)
@ -296,7 +305,7 @@ class Connection:
elif isinstance(message, BytesLike):
with self.send_context():
if self.send_in_progress:
raise RuntimeError(
raise ConcurrencyError(
"cannot call send while another thread "
"is already running send"
)
@ -322,7 +331,7 @@ class Connection:
text = True
with self.send_context():
if self.send_in_progress:
raise RuntimeError(
raise ConcurrencyError(
"cannot call send while another thread "
"is already running send"
)
@ -335,7 +344,7 @@ class Connection:
text = False
with self.send_context():
if self.send_in_progress:
raise RuntimeError(
raise ConcurrencyError(
"cannot call send while another thread "
"is already running send"
)
@ -371,7 +380,7 @@ class Connection:
self.protocol.send_continuation(b"", fin=True)
self.send_in_progress = False
except RuntimeError:
except ConcurrencyError:
# We didn't start sending a fragmented message.
# The connection is still usable.
raise
@ -445,17 +454,21 @@ class Connection:
Raises:
ConnectionClosed: When the connection is closed.
RuntimeError: If another ping was sent with the same data and
ConcurrencyError: If another ping was sent with the same data and
the corresponding pong wasn't received yet.
"""
if data is not None:
data = prepare_ctrl(data)
if isinstance(data, BytesLike):
data = bytes(data)
elif isinstance(data, str):
data = data.encode()
elif data is not None:
raise TypeError("data must be str or bytes-like")
with self.send_context():
# Protect against duplicates if a payload is explicitly set.
if data in self.ping_waiters:
raise RuntimeError("already waiting for a pong with the same data")
raise ConcurrencyError("already waiting for a pong with the same data")
# Generate a unique random payload otherwise.
while data is None or data in self.ping_waiters:
@ -481,7 +494,12 @@ class Connection:
ConnectionClosed: When the connection is closed.
"""
data = prepare_ctrl(data)
if isinstance(data, BytesLike):
data = bytes(data)
elif isinstance(data, str):
data = data.encode()
else:
raise TypeError("data must be str or bytes-like")
with self.send_context():
self.protocol.send_pong(data)
@ -615,8 +633,6 @@ class Connection:
self.logger.error("unexpected internal error", exc_info=True)
with self.protocol_mutex:
self.set_recv_exc(exc)
# We don't know where we crashed. Force protocol state to CLOSED.
self.protocol.state = CLOSED
finally:
# This isn't expected to raise an exception.
self.close_socket()
@ -656,7 +672,7 @@ class Connection:
# Let the caller interact with the protocol.
try:
yield
except (ProtocolError, RuntimeError):
except (ProtocolError, ConcurrencyError):
# The protocol state wasn't changed. Exit immediately.
raise
except Exception as exc:
@ -724,6 +740,7 @@ class Connection:
# raise an exception.
if raise_close_exc:
self.close_socket()
# Wait for the protocol state to be CLOSED before accessing close_exc.
self.recv_events_thread.join()
raise self.protocol.close_exc from original_exc
@ -774,4 +791,11 @@ class Connection:
except OSError:
pass # socket is already closed
self.socket.close()
# Calling protocol.receive_eof() is safe because it's idempotent.
# This guarantees that the protocol state becomes CLOSED.
self.protocol.receive_eof()
assert self.protocol.state is CLOSED
# Abort recv() with a ConnectionClosed exception.
self.recv_messages.close()

View File

@ -5,6 +5,7 @@ import queue
import threading
from typing import Iterator, cast
from ..exceptions import ConcurrencyError
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
from ..typing import Data
@ -74,7 +75,7 @@ class Assembler:
Raises:
EOFError: If the stream of frames has ended.
RuntimeError: If two threads run :meth:`get` or :meth:`get_iter`
ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter`
concurrently.
TimeoutError: If a timeout is provided and elapses before a
complete message is received.
@ -85,7 +86,7 @@ class Assembler:
raise EOFError("stream of frames ended")
if self.get_in_progress:
raise RuntimeError("get() or get_iter() is already running")
raise ConcurrencyError("get() or get_iter() is already running")
self.get_in_progress = True
@ -128,14 +129,14 @@ class Assembler:
:class:`bytes` for each frame in the message.
The iterator must be fully consumed before calling :meth:`get_iter` or
:meth:`get` again. Else, :exc:`RuntimeError` is raised.
:meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
This method only makes sense for fragmented messages. If messages aren't
fragmented, use :meth:`get` instead.
Raises:
EOFError: If the stream of frames has ended.
RuntimeError: If two threads run :meth:`get` or :meth:`get_iter`
ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter`
concurrently.
"""
@ -144,7 +145,7 @@ class Assembler:
raise EOFError("stream of frames ended")
if self.get_in_progress:
raise RuntimeError("get() or get_iter() is already running")
raise ConcurrencyError("get() or get_iter() is already running")
chunks = self.chunks
self.chunks = []
@ -198,7 +199,7 @@ class Assembler:
Raises:
EOFError: If the stream of frames has ended.
RuntimeError: If two threads run :meth:`put` concurrently.
ConcurrencyError: If two threads run :meth:`put` concurrently.
"""
with self.mutex:
@ -206,7 +207,7 @@ class Assembler:
raise EOFError("stream of frames ended")
if self.put_in_progress:
raise RuntimeError("put is already running")
raise ConcurrencyError("put is already running")
if frame.opcode is OP_TEXT:
self.decoder = UTF8Decoder(errors="strict")

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import hmac
import http
import logging
import os
@ -10,12 +11,17 @@ import sys
import threading
import warnings
from types import TracebackType
from typing import Any, Callable, Sequence
from typing import Any, Callable, Iterable, Sequence, Tuple, cast
from ..exceptions import InvalidHeader
from ..extensions.base import ServerExtensionFactory
from ..extensions.permessage_deflate import enable_server_permessage_deflate
from ..frames import CloseCode
from ..headers import validate_subprotocols
from ..headers import (
build_www_authenticate_basic,
parse_authorization_basic,
validate_subprotocols,
)
from ..http11 import SERVER, Request, Response
from ..protocol import CONNECTING, OPEN, Event
from ..server import ServerProtocol
@ -24,7 +30,7 @@ from .connection import Connection
from .utils import Deadline
__all__ = ["serve", "unix_serve", "ServerConnection", "Server"]
__all__ = ["serve", "unix_serve", "ServerConnection", "Server", "basic_auth"]
class ServerConnection(Connection):
@ -65,6 +71,7 @@ class ServerConnection(Connection):
protocol,
close_timeout=close_timeout,
)
self.username: str # see basic_auth()
def respond(self, status: StatusLike, text: str) -> Response:
"""
@ -111,61 +118,57 @@ class ServerConnection(Connection):
"""
if not self.request_rcvd.wait(timeout):
self.close_socket()
self.recv_events_thread.join()
raise TimeoutError("timed out during handshake")
if self.request is None:
self.close_socket()
self.recv_events_thread.join()
raise ConnectionError("connection closed during handshake")
if self.request is not None:
with self.send_context(expected_state=CONNECTING):
response = None
with self.send_context(expected_state=CONNECTING):
self.response = None
if process_request is not None:
try:
response = process_request(self, self.request)
except Exception as exc:
self.protocol.handshake_exc = exc
response = self.protocol.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
(
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
),
)
if process_request is not None:
try:
self.response = process_request(self, self.request)
except Exception as exc:
self.protocol.handshake_exc = exc
self.logger.error("opening handshake failed", exc_info=True)
self.response = self.protocol.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
(
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
),
)
if self.response is None:
self.response = self.protocol.accept(self.request)
if server_header:
self.response.headers["Server"] = server_header
if process_response is not None:
try:
response = process_response(self, self.request, self.response)
except Exception as exc:
self.protocol.handshake_exc = exc
self.logger.error("opening handshake failed", exc_info=True)
self.response = self.protocol.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
(
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
),
)
if response is None:
self.response = self.protocol.accept(self.request)
else:
self.response = response
if server_header:
self.response.headers["Server"] = server_header
response = None
if process_response is not None:
try:
response = process_response(self, self.request, self.response)
except Exception as exc:
self.protocol.handshake_exc = exc
response = self.protocol.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
(
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
),
)
if response is not None:
self.response = response
self.protocol.send_response(self.response)
self.protocol.send_response(self.response)
if self.protocol.state is not OPEN:
self.recv_events_thread.join(self.close_timeout)
self.close_socket()
self.recv_events_thread.join()
# self.protocol.handshake_exc is always set when the connection is lost
# before receiving a request, when the request cannot be parsed, when
# the handshake encounters an error, or when process_request or
# process_response sends an HTTP response that rejects the handshake.
if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
@ -297,7 +300,7 @@ class Server:
def __getattr__(name: str) -> Any:
if name == "WebSocketServer":
warnings.warn(
warnings.warn( # deprecated in 13.0 - 2024-08-20
"WebSocketServer was renamed to Server",
DeprecationWarning,
)
@ -368,10 +371,12 @@ def serve(
that it will be closed and call :meth:`~Server.serve_forever` to serve
requests::
from websockets.sync.server import serve
def handler(websocket):
...
with websockets.sync.server.serve(handler, ...) as server:
with serve(handler, ...) as server:
server.serve_forever()
Args:
@ -437,7 +442,10 @@ def serve(
# Backwards compatibility: ssl used to be called ssl_context.
if ssl is None and "ssl_context" in kwargs:
ssl = kwargs.pop("ssl_context")
warnings.warn("ssl_context was renamed to ssl", DeprecationWarning)
warnings.warn( # deprecated in 13.0 - 2024-08-20
"ssl_context was renamed to ssl",
DeprecationWarning,
)
if subprotocols is not None:
validate_subprotocols(subprotocols)
@ -540,26 +548,40 @@ def serve(
protocol,
close_timeout=close_timeout,
)
# On failure, handshake() closes the socket, raises an exception, and
# logs it.
connection.handshake(
process_request,
process_response,
server_header,
deadline.timeout(),
)
except Exception:
sock.close()
return
try:
handler(connection)
except Exception:
protocol.logger.error("connection handler failed", exc_info=True)
connection.close(CloseCode.INTERNAL_ERROR)
else:
connection.close()
try:
connection.handshake(
process_request,
process_response,
server_header,
deadline.timeout(),
)
except TimeoutError:
connection.close_socket()
connection.recv_events_thread.join()
return
except Exception:
connection.logger.error("opening handshake failed", exc_info=True)
connection.close_socket()
connection.recv_events_thread.join()
return
assert connection.protocol.state is OPEN
try:
handler(connection)
except Exception:
connection.logger.error("connection handler failed", exc_info=True)
connection.close(CloseCode.INTERNAL_ERROR)
else:
connection.close()
except Exception: # pragma: no cover
# Don't leak sockets on unexpected errors.
sock.close()
# Initialize server
@ -587,3 +609,119 @@ def unix_serve(
"""
return serve(handler, unix=True, path=path, **kwargs)
def is_credentials(credentials: Any) -> bool:
try:
username, password = credentials
except (TypeError, ValueError):
return False
else:
return isinstance(username, str) and isinstance(password, str)
def basic_auth(
realm: str = "",
credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None,
check_credentials: Callable[[str, str], bool] | None = None,
) -> Callable[[ServerConnection, Request], Response | None]:
"""
Factory for ``process_request`` to enforce HTTP Basic Authentication.
:func:`basic_auth` is designed to integrate with :func:`serve` as follows::
from websockets.sync.server import basic_auth, serve
with serve(
...,
process_request=basic_auth(
realm="my dev server",
credentials=("hello", "iloveyou"),
),
):
If authentication succeeds, the connection's ``username`` attribute is set.
If it fails, the server responds with an HTTP 401 Unauthorized status.
One of ``credentials`` or ``check_credentials`` must be provided; not both.
Args:
realm: Scope of protection. It should contain only ASCII characters
because the encoding of non-ASCII characters is undefined. Refer to
section 2.2 of :rfc:`7235` for details.
credentials: Hard coded authorized credentials. It can be a
``(username, password)`` pair or a list of such pairs.
check_credentials: Function that verifies credentials.
It receives ``username`` and ``password`` arguments and returns
whether they're valid.
Raises:
TypeError: If ``credentials`` or ``check_credentials`` is wrong.
"""
if (credentials is None) == (check_credentials is None):
raise TypeError("provide either credentials or check_credentials")
if credentials is not None:
if is_credentials(credentials):
credentials_list = [cast(Tuple[str, str], credentials)]
elif isinstance(credentials, Iterable):
credentials_list = list(cast(Iterable[Tuple[str, str]], credentials))
if not all(is_credentials(item) for item in credentials_list):
raise TypeError(f"invalid credentials argument: {credentials}")
else:
raise TypeError(f"invalid credentials argument: {credentials}")
credentials_dict = dict(credentials_list)
def check_credentials(username: str, password: str) -> bool:
try:
expected_password = credentials_dict[username]
except KeyError:
return False
return hmac.compare_digest(expected_password, password)
assert check_credentials is not None # help mypy
def process_request(
connection: ServerConnection,
request: Request,
) -> Response | None:
"""
Perform HTTP Basic Authentication.
If it succeeds, set the connection's ``username`` attribute and return
:obj:`None`. If it fails, return an HTTP 401 Unauthorized responss.
"""
try:
authorization = request.headers["Authorization"]
except KeyError:
response = connection.respond(
http.HTTPStatus.UNAUTHORIZED,
"Missing credentials\n",
)
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
return response
try:
username, password = parse_authorization_basic(authorization)
except InvalidHeader:
response = connection.respond(
http.HTTPStatus.UNAUTHORIZED,
"Unsupported credentials\n",
)
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
return response
if not check_credentials(username, password):
response = connection.respond(
http.HTTPStatus.UNAUTHORIZED,
"Invalid credentials\n",
)
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
return response
connection.username = username
return None
return process_request

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import dataclasses
import urllib.parse
from . import exceptions
from .exceptions import InvalidURI
__all__ = ["parse_uri", "WebSocketURI"]
@ -73,11 +73,11 @@ def parse_uri(uri: str) -> WebSocketURI:
"""
parsed = urllib.parse.urlparse(uri)
if parsed.scheme not in ["ws", "wss"]:
raise exceptions.InvalidURI(uri, "scheme isn't ws or wss")
raise InvalidURI(uri, "scheme isn't ws or wss")
if parsed.hostname is None:
raise exceptions.InvalidURI(uri, "hostname isn't provided")
raise InvalidURI(uri, "hostname isn't provided")
if parsed.fragment != "":
raise exceptions.InvalidURI(uri, "fragment identifier is meaningless")
raise InvalidURI(uri, "fragment identifier is meaningless")
secure = parsed.scheme == "wss"
host = parsed.hostname
@ -89,7 +89,7 @@ def parse_uri(uri: str) -> WebSocketURI:
# urllib.parse.urlparse accepts URLs with a username but without a
# password. This doesn't make sense for HTTP Basic Auth credentials.
if username is not None and password is None:
raise exceptions.InvalidURI(uri, "username provided without password")
raise InvalidURI(uri, "username provided without password")
try:
uri.encode("ascii")

View File

@ -20,7 +20,7 @@ __all__ = ["tag", "version", "commit"]
released = True
tag = version = commit = "13.0.1"
tag = version = commit = "13.1"
if not released: # pragma: no cover

View File

@ -1,16 +0,0 @@
zipp-3.20.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
zipp-3.20.1.dist-info/LICENSE,sha256=htoPAa6uRjSKPD1GUZXcHOzN55956HdppkuNoEsqR0E,1023
zipp-3.20.1.dist-info/METADATA,sha256=J3sSw1smScyUJdU4tZp-6S5SoFY5hi8Wk1WebRFWcWA,3682
zipp-3.20.1.dist-info/RECORD,,
zipp-3.20.1.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
zipp-3.20.1.dist-info/top_level.txt,sha256=iAbdoSHfaGqBfVb2XuR9JqSQHCoOsOtG6y9C_LSpqFw,5
zipp/__init__.py,sha256=hFKawMr3tL33PV0OuYkg67V3cRvuvpOdj_zwvZCXmCk,11834
zipp/__pycache__/__init__.cpython-38.pyc,,
zipp/__pycache__/glob.cpython-38.pyc,,
zipp/compat/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
zipp/compat/__pycache__/__init__.cpython-38.pyc,,
zipp/compat/__pycache__/overlay.cpython-38.pyc,,
zipp/compat/__pycache__/py310.cpython-38.pyc,,
zipp/compat/overlay.py,sha256=B1lcbC9TZVA2terZQTjC8A1xfSCaah4oHNMejJZVwyU,295
zipp/compat/py310.py,sha256=KS3sidGTSkoGh3biXiCqRzE6RMEGH0sbRQBevWU73dU,256
zipp/glob.py,sha256=yPjGfHwcJxUn0fld7I-K-ZQSfTaJBBoimCIygU1SZQw,3315

View File

@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: zipp
Version: 3.20.1
Version: 3.20.2
Summary: Backport of pathlib-compatible object wrapper for zip files
Author-email: "Jason R. Coombs" <jaraco@jaraco.com>
Project-URL: Source, https://github.com/jaraco/zipp

Some files were not shown because too many files have changed in this diff Show More