From 8dd91e1c57752be181e6e4c814089850f3f0ebd8 Mon Sep 17 00:00:00 2001 From: "bobo.yang" Date: Tue, 12 Nov 2024 11:58:54 +0800 Subject: [PATCH] add max_tokens for llm call --- site-packages/devchat/llm/openai.py | 36 ++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/site-packages/devchat/llm/openai.py b/site-packages/devchat/llm/openai.py index 60b5513..4404435 100644 --- a/site-packages/devchat/llm/openai.py +++ b/site-packages/devchat/llm/openai.py @@ -7,12 +7,15 @@ openai api utils import json import os import re -from typing import Dict, List +from pathlib import Path +from typing import Any, Dict, List import httpx import openai +import oyaml as yaml from devchat.ide import IDEService +from devchat.workflow.path import CHAT_CONFIG_FILENAME, CHAT_DIR from .pipeline import ( RetryException, # Import RetryException class @@ -42,6 +45,35 @@ def _try_remove_markdown_block_flag(content): return content +# 模块级变量用于缓存配置 +_chat_config: Dict[str, Any] = {} + + +def _load_chat_config() -> None: + """加载聊天配置到全局变量""" + global _chat_config + chat_config_path = Path(CHAT_DIR) / CHAT_CONFIG_FILENAME + with open(chat_config_path, "r", encoding="utf-8") as file: + _chat_config = yaml.safe_load(file) + + +def get_maxtokens_by_model(model: str) -> int: + # 如果配置还没有加载,则加载配置 + if not _chat_config: + _load_chat_config() + + # 默认值设置为1024 + default_max_tokens = 1024 + + # 检查模型是否在配置中 + if model in _chat_config.get("models", {}): + # 如果模型存在,尝试获取max_tokens,如果不存在则返回默认值 + return _chat_config["models"][model].get("max_tokens", default_max_tokens) + else: + # 如果模型不在配置中,返回默认值 + return default_max_tokens + + def chat_completion_stream_commit( messages: List[Dict], # [{"role": "user", "content": "hello"}] llm_config: Dict, # {"model": "...", ...} @@ -62,6 +94,7 @@ def chat_completion_stream_commit( # Update llm_config dictionary llm_config["stream"] = True llm_config["timeout"] = 60 + llm_config["max_tokens"] = get_maxtokens_by_model(llm_config["model"]) # Return chat completions return client.chat.completions.create(messages=messages, **llm_config) @@ -83,6 +116,7 @@ def chat_completion_stream_raw(**kwargs): # Update kwargs dictionary kwargs["stream"] = True kwargs["timeout"] = 60 + kwargs["max_tokens"] = get_maxtokens_by_model(kwargs["model"]) # Return chat completions return client.chat.completions.create(**kwargs)