2023-11-15 15:46:46 +08:00

233 lines
7.5 KiB
Python

import logging
import os
import re
import getpass
import socket
import subprocess
from typing import List, Tuple, Optional
import datetime
import hashlib
import tiktoken
try:
encoding = tiktoken.get_encoding("cl100k_base")
except Exception:
from tiktoken import registry
from tiktoken.registry import _find_constructors
from tiktoken.core import Encoding
def get_encoding(name: str):
_find_constructors()
constructor = registry.ENCODING_CONSTRUCTORS[name]
return Encoding(**constructor(), use_pure_python=True)
encoding = get_encoding("cl100k_base")
log_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
def setup_logger(file_path: Optional[str] = None):
"""Utility function to set up a global file log handler."""
if file_path is None:
handler = logging.StreamHandler()
else:
handler = logging.FileHandler(file_path)
handler.setFormatter(log_formatter)
logging.root.handlers = [handler]
def get_logger(name: str = None, handler: logging.Handler = None) -> logging.Logger:
local_logger = logging.getLogger(name)
# Default to 'INFO' if 'LOG_LEVEL' env is not set
log_level_str = os.getenv('LOG_LEVEL', 'INFO')
log_level = getattr(logging, log_level_str.upper(), logging.INFO)
local_logger.setLevel(log_level)
# If a handler is provided, configure and add it to the logger
if handler is not None:
handler.setLevel(log_level)
handler.setFormatter(log_formatter)
local_logger.addHandler(handler)
local_logger.info("Get %s", str(local_logger))
return local_logger
def find_root_dir() -> Tuple[Optional[str], Optional[str]]:
"""
Find the root directory of the repository and the user's home directory
"""
try:
user_dir = os.path.expanduser("~")
if not os.path.isdir(user_dir):
user_dir = None
except Exception:
user_dir = None
repo_dir = None
try:
repo_dir = subprocess.run(["git", "rev-parse", "--show-toplevel"],
capture_output=True, text=True, check=True,
encoding='utf-8').stdout.strip()
if not os.path.isdir(repo_dir):
repo_dir = None
else:
return repo_dir, user_dir
except Exception:
repo_dir = None
try:
result = subprocess.run(["svn", "info"],
capture_output=True, text=True, check=True, encoding='utf-8')
if result.returncode == 0:
for line in result.stdout.splitlines():
if line.startswith("Working Copy Root Path: "):
repo_dir = line.split("Working Copy Root Path: ", 1)[1].strip()
if os.path.isdir(repo_dir):
return repo_dir, user_dir
except Exception:
repo_dir = None
return repo_dir, user_dir
def add_gitignore(target_dir: str, *ignore_entries: str) -> None:
gitignore_path = os.path.join(target_dir, '.gitignore')
if os.path.exists(gitignore_path):
with open(gitignore_path, 'r', encoding='utf-8') as gitignore_file:
gitignore_content = gitignore_file.read()
new_entries = []
for entry in ignore_entries:
if entry not in gitignore_content:
new_entries.append(entry)
if new_entries:
with open(gitignore_path, 'a', encoding='utf-8') as gitignore_file:
gitignore_file.write('\n# devchat\n')
for entry in new_entries:
gitignore_file.write(f'{entry}\n')
else:
with open(gitignore_path, 'w', encoding='utf-8') as gitignore_file:
gitignore_file.write('# devchat\n')
for entry in ignore_entries:
gitignore_file.write(f'{entry}\n')
def unix_to_local_datetime(unix_time) -> datetime.datetime:
# Convert the Unix time to a naive datetime object in UTC
naive_dt = datetime.datetime.utcfromtimestamp(unix_time).replace(tzinfo=datetime.timezone.utc)
# Convert the UTC datetime object to the local timezone
local_dt = naive_dt.astimezone()
return local_dt
def get_user_info() -> Tuple[str, str]:
try:
cmd = ['git', 'config', 'user.name']
user_name = subprocess.check_output(cmd, encoding='utf-8').strip()
except Exception:
try:
user_name = getpass.getuser()
except Exception:
user_dir = os.path.expanduser("~")
user_name = user_dir.split(os.sep)[-1]
try:
cmd = ['git', 'config', 'user.email']
user_email = subprocess.check_output(cmd, encoding='utf-8').strip()
except Exception:
user_email = user_name + '@' + socket.gethostname()
return user_name, user_email
def user_id(user_name, user_email) -> Tuple[str, str]:
user_str = f"{user_name} <{user_email}>"
user_hash = hashlib.sha1(user_str.encode('utf-8')).hexdigest()
return user_str, user_hash
def parse_files(file_paths: List[str]) -> List[str]:
if not file_paths:
return []
for file_path in file_paths:
file_path = os.path.expanduser(file_path.strip())
if not os.path.isfile(file_path):
raise ValueError(f"File {file_path} does not exist.")
contents = []
for file_path in file_paths:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
if not content:
raise ValueError(f"File {file_path} is empty.")
contents.append(content)
return contents
def valid_hash(hash_str):
"""Check if a string is a valid hash value."""
pattern = re.compile(r'^[a-f0-9]{64}$') # for SHA-256 hash
return bool(pattern.match(hash_str))
def check_format(formatted_response) -> bool:
pattern = r"(User: .+ <.+@.+>\nDate: .+\n\n(?:.*\n)*\n(?:prompt [a-f0-9]{64}\n\n?)+)"
return bool(re.fullmatch(pattern, formatted_response))
def get_content(formatted_response) -> str:
header_pattern = r"User: .+ <.+@.+>\nDate: .+\n\n"
footer_pattern = r"\n(?:prompt [a-f0-9]{64}\n\n?)+"
content = re.sub(header_pattern, "", formatted_response)
content = re.sub(footer_pattern, "", content)
return content
def get_prompt_hash(formatted_response) -> str:
if not check_format(formatted_response):
raise ValueError("Invalid formatted response.")
footer_pattern = r"\n(?:prompt [a-f0-9]{64}\n\n?)+"
# get the last prompt hash
prompt_hash = re.findall(footer_pattern, formatted_response)[-1].strip()
prompt_hash = prompt_hash.replace("prompt ", "")
return prompt_hash
def update_dict(dict_to_update, key, value) -> dict:
"""
Update a dictionary with a key-value pair and return the dictionary.
"""
dict_to_update[key] = value
return dict_to_update
def _count_tokens(encoding_tik: tiktoken.Encoding, string: str) -> int:
"""
Count the number of tokens in a string.
"""
try:
return len(encoding_tik.encode(string))
except Exception:
word_count = len(re.findall(r'\w+', string))
# Note: This is a rough estimate and may not be accurate
return int(word_count / 0.75)
def openai_message_tokens(messages: dict, model: str) -> int: # pylint: disable=unused-argument
"""Returns the number of tokens used by a message."""
return len(encoding.encode(str(messages)))
def openai_response_tokens(message: dict, model: str) -> int:
"""Returns the number of tokens used by a response."""
return openai_message_tokens(message, model)