317 lines
11 KiB
Python
Raw Normal View History

"""
openai api utils
"""
2024-02-08 13:19:24 +08:00
# flake8: noqa: E402
# Import necessary libraries
2024-02-08 13:19:24 +08:00
import json
import os
import re
2024-11-12 11:58:54 +08:00
from pathlib import Path
from typing import Any, Dict, List
2024-02-08 13:19:24 +08:00
import httpx
2024-02-08 13:19:24 +08:00
import openai
2024-11-12 11:58:54 +08:00
import oyaml as yaml
2024-02-08 13:19:24 +08:00
from devchat.ide import IDEService
2024-11-12 11:58:54 +08:00
from devchat.workflow.path import CHAT_CONFIG_FILENAME, CHAT_DIR
2024-02-08 13:19:24 +08:00
from .pipeline import (
RetryException, # Import RetryException class
exception_handle, # Function to handle exceptions
parallel, # Function to run tasks in parallel
pipeline, # Function to create a pipeline of tasks
retry, # Function to retry a task
2024-02-08 13:19:24 +08:00
)
def _try_remove_markdown_block_flag(content):
"""
If the content is a markdown block, this function removes the header ```xxx and footer ```
2024-02-08 13:19:24 +08:00
"""
# Define a regex pattern to match the header and footer of a markdown block
2024-02-08 13:19:24 +08:00
pattern = r"^\s*```\s*(\w+)\s*\n(.*?)\n\s*```\s*$"
# Use the re module to match the pattern
2024-02-08 13:19:24 +08:00
match = re.search(pattern, content, re.DOTALL | re.MULTILINE)
if match:
# If a match is found, extract the content of the markdown block and return it
2024-02-08 13:19:24 +08:00
_ = match.group(1) # language
markdown_content = match.group(2)
return markdown_content.strip()
# If no match is found, return the original content
2024-02-08 13:19:24 +08:00
return content
2024-11-12 11:58:54 +08:00
# 模块级变量用于缓存配置
_chat_config: Dict[str, Any] = {}
def _load_chat_config() -> None:
"""加载聊天配置到全局变量"""
global _chat_config
chat_config_path = Path(CHAT_DIR) / CHAT_CONFIG_FILENAME
with open(chat_config_path, "r", encoding="utf-8") as file:
_chat_config = yaml.safe_load(file)
def get_maxtokens_by_model(model: str) -> int:
# 如果配置还没有加载,则加载配置
if not _chat_config:
_load_chat_config()
# 默认值设置为1024
default_max_tokens = 1024
# 检查模型是否在配置中
if model in _chat_config.get("models", {}):
# 如果模型存在尝试获取max_tokens如果不存在则返回默认值
return _chat_config["models"][model].get("max_tokens", default_max_tokens)
else:
# 如果模型不在配置中,返回默认值
return default_max_tokens
2024-02-08 13:19:24 +08:00
def chat_completion_stream_commit(
messages: List[Dict], # [{"role": "user", "content": "hello"}]
llm_config: Dict, # {"model": "...", ...}
):
"""
This function is used to commit chat completion stream
"""
proxy_url = os.environ.get("DEVCHAT_PROXY", "")
proxy_setting = {"proxy": {"https://": proxy_url, "http://": proxy_url}} if proxy_url else {}
# Initialize OpenAI client with API key, base URL and http client
2024-02-08 13:19:24 +08:00
client = openai.OpenAI(
api_key=os.environ.get("OPENAI_API_KEY", None),
base_url=os.environ.get("OPENAI_API_BASE", None),
http_client=httpx.Client(**proxy_setting, trust_env=False),
2024-02-08 13:19:24 +08:00
)
# Update llm_config dictionary
2024-02-08 13:19:24 +08:00
llm_config["stream"] = True
llm_config["timeout"] = 60
2024-11-12 11:58:54 +08:00
llm_config["max_tokens"] = get_maxtokens_by_model(llm_config["model"])
# Return chat completions
2024-02-08 13:19:24 +08:00
return client.chat.completions.create(messages=messages, **llm_config)
def chat_completion_stream_raw(**kwargs):
"""
This function is used to get raw chat completion stream
"""
proxy_url = os.environ.get("DEVCHAT_PROXY", "")
proxy_setting = {"proxy": {"https://": proxy_url, "http://": proxy_url}} if proxy_url else {}
# Initialize OpenAI client with API key, base URL and http client
2024-02-08 13:19:24 +08:00
client = openai.OpenAI(
api_key=os.environ.get("OPENAI_API_KEY", None),
base_url=os.environ.get("OPENAI_API_BASE", None),
http_client=httpx.Client(**proxy_setting, trust_env=False),
2024-02-08 13:19:24 +08:00
)
# Update kwargs dictionary
2024-02-08 13:19:24 +08:00
kwargs["stream"] = True
kwargs["timeout"] = 60
2024-11-12 11:58:54 +08:00
kwargs["max_tokens"] = get_maxtokens_by_model(kwargs["model"])
# Return chat completions
2024-02-08 13:19:24 +08:00
return client.chat.completions.create(**kwargs)
def stream_out_chunk(chunks):
"""
This function is used to print out chunks of data
"""
2024-02-08 13:19:24 +08:00
for chunk in chunks:
chunk_dict = chunk.dict()
if len(chunk_dict["choices"]) > 0:
delta = chunk_dict["choices"][0]["delta"]
if delta.get("content", None):
print(delta["content"], end="", flush=True)
yield chunk
2024-02-08 13:19:24 +08:00
def retry_timeout(chunks):
"""
This function is used to handle timeout errors
"""
2024-02-08 13:19:24 +08:00
try:
for chunk in chunks:
yield chunk
except (openai.APIConnectionError, openai.APITimeoutError) as err:
IDEService().ide_logging("info", f"in retry_timeout: err: {err}")
2024-02-08 13:19:24 +08:00
raise RetryException(err) from err
def chunk_list(chunks):
"""
This function is used to convert chunks into a list
"""
return [chunk for chunk in chunks]
2024-02-08 13:19:24 +08:00
def chunks_content(chunks):
"""
This function is used to extract content from chunks
"""
2024-02-08 13:19:24 +08:00
content = None
for chunk in chunks:
chunk_dict = chunk.dict()
if len(chunk_dict["choices"]) > 0:
delta = chunk_dict["choices"][0]["delta"]
if delta.get("content", None):
if content is None:
content = ""
content += delta["content"]
2024-02-08 13:19:24 +08:00
return content
def chunks_call(chunks):
"""
This function is used to extract tool
calls from chunks
"""
2024-02-08 13:19:24 +08:00
tool_calls = []
for chunk in chunks:
chunk = chunk.dict()
if len(chunk["choices"]) > 0:
delta = chunk["choices"][0]["delta"]
if "tool_calls" in delta and delta["tool_calls"]:
tool_call = delta["tool_calls"][0]["function"]
if delta["tool_calls"][0].get("index", None) is not None:
index = delta["tool_calls"][0]["index"]
if index >= len(tool_calls):
tool_calls.append({"name": None, "arguments": ""})
if tool_call.get("name", None):
tool_calls[-1]["name"] = tool_call["name"]
if tool_call.get("arguments", None):
tool_calls[-1]["arguments"] += tool_call["arguments"]
2024-02-08 13:19:24 +08:00
return tool_calls
def content_to_json(content):
"""
This function is used to convert content to JSON
"""
2024-02-08 13:19:24 +08:00
try:
content_no_block = _try_remove_markdown_block_flag(content)
response_obj = json.loads(content_no_block, strict=False)
2024-02-08 13:19:24 +08:00
return response_obj
except json.JSONDecodeError as err:
IDEService().ide_logging("debug", f"Receive content: {content}")
IDEService().ide_logging("debug", f"in content_to_json: json decode error: {err}")
2024-02-08 13:19:24 +08:00
raise RetryException(err) from err
except Exception as err:
IDEService().ide_logging("debug", f"in content_to_json: other error: {err}")
2024-02-08 13:19:24 +08:00
raise err
def to_dict_content_and_call(content, tool_calls=None):
"""
This function is used to convert content and tool calls to a dictionary
"""
2024-02-08 13:19:24 +08:00
if tool_calls is None:
tool_calls = []
return {
"content": content,
"function_name": tool_calls[0]["name"] if tool_calls else None,
"parameters": tool_calls[0]["arguments"] if tool_calls else "",
"tool_calls": tool_calls,
}
# Define a pipeline function for chat completion content.
# This pipeline first commits a chat completion stream, handles any timeout errors,
# and then extracts the content from the chunks.
# If any step in the pipeline fails, it will retry the entire pipeline up to 3 times.
2024-02-08 13:19:24 +08:00
chat_completion_content = retry(
pipeline(chat_completion_stream_commit, retry_timeout, chunks_content), times=3
)
# Define a pipeline function for chat completion stream content.
# This pipeline first commits a chat completion stream, handles any timeout errors,
# streams out the chunk, and then extracts the content from the chunks.
# If any step in the pipeline fails, it will retry the entire pipeline up to 3 times.
2024-02-08 13:19:24 +08:00
chat_completion_stream_content = retry(
pipeline(chat_completion_stream_commit, retry_timeout, stream_out_chunk, chunks_content),
times=3,
)
# Define a pipeline function for chat completion call.
# This pipeline first commits a chat completion stream, handles any timeout errors,
# and then extracts the tool calls from the chunks.
# If any step in the pipeline fails, it will retry the entire pipeline up to 3 times.
2024-02-08 13:19:24 +08:00
chat_completion_call = retry(
pipeline(chat_completion_stream_commit, retry_timeout, chunks_call), times=3
)
# Define a pipeline function for chat completion without streaming and return a JSON object.
# This pipeline first commits a chat completion stream, handles any timeout errors, extracts
# the content from the chunks and then converts the content to JSON.
# If any step in the pipeline fails, it will retry the entire pipeline up to 3 times.
# If a JSONDecodeError is encountered during the content to JSON conversion, it will log the
# error and retry the pipeline.
# If any other exception is encountered, it will log the error and raise it.
chat_completion_no_stream_return_json_with_retry = exception_handle(
2024-02-08 13:19:24 +08:00
retry(
pipeline(chat_completion_stream_commit, retry_timeout, chunks_content, content_to_json),
times=3,
),
None,
2024-02-08 13:19:24 +08:00
)
def chat_completion_no_stream_return_json(messages: List[Dict], llm_config: Dict):
"""
This function is used to get chat completion without streaming and return JSON object
"""
llm_config["response_format"] = {"type": "json_object"}
return chat_completion_no_stream_return_json_with_retry(
messages=messages, llm_config=llm_config
)
# Define a pipeline function for chat completion stream.
# This pipeline first commits a chat completion stream, handles any timeout errors,
# extracts the content from the chunks, and then converts the content and tool calls
# to a dictionary.
# If any step in the pipeline fails, it will retry the entire pipeline up to 3 times.
# If an exception is encountered, it will return a dictionary with None values and the error.
2024-02-08 13:19:24 +08:00
chat_completion_stream = exception_handle(
retry(
pipeline(
chat_completion_stream_commit,
retry_timeout,
chunks_content,
to_dict_content_and_call,
),
times=3,
),
None,
2024-02-08 13:19:24 +08:00
)
# Define a pipeline function for chat call completion stream.
# This pipeline first commits a chat completion stream, handles any timeout errors,
# converts the chunks to a list, extracts the content and tool calls from the chunks
# in parallel, and then converts the content and tool calls to a dictionary.
# If any step in the pipeline fails, it will retry the entire pipeline up to 3 times.
# If an exception is encountered, it will return a dictionary with None values, an empty
# tool calls list, and the error.
2024-02-08 13:19:24 +08:00
chat_call_completion_stream = exception_handle(
retry(
pipeline(
chat_completion_stream_commit,
retry_timeout,
chunk_list,
parallel(chunks_content, chunks_call),
to_dict_content_and_call,
),
times=3,
),
None,
2024-02-08 13:19:24 +08:00
)