2023-11-03 11:02:53 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import inspect
|
|
|
|
import datetime
|
|
|
|
import functools
|
|
|
|
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast
|
|
|
|
from typing_extensions import Awaitable, ParamSpec, get_args, override, get_origin
|
|
|
|
|
|
|
|
import httpx
|
|
|
|
import pydantic
|
|
|
|
|
2023-11-15 15:46:46 +08:00
|
|
|
from ._types import NoneType, UnknownResponse, BinaryResponseContent
|
2023-11-03 11:02:53 +08:00
|
|
|
from ._utils import is_given
|
|
|
|
from ._models import BaseModel
|
|
|
|
from ._constants import RAW_RESPONSE_HEADER
|
|
|
|
from ._exceptions import APIResponseValidationError
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from ._models import FinalRequestOptions
|
|
|
|
from ._base_client import Stream, BaseClient, AsyncStream
|
|
|
|
|
|
|
|
|
|
|
|
P = ParamSpec("P")
|
|
|
|
R = TypeVar("R")
|
|
|
|
|
|
|
|
|
|
|
|
class APIResponse(Generic[R]):
|
|
|
|
_cast_to: type[R]
|
|
|
|
_client: BaseClient[Any, Any]
|
|
|
|
_parsed: R | None
|
|
|
|
_stream: bool
|
|
|
|
_stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
|
|
|
|
_options: FinalRequestOptions
|
|
|
|
|
|
|
|
http_response: httpx.Response
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
raw: httpx.Response,
|
|
|
|
cast_to: type[R],
|
|
|
|
client: BaseClient[Any, Any],
|
|
|
|
stream: bool,
|
|
|
|
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
|
|
|
|
options: FinalRequestOptions,
|
|
|
|
) -> None:
|
|
|
|
self._cast_to = cast_to
|
|
|
|
self._client = client
|
|
|
|
self._parsed = None
|
|
|
|
self._stream = stream
|
|
|
|
self._stream_cls = stream_cls
|
|
|
|
self._options = options
|
|
|
|
self.http_response = raw
|
|
|
|
|
|
|
|
def parse(self) -> R:
|
|
|
|
if self._parsed is not None:
|
|
|
|
return self._parsed
|
|
|
|
|
|
|
|
parsed = self._parse()
|
|
|
|
if is_given(self._options.post_parser):
|
|
|
|
parsed = self._options.post_parser(parsed)
|
|
|
|
|
|
|
|
self._parsed = parsed
|
|
|
|
return parsed
|
|
|
|
|
|
|
|
@property
|
|
|
|
def headers(self) -> httpx.Headers:
|
|
|
|
return self.http_response.headers
|
|
|
|
|
|
|
|
@property
|
|
|
|
def http_request(self) -> httpx.Request:
|
|
|
|
return self.http_response.request
|
|
|
|
|
|
|
|
@property
|
|
|
|
def status_code(self) -> int:
|
|
|
|
return self.http_response.status_code
|
|
|
|
|
|
|
|
@property
|
|
|
|
def url(self) -> httpx.URL:
|
|
|
|
return self.http_response.url
|
|
|
|
|
|
|
|
@property
|
|
|
|
def method(self) -> str:
|
|
|
|
return self.http_request.method
|
|
|
|
|
|
|
|
@property
|
|
|
|
def content(self) -> bytes:
|
|
|
|
return self.http_response.content
|
|
|
|
|
|
|
|
@property
|
|
|
|
def text(self) -> str:
|
|
|
|
return self.http_response.text
|
|
|
|
|
|
|
|
@property
|
|
|
|
def http_version(self) -> str:
|
|
|
|
return self.http_response.http_version
|
|
|
|
|
|
|
|
@property
|
|
|
|
def elapsed(self) -> datetime.timedelta:
|
|
|
|
"""The time taken for the complete request/response cycle to complete."""
|
|
|
|
return self.http_response.elapsed
|
|
|
|
|
|
|
|
def _parse(self) -> R:
|
|
|
|
if self._stream:
|
|
|
|
if self._stream_cls:
|
|
|
|
return cast(
|
|
|
|
R,
|
|
|
|
self._stream_cls(
|
|
|
|
cast_to=_extract_stream_chunk_type(self._stream_cls),
|
|
|
|
response=self.http_response,
|
|
|
|
client=cast(Any, self._client),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls)
|
|
|
|
if stream_cls is None:
|
|
|
|
raise MissingStreamClassError()
|
|
|
|
|
|
|
|
return cast(
|
|
|
|
R,
|
|
|
|
stream_cls(
|
|
|
|
cast_to=self._cast_to,
|
|
|
|
response=self.http_response,
|
|
|
|
client=cast(Any, self._client),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
cast_to = self._cast_to
|
|
|
|
if cast_to is NoneType:
|
|
|
|
return cast(R, None)
|
|
|
|
|
|
|
|
response = self.http_response
|
|
|
|
if cast_to == str:
|
|
|
|
return cast(R, response.text)
|
|
|
|
|
|
|
|
origin = get_origin(cast_to) or cast_to
|
|
|
|
|
2023-11-15 15:46:46 +08:00
|
|
|
if inspect.isclass(origin) and issubclass(origin, BinaryResponseContent):
|
|
|
|
return cast(R, cast_to(response)) # type: ignore
|
|
|
|
|
2023-11-03 11:02:53 +08:00
|
|
|
if origin == APIResponse:
|
|
|
|
raise RuntimeError("Unexpected state - cast_to is `APIResponse`")
|
|
|
|
|
|
|
|
if inspect.isclass(origin) and issubclass(origin, httpx.Response):
|
|
|
|
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
|
|
|
|
# and pass that class to our request functions. We cannot change the variance to be either
|
|
|
|
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
|
|
|
|
# the response class ourselves but that is something that should be supported directly in httpx
|
|
|
|
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
|
|
|
|
if cast_to != httpx.Response:
|
|
|
|
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
|
|
|
|
return cast(R, response)
|
|
|
|
|
|
|
|
# The check here is necessary as we are subverting the the type system
|
|
|
|
# with casts as the relationship between TypeVars and Types are very strict
|
|
|
|
# which means we must return *exactly* what was input or transform it in a
|
|
|
|
# way that retains the TypeVar state. As we cannot do that in this function
|
|
|
|
# then we have to resort to using `cast`. At the time of writing, we know this
|
|
|
|
# to be safe as we have handled all the types that could be bound to the
|
|
|
|
# `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
|
|
|
|
# this function would become unsafe but a type checker would not report an error.
|
|
|
|
if (
|
|
|
|
cast_to is not UnknownResponse
|
|
|
|
and not origin is list
|
|
|
|
and not origin is dict
|
|
|
|
and not origin is Union
|
|
|
|
and not issubclass(origin, BaseModel)
|
|
|
|
):
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}."
|
|
|
|
)
|
|
|
|
|
|
|
|
# split is required to handle cases where additional information is included
|
|
|
|
# in the response, e.g. application/json; charset=utf-8
|
|
|
|
content_type, *_ = response.headers.get("content-type").split(";")
|
|
|
|
if content_type != "application/json":
|
|
|
|
if self._client._strict_response_validation:
|
|
|
|
raise APIResponseValidationError(
|
|
|
|
response=response,
|
|
|
|
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
|
|
|
|
body=response.text,
|
|
|
|
)
|
|
|
|
|
|
|
|
# If the API responds with content that isn't JSON then we just return
|
|
|
|
# the (decoded) text without performing any parsing so that you can still
|
|
|
|
# handle the response however you need to.
|
|
|
|
return response.text # type: ignore
|
|
|
|
|
|
|
|
data = response.json()
|
|
|
|
|
|
|
|
try:
|
|
|
|
return self._client._process_response_data(
|
|
|
|
data=data,
|
|
|
|
cast_to=cast_to, # type: ignore
|
|
|
|
response=response,
|
|
|
|
)
|
|
|
|
except pydantic.ValidationError as err:
|
|
|
|
raise APIResponseValidationError(response=response, body=data) from err
|
|
|
|
|
|
|
|
@override
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
return f"<APIResponse [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>"
|
|
|
|
|
|
|
|
|
|
|
|
class MissingStreamClassError(TypeError):
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__(
|
|
|
|
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_stream_chunk_type(stream_cls: type) -> type:
|
|
|
|
args = get_args(stream_cls)
|
|
|
|
if not args:
|
|
|
|
raise TypeError(
|
|
|
|
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
|
|
|
|
)
|
|
|
|
return cast(type, args[0])
|
|
|
|
|
|
|
|
|
|
|
|
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:
|
|
|
|
"""Higher order function that takes one of our bound API methods and wraps it
|
|
|
|
to support returning the raw `APIResponse` object directly.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@functools.wraps(func)
|
|
|
|
def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]:
|
|
|
|
extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
|
|
|
|
extra_headers[RAW_RESPONSE_HEADER] = "true"
|
|
|
|
|
|
|
|
kwargs["extra_headers"] = extra_headers
|
|
|
|
|
|
|
|
return cast(APIResponse[R], func(*args, **kwargs))
|
|
|
|
|
|
|
|
return wrapped
|
|
|
|
|
|
|
|
|
|
|
|
def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[APIResponse[R]]]:
|
|
|
|
"""Higher order function that takes one of our bound API methods and wraps it
|
|
|
|
to support returning the raw `APIResponse` object directly.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@functools.wraps(func)
|
|
|
|
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]:
|
|
|
|
extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
|
|
|
|
extra_headers[RAW_RESPONSE_HEADER] = "true"
|
|
|
|
|
|
|
|
kwargs["extra_headers"] = extra_headers
|
|
|
|
|
|
|
|
return cast(APIResponse[R], await func(*args, **kwargs))
|
|
|
|
|
|
|
|
return wrapped
|