110 lines
4.0 KiB
Python
110 lines
4.0 KiB
Python
import json
|
|
import os
|
|
from typing import Optional, Union, List, Dict, Iterator
|
|
from pydantic import BaseModel, Field
|
|
import openai
|
|
from devchat.chat import Chat
|
|
from devchat.utils import get_user_info, user_id
|
|
from .openai_message import OpenAIMessage
|
|
from .openai_prompt import OpenAIPrompt
|
|
from .minimax_chat import chat_completion, stream_chat_completion
|
|
|
|
|
|
class OpenAIChatParameters(BaseModel, extra='ignore'):
|
|
temperature: Optional[float] = Field(0, ge=0, le=2)
|
|
top_p: Optional[float] = Field(None, ge=0, le=1)
|
|
n: Optional[int] = Field(None, ge=1)
|
|
stream: Optional[bool] = Field(None)
|
|
stop: Optional[Union[str, List[str]]] = Field(None)
|
|
max_tokens: Optional[int] = Field(None, ge=1)
|
|
presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
|
|
frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
|
|
logit_bias: Optional[Dict[int, float]] = Field(None)
|
|
user: Optional[str] = Field(None)
|
|
request_timeout: Optional[int] = Field(32, ge=3)
|
|
|
|
|
|
class OpenAIChatConfig(OpenAIChatParameters):
|
|
"""
|
|
Configuration object for the OpenAIChat APIs.
|
|
"""
|
|
model: str
|
|
|
|
|
|
class OpenAIChat(Chat):
|
|
"""
|
|
OpenAIChat class that handles communication with the OpenAI Chat API.
|
|
"""
|
|
def __init__(self, config: OpenAIChatConfig):
|
|
"""
|
|
Initialize the OpenAIChat class with a configuration object.
|
|
|
|
Args:
|
|
config (OpenAIChatConfig): Configuration object with parameters for the OpenAI Chat API.
|
|
"""
|
|
self.config = config
|
|
|
|
def init_prompt(self, request: str, function_name: Optional[str] = None) -> OpenAIPrompt:
|
|
user, email = get_user_info()
|
|
self.config.user = user_id(user, email)[1]
|
|
prompt = OpenAIPrompt(self.config.model, user, email)
|
|
prompt.set_request(request, function_name=function_name)
|
|
return prompt
|
|
|
|
def load_prompt(self, data: dict) -> OpenAIPrompt:
|
|
data['_new_messages'] = {
|
|
k: [OpenAIMessage.from_dict(m) for m in v]
|
|
if isinstance(v, list) else OpenAIMessage.from_dict(v)
|
|
for k, v in data['_new_messages'].items() if k != 'function'
|
|
}
|
|
data['_history_messages'] = {k: [OpenAIMessage.from_dict(m) for m in v]
|
|
for k, v in data['_history_messages'].items()}
|
|
return OpenAIPrompt(**data)
|
|
|
|
def complete_response(self, prompt: OpenAIPrompt) -> str:
|
|
# Filter the config parameters with set values
|
|
config_params = self.config.dict(exclude_unset=True)
|
|
if prompt.get_functions():
|
|
config_params['functions'] = prompt.get_functions()
|
|
config_params['function_call'] = 'auto'
|
|
config_params['stream'] = False
|
|
|
|
if config_params['model'].startswith('abab'):
|
|
return chat_completion(prompt.messages, config_params)
|
|
|
|
client = openai.OpenAI(
|
|
api_key=os.environ.get("OPENAI_API_KEY", None),
|
|
base_url=os.environ.get("OPENAI_API_BASE", None)
|
|
)
|
|
|
|
response = client.chat.completions.create(
|
|
messages=prompt.messages,
|
|
**config_params
|
|
)
|
|
if isinstance(response, openai.types.chat.chat_completion.ChatCompletion):
|
|
return json.dumps(response.dict())
|
|
return str(response)
|
|
|
|
def stream_response(self, prompt: OpenAIPrompt) -> Iterator:
|
|
# Filter the config parameters with set values
|
|
config_params = self.config.dict(exclude_unset=True)
|
|
if prompt.get_functions():
|
|
config_params['functions'] = prompt.get_functions()
|
|
config_params['function_call'] = 'auto'
|
|
config_params['stream'] = True
|
|
|
|
if config_params['model'].startswith('abab'):
|
|
return stream_chat_completion(prompt.messages, config_params)
|
|
|
|
client = openai.OpenAI(
|
|
api_key=os.environ.get("OPENAI_API_KEY", None),
|
|
base_url=os.environ.get("OPENAI_API_BASE", None)
|
|
)
|
|
|
|
response = client.chat.completions.create(
|
|
messages=prompt.messages,
|
|
**config_params,
|
|
timeout=60
|
|
)
|
|
return response
|