2023-12-20 21:13:05 +08:00
|
|
|
|
# flake8: noqa: E402
|
2024-01-05 12:27:57 +08:00
|
|
|
|
import json
|
2023-12-13 14:18:08 +08:00
|
|
|
|
import os
|
2024-01-05 12:27:57 +08:00
|
|
|
|
import re
|
2023-12-08 10:37:32 +08:00
|
|
|
|
import sys
|
2024-01-24 17:34:27 +08:00
|
|
|
|
from typing import Dict, List
|
2023-12-08 10:37:32 +08:00
|
|
|
|
|
|
|
|
|
import openai
|
|
|
|
|
|
2023-12-20 21:13:05 +08:00
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
|
|
|
|
|
|
|
|
|
from ide_services.services import log_warn
|
|
|
|
|
|
2023-12-08 10:55:48 +08:00
|
|
|
|
|
|
|
|
|
def _try_remove_markdown_block_flag(content):
|
|
|
|
|
"""
|
|
|
|
|
如果content是一个markdown块,则删除它的头部```xxx和尾部```
|
|
|
|
|
"""
|
|
|
|
|
# 定义正则表达式模式,用于匹配markdown块的头部和尾部
|
2023-12-08 18:28:36 +08:00
|
|
|
|
pattern = r"^\s*```\s*(\w+)\s*\n(.*?)\n\s*```\s*$"
|
|
|
|
|
|
2023-12-08 10:55:48 +08:00
|
|
|
|
# 使用re模块进行匹配
|
|
|
|
|
match = re.search(pattern, content, re.DOTALL | re.MULTILINE)
|
2023-12-08 18:28:36 +08:00
|
|
|
|
|
2023-12-08 10:55:48 +08:00
|
|
|
|
if match:
|
|
|
|
|
# 如果匹配成功,则提取出markdown块的内容并返回
|
2023-12-08 18:38:12 +08:00
|
|
|
|
_ = match.group(1) # language
|
2023-12-08 10:55:48 +08:00
|
|
|
|
markdown_content = match.group(2)
|
|
|
|
|
return markdown_content.strip()
|
|
|
|
|
else:
|
|
|
|
|
# 如果匹配失败,则返回原始内容
|
|
|
|
|
return content
|
|
|
|
|
|
2023-12-08 18:28:36 +08:00
|
|
|
|
|
2024-01-24 17:34:27 +08:00
|
|
|
|
def chat_completion_stream(
|
|
|
|
|
messages: List[Dict], # [{"role": "user", "content": "hello"}]
|
|
|
|
|
llm_config: Dict, # {"model": "...", ...}
|
|
|
|
|
error_out: bool = True,
|
2024-01-24 18:01:29 +08:00
|
|
|
|
stream_out=False,
|
2024-01-24 17:34:27 +08:00
|
|
|
|
) -> str:
|
2023-12-08 11:08:36 +08:00
|
|
|
|
"""
|
|
|
|
|
通过ChatCompletion API获取OpenAI聊天机器人的回复。
|
2023-12-08 18:28:36 +08:00
|
|
|
|
|
2023-12-08 11:08:36 +08:00
|
|
|
|
Args:
|
|
|
|
|
messages: 一个列表,包含用户输入的消息。
|
|
|
|
|
llm_config: 一个字典,包含ChatCompletion API的配置信息。
|
|
|
|
|
error_out: 如果为True,遇到异常时输出错误信息并返回None,否则返回None。
|
2023-12-08 18:28:36 +08:00
|
|
|
|
|
2023-12-08 11:08:36 +08:00
|
|
|
|
Returns:
|
|
|
|
|
如果成功获取到聊天机器人的回复,返回一个字符串类型的回复消息。如果连接失败,则返回None。
|
2023-12-08 18:28:36 +08:00
|
|
|
|
|
2023-12-08 11:08:36 +08:00
|
|
|
|
"""
|
2023-12-13 14:18:08 +08:00
|
|
|
|
for try_times in range(3):
|
2023-12-08 10:37:32 +08:00
|
|
|
|
try:
|
2023-12-13 14:18:08 +08:00
|
|
|
|
client = openai.OpenAI(
|
|
|
|
|
api_key=os.environ.get("OPENAI_API_KEY", None),
|
2023-12-14 11:01:04 +08:00
|
|
|
|
base_url=os.environ.get("OPENAI_API_BASE", None),
|
2023-12-13 14:18:08 +08:00
|
|
|
|
)
|
|
|
|
|
|
2023-12-14 11:01:04 +08:00
|
|
|
|
llm_config["stream"] = True
|
|
|
|
|
llm_config["timeout"] = 8
|
|
|
|
|
response = client.chat.completions.create(messages=messages, **llm_config)
|
|
|
|
|
|
|
|
|
|
response_result = {"content": None, "function_name": None, "parameters": ""}
|
|
|
|
|
for chunk in response: # pylint: disable=E1133
|
2023-12-13 14:18:08 +08:00
|
|
|
|
chunk = chunk.dict()
|
|
|
|
|
delta = chunk["choices"][0]["delta"]
|
2023-12-14 11:01:04 +08:00
|
|
|
|
if "tool_calls" in delta and delta["tool_calls"]:
|
|
|
|
|
tool_call = delta["tool_calls"][0]["function"]
|
|
|
|
|
if tool_call.get("name", None):
|
|
|
|
|
response_result["function_name"] = tool_call["name"]
|
2023-12-13 14:18:08 +08:00
|
|
|
|
if tool_call.get("arguments", None):
|
|
|
|
|
response_result["parameters"] += tool_call["arguments"]
|
2023-12-14 11:01:04 +08:00
|
|
|
|
if delta.get("content", None):
|
2024-01-02 21:37:02 +08:00
|
|
|
|
if stream_out:
|
|
|
|
|
print(delta["content"], end="", flush=True)
|
2023-12-13 14:18:08 +08:00
|
|
|
|
if response_result["content"]:
|
|
|
|
|
response_result["content"] += delta["content"]
|
|
|
|
|
else:
|
|
|
|
|
response_result["content"] = delta["content"]
|
|
|
|
|
return response_result
|
|
|
|
|
except (openai.APIConnectionError, openai.APITimeoutError) as err:
|
2024-01-04 10:59:46 +08:00
|
|
|
|
log_warn(f"Exception: {err.__class__.__name__}: {err}")
|
2023-12-13 14:18:08 +08:00
|
|
|
|
if try_times >= 2:
|
2024-01-24 17:34:27 +08:00
|
|
|
|
return {"content": None, "function_name": None, "parameters": "", "error": err}
|
2023-12-08 10:37:32 +08:00
|
|
|
|
continue
|
2023-12-13 14:18:08 +08:00
|
|
|
|
except openai.APIError as err:
|
2023-12-08 10:37:32 +08:00
|
|
|
|
if error_out:
|
2024-01-24 17:34:27 +08:00
|
|
|
|
print("Exception:", err, file=sys.stderr, flush=True)
|
|
|
|
|
return {"content": None, "function_name": None, "parameters": "", "error": err}
|
2024-01-24 17:34:27 +08:00
|
|
|
|
except Exception as err:
|
|
|
|
|
if error_out:
|
2024-01-24 17:34:27 +08:00
|
|
|
|
print("Exception:", err, file=sys.stderr, flush=True)
|
|
|
|
|
return {"content": None, "function_name": None, "parameters": "", "error": err}
|
2023-12-08 10:37:32 +08:00
|
|
|
|
|
2023-12-08 18:28:36 +08:00
|
|
|
|
|
|
|
|
|
def chat_completion_no_stream_return_json(messages, llm_config, error_out: bool = True):
|
2023-12-08 11:08:36 +08:00
|
|
|
|
"""
|
2023-12-08 18:38:12 +08:00
|
|
|
|
尝试三次从聊天完成API获取结果,并返回JSON对象。
|
|
|
|
|
如果无法解析JSON,将尝试三次,直到出现错误或达到最大尝试次数。
|
2023-12-08 18:28:36 +08:00
|
|
|
|
|
2023-12-08 11:08:36 +08:00
|
|
|
|
Args:
|
|
|
|
|
messages (List[str]): 用户输入的消息列表。
|
|
|
|
|
llm_config (Dict[str, Any]): 聊天配置字典。
|
|
|
|
|
error_out (bool, optional): 如果为True,则如果出现错误将打印错误消息并返回None。默认为True。
|
2023-12-08 18:28:36 +08:00
|
|
|
|
|
2023-12-08 11:08:36 +08:00
|
|
|
|
Returns:
|
2023-12-08 18:38:12 +08:00
|
|
|
|
Dict[str, Any]: 从聊天完成API获取的JSON对象。
|
|
|
|
|
如果无法解析JSON或达到最大尝试次数,则返回None。
|
2023-12-08 11:08:36 +08:00
|
|
|
|
"""
|
2023-12-08 10:37:32 +08:00
|
|
|
|
for _1 in range(3):
|
2024-01-03 11:08:38 +08:00
|
|
|
|
response = chat_completion_stream(messages, llm_config)
|
2024-01-24 17:34:27 +08:00
|
|
|
|
if not response["content"]:
|
2023-12-08 10:37:32 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
try:
|
2023-12-08 10:55:48 +08:00
|
|
|
|
# json will format as ```json ... ``` in 1106 model
|
|
|
|
|
response_content = _try_remove_markdown_block_flag(response["content"])
|
|
|
|
|
response_obj = json.loads(response_content)
|
2023-12-08 10:37:32 +08:00
|
|
|
|
return response_obj
|
2023-12-13 14:18:08 +08:00
|
|
|
|
except json.JSONDecodeError:
|
2023-12-20 21:13:05 +08:00
|
|
|
|
log_warn(f"JSONDecodeError: {response['content']}")
|
2023-12-08 10:37:32 +08:00
|
|
|
|
continue
|
2023-12-13 14:18:08 +08:00
|
|
|
|
except Exception as err:
|
|
|
|
|
if error_out:
|
2023-12-14 11:01:04 +08:00
|
|
|
|
print("Exception: ", err, file=sys.stderr, flush=True)
|
2023-12-13 14:18:08 +08:00
|
|
|
|
return None
|
2023-12-08 10:37:32 +08:00
|
|
|
|
if error_out:
|
|
|
|
|
print("Not valid json response:", response["content"], file=sys.stderr, flush=True)
|
|
|
|
|
return None
|