Merge pull request #121 from devchat-ai/update_model_context_size_for_pr
Refactor config_util.py to handle gitlab host input
This commit is contained in:
commit
379f356d15
@ -8,15 +8,12 @@ import os
|
||||
import sys
|
||||
|
||||
# add the current directory to the path
|
||||
from os.path import abspath, dirname
|
||||
|
||||
from lib.ide_service import IDEService
|
||||
|
||||
sys.path.append(dirname(dirname(abspath(__file__))))
|
||||
|
||||
# add new model configs to algo.MAX_TOKENS
|
||||
import pr_agent.algo as algo
|
||||
|
||||
from lib.ide_service import IDEService
|
||||
from merico.pr.config_util import get_model_max_input
|
||||
|
||||
algo.MAX_TOKENS["gpt-4-turbo-preview"] = 128000
|
||||
algo.MAX_TOKENS["claude-3-opus"] = 100000
|
||||
algo.MAX_TOKENS["claude-3-sonnet"] = 100000
|
||||
@ -42,12 +39,9 @@ algo.MAX_TOKENS["BAAI/bge-large-en-v1.5"] = 512
|
||||
algo.MAX_TOKENS["BAAI/bge-base-en-v1.5"] = 512
|
||||
algo.MAX_TOKENS["sentence-transformers/msmarco-bert-base-dot-v5"] = 512
|
||||
algo.MAX_TOKENS["bert-base-uncased"] = 512
|
||||
if os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106") not in algo.MAX_TOKENS:
|
||||
current_model = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")
|
||||
IDEService().ide_logging(
|
||||
"info", f"{current_model}'s max tokens is not config, we use it as default 16000"
|
||||
)
|
||||
algo.MAX_TOKENS[os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")] = 16000
|
||||
|
||||
current_model = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")
|
||||
algo.MAX_TOKENS[current_model] = get_model_max_input(current_model)
|
||||
|
||||
|
||||
# add new git provider
|
||||
@ -62,7 +56,8 @@ def get_git_provider():
|
||||
|
||||
|
||||
import pr_agent.git_providers as git_providers
|
||||
from providers.devchat_provider import DevChatProvider
|
||||
|
||||
from merico.pr.providers.devchat_provider import DevChatProvider
|
||||
|
||||
git_providers._GIT_PROVIDERS["devchat"] = DevChatProvider
|
||||
_get_git_provider_old = git_providers.get_git_provider
|
||||
@ -103,8 +98,12 @@ logger.add(
|
||||
)
|
||||
|
||||
|
||||
from config_util import get_gitlab_host, get_repo_type, read_server_access_token_with_input
|
||||
from custom_suggestions_config import get_custom_suggestions_system_prompt
|
||||
from merico.pr.config_util import (
|
||||
get_gitlab_host,
|
||||
get_repo_type,
|
||||
read_server_access_token_with_input,
|
||||
)
|
||||
from merico.pr.custom_suggestions_config import get_custom_suggestions_system_prompt
|
||||
|
||||
# set openai key and api base
|
||||
get_settings().set("OPENAI.KEY", os.environ.get("OPENAI_API_KEY", ""))
|
||||
|
@ -1,6 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
from lib.chatmark import Radio, TextEditor
|
||||
|
||||
|
||||
@ -224,3 +226,17 @@ def get_gitlab_host(pr_url):
|
||||
gitlab_host_map[pr_host] = host
|
||||
_save_config_value("gitlab_host_map", gitlab_host_map)
|
||||
return host
|
||||
|
||||
|
||||
def get_model_max_input(model):
|
||||
config_file = os.path.expanduser("~/.chat/config.yml")
|
||||
try:
|
||||
with open(config_file, "r", encoding="utf-8") as file:
|
||||
yaml_contents = file.read()
|
||||
parsed_yaml = yaml.safe_load(yaml_contents)
|
||||
for model_t in parsed_yaml.get("models", {}):
|
||||
if model_t == model:
|
||||
return parsed_yaml["models"][model_t].get("max_input_tokens", 6000)
|
||||
return 6000
|
||||
except Exception:
|
||||
return 6000
|
||||
|
Loading…
x
Reference in New Issue
Block a user