233 lines
7.5 KiB
Python
233 lines
7.5 KiB
Python
import logging
|
|
import os
|
|
import re
|
|
import getpass
|
|
import socket
|
|
import subprocess
|
|
from typing import List, Tuple, Optional
|
|
import datetime
|
|
import hashlib
|
|
import tiktoken
|
|
|
|
try:
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
except Exception:
|
|
from tiktoken import registry
|
|
from tiktoken.registry import _find_constructors
|
|
from tiktoken.core import Encoding
|
|
|
|
def get_encoding(name: str):
|
|
_find_constructors()
|
|
constructor = registry.ENCODING_CONSTRUCTORS[name]
|
|
return Encoding(**constructor(), use_pure_python=True)
|
|
|
|
encoding = get_encoding("cl100k_base")
|
|
|
|
log_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
def setup_logger(file_path: Optional[str] = None):
|
|
"""Utility function to set up a global file log handler."""
|
|
if file_path is None:
|
|
handler = logging.StreamHandler()
|
|
else:
|
|
handler = logging.FileHandler(file_path)
|
|
handler.setFormatter(log_formatter)
|
|
logging.root.handlers = [handler]
|
|
|
|
|
|
def get_logger(name: str = None, handler: logging.Handler = None) -> logging.Logger:
|
|
local_logger = logging.getLogger(name)
|
|
|
|
# Default to 'INFO' if 'LOG_LEVEL' env is not set
|
|
log_level_str = os.getenv('LOG_LEVEL', 'INFO')
|
|
log_level = getattr(logging, log_level_str.upper(), logging.INFO)
|
|
local_logger.setLevel(log_level)
|
|
|
|
# If a handler is provided, configure and add it to the logger
|
|
if handler is not None:
|
|
handler.setLevel(log_level)
|
|
handler.setFormatter(log_formatter)
|
|
local_logger.addHandler(handler)
|
|
|
|
local_logger.info("Get %s", str(local_logger))
|
|
return local_logger
|
|
|
|
|
|
def find_root_dir() -> Tuple[Optional[str], Optional[str]]:
|
|
"""
|
|
Find the root directory of the repository and the user's home directory
|
|
"""
|
|
try:
|
|
user_dir = os.path.expanduser("~")
|
|
if not os.path.isdir(user_dir):
|
|
user_dir = None
|
|
except Exception:
|
|
user_dir = None
|
|
|
|
repo_dir = None
|
|
try:
|
|
repo_dir = subprocess.run(["git", "rev-parse", "--show-toplevel"],
|
|
capture_output=True, text=True, check=True,
|
|
encoding='utf-8').stdout.strip()
|
|
if not os.path.isdir(repo_dir):
|
|
repo_dir = None
|
|
else:
|
|
return repo_dir, user_dir
|
|
except Exception:
|
|
repo_dir = None
|
|
|
|
try:
|
|
result = subprocess.run(["svn", "info"],
|
|
capture_output=True, text=True, check=True, encoding='utf-8')
|
|
if result.returncode == 0:
|
|
for line in result.stdout.splitlines():
|
|
if line.startswith("Working Copy Root Path: "):
|
|
repo_dir = line.split("Working Copy Root Path: ", 1)[1].strip()
|
|
if os.path.isdir(repo_dir):
|
|
return repo_dir, user_dir
|
|
except Exception:
|
|
repo_dir = None
|
|
|
|
return repo_dir, user_dir
|
|
|
|
|
|
def add_gitignore(target_dir: str, *ignore_entries: str) -> None:
|
|
gitignore_path = os.path.join(target_dir, '.gitignore')
|
|
|
|
if os.path.exists(gitignore_path):
|
|
with open(gitignore_path, 'r', encoding='utf-8') as gitignore_file:
|
|
gitignore_content = gitignore_file.read()
|
|
|
|
new_entries = []
|
|
for entry in ignore_entries:
|
|
if entry not in gitignore_content:
|
|
new_entries.append(entry)
|
|
|
|
if new_entries:
|
|
with open(gitignore_path, 'a', encoding='utf-8') as gitignore_file:
|
|
gitignore_file.write('\n# devchat\n')
|
|
for entry in new_entries:
|
|
gitignore_file.write(f'{entry}\n')
|
|
else:
|
|
with open(gitignore_path, 'w', encoding='utf-8') as gitignore_file:
|
|
gitignore_file.write('# devchat\n')
|
|
for entry in ignore_entries:
|
|
gitignore_file.write(f'{entry}\n')
|
|
|
|
|
|
def unix_to_local_datetime(unix_time) -> datetime.datetime:
|
|
# Convert the Unix time to a naive datetime object in UTC
|
|
naive_dt = datetime.datetime.utcfromtimestamp(unix_time).replace(tzinfo=datetime.timezone.utc)
|
|
|
|
# Convert the UTC datetime object to the local timezone
|
|
local_dt = naive_dt.astimezone()
|
|
|
|
return local_dt
|
|
|
|
|
|
def get_user_info() -> Tuple[str, str]:
|
|
try:
|
|
cmd = ['git', 'config', 'user.name']
|
|
user_name = subprocess.check_output(cmd, encoding='utf-8').strip()
|
|
except Exception:
|
|
try:
|
|
user_name = getpass.getuser()
|
|
except Exception:
|
|
user_dir = os.path.expanduser("~")
|
|
user_name = user_dir.split(os.sep)[-1]
|
|
|
|
try:
|
|
cmd = ['git', 'config', 'user.email']
|
|
user_email = subprocess.check_output(cmd, encoding='utf-8').strip()
|
|
except Exception:
|
|
user_email = user_name + '@' + socket.gethostname()
|
|
|
|
return user_name, user_email
|
|
|
|
|
|
def user_id(user_name, user_email) -> Tuple[str, str]:
|
|
user_str = f"{user_name} <{user_email}>"
|
|
user_hash = hashlib.sha1(user_str.encode('utf-8')).hexdigest()
|
|
return user_str, user_hash
|
|
|
|
|
|
def parse_files(file_paths: List[str]) -> List[str]:
|
|
if not file_paths:
|
|
return []
|
|
|
|
for file_path in file_paths:
|
|
file_path = os.path.expanduser(file_path.strip())
|
|
if not os.path.isfile(file_path):
|
|
raise ValueError(f"File {file_path} does not exist.")
|
|
|
|
contents = []
|
|
for file_path in file_paths:
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
content = file.read()
|
|
if not content:
|
|
raise ValueError(f"File {file_path} is empty.")
|
|
contents.append(content)
|
|
return contents
|
|
|
|
|
|
def valid_hash(hash_str):
|
|
"""Check if a string is a valid hash value."""
|
|
pattern = re.compile(r'^[a-f0-9]{64}$') # for SHA-256 hash
|
|
return bool(pattern.match(hash_str))
|
|
|
|
|
|
def check_format(formatted_response) -> bool:
|
|
pattern = r"(User: .+ <.+@.+>\nDate: .+\n\n(?:.*\n)*\n(?:prompt [a-f0-9]{64}\n\n?)+)"
|
|
return bool(re.fullmatch(pattern, formatted_response))
|
|
|
|
|
|
def get_content(formatted_response) -> str:
|
|
header_pattern = r"User: .+ <.+@.+>\nDate: .+\n\n"
|
|
footer_pattern = r"\n(?:prompt [a-f0-9]{64}\n\n?)+"
|
|
|
|
content = re.sub(header_pattern, "", formatted_response)
|
|
content = re.sub(footer_pattern, "", content)
|
|
|
|
return content
|
|
|
|
|
|
def get_prompt_hash(formatted_response) -> str:
|
|
if not check_format(formatted_response):
|
|
raise ValueError("Invalid formatted response.")
|
|
footer_pattern = r"\n(?:prompt [a-f0-9]{64}\n\n?)+"
|
|
# get the last prompt hash
|
|
prompt_hash = re.findall(footer_pattern, formatted_response)[-1].strip()
|
|
prompt_hash = prompt_hash.replace("prompt ", "")
|
|
return prompt_hash
|
|
|
|
|
|
def update_dict(dict_to_update, key, value) -> dict:
|
|
"""
|
|
Update a dictionary with a key-value pair and return the dictionary.
|
|
"""
|
|
dict_to_update[key] = value
|
|
return dict_to_update
|
|
|
|
|
|
def _count_tokens(encoding_tik: tiktoken.Encoding, string: str) -> int:
|
|
"""
|
|
Count the number of tokens in a string.
|
|
"""
|
|
try:
|
|
return len(encoding_tik.encode(string))
|
|
except Exception:
|
|
word_count = len(re.findall(r'\w+', string))
|
|
# Note: This is a rough estimate and may not be accurate
|
|
return int(word_count / 0.75)
|
|
|
|
|
|
def openai_message_tokens(messages: dict, model: str) -> int: # pylint: disable=unused-argument
|
|
"""Returns the number of tokens used by a message."""
|
|
return len(encoding.encode(str(messages)))
|
|
|
|
|
|
def openai_response_tokens(message: dict, model: str) -> int:
|
|
"""Returns the number of tokens used by a response."""
|
|
return openai_message_tokens(message, model)
|