# flake8: noqa: E402 import os import re import subprocess import sys from devchat.llm import chat_completion_stream from lib.chatmark import Button, Checkbox, Form, TextEditor from lib.ide_service import IDEService sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) from common_util import assert_exit # noqa: E402 from git_api import get_issue_info, subprocess_check_output, subprocess_run diff_too_large_message_en = ( "Commit failed. The modified content is too long " "and exceeds the model's length limit. " "You can try to make partial changes to the file and submit multiple times. " "Making small changes and submitting them multiple times is a better practice." ) diff_too_large_message_zh = ( "提交失败。修改内容太长,超出模型限制长度," "可以尝试选择部分修改文件多次提交,小修改多提交是更好的做法。" ) COMMIT_PROMPT_LIMIT_SIZE = 20000 def extract_markdown_block(text): """ Extracts the first Markdown code block from the given text without the language specifier. :param text: A string containing Markdown text :return: The content of the first Markdown code block, or None if not found """ # 正则表达式匹配Markdown代码块,忽略可选的语言类型标记 pattern = r"```(?:\w+)?\s*\n(.*?)\n```" match = re.search(pattern, text, re.DOTALL) if match: # 返回第一个匹配的代码块内容,去除首尾的反引号和语言类型标记 # 去除块结束标记前的一个换行符,但保留其他内容 block_content = match.group(1) return block_content else: return text # Read the prompt from the diffCommitMessagePrompt.txt file def read_prompt_from_file(filename): """ Reads the content of a file and returns it as a string. This function is designed to read a prompt message from a text file. It expects the file to be encoded in UTF-8 and will strip any leading or trailing whitespace from the content of the file. If the file does not exist or an error occurs during reading, the function logs an error message and exits the script. Parameters: - filename (str): The path to the file that contains the prompt message. Returns: - str: The content of the file as a string. Raises: - FileNotFoundError: If the file does not exist. - Exception: If any other error occurs during file reading. """ try: with open(filename, "r", encoding="utf-8") as file: return file.read().strip() except FileNotFoundError: IDEService().ide_logging( "error", f"File {filename} not found. " "Please make sure it exists in the same directory as the script.", ) sys.exit(1) except Exception as e: IDEService().ide_logging( "error", f"An error occurred while reading the file {filename}: {e}" ) sys.exit(1) # Read the prompt content from the file script_path = os.path.dirname(__file__) PROMPT_FILENAME = os.path.join(script_path, "diffCommitMessagePrompt.txt") PROMPT_COMMIT_MESSAGE_BY_DIFF_USER_INPUT = read_prompt_from_file(PROMPT_FILENAME) prompt_commit_message_by_diff_user_input_llm_config = { "model": os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106") } language = "" def assert_value(value, message): """ 判断给定的value是否为True,如果是,则输出指定的message并终止程序。 Args: value: 用于判断的值。 message: 如果value为True时需要输出的信息。 Returns: 无返回值。 """ if value: print(message, file=sys.stderr, flush=True) sys.exit(-1) def decode_path(encoded_path): octal_pattern = re.compile(r"\\[0-7]{3}") if octal_pattern.search(encoded_path): bytes_path = encoded_path.encode("utf-8").decode("unicode_escape").encode("latin1") decoded_path = bytes_path.decode("utf-8") return decoded_path else: return encoded_path def get_modified_files(): """ 获取当前修改文件列表以及已经staged的文件列表 Args: 无 Returns: tuple: 包含两个list的元组,第一个list包含当前修改过的文件,第二个list包含已经staged的文件 """ """ 获取当前修改文件列表以及已经staged的文件列表""" output = subprocess_check_output(["git", "status", "-s", "-u"], text=True, encoding="utf-8") lines = output.split("\n") modified_files = [] staged_files = [] def strip_file_name(file_name): file = file_name.strip() if file.startswith('"'): file = file[1:-1] return file for line in lines: if len(line) > 2: status, filename = line[:2], decode_path(line[3:]) # check wether filename is a directory if os.path.isdir(filename): continue modified_files.append(os.path.normpath(strip_file_name(filename))) if status == "M " or status == "A " or status == "D ": staged_files.append(os.path.normpath(strip_file_name(filename))) return modified_files, staged_files def get_marked_files(modified_files, staged_files): """ 根据给定的参数获取用户选中以供提交的文件 Args: modified_files (List[str]): 用户已修改文件列表 staged_files (List[str]): 用户已staged文件列表 Returns: List[str]: 用户选中的文件列表 """ # Create two Checkbox instances for staged and unstaged files staged_checkbox = Checkbox(staged_files, [True] * len(staged_files)) unstaged_files = [file for file in modified_files if file not in staged_files] unstaged_checkbox = Checkbox(unstaged_files, [False] * len(unstaged_files)) # Create a Form with both Checkbox instances form_list = [] if len(staged_files) > 0: form_list.append("Staged:\n\n") form_list.append(staged_checkbox) if len(unstaged_files) > 0: form_list.append("Unstaged:\n\n") form_list.append(unstaged_checkbox) form = Form(form_list, submit_button_name="Continue") # Render the Form and get user input form.render() # Retrieve the selected files from both Checkbox instances staged_checkbox_selections = staged_checkbox.selections if staged_checkbox.selections else [] unstaged_selections = unstaged_checkbox.selections if unstaged_checkbox.selections else [] selected_staged_files = [staged_files[idx] for idx in staged_checkbox_selections] selected_unstaged_files = [unstaged_files[idx] for idx in unstaged_selections] # Combine the selections from both checkboxes selected_files = selected_staged_files + selected_unstaged_files return selected_files def rebuild_stage_list(user_files): """ 根据用户选中文件,重新构建stage列表 Args: user_files: 用户选中的文件列表 Returns: None """ # Unstage all files subprocess_check_output(["git", "reset"]) # Stage all user_files for file in user_files: subprocess_run(["git", "add", file]) def get_diff(): """ 获取暂存区文件的Diff信息 Args: 无 Returns: bytes: 返回bytes类型,是git diff --cached命令的输出结果 """ return subprocess_check_output(["git", "diff", "--cached"]) def get_current_branch(): try: # 使用git命令获取当前分支名称 result = subprocess_check_output( ["git", "branch", "--show-current"], stderr=subprocess.STDOUT ).strip() # 将结果从bytes转换为str current_branch = result.decode("utf-8") return current_branch except subprocess.CalledProcessError: # 如果发生错误,打印错误信息 return None except FileNotFoundError: # 如果未找到git命令,可能是没有安装git或者不在PATH中 return None def generate_commit_message_base_diff(user_input, diff, issue): """ 根据diff信息,通过AI生成一个commit消息 Args: user_input (str): 用户输入的commit信息 diff (str): 提交的diff信息 Returns: str: 生成的commit消息 """ global language language_prompt = "You must response commit message in chinese。\n" if language == "zh" else "" prompt = ( PROMPT_COMMIT_MESSAGE_BY_DIFF_USER_INPUT.replace("{__DIFF__}", f"{diff}") .replace("{__USER_INPUT__}", f"{user_input + language_prompt}") .replace("{__ISSUE__}", f"{issue}") ) model_token_limit_error = ( diff_too_large_message_en if language == "en" else diff_too_large_message_zh ) if len(str(prompt)) > COMMIT_PROMPT_LIMIT_SIZE: print(model_token_limit_error, flush=True) sys.exit(0) messages = [{"role": "user", "content": prompt}] response = chat_completion_stream(messages, prompt_commit_message_by_diff_user_input_llm_config) if ( not response["content"] and response.get("error", None) and f"{response['error']}".find("This model's maximum context length is") > 0 ): print(model_token_limit_error) sys.exit(0) assert_value(not response["content"], response.get("error", "")) response["content"] = extract_markdown_block(response["content"]) return response def display_commit_message_and_commit(commit_message): """ 展示提交信息并提交。 Args: commit_message: 提交信息。 Returns: None。 """ text_editor = TextEditor(commit_message, submit_button_name="Commit") text_editor.render() new_commit_message = text_editor.new_text if not new_commit_message: return None return subprocess_check_output(["git", "commit", "-m", new_commit_message]) def extract_issue_id(branch_name): if "#" in branch_name: return branch_name.split("#")[-1] return None def get_issue_json(issue_id): issue = {"id": "no issue id", "title": "", "description": ""} if issue_id: issue = get_issue_info(issue_id) assert_exit(not issue, f"Failed to retrieve issue with ID: {issue_id}", exit_code=-1) issue = { "id": issue_id, "web_url": issue["web_url"], "title": issue["title"], "description": issue["description"], } return issue def check_git_installed(): try: subprocess.run( ["git", "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, ) return True except subprocess.CalledProcessError: print("Git is not installed on your system.", file=sys.stderr, flush=True) except FileNotFoundError: print("Git is not installed on your system.", file=sys.stderr, flush=True) except Exception: print("Git is not installed on your system.", file=sys.stderr, flush=True) return False def ask_for_push(): """ 询问用户是否要推送(push)更改到远程仓库 Returns: bool: 用户是否选择推送 """ print( "Step 3/3: Would you like to push your commit to the remote repository?", end="\n\n", flush=True, ) button = Button(["Yes, push now", "No, I'll push later"]) button.render() return button.clicked == 0 # 如果用户点击第一个按钮(Yes),则返回True def push_changes(): """ 推送更改到远程仓库 Returns: bool: 推送是否成功 """ try: current_branch = get_current_branch() if not current_branch: print( "Could not determine current branch. Push failed.", end="\n\n", file=sys.stderr, flush=True, ) return False print(f"Pushing changes to origin/{current_branch}...", end="\n\n", flush=True) result = subprocess_run( ["git", "push", "origin", current_branch], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) if result.returncode != 0: print(f"Push failed: {result.stderr}", end="\n\n", flush=True) return False print("Push completed successfully.", end="\n\n", flush=True) return True except subprocess.CalledProcessError as e: print(f"Push failed: {str(e)}", end="\n\n", file=sys.stderr, flush=True) return False except Exception as e: print(f"An unexpected error occurred: {str(e)}", end="\n\n", file=sys.stderr, flush=True) return False def main(): global language try: print("Let's follow the steps below.\n\n") # Ensure enough command line arguments are provided if len(sys.argv) < 2: print("Usage: python script.py ", file=sys.stderr, flush=True) sys.exit(-1) user_input = sys.argv[1] language = "english" if len(sys.argv) > 2: language = sys.argv[2] if not check_git_installed(): sys.exit(-1) print( "Step 1/3: Select the files you've changed that you wish to include in this commit, " "then click 'Submit'.", end="\n\n", flush=True, ) modified_files, staged_files = get_modified_files() if len(modified_files) == 0: print("No files to commit.", file=sys.stderr, flush=True) sys.exit(-1) selected_files = get_marked_files(modified_files, staged_files) if not selected_files: print("No files selected, commit aborted.") return rebuild_stage_list(selected_files) print( "Step 2/3: Review the commit message I've drafted for you. " "Edit it below if needed. Then click 'Commit' to proceed with " "the commit using this message.", end="\n\n", flush=True, ) diff = get_diff() branch_name = get_current_branch() issue_id = extract_issue_id(branch_name) issue = str(get_issue_json(issue_id)) if branch_name: user_input += "\ncurrent repo branch name is:" + branch_name commit_message = generate_commit_message_base_diff(user_input, diff, issue) # TODO # remove Closes #IssueNumber in commit message commit_message["content"] = ( commit_message["content"] .replace("Closes #IssueNumber", "") .replace("No specific issue to close", "") .replace("No specific issue mentioned.", "") ) commit_result = display_commit_message_and_commit(commit_message["content"]) if not commit_result: print("Commit aborted.", flush=True) else: # 添加推送步骤 if ask_for_push(): if not push_changes(): print("Push failed.", flush=True) sys.exit(-1) print("Commit completed.", flush=True) sys.exit(0) except Exception as err: print("Exception:", err, file=sys.stderr, flush=True) sys.exit(-1) if __name__ == "__main__": main()