55 lines
1.5 KiB
Python
55 lines
1.5 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, cast
|
|
|
|
import pydantic
|
|
|
|
from ._pydantic import to_strict_json_schema
|
|
from ..types.chat import ChatCompletionToolParam
|
|
from ..types.shared_params import FunctionDefinition
|
|
|
|
|
|
class PydanticFunctionTool(Dict[str, Any]):
|
|
"""Dictionary wrapper so we can pass the given base model
|
|
throughout the entire request stack without having to special
|
|
case it.
|
|
"""
|
|
|
|
model: type[pydantic.BaseModel]
|
|
|
|
def __init__(self, defn: FunctionDefinition, model: type[pydantic.BaseModel]) -> None:
|
|
super().__init__(defn)
|
|
self.model = model
|
|
|
|
def cast(self) -> FunctionDefinition:
|
|
return cast(FunctionDefinition, self)
|
|
|
|
|
|
def pydantic_function_tool(
|
|
model: type[pydantic.BaseModel],
|
|
*,
|
|
name: str | None = None, # inferred from class name by default
|
|
description: str | None = None, # inferred from class docstring by default
|
|
) -> ChatCompletionToolParam:
|
|
if description is None:
|
|
# note: we intentionally don't use `.getdoc()` to avoid
|
|
# including pydantic's docstrings
|
|
description = model.__doc__
|
|
|
|
function = PydanticFunctionTool(
|
|
{
|
|
"name": name or model.__name__,
|
|
"strict": True,
|
|
"parameters": to_strict_json_schema(model),
|
|
},
|
|
model,
|
|
).cast()
|
|
|
|
if description is not None:
|
|
function["description"] = description
|
|
|
|
return {
|
|
"type": "function",
|
|
"function": function,
|
|
}
|