180 lines
6.2 KiB
Python
180 lines
6.2 KiB
Python
from enum import Enum
|
|
import os
|
|
import sys
|
|
from typing import List, Dict, Tuple, Union, Optional
|
|
from pydantic import BaseModel
|
|
import oyaml as yaml
|
|
from devchat.openai import OpenAIChatParameters
|
|
from devchat.anthropic import AnthropicChatParameters
|
|
|
|
|
|
class Client(str, Enum):
|
|
OPENAI = "openai"
|
|
ANTHROPIC = "anthropic"
|
|
GENERAL = "general"
|
|
|
|
|
|
class ProviderConfig(BaseModel):
|
|
client: Optional[str]
|
|
|
|
|
|
class OpenAIProviderConfig(ProviderConfig, extra='forbid'):
|
|
api_key: Optional[str]
|
|
api_base: Optional[str]
|
|
api_type: Optional[str]
|
|
api_version: Optional[str]
|
|
deployment_name: Optional[str]
|
|
|
|
|
|
class AnthropicProviderConfig(ProviderConfig, extra='forbid'):
|
|
api_key: Optional[str]
|
|
api_base: Optional[str]
|
|
timeout: Optional[float]
|
|
|
|
|
|
class ModelConfig(BaseModel, extra='forbid'):
|
|
max_input_tokens: Optional[int] = sys.maxsize
|
|
provider: Optional[str]
|
|
|
|
|
|
class OpenAIModelConfig(ModelConfig, OpenAIChatParameters):
|
|
pass
|
|
|
|
|
|
class AnthropicModelConfig(ModelConfig, AnthropicChatParameters):
|
|
pass
|
|
|
|
|
|
class GeneralModelConfig(ModelConfig):
|
|
max_tokens: Optional[int]
|
|
stop_sequences: Optional[List[str]]
|
|
temperature: Optional[float]
|
|
top_p: Optional[float]
|
|
top_k: Optional[int]
|
|
stream: Optional[bool]
|
|
|
|
|
|
class ChatConfig(BaseModel, extra='forbid'):
|
|
providers: Optional[Dict[str, Union[OpenAIProviderConfig,
|
|
AnthropicProviderConfig,
|
|
ProviderConfig]]]
|
|
models: Dict[str, Union[OpenAIModelConfig, AnthropicModelConfig, GeneralModelConfig]]
|
|
default_model: Optional[str]
|
|
|
|
|
|
class ConfigManager:
|
|
def __init__(self, dir_path: str):
|
|
self.config_path = os.path.join(dir_path, 'config.yml')
|
|
if not os.path.exists(self.config_path):
|
|
self._create_sample_file()
|
|
self._file_is_new = True
|
|
else:
|
|
self._file_is_new = False
|
|
self.config = self._load_and_validate_config()
|
|
|
|
@property
|
|
def file_is_new(self) -> bool:
|
|
return self._file_is_new
|
|
|
|
@property
|
|
def file_last_modified(self) -> float:
|
|
return os.path.getmtime(self.config_path)
|
|
|
|
def _load_and_validate_config(self) -> ChatConfig:
|
|
with open(self.config_path, 'r', encoding='utf-8') as file:
|
|
data = yaml.safe_load(file)
|
|
|
|
if 'providers' in data:
|
|
for provider, config in data['providers'].items():
|
|
if config['client'] == "openai":
|
|
data['providers'][provider] = OpenAIProviderConfig(**config)
|
|
elif config['client'] == "anthropic":
|
|
data['providers'][provider] = AnthropicProviderConfig(**config)
|
|
else:
|
|
data['providers'][provider] = ProviderConfig(**config)
|
|
for model, config in data['models'].items():
|
|
if 'provider' not in config:
|
|
data['models'][model] = GeneralModelConfig(**config)
|
|
elif 'parameters' in config:
|
|
provider = data['providers'][config['provider']]
|
|
if provider.client == Client.OPENAI:
|
|
data['models'][model] = OpenAIModelConfig(**config)
|
|
elif provider.client == Client.ANTHROPIC:
|
|
data['models'][model] = AnthropicModelConfig(**config)
|
|
else:
|
|
data['models'][model] = GeneralModelConfig(**config)
|
|
|
|
return ChatConfig(**data)
|
|
|
|
def model_config(self, model_id: Optional[str] = None) -> Tuple[str, ModelConfig]:
|
|
if not model_id:
|
|
if self.config.default_model:
|
|
return self.model_config(self.config.default_model)
|
|
if self.config.models:
|
|
return next(iter(self.config.models.items()))
|
|
raise ValueError(f"No models found in {self.config_path}")
|
|
if model_id not in self.config.models:
|
|
raise ValueError(f"Model '{model_id}' not found in {self.config_path}")
|
|
return model_id, self.config.models[model_id]
|
|
|
|
def update_model_config(
|
|
self,
|
|
model_id: str,
|
|
new_config: Union[OpenAIModelConfig, AnthropicModelConfig]
|
|
) -> Union[OpenAIModelConfig, AnthropicModelConfig]:
|
|
_, old_config = self.model_config(model_id)
|
|
if new_config.max_input_tokens is not None:
|
|
old_config.max_input_tokens = new_config.max_input_tokens
|
|
updated_parameters = old_config.dict(exclude_unset=True)
|
|
updated_parameters.update(new_config.dict(exclude_unset=True))
|
|
self.config.models[model_id] = type(new_config)(**updated_parameters)
|
|
return self.config.models[model_id]
|
|
|
|
def sync(self):
|
|
with open(self.config_path, 'w', encoding='utf-8') as file:
|
|
yaml.dump(self.config.dict(exclude_unset=True), file)
|
|
|
|
def _create_sample_file(self):
|
|
sample_config = ChatConfig(
|
|
providers={
|
|
"devchat.ai": OpenAIProviderConfig(
|
|
client=Client.OPENAI,
|
|
api_key=""
|
|
),
|
|
"openai.com": OpenAIProviderConfig(
|
|
client=Client.OPENAI,
|
|
api_key=""
|
|
),
|
|
"general": ProviderConfig(
|
|
client=Client.GENERAL
|
|
)
|
|
},
|
|
models={
|
|
"gpt-4": OpenAIModelConfig(
|
|
max_input_tokens=6000,
|
|
provider='devchat.ai',
|
|
temperature=0,
|
|
stream=True
|
|
),
|
|
"gpt-3.5-turbo-16k": OpenAIModelConfig(
|
|
max_input_tokens=12000,
|
|
provider='devchat.ai',
|
|
temperature=0,
|
|
stream=True
|
|
),
|
|
"gpt-3.5-turbo": OpenAIModelConfig(
|
|
max_input_tokens=3000,
|
|
provider='devchat.ai',
|
|
temperature=0,
|
|
stream=True
|
|
),
|
|
"claude-2": GeneralModelConfig(
|
|
provider='general',
|
|
max_tokens=20000
|
|
)
|
|
},
|
|
default_model="gpt-3.5-turbo"
|
|
)
|
|
with open(self.config_path, 'w', encoding='utf-8') as file:
|
|
yaml.dump(sample_config.dict(exclude_unset=True), file)
|