265 lines
8.9 KiB
Python

from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Iterable, cast
from typing_extensions import TypeVar, TypeGuard, assert_never
import pydantic
from .._tools import PydanticFunctionTool
from ..._types import NOT_GIVEN, NotGiven
from ..._utils import is_dict, is_given
from ..._compat import PYDANTIC_V2, model_parse_json
from ..._models import construct_type_unchecked
from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
from ...types.chat import (
ParsedChoice,
ChatCompletion,
ParsedFunction,
ParsedChatCompletion,
ChatCompletionMessage,
ParsedFunctionToolCall,
ChatCompletionToolParam,
ParsedChatCompletionMessage,
completion_create_params,
)
from ..._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
from ...types.shared_params import FunctionDefinition
from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
from ...types.chat.chat_completion_message_tool_call import Function
ResponseFormatT = TypeVar(
"ResponseFormatT",
# if it isn't given then we don't do any parsing
default=None,
)
_default_response_format: None = None
def validate_input_tools(
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> None:
if not is_given(tools):
return
for tool in tools:
if tool["type"] != "function":
raise ValueError(
f'Currently only `function` tool types support auto-parsing; Received `{tool["type"]}`',
)
strict = tool["function"].get("strict")
if strict is not True:
raise ValueError(
f'`{tool["function"]["name"]}` is not strict. Only `strict` function tools can be auto-parsed'
)
def parse_chat_completion(
*,
response_format: type[ResponseFormatT] | completion_create_params.ResponseFormat | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
chat_completion: ChatCompletion | ParsedChatCompletion[object],
) -> ParsedChatCompletion[ResponseFormatT]:
if is_given(input_tools):
input_tools = [t for t in input_tools]
else:
input_tools = []
choices: list[ParsedChoice[ResponseFormatT]] = []
for choice in chat_completion.choices:
if choice.finish_reason == "length":
raise LengthFinishReasonError()
if choice.finish_reason == "content_filter":
raise ContentFilterFinishReasonError()
message = choice.message
tool_calls: list[ParsedFunctionToolCall] = []
if message.tool_calls:
for tool_call in message.tool_calls:
if tool_call.type == "function":
tool_call_dict = tool_call.to_dict()
tool_calls.append(
construct_type_unchecked(
value={
**tool_call_dict,
"function": {
**cast(Any, tool_call_dict["function"]),
"parsed_arguments": parse_function_tool_arguments(
input_tools=input_tools, function=tool_call.function
),
},
},
type_=ParsedFunctionToolCall,
)
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call)
else:
tool_calls.append(tool_call)
choices.append(
construct_type_unchecked(
type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
value={
**choice.to_dict(),
"message": {
**message.to_dict(),
"parsed": maybe_parse_content(
response_format=response_format,
message=message,
),
"tool_calls": tool_calls,
},
},
)
)
return cast(
ParsedChatCompletion[ResponseFormatT],
construct_type_unchecked(
type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
value={
**chat_completion.to_dict(),
"choices": choices,
},
),
)
def get_input_tool_by_name(*, input_tools: list[ChatCompletionToolParam], name: str) -> ChatCompletionToolParam | None:
return next((t for t in input_tools if t.get("function", {}).get("name") == name), None)
def parse_function_tool_arguments(
*, input_tools: list[ChatCompletionToolParam], function: Function | ParsedFunction
) -> object:
input_tool = get_input_tool_by_name(input_tools=input_tools, name=function.name)
if not input_tool:
return None
input_fn = cast(object, input_tool.get("function"))
if isinstance(input_fn, PydanticFunctionTool):
return model_parse_json(input_fn.model, function.arguments)
input_fn = cast(FunctionDefinition, input_fn)
if not input_fn.get("strict"):
return None
return json.loads(function.arguments)
def maybe_parse_content(
*,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
message: ChatCompletionMessage | ParsedChatCompletionMessage[object],
) -> ResponseFormatT | None:
if has_rich_response_format(response_format) and message.content is not None and not message.refusal:
return _parse_content(response_format, message.content)
return None
def solve_response_format_t(
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
) -> type[ResponseFormatT]:
"""Return the runtime type for the given response format.
If no response format is given, or if we won't auto-parse the response format
then we default to `None`.
"""
if has_rich_response_format(response_format):
return response_format
return cast("type[ResponseFormatT]", _default_response_format)
def has_parseable_input(
*,
response_format: type | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> bool:
if has_rich_response_format(response_format):
return True
for input_tool in input_tools or []:
if is_parseable_tool(input_tool):
return True
return False
def has_rich_response_format(
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
) -> TypeGuard[type[ResponseFormatT]]:
if not is_given(response_format):
return False
if is_response_format_param(response_format):
return False
return True
def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
return is_dict(response_format)
def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
input_fn = cast(object, input_tool.get("function"))
if isinstance(input_fn, PydanticFunctionTool):
return True
return cast(FunctionDefinition, input_fn).get("strict") or False
def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
if is_basemodel_type(response_format):
return cast(ResponseFormatT, model_parse_json(response_format, content))
if is_dataclass_like_type(response_format):
if not PYDANTIC_V2:
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
return pydantic.TypeAdapter(response_format).validate_json(content)
raise TypeError(f"Unable to automatically parse response format type {response_format}")
def type_to_response_format_param(
response_format: type | completion_create_params.ResponseFormat | NotGiven,
) -> ResponseFormatParam | NotGiven:
if not is_given(response_format):
return NOT_GIVEN
if is_response_format_param(response_format):
return response_format
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
# a safe default behaviour but we know that at this point the `response_format`
# can only be a `type`
response_format = cast(type, response_format)
json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
if is_basemodel_type(response_format):
name = response_format.__name__
json_schema_type = response_format
elif is_dataclass_like_type(response_format):
name = response_format.__name__
json_schema_type = pydantic.TypeAdapter(response_format)
else:
raise TypeError(f"Unsupported response_format type - {response_format}")
return {
"type": "json_schema",
"json_schema": {
"schema": to_strict_json_schema(json_schema_type),
"name": name,
"strict": True,
},
}