import datetime import getpass import hashlib import logging import os import re import socket import subprocess from typing import List, Optional, Tuple log_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") encoding = None def setup_logger(file_path: Optional[str] = None): """Utility function to set up a global file log handler.""" if file_path is None: handler = logging.StreamHandler() else: handler = logging.FileHandler(file_path) handler.setFormatter(log_formatter) logging.root.handlers = [handler] def get_logger(name: str = None, handler: logging.Handler = None) -> logging.Logger: local_logger = logging.getLogger(name) # Default to 'INFO' if 'LOG_LEVEL' env is not set log_level_str = os.getenv("LOG_LEVEL", "INFO") log_level = getattr(logging, log_level_str.upper(), logging.INFO) local_logger.setLevel(log_level) # If a handler is provided, configure and add it to the logger if handler is not None: handler.setLevel(log_level) handler.setFormatter(log_formatter) local_logger.addHandler(handler) local_logger.info("Get %s", str(local_logger)) return local_logger def find_root_dir() -> Tuple[Optional[str], Optional[str]]: """ Find the root directory of the repository and the user's home directory """ try: user_dir = os.path.expanduser("~") if not os.path.isdir(user_dir): user_dir = None except Exception: user_dir = None repo_dir = None try: repo_dir = subprocess.run( ["git", "rev-parse", "--show-toplevel"], capture_output=True, text=True, check=True, encoding="utf-8", ).stdout.strip() if not os.path.isdir(repo_dir): repo_dir = None else: return repo_dir, user_dir except Exception: repo_dir = None try: result = subprocess.run( ["svn", "info"], capture_output=True, text=True, check=True, encoding="utf-8" ) if result.returncode == 0: for line in result.stdout.splitlines(): if line.startswith("Working Copy Root Path: "): repo_dir = line.split("Working Copy Root Path: ", 1)[1].strip() if os.path.isdir(repo_dir): return repo_dir, user_dir except Exception: repo_dir = None return repo_dir, user_dir def add_gitignore(target_dir: str, *ignore_entries: str) -> None: gitignore_path = os.path.join(target_dir, ".gitignore") if os.path.exists(gitignore_path): with open(gitignore_path, "r", encoding="utf-8") as gitignore_file: gitignore_content = gitignore_file.read() new_entries = [] for entry in ignore_entries: if entry not in gitignore_content: new_entries.append(entry) if new_entries: with open(gitignore_path, "a", encoding="utf-8") as gitignore_file: gitignore_file.write("\n# devchat\n") for entry in new_entries: gitignore_file.write(f"{entry}\n") else: with open(gitignore_path, "w", encoding="utf-8") as gitignore_file: gitignore_file.write("# devchat\n") for entry in ignore_entries: gitignore_file.write(f"{entry}\n") def unix_to_local_datetime(unix_time) -> datetime.datetime: # Convert the Unix time to a naive datetime object in UTC naive_dt = datetime.datetime.utcfromtimestamp(unix_time).replace(tzinfo=datetime.timezone.utc) # Convert the UTC datetime object to the local timezone local_dt = naive_dt.astimezone() return local_dt def get_user_info() -> Tuple[str, str]: try: cmd = ["git", "config", "user.name"] user_name = subprocess.check_output(cmd, encoding="utf-8").strip() except Exception: try: user_name = getpass.getuser() except Exception: user_dir = os.path.expanduser("~") user_name = user_dir.split(os.sep)[-1] try: cmd = ["git", "config", "user.email"] user_email = subprocess.check_output(cmd, encoding="utf-8").strip() except Exception: user_email = user_name + "@" + socket.gethostname() return user_name, user_email def user_id(user_name, user_email) -> Tuple[str, str]: user_str = f"{user_name} <{user_email}>" user_hash = hashlib.sha1(user_str.encode("utf-8")).hexdigest() return user_str, user_hash def parse_files(file_paths: List[str]) -> List[str]: if not file_paths: return [] for file_path in file_paths: file_path = os.path.expanduser(file_path.strip()) if not os.path.isfile(file_path): raise ValueError(f"File {file_path} does not exist.") contents = [] for file_path in file_paths: with open(file_path, "r", encoding="utf-8") as file: content = file.read() if not content: raise ValueError(f"File {file_path} is empty.") contents.append(content) return contents def valid_hash(hash_str): """Check if a string is a valid hash value.""" pattern = re.compile(r"^[a-f0-9]{64}$") # for SHA-256 hash return bool(pattern.match(hash_str)) def check_format(formatted_response) -> bool: pattern = r"(User: .+ <.+@.+>\nDate: .+\n\n(?:.*\n)*\n(?:prompt [a-f0-9]{64}\n\n?)+)" return bool(re.fullmatch(pattern, formatted_response)) def get_content(formatted_response) -> str: header_pattern = r"User: .+ <.+@.+>\nDate: .+\n\n" footer_pattern = r"\n(?:prompt [a-f0-9]{64}\n\n?)+" content = re.sub(header_pattern, "", formatted_response) content = re.sub(footer_pattern, "", content) return content def get_prompt_hash(formatted_response) -> str: if not check_format(formatted_response): raise ValueError("Invalid formatted response.") footer_pattern = r"\n(?:prompt [a-f0-9]{64}\n\n?)+" # get the last prompt hash prompt_hash = re.findall(footer_pattern, formatted_response)[-1].strip() prompt_hash = prompt_hash.replace("prompt ", "") return prompt_hash def update_dict(dict_to_update, key, value) -> dict: """ Update a dictionary with a key-value pair and return the dictionary. """ dict_to_update[key] = value return dict_to_update def openai_message_tokens(messages: dict, model: str) -> int: """Returns the number of tokens used by a message.""" if not os.environ.get("USE_TIKTOKEN", False): return len(str(messages)) / 4 global encoding if not encoding: import tiktoken script_dir = os.path.dirname(os.path.realpath(__file__)) os.environ["TIKTOKEN_CACHE_DIR"] = os.path.join(script_dir, "tiktoken_cache") try: encoding = tiktoken.get_encoding("cl100k_base") except Exception: from tiktoken import registry from tiktoken.core import Encoding from tiktoken.registry import _find_constructors def get_encoding(name: str): _find_constructors() constructor = registry.ENCODING_CONSTRUCTORS[name] return Encoding(**constructor(), use_pure_python=True) encoding = get_encoding("cl100k_base") return len(encoding.encode(str(messages), disallowed_special=())) def openai_response_tokens(message: dict, model: str) -> int: """Returns the number of tokens used by a response.""" return openai_message_tokens(message, model)