2023-11-03 11:02:53 +08:00
|
|
|
import ast
|
|
|
|
import json
|
2024-05-14 08:51:42 +00:00
|
|
|
from dataclasses import asdict, dataclass, field, fields
|
2023-11-03 11:02:53 +08:00
|
|
|
from typing import Dict, Optional
|
|
|
|
|
|
|
|
from devchat.message import Message
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class OpenAIMessage(Message):
|
|
|
|
role: str = None
|
|
|
|
name: Optional[str] = None
|
|
|
|
function_call: Dict[str, str] = field(default_factory=dict)
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
if not self._validate_role():
|
|
|
|
raise ValueError("Invalid role. Must be one of 'system', 'user', or 'assistant'.")
|
|
|
|
|
|
|
|
if not self._validate_name():
|
2024-05-14 08:51:42 +00:00
|
|
|
raise ValueError(
|
|
|
|
"Invalid name. Must contain a-z, A-Z, 0-9, and underscores, "
|
|
|
|
"with a maximum length of 64 characters."
|
|
|
|
)
|
2023-11-03 11:02:53 +08:00
|
|
|
|
|
|
|
def to_dict(self) -> dict:
|
|
|
|
state = asdict(self)
|
2024-05-14 08:51:42 +00:00
|
|
|
if state["name"] is None:
|
|
|
|
del state["name"]
|
|
|
|
if not state["function_call"] or len(state["function_call"].keys()) == 0:
|
|
|
|
del state["function_call"]
|
2023-11-03 11:02:53 +08:00
|
|
|
return state
|
|
|
|
|
|
|
|
@classmethod
|
2024-05-14 08:51:42 +00:00
|
|
|
def from_dict(cls, message_data: dict) -> "OpenAIMessage":
|
2023-11-03 11:02:53 +08:00
|
|
|
keys = {f.name for f in fields(cls)}
|
|
|
|
kwargs = {k: v for k, v in message_data.items() if k in keys}
|
|
|
|
return cls(**kwargs)
|
|
|
|
|
|
|
|
def function_call_to_json(self):
|
|
|
|
'''
|
|
|
|
convert function_call to json
|
|
|
|
function_call is like this:
|
|
|
|
{
|
|
|
|
"name": function_name,
|
|
|
|
"arguments": '{"key": """value"""}'
|
|
|
|
}
|
|
|
|
'''
|
|
|
|
if not self.function_call:
|
2024-05-14 08:51:42 +00:00
|
|
|
return ""
|
2023-11-03 11:02:53 +08:00
|
|
|
function_call_copy = self.function_call.copy()
|
2024-05-14 08:51:42 +00:00
|
|
|
if "arguments" in function_call_copy:
|
2023-11-03 11:02:53 +08:00
|
|
|
# arguments field may be not a json string
|
|
|
|
# we can try parse it by eval
|
|
|
|
try:
|
2024-05-14 08:51:42 +00:00
|
|
|
function_call_copy["arguments"] = ast.literal_eval(function_call_copy["arguments"])
|
2023-11-03 11:02:53 +08:00
|
|
|
except Exception:
|
|
|
|
# if it is not a json string, we can do nothing
|
|
|
|
try:
|
2024-05-14 08:51:42 +00:00
|
|
|
function_call_copy["arguments"] = json.loads(function_call_copy["arguments"])
|
2023-11-03 11:02:53 +08:00
|
|
|
except Exception:
|
|
|
|
pass
|
2024-05-14 08:51:42 +00:00
|
|
|
return "```command\n" + json.dumps(function_call_copy) + "\n```"
|
2023-11-03 11:02:53 +08:00
|
|
|
|
|
|
|
def stream_from_dict(self, message_data: dict) -> str:
|
|
|
|
"""Append to the message from a dictionary returned from a streaming chat API."""
|
2024-05-14 08:51:42 +00:00
|
|
|
delta = message_data.get("content", "")
|
2023-11-03 11:02:53 +08:00
|
|
|
if self.content:
|
|
|
|
self.content += delta
|
|
|
|
else:
|
|
|
|
self.content = delta
|
|
|
|
|
|
|
|
return delta
|
|
|
|
|
|
|
|
def _validate_role(self) -> bool:
|
|
|
|
"""Validate the role attribute.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: True if the role is valid, False otherwise.
|
|
|
|
"""
|
|
|
|
return self.role in ["system", "user", "assistant", "function"]
|
|
|
|
|
|
|
|
def _validate_name(self) -> bool:
|
|
|
|
"""Validate the name attribute.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: True if the name is valid or None, False otherwise.
|
|
|
|
"""
|
|
|
|
return self._validate_string(self.name)
|
|
|
|
|
|
|
|
def _validate_string(self, string: str) -> bool:
|
|
|
|
"""Validate a string attribute.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: True if the string is valid or None, False otherwise.
|
|
|
|
"""
|
|
|
|
if string is None:
|
|
|
|
return True
|
|
|
|
if not string.strip():
|
|
|
|
return False
|
|
|
|
return len(string) <= 64 and string.replace("_", "").isalnum()
|