151 lines
4.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# flake8: noqa: E402
import json
import os
import re
from typing import Dict, List
import openai
from .pipeline import RetryException, exception_err, exception_handle, parallel, pipeline, retry
def _try_remove_markdown_block_flag(content):
"""
如果content是一个markdown块则删除它的头部```xxx和尾部```
"""
# 定义正则表达式模式用于匹配markdown块的头部和尾部
pattern = r"^\s*```\s*(\w+)\s*\n(.*?)\n\s*```\s*$"
# 使用re模块进行匹配
match = re.search(pattern, content, re.DOTALL | re.MULTILINE)
if match:
# 如果匹配成功则提取出markdown块的内容并返回
_ = match.group(1) # language
markdown_content = match.group(2)
return markdown_content.strip()
else:
# 如果匹配失败,则返回原始内容
return content
def chat_completion_stream_commit(
messages: List[Dict], # [{"role": "user", "content": "hello"}]
llm_config: Dict, # {"model": "...", ...}
):
client = openai.OpenAI(
api_key=os.environ.get("OPENAI_API_KEY", None),
base_url=os.environ.get("OPENAI_API_BASE", None),
)
llm_config["stream"] = True
llm_config["timeout"] = 60
return client.chat.completions.create(messages=messages, **llm_config)
def chat_completion_stream_raw(**kwargs):
client = openai.OpenAI(
api_key=os.environ.get("OPENAI_API_KEY", None),
base_url=os.environ.get("OPENAI_API_BASE", None),
)
kwargs["stream"] = True
kwargs["timeout"] = 60
return client.chat.completions.create(**kwargs)
def stream_out_chunk(chunks):
for chunk in chunks:
chunk_dict = chunk.dict()
delta = chunk_dict["choices"][0]["delta"]
if delta.get("content", None):
print(delta["content"], end="", flush=True)
yield chunk
def retry_timeout(chunks):
try:
for chunk in chunks:
yield chunk
except (openai.APIConnectionError, openai.APITimeoutError) as err:
raise RetryException(err)
def chunk_list(chunks):
return [chunk for chunk in chunks]
def chunks_content(chunks):
content = None
for chunk in chunks:
chunk_dict = chunk.dict()
delta = chunk_dict["choices"][0]["delta"]
if delta.get("content", None):
if content is None:
content = ""
content += delta["content"]
return content
def chunks_call(chunks):
function_name = None
parameters = ""
for chunk in chunks:
chunk = chunk.dict()
delta = chunk["choices"][0]["delta"]
if "tool_calls" in delta and delta["tool_calls"]:
tool_call = delta["tool_calls"][0]["function"]
if tool_call.get("name", None):
function_name = tool_call["name"]
if tool_call.get("arguments", None):
parameters += tool_call["arguments"]
return {"function_name": function_name, "parameters": parameters}
def content_to_json(content):
try:
# json will format as ```json ... ``` in 1106 model
response_content = _try_remove_markdown_block_flag(content)
response_obj = json.loads(response_content)
return response_obj
except json.JSONDecodeError as err:
raise RetryException(err)
except Exception as err:
raise err
def to_dict_content_and_call(content, function_call):
return {"content": content, **function_call}
chat_completion_content = retry(
pipeline(chat_completion_stream_commit, retry_timeout, chunks_content), times=3
)
chat_completion_stream_content = retry(
pipeline(chat_completion_stream_commit, retry_timeout, stream_out_chunk, chunks_content),
times=3,
)
chat_completion_call = retry(
pipeline(chat_completion_stream_commit, retry_timeout, chunks_call), times=3
)
chat_completion_no_stream_return_json = retry(
pipeline(chat_completion_stream_commit, retry_timeout, chunks_content, content_to_json), times=3
)
chat_completion_stream = exception_handle(
retry(
pipeline(
chat_completion_stream_commit,
retry_timeout,
parallel(chunks_content, chunks_call),
to_dict_content_and_call,
),
times=3,
),
lambda err: {"content": None, "function_name": None, "parameters": "", "error": err},
)