2023-11-03 11:02:53 +08:00
|
|
|
import os
|
|
|
|
import sys
|
2024-05-14 08:51:42 +00:00
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
|
2023-11-03 11:02:53 +08:00
|
|
|
import oyaml as yaml
|
2024-05-14 08:51:42 +00:00
|
|
|
from pydantic import BaseModel
|
2023-11-03 11:02:53 +08:00
|
|
|
|
|
|
|
|
2024-02-22 14:42:55 +08:00
|
|
|
class GeneralProviderConfig(BaseModel):
|
2023-11-03 11:02:53 +08:00
|
|
|
api_key: Optional[str]
|
|
|
|
api_base: Optional[str]
|
|
|
|
|
2024-05-14 08:51:42 +00:00
|
|
|
|
2024-04-28 14:56:52 +00:00
|
|
|
class ModelConfig(BaseModel):
|
2023-11-03 11:02:53 +08:00
|
|
|
max_input_tokens: Optional[int] = sys.maxsize
|
|
|
|
provider: Optional[str]
|
|
|
|
|
2024-05-14 08:51:42 +00:00
|
|
|
|
2023-11-03 11:02:53 +08:00
|
|
|
class GeneralModelConfig(ModelConfig):
|
|
|
|
max_tokens: Optional[int]
|
|
|
|
stop_sequences: Optional[List[str]]
|
|
|
|
temperature: Optional[float]
|
|
|
|
top_p: Optional[float]
|
|
|
|
top_k: Optional[int]
|
|
|
|
stream: Optional[bool]
|
|
|
|
|
|
|
|
|
2024-02-22 14:42:55 +08:00
|
|
|
class ChatConfig(BaseModel):
|
|
|
|
providers: Optional[Dict[str, GeneralProviderConfig]]
|
|
|
|
models: Dict[str, GeneralModelConfig]
|
2023-11-03 11:02:53 +08:00
|
|
|
default_model: Optional[str]
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigManager:
|
|
|
|
def __init__(self, dir_path: str):
|
2024-05-14 08:51:42 +00:00
|
|
|
self.config_path = os.path.join(dir_path, "config.yml")
|
2023-11-03 11:02:53 +08:00
|
|
|
if not os.path.exists(self.config_path):
|
|
|
|
self._create_sample_file()
|
|
|
|
self._file_is_new = True
|
|
|
|
else:
|
|
|
|
self._file_is_new = False
|
|
|
|
self.config = self._load_and_validate_config()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def file_is_new(self) -> bool:
|
|
|
|
return self._file_is_new
|
|
|
|
|
|
|
|
@property
|
|
|
|
def file_last_modified(self) -> float:
|
|
|
|
return os.path.getmtime(self.config_path)
|
|
|
|
|
|
|
|
def _load_and_validate_config(self) -> ChatConfig:
|
2024-05-14 08:51:42 +00:00
|
|
|
with open(self.config_path, "r", encoding="utf-8") as file:
|
2023-11-03 11:02:53 +08:00
|
|
|
data = yaml.safe_load(file)
|
|
|
|
|
2024-05-14 08:51:42 +00:00
|
|
|
if "providers" in data:
|
|
|
|
for provider, config in data["providers"].items():
|
|
|
|
data["providers"][provider] = GeneralProviderConfig(**config)
|
|
|
|
for model, config in data["models"].items():
|
|
|
|
data["models"][model] = GeneralModelConfig(**config)
|
2023-11-03 11:02:53 +08:00
|
|
|
|
|
|
|
return ChatConfig(**data)
|
|
|
|
|
|
|
|
def model_config(self, model_id: Optional[str] = None) -> Tuple[str, ModelConfig]:
|
|
|
|
if not model_id:
|
|
|
|
if self.config.default_model:
|
|
|
|
return self.model_config(self.config.default_model)
|
|
|
|
if self.config.models:
|
|
|
|
return next(iter(self.config.models.items()))
|
|
|
|
raise ValueError(f"No models found in {self.config_path}")
|
|
|
|
if model_id not in self.config.models:
|
|
|
|
raise ValueError(f"Model '{model_id}' not found in {self.config_path}")
|
|
|
|
return model_id, self.config.models[model_id]
|
|
|
|
|
|
|
|
def update_model_config(
|
2024-05-14 08:51:42 +00:00
|
|
|
self, model_id: str, new_config: GeneralModelConfig
|
2024-02-22 14:42:55 +08:00
|
|
|
) -> GeneralModelConfig:
|
2023-11-03 11:02:53 +08:00
|
|
|
_, old_config = self.model_config(model_id)
|
|
|
|
if new_config.max_input_tokens is not None:
|
|
|
|
old_config.max_input_tokens = new_config.max_input_tokens
|
|
|
|
updated_parameters = old_config.dict(exclude_unset=True)
|
|
|
|
updated_parameters.update(new_config.dict(exclude_unset=True))
|
|
|
|
self.config.models[model_id] = type(new_config)(**updated_parameters)
|
|
|
|
return self.config.models[model_id]
|
|
|
|
|
|
|
|
def sync(self):
|
2024-05-14 08:51:42 +00:00
|
|
|
with open(self.config_path, "w", encoding="utf-8") as file:
|
2023-11-03 11:02:53 +08:00
|
|
|
yaml.dump(self.config.dict(exclude_unset=True), file)
|
|
|
|
|
|
|
|
def _create_sample_file(self):
|
|
|
|
sample_config = ChatConfig(
|
|
|
|
providers={
|
2024-05-14 08:51:42 +00:00
|
|
|
"devchat.ai": GeneralProviderConfig(api_key=""),
|
|
|
|
"openai.com": GeneralProviderConfig(api_key=""),
|
|
|
|
"general": GeneralProviderConfig(),
|
2023-11-03 11:02:53 +08:00
|
|
|
},
|
|
|
|
models={
|
2024-02-22 14:42:55 +08:00
|
|
|
"gpt-4": GeneralModelConfig(
|
2024-05-14 08:51:42 +00:00
|
|
|
max_input_tokens=6000, provider="devchat.ai", temperature=0, stream=True
|
2023-11-03 11:02:53 +08:00
|
|
|
),
|
2024-02-22 14:42:55 +08:00
|
|
|
"gpt-3.5-turbo-16k": GeneralModelConfig(
|
2024-05-14 08:51:42 +00:00
|
|
|
max_input_tokens=12000, provider="devchat.ai", temperature=0, stream=True
|
2023-11-03 11:02:53 +08:00
|
|
|
),
|
2024-02-22 14:42:55 +08:00
|
|
|
"gpt-3.5-turbo": GeneralModelConfig(
|
2024-05-14 08:51:42 +00:00
|
|
|
max_input_tokens=3000, provider="devchat.ai", temperature=0, stream=True
|
2023-11-03 11:02:53 +08:00
|
|
|
),
|
2024-05-14 08:51:42 +00:00
|
|
|
"claude-2": GeneralModelConfig(provider="general", max_tokens=20000),
|
2023-11-03 11:02:53 +08:00
|
|
|
},
|
2024-05-14 08:51:42 +00:00
|
|
|
default_model="gpt-3.5-turbo",
|
2023-11-03 11:02:53 +08:00
|
|
|
)
|
2024-05-14 08:51:42 +00:00
|
|
|
with open(self.config_path, "w", encoding="utf-8") as file:
|
2023-11-03 11:02:53 +08:00
|
|
|
yaml.dump(sample_config.dict(exclude_unset=True), file)
|