Merge pull request #121 from devchat-ai/update_model_context_size_for_pr

Refactor config_util.py to handle gitlab host input
This commit is contained in:
kagami 2024-05-24 13:33:23 +00:00 committed by GitHub
commit 379f356d15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 15 deletions

View File

@ -8,15 +8,12 @@ import os
import sys
# add the current directory to the path
from os.path import abspath, dirname
from lib.ide_service import IDEService
sys.path.append(dirname(dirname(abspath(__file__))))
# add new model configs to algo.MAX_TOKENS
import pr_agent.algo as algo
from lib.ide_service import IDEService
from merico.pr.config_util import get_model_max_input
algo.MAX_TOKENS["gpt-4-turbo-preview"] = 128000
algo.MAX_TOKENS["claude-3-opus"] = 100000
algo.MAX_TOKENS["claude-3-sonnet"] = 100000
@ -42,12 +39,9 @@ algo.MAX_TOKENS["BAAI/bge-large-en-v1.5"] = 512
algo.MAX_TOKENS["BAAI/bge-base-en-v1.5"] = 512
algo.MAX_TOKENS["sentence-transformers/msmarco-bert-base-dot-v5"] = 512
algo.MAX_TOKENS["bert-base-uncased"] = 512
if os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106") not in algo.MAX_TOKENS:
current_model = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")
IDEService().ide_logging(
"info", f"{current_model}'s max tokens is not config, we use it as default 16000"
)
algo.MAX_TOKENS[os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")] = 16000
current_model = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")
algo.MAX_TOKENS[current_model] = get_model_max_input(current_model)
# add new git provider
@ -62,7 +56,8 @@ def get_git_provider():
import pr_agent.git_providers as git_providers
from providers.devchat_provider import DevChatProvider
from merico.pr.providers.devchat_provider import DevChatProvider
git_providers._GIT_PROVIDERS["devchat"] = DevChatProvider
_get_git_provider_old = git_providers.get_git_provider
@ -103,8 +98,12 @@ logger.add(
)
from config_util import get_gitlab_host, get_repo_type, read_server_access_token_with_input
from custom_suggestions_config import get_custom_suggestions_system_prompt
from merico.pr.config_util import (
get_gitlab_host,
get_repo_type,
read_server_access_token_with_input,
)
from merico.pr.custom_suggestions_config import get_custom_suggestions_system_prompt
# set openai key and api base
get_settings().set("OPENAI.KEY", os.environ.get("OPENAI_API_KEY", ""))

View File

@ -1,6 +1,8 @@
import json
import os
import yaml
from lib.chatmark import Radio, TextEditor
@ -224,3 +226,17 @@ def get_gitlab_host(pr_url):
gitlab_host_map[pr_host] = host
_save_config_value("gitlab_host_map", gitlab_host_map)
return host
def get_model_max_input(model):
config_file = os.path.expanduser("~/.chat/config.yml")
try:
with open(config_file, "r", encoding="utf-8") as file:
yaml_contents = file.read()
parsed_yaml = yaml.safe_load(yaml_contents)
for model_t in parsed_yaml.get("models", {}):
if model_t == model:
return parsed_yaml["models"][model_t].get("max_input_tokens", 6000)
return 6000
except Exception:
return 6000