2023-11-29 14:07:47 +08:00

250 lines
10 KiB
Python

from dataclasses import dataclass
import json
import sys
from typing import List, Optional
from devchat.prompt import Prompt
from devchat.message import Message
from devchat.utils import update_dict, get_logger
from devchat.utils import openai_message_tokens, openai_response_tokens
from .openai_message import OpenAIMessage
logger = get_logger(__name__)
@dataclass
class OpenAIPrompt(Prompt):
"""
A class to represent a prompt and its corresponding responses from OpenAI APIs.
"""
_id: str = None
@property
def id(self) -> str:
return self._id
@property
def messages(self) -> List[dict]:
combined = []
# Instruction
if self._new_messages[Message.INSTRUCT]:
combined += [msg.to_dict() for msg in self._new_messages[Message.INSTRUCT]]
# History context
if self._history_messages[Message.CONTEXT]:
combined += [update_dict(msg.to_dict(), 'content',
f"<context>\n{msg.content}\n</context>")
for msg in self._history_messages[Message.CONTEXT]]
# History chat
if self._history_messages[Message.CHAT]:
combined += [msg.to_dict() for msg in self._history_messages[Message.CHAT]]
# Request
if self.request:
combined += [self.request.to_dict()]
# New context
if self.new_context:
combined += [update_dict(msg.to_dict(), 'content',
f"<context>\n{msg.content}\n</context>")
for msg in self.new_context]
return combined
def input_messages(self, messages: List[dict]):
self._request_tokens = 0
state = "new_instruct"
for message_data in messages:
self._request_tokens += openai_message_tokens(message_data, self.model)
message = OpenAIMessage.from_dict(message_data)
if state == "new_instruct":
if message.role == "system" and not message.content.startswith("<context>"):
self._new_messages[Message.INSTRUCT].append(message)
else:
state = "history_context"
if state == "history_context":
if message.role == "system" and message.content.startswith("<context>"):
content = message.content.replace("<context>", "").replace("</context>", "")
message.content = content.strip()
self._history_messages[Message.CONTEXT].append(message)
else:
state = "history_chat"
if state == "history_chat":
if message.role in ("user", "assistant", "function"):
self._history_messages[Message.CHAT].append(message)
else:
state = "new_context"
if state == "new_context":
if message.role == "system" and message.content.startswith("<context>"):
content = message.content.replace("<context>", "").replace("</context>", "")
message.content = content.strip()
self._new_messages[Message.CONTEXT].append(message)
else:
logger.warning("Invalid new context message: %s", message)
if not self.request:
while True:
last_message = self._history_messages[Message.CHAT].pop()
if last_message.role in ("user", "function"):
self.request = last_message
break
if last_message.role == "assistant":
self.responses.append(last_message)
continue
self._history_messages[Message.CHAT].append(last_message)
def append_new(self, message_type: str, content: str,
available_tokens: int = sys.maxsize) -> bool:
if message_type not in (Message.INSTRUCT, Message.CONTEXT):
raise ValueError(f"Current messages cannot be of type {message_type}.")
# New instructions and context are of the system role
message = OpenAIMessage(content=content, role='system')
num_tokens = openai_message_tokens(message.to_dict(), self.model)
if num_tokens > available_tokens:
return False
self._new_messages[message_type].append(message)
self._request_tokens += num_tokens
return True
def set_functions(self, functions, available_tokens: int = sys.maxsize):
num_tokens = openai_message_tokens({"functions": json.dumps(functions)}, self.model)
if num_tokens > available_tokens:
return False
self._new_messages[Message.FUNCTION] = functions
self._request_tokens += num_tokens
return True
def get_functions(self):
return self._new_messages.get(Message.FUNCTION, None)
def _prepend_history(self, message_type: str, message: Message,
token_limit: int = sys.maxsize) -> bool:
if message_type == Message.INSTRUCT:
raise ValueError("History messages cannot be of type INSTRUCT.")
num_tokens = openai_message_tokens(message.to_dict(), self.model)
if num_tokens > token_limit - self._request_tokens:
return False
self._history_messages[message_type].insert(0, message)
self._request_tokens += num_tokens
return True
def prepend_history(self, prompt: 'OpenAIPrompt', token_limit: int = sys.maxsize) -> bool:
# Prepend the first response and the request of the prompt
if not self._prepend_history(Message.CHAT, prompt.responses[0], token_limit):
return False
if not self._prepend_history(Message.CHAT, prompt.request, token_limit):
return False
# Append the context messages of the appended prompt
for context_message in prompt.new_context:
if not self._prepend_history(Message.CONTEXT, context_message, token_limit):
return False
return True
def set_request(self, content: str, function_name: Optional[str] = None) -> int:
if not content.strip():
raise ValueError("The request cannot be empty.")
message = OpenAIMessage(content=content,
role=('user' if not function_name else 'function'),
name=function_name)
self.request = message
self._request_tokens += openai_message_tokens(message.to_dict(), self.model)
def set_response(self, response_str: str):
"""
Parse the API response string and set the Prompt object's attributes.
Args:
response_str (str): The JSON-formatted response string from the chat API.
"""
response_data = json.loads(response_str)
self._validate_model(response_data)
self._timestamp_from_dict(response_data)
self._id_from_dict(response_data)
self._request_tokens = response_data['usage']['prompt_tokens']
self._response_tokens = response_data['usage']['completion_tokens']
for choice in response_data['choices']:
index = choice['index']
if index >= len(self.responses):
self.responses.extend([None] * (index - len(self.responses) + 1))
self._response_reasons.extend([None] * (index - len(self._response_reasons) + 1))
self.responses[index] = OpenAIMessage.from_dict(choice['message'])
if choice['finish_reason']:
self._response_reasons[index] = choice['finish_reason']
def append_response(self, delta_str: str) -> str:
"""
Append the content of a streaming response to the existing messages.
Args:
delta_str (str): The JSON-formatted delta string from the chat API.
Returns:
str: The delta content with index 0. None when the response is over.
"""
response_data = json.loads(delta_str)
self._validate_model(response_data)
self._timestamp_from_dict(response_data)
self._id_from_dict(response_data)
delta_content = ''
for choice in response_data['choices']:
delta = choice['delta']
index = choice['index']
finish_reason = choice['finish_reason']
if index >= len(self.responses):
self.responses.extend([None] * (index - len(self.responses) + 1))
self._response_reasons.extend([None] * (index - len(self._response_reasons) + 1))
if not self.responses[index]:
self.responses[index] = OpenAIMessage.from_dict(delta)
if index == 0:
delta_content = self.responses[0].content if self.responses[0].content else ''
else:
if index == 0:
delta_content = self.responses[0].stream_from_dict(delta)
else:
self.responses[index].stream_from_dict(delta)
if 'function_call' in delta:
if 'name' in delta['function_call'] and \
self.responses[index].function_call.get('name', '') == '':
self.responses[index].function_call['name'] = \
delta['function_call']['name']
if 'arguments' in delta['function_call']:
self.responses[index].function_call['arguments'] = \
self.responses[index].function_call.get('arguments', '') + \
delta['function_call']['arguments']
if finish_reason:
self._response_reasons[index] = finish_reason
return delta_content
def _count_response_tokens(self) -> int:
return sum(openai_response_tokens(resp.to_dict(), self.model) for resp in self.responses)
def _validate_model(self, response_data: dict):
if not response_data['model'].startswith(self.model):
raise ValueError(f"Model mismatch: expected '{self.model}', "
f"got '{response_data['model']}'")
def _timestamp_from_dict(self, response_data: dict):
if not self._timestamp:
self._timestamp = response_data['created']
elif self._timestamp != response_data['created']:
self._timestamp = response_data['created']
def _id_from_dict(self, response_data: dict):
if self._id is None:
self._id = response_data['id']
elif self._id != response_data['id']:
raise ValueError(f"ID mismatch: expected {self._id}, "
f"got {response_data['id']}")