add max_tokens for llm call
This commit is contained in:
parent
afa5d281b6
commit
8dd91e1c57
@ -7,12 +7,15 @@ openai api utils
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
|
import oyaml as yaml
|
||||||
|
|
||||||
from devchat.ide import IDEService
|
from devchat.ide import IDEService
|
||||||
|
from devchat.workflow.path import CHAT_CONFIG_FILENAME, CHAT_DIR
|
||||||
|
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
RetryException, # Import RetryException class
|
RetryException, # Import RetryException class
|
||||||
@ -42,6 +45,35 @@ def _try_remove_markdown_block_flag(content):
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
# 模块级变量用于缓存配置
|
||||||
|
_chat_config: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_chat_config() -> None:
|
||||||
|
"""加载聊天配置到全局变量"""
|
||||||
|
global _chat_config
|
||||||
|
chat_config_path = Path(CHAT_DIR) / CHAT_CONFIG_FILENAME
|
||||||
|
with open(chat_config_path, "r", encoding="utf-8") as file:
|
||||||
|
_chat_config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
|
||||||
|
def get_maxtokens_by_model(model: str) -> int:
|
||||||
|
# 如果配置还没有加载,则加载配置
|
||||||
|
if not _chat_config:
|
||||||
|
_load_chat_config()
|
||||||
|
|
||||||
|
# 默认值设置为1024
|
||||||
|
default_max_tokens = 1024
|
||||||
|
|
||||||
|
# 检查模型是否在配置中
|
||||||
|
if model in _chat_config.get("models", {}):
|
||||||
|
# 如果模型存在,尝试获取max_tokens,如果不存在则返回默认值
|
||||||
|
return _chat_config["models"][model].get("max_tokens", default_max_tokens)
|
||||||
|
else:
|
||||||
|
# 如果模型不在配置中,返回默认值
|
||||||
|
return default_max_tokens
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_stream_commit(
|
def chat_completion_stream_commit(
|
||||||
messages: List[Dict], # [{"role": "user", "content": "hello"}]
|
messages: List[Dict], # [{"role": "user", "content": "hello"}]
|
||||||
llm_config: Dict, # {"model": "...", ...}
|
llm_config: Dict, # {"model": "...", ...}
|
||||||
@ -62,6 +94,7 @@ def chat_completion_stream_commit(
|
|||||||
# Update llm_config dictionary
|
# Update llm_config dictionary
|
||||||
llm_config["stream"] = True
|
llm_config["stream"] = True
|
||||||
llm_config["timeout"] = 60
|
llm_config["timeout"] = 60
|
||||||
|
llm_config["max_tokens"] = get_maxtokens_by_model(llm_config["model"])
|
||||||
# Return chat completions
|
# Return chat completions
|
||||||
return client.chat.completions.create(messages=messages, **llm_config)
|
return client.chat.completions.create(messages=messages, **llm_config)
|
||||||
|
|
||||||
@ -83,6 +116,7 @@ def chat_completion_stream_raw(**kwargs):
|
|||||||
# Update kwargs dictionary
|
# Update kwargs dictionary
|
||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
kwargs["timeout"] = 60
|
kwargs["timeout"] = 60
|
||||||
|
kwargs["max_tokens"] = get_maxtokens_by_model(kwargs["model"])
|
||||||
# Return chat completions
|
# Return chat completions
|
||||||
return client.chat.completions.create(**kwargs)
|
return client.chat.completions.create(**kwargs)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user