workflows/libs/llm_api/minimax_chat.py

235 lines
7.9 KiB
Python
Raw Normal View History

2024-01-17 22:33:28 +08:00
import json
import os
import sys
import time
2024-01-17 22:33:28 +08:00
2024-01-17 22:33:28 +08:00
import requests
class StreamIterWrapper:
2024-01-21 17:11:56 +08:00
def __init__(self, response, is_private=True):
2024-01-17 22:33:28 +08:00
self.response = response
self.create_time = int(time.time())
self.line_iterator = response.iter_lines()
2024-01-21 17:11:56 +08:00
self.is_private = is_private
self.stop = False
2024-01-17 22:33:28 +08:00
def __iter__(self):
return self
def __next__(self):
try:
2024-01-21 17:11:56 +08:00
if self.stop:
raise StopIteration
2024-01-17 22:33:28 +08:00
response_line = next(self.line_iterator)
2024-01-17 22:33:28 +08:00
if response_line == b"":
2024-01-17 22:33:28 +08:00
return self.__next__()
2024-01-17 22:33:28 +08:00
if response_line == b"\n":
2024-01-17 22:33:28 +08:00
return self.__next__()
2024-01-17 22:33:28 +08:00
response_line = response_line.replace(b"data: ", b"")
response_result = json.loads(response_line.decode("utf-8"))
2024-01-21 17:11:56 +08:00
if self.is_private:
if "finish" in response_result and response_result["finish"] == True:
self.stop = True
if "err" in response_result and response_result["err"]:
raise ValueError(f"minimax api response error: {response_result['err']}")
if not self.is_private:
if response_result["choices"][0].get("finish_reason", None):
raise StopIteration
data = {}
if self.is_private:
data = json.loads(response_result["data"])
2024-01-17 22:33:28 +08:00
stream_response = {
2024-01-17 22:33:28 +08:00
"id": f"minimax_{self.create_time}",
"created": self.create_time,
"object": "chat.completion.chunk",
2024-01-21 17:11:56 +08:00
"model": response_result.get("model", "abab5.5-chat"),
2024-01-17 22:33:28 +08:00
"choices": [
2024-01-17 22:33:28 +08:00
{
2024-01-17 22:33:28 +08:00
"index": 0,
"finish_reason": "stop",
"delta": {
"role": "assistant",
2024-01-21 17:11:56 +08:00
"content": response_result["choices"][0]["messages"][0]["text"] if not self.is_private else data.get("text", ""),
2024-01-17 22:33:28 +08:00
},
2024-01-17 22:33:28 +08:00
}
],
2024-01-17 22:33:28 +08:00
"usage": {"prompt_tokens": 10, "completion_tokens": 100},
2024-01-17 22:33:28 +08:00
}
return stream_response
except StopIteration as exc: # If there is no more event
raise StopIteration from exc
except Exception as err:
print("Exception:", err.__class__.__name__, err, file=sys.stderr, end="\n\n")
raise StopIteration from err
def chat_completion(messages, llm_config):
url = _make_api_url()
headers = _make_header()
if _is_private_llm():
payload = _make_private_payload(messages, llm_config)
else:
payload = _make_public_payload(messages, llm_config)
response = requests.post(url, headers=headers, json=payload)
2024-01-21 17:11:56 +08:00
response_json = json.loads(response.text)
if not response_json.get("texts", []):
raise ValueError(f"minimax api response error: {response_json}")
return {"content": response_json["texts"][0]}
2024-01-17 22:33:28 +08:00
2024-01-17 22:33:28 +08:00
2024-01-17 22:33:28 +08:00
def stream_chat_completion(messages, llm_config):
url = _make_api_url()
headers = _make_header()
if _is_private_llm():
payload = _make_private_payload(messages, llm_config, True)
else:
payload = _make_public_payload(messages, llm_config, True)
response = requests.post(url, headers=headers, json=payload)
2024-01-21 17:11:56 +08:00
streamIters = StreamIterWrapper(response, _is_private_llm())
2024-01-17 22:33:28 +08:00
return streamIters
def _is_private_llm():
api_base_url = os.environ.get("OPENAI_API_BASE", "")
return not api_base_url.startswith("https://api.minimax.chat")
2024-01-17 22:33:28 +08:00
2024-01-17 22:33:28 +08:00
def _make_api_url():
api_base_url = os.environ.get("OPENAI_API_BASE", None)
if not api_base_url:
raise ValueError("minimax api url is not set")
if api_base_url.startswith("https://api.minimax.chat"):
if api_base_url.endswith("/"):
api_base_url = api_base_url[:-1]
if not api_base_url.endswith("/v1"):
api_base_url = api_base_url + "/v1"
api_base_url += "/text/chatcompletion_pro"
api_key = os.environ.get("OPENAI_API_KEY", None)
if not api_key:
raise ValueError("minimax api key is not set")
2024-01-17 22:33:28 +08:00
2024-01-17 22:33:28 +08:00
group_id = api_key.split("##")[0]
api_base_url += f"?GroupId={group_id}"
return api_base_url
else:
if api_base_url.endswith("/"):
api_base_url = api_base_url[:-1]
if not api_base_url.endswith("/interact"):
api_base_url = api_base_url + "/interact"
return api_base_url
def _make_api_key():
if _is_private_llm():
return ""
api_key = os.environ.get("OPENAI_API_KEY", None)
return api_key.split("##")[1]
2024-01-17 22:33:28 +08:00
2024-01-17 22:33:28 +08:00
def _make_header():
api_key = _make_api_key()
return {
**({"Authorization": f"Bearer {api_key}"} if not _is_private_llm() else {}),
2024-01-17 22:33:28 +08:00
"Content-Type": "application/json",
2024-01-17 22:33:28 +08:00
}
2024-01-17 22:33:28 +08:00
2024-01-17 22:33:28 +08:00
def _to_private_messages(messages):
new_messages = []
for message in messages:
if message["role"] == "user":
2024-01-17 22:33:28 +08:00
new_messages.append({"role": "user", "name": "user", "text": message["content"]})
2024-01-17 22:33:28 +08:00
else:
2024-01-17 22:33:28 +08:00
new_messages.append({"role": "ai", "name": "ai", "text": message["content"]})
new_messages.append({"role": "ai", "name": "ai", "text": ""})
2024-01-17 22:33:28 +08:00
return new_messages
2024-01-17 22:33:28 +08:00
2024-01-17 22:33:28 +08:00
def _make_private_payload(messages, llm_config, stream=False):
return {
"data": _to_private_messages(messages),
"model_control": {
"system_data": [
{
2024-01-17 22:33:28 +08:00
"role": "system",
2024-01-18 10:07:07 +08:00
"ai_setting": "ai",
"text": "你是minimax编码助理擅长编写代码编写注释编写测试用例并且很注重编码的规范性。",
2024-01-17 22:33:28 +08:00
},
],
# "alpha_frequency": 128,
# "alpha_frequency_src": 1,
# "alpha_presence": 0,
# "alpha_presence_src": 0,
# "block_ngram": 0,
# "clean_init_no_penalty_list": True,
# "context_block_ngram": 0,
# "factual_topp": False,
# "lamda_decay": 1,
# "length_penalty": 1,
# "no_penalty_list": ",",
# "omega_bound": 0,
# "repeat_filter": False,
# "repeat_sampling": 1,
# "skip_text_mask": True,
2024-01-18 14:37:15 +08:00
"tokens_to_generate": llm_config.get("max_tokens", 2048),
2024-01-17 22:33:28 +08:00
# "sampler_type": "nucleus",
"beam_width": 1,
# "delimiter": "\n",
# "min_length": 0,
# "skip_info_mask": True,
"stop_sequence": [],
# "top_p": 0.95,
"temperature": llm_config.get("temperature", 0.95),
},
2024-01-17 22:33:28 +08:00
"stream": stream,
2024-01-17 22:33:28 +08:00
}
2024-01-17 22:33:28 +08:00
2024-01-17 22:33:28 +08:00
def _to_public_messages(messages):
new_messages = []
for message in messages:
if message["role"] == "user":
2024-01-17 22:33:28 +08:00
new_messages.append(
{"sender_type": "USER", "sender_name": "USER", "text": message["content"]}
)
2024-01-17 22:33:28 +08:00
else:
2024-01-17 22:33:28 +08:00
new_messages.append(
{"sender_type": "BOT", "sender_name": "ai", "text": message["content"]}
)
2024-01-17 22:33:28 +08:00
return new_messages
2024-01-17 22:33:28 +08:00
2024-01-17 22:33:28 +08:00
def _make_public_payload(messages, llm_config, stream=False):
response = {
"model": "abab5.5-chat",
2024-01-18 14:37:15 +08:00
"tokens_to_generate": llm_config.get("max_tokens", 2048),
2024-01-17 22:33:28 +08:00
"temperature": llm_config.get("temperature", 0.1),
# "top_p": 0.9,
2024-01-17 22:33:28 +08:00
"reply_constraints": {"sender_type": "BOT", "sender_name": "ai"},
2024-01-17 22:33:28 +08:00
"sample_messages": [],
"plugins": [],
"messages": _to_public_messages(messages),
"bot_setting": [
{
2024-01-17 22:33:28 +08:00
"bot_name": "ai",
"content": (
"MM智能助理是一款由MiniMax自研的"
"没有调用其他产品的接口的大型语言模型。"
"MiniMax是一家中国科技公司一直致力于进行大模型相关的研究。"
),
2024-01-17 22:33:28 +08:00
}
],
"stream": stream,
}
return response