add max_tokens for llm call
This commit is contained in:
parent
afa5d281b6
commit
8dd91e1c57
@ -7,12 +7,15 @@ openai api utils
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
import oyaml as yaml
|
||||
|
||||
from devchat.ide import IDEService
|
||||
from devchat.workflow.path import CHAT_CONFIG_FILENAME, CHAT_DIR
|
||||
|
||||
from .pipeline import (
|
||||
RetryException, # Import RetryException class
|
||||
@ -42,6 +45,35 @@ def _try_remove_markdown_block_flag(content):
|
||||
return content
|
||||
|
||||
|
||||
# 模块级变量用于缓存配置
|
||||
_chat_config: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _load_chat_config() -> None:
|
||||
"""加载聊天配置到全局变量"""
|
||||
global _chat_config
|
||||
chat_config_path = Path(CHAT_DIR) / CHAT_CONFIG_FILENAME
|
||||
with open(chat_config_path, "r", encoding="utf-8") as file:
|
||||
_chat_config = yaml.safe_load(file)
|
||||
|
||||
|
||||
def get_maxtokens_by_model(model: str) -> int:
|
||||
# 如果配置还没有加载,则加载配置
|
||||
if not _chat_config:
|
||||
_load_chat_config()
|
||||
|
||||
# 默认值设置为1024
|
||||
default_max_tokens = 1024
|
||||
|
||||
# 检查模型是否在配置中
|
||||
if model in _chat_config.get("models", {}):
|
||||
# 如果模型存在,尝试获取max_tokens,如果不存在则返回默认值
|
||||
return _chat_config["models"][model].get("max_tokens", default_max_tokens)
|
||||
else:
|
||||
# 如果模型不在配置中,返回默认值
|
||||
return default_max_tokens
|
||||
|
||||
|
||||
def chat_completion_stream_commit(
|
||||
messages: List[Dict], # [{"role": "user", "content": "hello"}]
|
||||
llm_config: Dict, # {"model": "...", ...}
|
||||
@ -62,6 +94,7 @@ def chat_completion_stream_commit(
|
||||
# Update llm_config dictionary
|
||||
llm_config["stream"] = True
|
||||
llm_config["timeout"] = 60
|
||||
llm_config["max_tokens"] = get_maxtokens_by_model(llm_config["model"])
|
||||
# Return chat completions
|
||||
return client.chat.completions.create(messages=messages, **llm_config)
|
||||
|
||||
@ -83,6 +116,7 @@ def chat_completion_stream_raw(**kwargs):
|
||||
# Update kwargs dictionary
|
||||
kwargs["stream"] = True
|
||||
kwargs["timeout"] = 60
|
||||
kwargs["max_tokens"] = get_maxtokens_by_model(kwargs["model"])
|
||||
# Return chat completions
|
||||
return client.chat.completions.create(**kwargs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user