2024-09-23 06:27:56 +00:00
|
|
|
|
"""
|
|
|
|
|
openai api utils
|
|
|
|
|
"""
|
|
|
|
|
|
2024-02-08 13:19:24 +08:00
|
|
|
|
# flake8: noqa: E402
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Import necessary libraries
|
2024-02-08 13:19:24 +08:00
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import re
|
2024-11-12 11:58:54 +08:00
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Any, Dict, List
|
2024-02-08 13:19:24 +08:00
|
|
|
|
|
2024-05-08 07:31:22 +00:00
|
|
|
|
import httpx
|
2024-02-08 13:19:24 +08:00
|
|
|
|
import openai
|
2024-11-12 11:58:54 +08:00
|
|
|
|
import oyaml as yaml
|
2024-02-08 13:19:24 +08:00
|
|
|
|
|
2024-05-18 07:06:58 +00:00
|
|
|
|
from devchat.ide import IDEService
|
2024-11-12 11:58:54 +08:00
|
|
|
|
from devchat.workflow.path import CHAT_CONFIG_FILENAME, CHAT_DIR
|
2024-05-18 07:06:58 +00:00
|
|
|
|
|
2024-02-08 13:19:24 +08:00
|
|
|
|
from .pipeline import (
|
2024-09-23 06:27:56 +00:00
|
|
|
|
RetryException, # Import RetryException class
|
|
|
|
|
exception_handle, # Function to handle exceptions
|
|
|
|
|
parallel, # Function to run tasks in parallel
|
|
|
|
|
pipeline, # Function to create a pipeline of tasks
|
|
|
|
|
retry, # Function to retry a task
|
2024-02-08 13:19:24 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _try_remove_markdown_block_flag(content):
|
|
|
|
|
"""
|
2024-09-23 06:27:56 +00:00
|
|
|
|
If the content is a markdown block, this function removes the header ```xxx and footer ```
|
2024-02-08 13:19:24 +08:00
|
|
|
|
"""
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Define a regex pattern to match the header and footer of a markdown block
|
2024-02-08 13:19:24 +08:00
|
|
|
|
pattern = r"^\s*```\s*(\w+)\s*\n(.*?)\n\s*```\s*$"
|
|
|
|
|
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Use the re module to match the pattern
|
2024-02-08 13:19:24 +08:00
|
|
|
|
match = re.search(pattern, content, re.DOTALL | re.MULTILINE)
|
|
|
|
|
|
|
|
|
|
if match:
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# If a match is found, extract the content of the markdown block and return it
|
2024-02-08 13:19:24 +08:00
|
|
|
|
_ = match.group(1) # language
|
|
|
|
|
markdown_content = match.group(2)
|
|
|
|
|
return markdown_content.strip()
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# If no match is found, return the original content
|
2024-02-08 13:19:24 +08:00
|
|
|
|
return content
|
|
|
|
|
|
|
|
|
|
|
2024-11-12 11:58:54 +08:00
|
|
|
|
# 模块级变量用于缓存配置
|
|
|
|
|
_chat_config: Dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_chat_config() -> None:
|
|
|
|
|
"""加载聊天配置到全局变量"""
|
|
|
|
|
global _chat_config
|
|
|
|
|
chat_config_path = Path(CHAT_DIR) / CHAT_CONFIG_FILENAME
|
|
|
|
|
with open(chat_config_path, "r", encoding="utf-8") as file:
|
|
|
|
|
_chat_config = yaml.safe_load(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_maxtokens_by_model(model: str) -> int:
|
|
|
|
|
# 如果配置还没有加载,则加载配置
|
|
|
|
|
if not _chat_config:
|
|
|
|
|
_load_chat_config()
|
|
|
|
|
|
|
|
|
|
# 默认值设置为1024
|
|
|
|
|
default_max_tokens = 1024
|
|
|
|
|
|
|
|
|
|
# 检查模型是否在配置中
|
|
|
|
|
if model in _chat_config.get("models", {}):
|
|
|
|
|
# 如果模型存在,尝试获取max_tokens,如果不存在则返回默认值
|
|
|
|
|
return _chat_config["models"][model].get("max_tokens", default_max_tokens)
|
|
|
|
|
else:
|
|
|
|
|
# 如果模型不在配置中,返回默认值
|
|
|
|
|
return default_max_tokens
|
|
|
|
|
|
|
|
|
|
|
2024-02-08 13:19:24 +08:00
|
|
|
|
def chat_completion_stream_commit(
|
|
|
|
|
messages: List[Dict], # [{"role": "user", "content": "hello"}]
|
|
|
|
|
llm_config: Dict, # {"model": "...", ...}
|
|
|
|
|
):
|
2024-09-23 06:27:56 +00:00
|
|
|
|
"""
|
|
|
|
|
This function is used to commit chat completion stream
|
|
|
|
|
"""
|
2024-05-08 07:31:22 +00:00
|
|
|
|
proxy_url = os.environ.get("DEVCHAT_PROXY", "")
|
2024-05-14 08:51:42 +00:00
|
|
|
|
proxy_setting = {"proxy": {"https://": proxy_url, "http://": proxy_url}} if proxy_url else {}
|
2024-05-08 07:31:22 +00:00
|
|
|
|
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Initialize OpenAI client with API key, base URL and http client
|
2024-02-08 13:19:24 +08:00
|
|
|
|
client = openai.OpenAI(
|
|
|
|
|
api_key=os.environ.get("OPENAI_API_KEY", None),
|
|
|
|
|
base_url=os.environ.get("OPENAI_API_BASE", None),
|
2024-05-14 08:51:42 +00:00
|
|
|
|
http_client=httpx.Client(**proxy_setting, trust_env=False),
|
2024-02-08 13:19:24 +08:00
|
|
|
|
)
|
|
|
|
|
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Update llm_config dictionary
|
2024-02-08 13:19:24 +08:00
|
|
|
|
llm_config["stream"] = True
|
|
|
|
|
llm_config["timeout"] = 60
|
2024-11-12 11:58:54 +08:00
|
|
|
|
llm_config["max_tokens"] = get_maxtokens_by_model(llm_config["model"])
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Return chat completions
|
2024-02-08 13:19:24 +08:00
|
|
|
|
return client.chat.completions.create(messages=messages, **llm_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chat_completion_stream_raw(**kwargs):
|
2024-09-23 06:27:56 +00:00
|
|
|
|
"""
|
|
|
|
|
This function is used to get raw chat completion stream
|
|
|
|
|
"""
|
2024-05-08 07:31:22 +00:00
|
|
|
|
proxy_url = os.environ.get("DEVCHAT_PROXY", "")
|
2024-05-14 08:51:42 +00:00
|
|
|
|
proxy_setting = {"proxy": {"https://": proxy_url, "http://": proxy_url}} if proxy_url else {}
|
2024-05-08 07:31:22 +00:00
|
|
|
|
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Initialize OpenAI client with API key, base URL and http client
|
2024-02-08 13:19:24 +08:00
|
|
|
|
client = openai.OpenAI(
|
|
|
|
|
api_key=os.environ.get("OPENAI_API_KEY", None),
|
|
|
|
|
base_url=os.environ.get("OPENAI_API_BASE", None),
|
2024-05-14 08:51:42 +00:00
|
|
|
|
http_client=httpx.Client(**proxy_setting, trust_env=False),
|
2024-02-08 13:19:24 +08:00
|
|
|
|
)
|
|
|
|
|
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Update kwargs dictionary
|
2024-02-08 13:19:24 +08:00
|
|
|
|
kwargs["stream"] = True
|
|
|
|
|
kwargs["timeout"] = 60
|
2024-11-12 11:58:54 +08:00
|
|
|
|
kwargs["max_tokens"] = get_maxtokens_by_model(kwargs["model"])
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Return chat completions
|
2024-02-08 13:19:24 +08:00
|
|
|
|
return client.chat.completions.create(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stream_out_chunk(chunks):
|
2024-09-23 06:27:56 +00:00
|
|
|
|
"""
|
|
|
|
|
This function is used to print out chunks of data
|
|
|
|
|
"""
|
2024-02-08 13:19:24 +08:00
|
|
|
|
for chunk in chunks:
|
|
|
|
|
chunk_dict = chunk.dict()
|
2024-07-19 11:37:26 +00:00
|
|
|
|
if len(chunk_dict["choices"]) > 0:
|
|
|
|
|
delta = chunk_dict["choices"][0]["delta"]
|
|
|
|
|
if delta.get("content", None):
|
|
|
|
|
print(delta["content"], end="", flush=True)
|
|
|
|
|
yield chunk
|
2024-02-08 13:19:24 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retry_timeout(chunks):
|
2024-09-23 06:27:56 +00:00
|
|
|
|
"""
|
|
|
|
|
This function is used to handle timeout errors
|
|
|
|
|
"""
|
2024-02-08 13:19:24 +08:00
|
|
|
|
try:
|
|
|
|
|
for chunk in chunks:
|
|
|
|
|
yield chunk
|
|
|
|
|
except (openai.APIConnectionError, openai.APITimeoutError) as err:
|
2024-05-18 07:06:58 +00:00
|
|
|
|
IDEService().ide_logging("info", f"in retry_timeout: err: {err}")
|
2024-02-08 13:19:24 +08:00
|
|
|
|
raise RetryException(err) from err
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chunk_list(chunks):
|
2024-09-23 06:27:56 +00:00
|
|
|
|
"""
|
|
|
|
|
This function is used to convert chunks into a list
|
|
|
|
|
"""
|
2024-05-14 08:51:42 +00:00
|
|
|
|
return [chunk for chunk in chunks]
|
2024-02-08 13:19:24 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chunks_content(chunks):
|
2024-09-23 06:27:56 +00:00
|
|
|
|
"""
|
|
|
|
|
This function is used to extract content from chunks
|
|
|
|
|
"""
|
2024-02-08 13:19:24 +08:00
|
|
|
|
content = None
|
|
|
|
|
for chunk in chunks:
|
|
|
|
|
chunk_dict = chunk.dict()
|
2024-07-19 11:37:26 +00:00
|
|
|
|
if len(chunk_dict["choices"]) > 0:
|
|
|
|
|
delta = chunk_dict["choices"][0]["delta"]
|
|
|
|
|
if delta.get("content", None):
|
|
|
|
|
if content is None:
|
|
|
|
|
content = ""
|
|
|
|
|
content += delta["content"]
|
2024-02-08 13:19:24 +08:00
|
|
|
|
return content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chunks_call(chunks):
|
2024-09-23 06:27:56 +00:00
|
|
|
|
"""
|
|
|
|
|
This function is used to extract tool
|
|
|
|
|
calls from chunks
|
|
|
|
|
"""
|
2024-02-08 13:19:24 +08:00
|
|
|
|
tool_calls = []
|
|
|
|
|
|
|
|
|
|
for chunk in chunks:
|
|
|
|
|
chunk = chunk.dict()
|
2024-07-19 11:37:26 +00:00
|
|
|
|
if len(chunk["choices"]) > 0:
|
|
|
|
|
delta = chunk["choices"][0]["delta"]
|
|
|
|
|
if "tool_calls" in delta and delta["tool_calls"]:
|
|
|
|
|
tool_call = delta["tool_calls"][0]["function"]
|
|
|
|
|
if delta["tool_calls"][0].get("index", None) is not None:
|
|
|
|
|
index = delta["tool_calls"][0]["index"]
|
|
|
|
|
if index >= len(tool_calls):
|
|
|
|
|
tool_calls.append({"name": None, "arguments": ""})
|
|
|
|
|
if tool_call.get("name", None):
|
|
|
|
|
tool_calls[-1]["name"] = tool_call["name"]
|
|
|
|
|
if tool_call.get("arguments", None):
|
|
|
|
|
tool_calls[-1]["arguments"] += tool_call["arguments"]
|
2024-02-08 13:19:24 +08:00
|
|
|
|
return tool_calls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def content_to_json(content):
|
2024-09-23 06:27:56 +00:00
|
|
|
|
"""
|
|
|
|
|
This function is used to convert content to JSON
|
|
|
|
|
"""
|
2024-02-08 13:19:24 +08:00
|
|
|
|
try:
|
2024-05-08 13:30:30 +00:00
|
|
|
|
content_no_block = _try_remove_markdown_block_flag(content)
|
2024-09-23 06:27:56 +00:00
|
|
|
|
response_obj = json.loads(content_no_block, strict=False)
|
2024-02-08 13:19:24 +08:00
|
|
|
|
return response_obj
|
|
|
|
|
except json.JSONDecodeError as err:
|
2024-09-23 06:27:56 +00:00
|
|
|
|
IDEService().ide_logging("debug", f"Receive content: {content}")
|
|
|
|
|
IDEService().ide_logging("debug", f"in content_to_json: json decode error: {err}")
|
2024-02-08 13:19:24 +08:00
|
|
|
|
raise RetryException(err) from err
|
|
|
|
|
except Exception as err:
|
2024-09-23 06:27:56 +00:00
|
|
|
|
IDEService().ide_logging("debug", f"in content_to_json: other error: {err}")
|
2024-02-08 13:19:24 +08:00
|
|
|
|
raise err
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_dict_content_and_call(content, tool_calls=None):
|
2024-09-23 06:27:56 +00:00
|
|
|
|
"""
|
|
|
|
|
This function is used to convert content and tool calls to a dictionary
|
|
|
|
|
"""
|
2024-02-08 13:19:24 +08:00
|
|
|
|
if tool_calls is None:
|
|
|
|
|
tool_calls = []
|
|
|
|
|
return {
|
|
|
|
|
"content": content,
|
|
|
|
|
"function_name": tool_calls[0]["name"] if tool_calls else None,
|
|
|
|
|
"parameters": tool_calls[0]["arguments"] if tool_calls else "",
|
|
|
|
|
"tool_calls": tool_calls,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Define a pipeline function for chat completion content.
|
|
|
|
|
# This pipeline first commits a chat completion stream, handles any timeout errors,
|
|
|
|
|
# and then extracts the content from the chunks.
|
|
|
|
|
# If any step in the pipeline fails, it will retry the entire pipeline up to 3 times.
|
2024-02-08 13:19:24 +08:00
|
|
|
|
chat_completion_content = retry(
|
|
|
|
|
pipeline(chat_completion_stream_commit, retry_timeout, chunks_content), times=3
|
|
|
|
|
)
|
|
|
|
|
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Define a pipeline function for chat completion stream content.
|
|
|
|
|
# This pipeline first commits a chat completion stream, handles any timeout errors,
|
|
|
|
|
# streams out the chunk, and then extracts the content from the chunks.
|
|
|
|
|
# If any step in the pipeline fails, it will retry the entire pipeline up to 3 times.
|
2024-02-08 13:19:24 +08:00
|
|
|
|
chat_completion_stream_content = retry(
|
|
|
|
|
pipeline(chat_completion_stream_commit, retry_timeout, stream_out_chunk, chunks_content),
|
|
|
|
|
times=3,
|
|
|
|
|
)
|
|
|
|
|
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Define a pipeline function for chat completion call.
|
|
|
|
|
# This pipeline first commits a chat completion stream, handles any timeout errors,
|
|
|
|
|
# and then extracts the tool calls from the chunks.
|
|
|
|
|
# If any step in the pipeline fails, it will retry the entire pipeline up to 3 times.
|
2024-02-08 13:19:24 +08:00
|
|
|
|
chat_completion_call = retry(
|
|
|
|
|
pipeline(chat_completion_stream_commit, retry_timeout, chunks_call), times=3
|
|
|
|
|
)
|
|
|
|
|
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Define a pipeline function for chat completion without streaming and return a JSON object.
|
|
|
|
|
# This pipeline first commits a chat completion stream, handles any timeout errors, extracts
|
|
|
|
|
# the content from the chunks and then converts the content to JSON.
|
|
|
|
|
# If any step in the pipeline fails, it will retry the entire pipeline up to 3 times.
|
|
|
|
|
# If a JSONDecodeError is encountered during the content to JSON conversion, it will log the
|
|
|
|
|
# error and retry the pipeline.
|
|
|
|
|
# If any other exception is encountered, it will log the error and raise it.
|
2024-05-07 06:41:42 +00:00
|
|
|
|
chat_completion_no_stream_return_json_with_retry = exception_handle(
|
2024-02-08 13:19:24 +08:00
|
|
|
|
retry(
|
|
|
|
|
pipeline(chat_completion_stream_commit, retry_timeout, chunks_content, content_to_json),
|
|
|
|
|
times=3,
|
|
|
|
|
),
|
2024-09-23 06:27:56 +00:00
|
|
|
|
None,
|
2024-02-08 13:19:24 +08:00
|
|
|
|
)
|
|
|
|
|
|
2024-05-14 08:51:42 +00:00
|
|
|
|
|
|
|
|
|
def chat_completion_no_stream_return_json(messages: List[Dict], llm_config: Dict):
|
2024-09-23 06:27:56 +00:00
|
|
|
|
"""
|
|
|
|
|
This function is used to get chat completion without streaming and return JSON object
|
|
|
|
|
"""
|
2024-05-14 08:51:42 +00:00
|
|
|
|
llm_config["response_format"] = {"type": "json_object"}
|
2024-05-07 06:41:42 +00:00
|
|
|
|
return chat_completion_no_stream_return_json_with_retry(
|
2024-05-14 08:51:42 +00:00
|
|
|
|
messages=messages, llm_config=llm_config
|
|
|
|
|
)
|
2024-05-07 06:41:42 +00:00
|
|
|
|
|
|
|
|
|
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Define a pipeline function for chat completion stream.
|
|
|
|
|
# This pipeline first commits a chat completion stream, handles any timeout errors,
|
|
|
|
|
# extracts the content from the chunks, and then converts the content and tool calls
|
|
|
|
|
# to a dictionary.
|
|
|
|
|
# If any step in the pipeline fails, it will retry the entire pipeline up to 3 times.
|
|
|
|
|
# If an exception is encountered, it will return a dictionary with None values and the error.
|
2024-02-08 13:19:24 +08:00
|
|
|
|
chat_completion_stream = exception_handle(
|
|
|
|
|
retry(
|
|
|
|
|
pipeline(
|
|
|
|
|
chat_completion_stream_commit,
|
|
|
|
|
retry_timeout,
|
|
|
|
|
chunks_content,
|
|
|
|
|
to_dict_content_and_call,
|
|
|
|
|
),
|
|
|
|
|
times=3,
|
|
|
|
|
),
|
2024-09-23 06:27:56 +00:00
|
|
|
|
None,
|
2024-02-08 13:19:24 +08:00
|
|
|
|
)
|
|
|
|
|
|
2024-09-23 06:27:56 +00:00
|
|
|
|
# Define a pipeline function for chat call completion stream.
|
|
|
|
|
# This pipeline first commits a chat completion stream, handles any timeout errors,
|
|
|
|
|
# converts the chunks to a list, extracts the content and tool calls from the chunks
|
|
|
|
|
# in parallel, and then converts the content and tool calls to a dictionary.
|
|
|
|
|
# If any step in the pipeline fails, it will retry the entire pipeline up to 3 times.
|
|
|
|
|
# If an exception is encountered, it will return a dictionary with None values, an empty
|
|
|
|
|
# tool calls list, and the error.
|
2024-02-08 13:19:24 +08:00
|
|
|
|
chat_call_completion_stream = exception_handle(
|
|
|
|
|
retry(
|
|
|
|
|
pipeline(
|
|
|
|
|
chat_completion_stream_commit,
|
|
|
|
|
retry_timeout,
|
|
|
|
|
chunk_list,
|
|
|
|
|
parallel(chunks_content, chunks_call),
|
|
|
|
|
to_dict_content_and_call,
|
|
|
|
|
),
|
|
|
|
|
times=3,
|
|
|
|
|
),
|
2024-09-23 06:27:56 +00:00
|
|
|
|
None,
|
2024-02-08 13:19:24 +08:00
|
|
|
|
)
|