2023-11-03 11:02:53 +08:00

180 lines
6.2 KiB
Python

from enum import Enum
import os
import sys
from typing import List, Dict, Tuple, Union, Optional
from pydantic import BaseModel
import oyaml as yaml
from devchat.openai import OpenAIChatParameters
from devchat.anthropic import AnthropicChatParameters
class Client(str, Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
GENERAL = "general"
class ProviderConfig(BaseModel):
client: Optional[str]
class OpenAIProviderConfig(ProviderConfig, extra='forbid'):
api_key: Optional[str]
api_base: Optional[str]
api_type: Optional[str]
api_version: Optional[str]
deployment_name: Optional[str]
class AnthropicProviderConfig(ProviderConfig, extra='forbid'):
api_key: Optional[str]
api_base: Optional[str]
timeout: Optional[float]
class ModelConfig(BaseModel, extra='forbid'):
max_input_tokens: Optional[int] = sys.maxsize
provider: Optional[str]
class OpenAIModelConfig(ModelConfig, OpenAIChatParameters):
pass
class AnthropicModelConfig(ModelConfig, AnthropicChatParameters):
pass
class GeneralModelConfig(ModelConfig):
max_tokens: Optional[int]
stop_sequences: Optional[List[str]]
temperature: Optional[float]
top_p: Optional[float]
top_k: Optional[int]
stream: Optional[bool]
class ChatConfig(BaseModel, extra='forbid'):
providers: Optional[Dict[str, Union[OpenAIProviderConfig,
AnthropicProviderConfig,
ProviderConfig]]]
models: Dict[str, Union[OpenAIModelConfig, AnthropicModelConfig, GeneralModelConfig]]
default_model: Optional[str]
class ConfigManager:
def __init__(self, dir_path: str):
self.config_path = os.path.join(dir_path, 'config.yml')
if not os.path.exists(self.config_path):
self._create_sample_file()
self._file_is_new = True
else:
self._file_is_new = False
self.config = self._load_and_validate_config()
@property
def file_is_new(self) -> bool:
return self._file_is_new
@property
def file_last_modified(self) -> float:
return os.path.getmtime(self.config_path)
def _load_and_validate_config(self) -> ChatConfig:
with open(self.config_path, 'r', encoding='utf-8') as file:
data = yaml.safe_load(file)
if 'providers' in data:
for provider, config in data['providers'].items():
if config['client'] == "openai":
data['providers'][provider] = OpenAIProviderConfig(**config)
elif config['client'] == "anthropic":
data['providers'][provider] = AnthropicProviderConfig(**config)
else:
data['providers'][provider] = ProviderConfig(**config)
for model, config in data['models'].items():
if 'provider' not in config:
data['models'][model] = GeneralModelConfig(**config)
elif 'parameters' in config:
provider = data['providers'][config['provider']]
if provider.client == Client.OPENAI:
data['models'][model] = OpenAIModelConfig(**config)
elif provider.client == Client.ANTHROPIC:
data['models'][model] = AnthropicModelConfig(**config)
else:
data['models'][model] = GeneralModelConfig(**config)
return ChatConfig(**data)
def model_config(self, model_id: Optional[str] = None) -> Tuple[str, ModelConfig]:
if not model_id:
if self.config.default_model:
return self.model_config(self.config.default_model)
if self.config.models:
return next(iter(self.config.models.items()))
raise ValueError(f"No models found in {self.config_path}")
if model_id not in self.config.models:
raise ValueError(f"Model '{model_id}' not found in {self.config_path}")
return model_id, self.config.models[model_id]
def update_model_config(
self,
model_id: str,
new_config: Union[OpenAIModelConfig, AnthropicModelConfig]
) -> Union[OpenAIModelConfig, AnthropicModelConfig]:
_, old_config = self.model_config(model_id)
if new_config.max_input_tokens is not None:
old_config.max_input_tokens = new_config.max_input_tokens
updated_parameters = old_config.dict(exclude_unset=True)
updated_parameters.update(new_config.dict(exclude_unset=True))
self.config.models[model_id] = type(new_config)(**updated_parameters)
return self.config.models[model_id]
def sync(self):
with open(self.config_path, 'w', encoding='utf-8') as file:
yaml.dump(self.config.dict(exclude_unset=True), file)
def _create_sample_file(self):
sample_config = ChatConfig(
providers={
"devchat.ai": OpenAIProviderConfig(
client=Client.OPENAI,
api_key=""
),
"openai.com": OpenAIProviderConfig(
client=Client.OPENAI,
api_key=""
),
"general": ProviderConfig(
client=Client.GENERAL
)
},
models={
"gpt-4": OpenAIModelConfig(
max_input_tokens=6000,
provider='devchat.ai',
temperature=0,
stream=True
),
"gpt-3.5-turbo-16k": OpenAIModelConfig(
max_input_tokens=12000,
provider='devchat.ai',
temperature=0,
stream=True
),
"gpt-3.5-turbo": OpenAIModelConfig(
max_input_tokens=3000,
provider='devchat.ai',
temperature=0,
stream=True
),
"claude-2": GeneralModelConfig(
provider='general',
max_tokens=20000
)
},
default_model="gpt-3.5-turbo"
)
with open(self.config_path, 'w', encoding='utf-8') as file:
yaml.dump(sample_config.dict(exclude_unset=True), file)