workflows/commit/commit.py
2023-12-27 15:42:51 +08:00

486 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# flake8: noqa: E402
import os
import re
import sys
import json
import subprocess
from typing import List
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "libs"))
sys.path.append(os.path.dirname(__file__))
from ui_utils import ui_checkbox_select, ui_text_edit, CheckboxOption # noqa: E402
from llm_api import chat_completion_no_stream, chat_completion_no_stream_return_json # noqa: E402
from ide_services.services import log_info
from prompts import (
PROMPT_SUMMARY_FOR_FILES,
PROMPT_GROUP_FILES,
PROMPT_COMMIT_MESSAGE_BY_DIFF_USER_INPUT,
PROMPT_COMMIT_MESSAGE_BY_SUMMARY_USER_INPUT,
PROMPT_SUMMARY_FOR_FILES_RETRY,
PROMPT_GROUP_FILES_RETRY,
prompt_summary_for_files_llm_config,
prompt_group_files_llm_config,
prompt_commit_message_by_diff_user_input_llm_config,
prompt_commit_message_by_summary_user_input_llm_config,
)
language = ""
def assert_value(value, message):
"""
判断给定的value是否为True如果是则输出指定的message并终止程序。
Args:
value: 用于判断的值。
message: 如果value为True时需要输出的信息。
Returns:
无返回值。
"""
if value:
print(message, file=sys.stderr, flush=True)
sys.exit(-1)
def decode_path(encoded_path):
octal_pattern = re.compile(r"\\[0-7]{3}")
if octal_pattern.search(encoded_path):
bytes_path = encoded_path.encode("utf-8").decode("unicode_escape").encode("latin1")
decoded_path = bytes_path.decode("utf-8")
return decoded_path
else:
return encoded_path
def get_modified_files():
"""
获取当前修改文件列表以及已经staged的文件列表
Args:
Returns:
tuple: 包含两个list的元组第一个list包含当前修改过的文件第二个list包含已经staged的文件
"""
""" 获取当前修改文件列表以及已经staged的文件列表"""
output = subprocess.check_output(["git", "status", "-s", "-u"], text=True, encoding="utf-8")
lines = output.split("\n")
modified_files = []
staged_files = []
def strip_file_name(file_name):
file = file_name.strip()
if file.startswith('"'):
file = file[1:-1]
return file
for line in lines:
if len(line) > 2:
status, filename = line[:2], decode_path(line[3:])
# check wether filename is a directory
if os.path.isdir(filename):
continue
modified_files.append(os.path.normpath(strip_file_name(filename)))
if status == "M " or status == "A ":
staged_files.append(os.path.normpath(strip_file_name(filename)))
return modified_files, staged_files
def gpt_file_summary(diff, diff_files, user_input):
"""
生成GPT对话获取文件差异内容的摘要。
Args:
diff (str): 文件差异内容。
diff_files (List[str]): 文件差异列表。
user_input (str): 用户输入。
Returns:
dict: 文件路径作为键,摘要内容作为值的字典。
"""
global language
prompt = PROMPT_SUMMARY_FOR_FILES.replace("{__DIFF__}", f"{diff}").replace(
"{__USER_INPUT__}", f"{user_input}"
)
messages = [
{
"role": "user",
"content": prompt
+ (" \nPlease response summaries in chinese" if language == "chinese" else ""),
}
]
normpath_summaries = {}
retry_times = 0
while retry_times < 3:
retry_times += 1
file_summaries = chat_completion_no_stream_return_json(
messages, prompt_summary_for_files_llm_config
)
if not file_summaries:
continue
for key, value in file_summaries.items():
normpath_summaries[os.path.normpath(key)] = value
missed_files = [file for file in diff_files if file not in normpath_summaries]
if len(missed_files) > 0:
prompt_retry = PROMPT_SUMMARY_FOR_FILES_RETRY.replace(
"{__MISSED_FILES__}", f"{missed_files}"
)
messages.append({"role": "assistant", "content": json.dumps(file_summaries)})
messages.append(
{
"role": "user",
"content": prompt_retry
+ (" \nPlease response summaries in chinese" if language == "chinese" else ""),
}
)
else:
break
return normpath_summaries
def gpt_file_group(diff, diff_files):
"""
根据diff和diff_files列表对文件进行分组返回分组结果。
Args:
diff (str): 差异信息。
diff_files (List[str]): 文件列表。
Returns:
List[Dict]: 文件分组结果,每个分组是一个字典,包含"name""files"两个键值对,
分别表示分组名称和该分组下的文件列表。
"""
prompt = PROMPT_GROUP_FILES.replace("{__DIFF__}", f"{diff}")
messages = [{"role": "user", "content": prompt}]
file_groups = []
retry_times = 0
while retry_times < 3:
retry_times += 1
file_groups = chat_completion_no_stream_return_json(messages, prompt_group_files_llm_config)
if not file_groups:
continue
if "groups" in file_groups:
file_groups = file_groups["groups"]
grouped_files = []
for group in file_groups:
grouped_files.extend(group["files"])
missed_files = [file for file in diff_files if file not in grouped_files]
if len(missed_files) > 0:
prompt_retry = PROMPT_GROUP_FILES_RETRY.replace("{__MISSED_FILES__}", f"{missed_files}")
messages.append({"role": "assistant", "content": json.dumps(file_groups)})
messages.append({"role": "user", "content": prompt_retry})
else:
break
return file_groups
def get_file_summaries(modified_files, staged_files, user_input):
"""
计算git暂存区文件的差异并生成对应的AI模型输入格式。
Args:
modified_files (list): 当前工作区的修改文件列表
staged_files (list): 已暂存文件列表
user_input (str): 用户输入信息
Returns:
dict: AI模型输出格式包含normpath_summaries和modified_files两个key-value对。
"""
diffs = []
for file in modified_files:
if file not in staged_files:
subprocess.check_output(["git", "add", file])
diff = subprocess.check_output(["git", "diff", "--cached", file])
if file not in staged_files:
subprocess.check_output(["git", "reset", file])
diffs.append(diff.decode("utf-8"))
# total_diff = subprocess.check_output(["git", "diff", "HEAD"])
total_diff_decoded = "\n".join(diffs) # total_diff.decode('utf-8')
if len(total_diff_decoded) > 15000:
print("Current diff length:", len(total_diff_decoded), flush=True)
return {}
# 在prompt中明确处置AI模型的输出格式需求
normpath_summaries = gpt_file_summary(total_diff_decoded, modified_files, user_input)
return normpath_summaries
def get_file_summaries_and_groups(modified_files, staged_files, user_input):
"""
获取已修改文件的摘要和分组。
Args:
modified_files (List[str]): 已修改的文件列表。
staged_files (List[str]): 已暂存的文件列表。
user_input (str): 用户输入。
Returns:
Tuple[Dict[str, Any], List[Dict[str, Any]]]: 包含以下两个元素的元组:
- 文件摘要信息,字典类型,键为文件路径,值为该文件对应的摘要信息;
- 文件分组信息,列表类型,每个元素为包含以下三个键值对的字典:
* group_id组ID。
* files属于该分组的文件列表。
* summary该分组的摘要信息。
"""
diffs = []
for file in modified_files:
if file not in staged_files:
subprocess.check_output(["git", "add", file])
diff = subprocess.check_output(["git", "diff", "--cached", file])
if file not in staged_files:
subprocess.check_output(["git", "reset", file])
diffs.append(diff.decode("utf-8"))
# total_diff = subprocess.check_output(["git", "diff", "HEAD"])
total_diff_decoded = "\n".join(diffs) # total_diff.decode('utf-8')
if len(total_diff_decoded) > 15000:
print("Current diff length:", len(total_diff_decoded), flush=True)
return {}, []
# 在prompt中明确处置AI模型的输出格式需求
normpath_summaries = gpt_file_summary(total_diff_decoded, modified_files, user_input)
print(
f"""
``` file summary
{json.dumps(normpath_summaries, indent=4)}
```
"""
)
# 通过AI模型对提交文件进行分组分组的依据是按修改内容的关联性。
file_groups = gpt_file_group(total_diff_decoded, modified_files)
print(
f"""
``` group
{json.dumps(file_groups, indent=4)}
```
"""
)
return normpath_summaries, file_groups
def get_marked_files(modified_files, staged_files, file_summaries):
"""
根据给定的参数获取用户选中以供提交的文件
Args:
modified_files (List[str]): 用户已修改文件列表
staged_files (List[str]): 用户已staged文件列表
file_summaries (Dict[str, str]): 文件摘要信息key为文件名value为摘要信息
file_groups (List[Dict[str, Any]]): 文件分组信息,每个元素是一个字典,
包含两个key值分别为 "importance_level""files"
分别表示文件的重要程度和该重要程度下的文件列表
Returns:
List[str]: 用户选中的文件列表
"""
options: List[CheckboxOption] = []
options += [
CheckboxOption(file, file + " - " + file_summaries.get(file, ""), "Staged", True)
for file in staged_files
]
options += [
CheckboxOption(file, file + " - " + file_summaries.get(file, ""), "Unstaged", False)
for file in modified_files
if file not in staged_files
]
selected_files = ui_checkbox_select("Select files to commit", options)
return selected_files
def rebuild_stage_list(user_files):
"""
根据用户选中文件重新构建stage列表
Args:
user_files: 用户选中的文件列表
Returns:
None
"""
# Unstage all files
subprocess.check_output(["git", "reset"])
# Stage all user_files
for file in user_files:
os.system(f'git add "{file}"')
def get_diff():
"""
获取暂存区文件的Diff信息
Args:
Returns:
bytes: 返回bytes类型是git diff --cached命令的输出结果
"""
return subprocess.check_output(["git", "diff", "--cached"])
def generate_commit_message_base_diff(user_input, diff):
"""
根据diff信息通过AI生成一个commit消息
Args:
user_input (str): 用户输入的commit信息
diff (str): 提交的diff信息
Returns:
str: 生成的commit消息
"""
global language
language_prompt = (
"You must response commit message in chinese。\n" if language == "chinese" else ""
)
prompt = PROMPT_COMMIT_MESSAGE_BY_DIFF_USER_INPUT.replace("{__DIFF__}", f"{diff}").replace(
"{__USER_INPUT__}", f"{user_input + language_prompt}"
)
if len(str(prompt)) > 10000:
print(
"Due to the large size of the diff data, "
"generating a commit message through AI would be very costly, therefore, "
"it is not recommended to use AI for generating the description. "
"Please manually edit the commit message before submitting."
)
return {"content": ""}
messages = [{"role": "user", "content": prompt}]
response = chat_completion_no_stream(
messages, prompt_commit_message_by_diff_user_input_llm_config
)
assert_value(not response, "")
return response
def generate_commit_message_base_file_summaries(user_input, file_summaries):
"""
根据文件摘要生成通过AI生成的提交消息
Args:
user_input (str): 用户输入
file_summaries (list[dict]): 文件摘要列表
Returns:
str: 提交消息
"""
global language
language_prompt = (
"Please response commit message in chinese.\n" if language == "chinese" else ""
)
prompt = PROMPT_COMMIT_MESSAGE_BY_SUMMARY_USER_INPUT.replace(
"{__USER_INPUT__}", f"{user_input}"
).replace("{__FILE_SUMMARY__}", f"{json.dumps(file_summaries, indent=4)}")
# Call AI model to generate commit message
messages = [{"role": "user", "content": language_prompt + prompt}]
response = chat_completion_no_stream(
messages, prompt_commit_message_by_summary_user_input_llm_config
)
assert_value(not response, "")
return response
def display_commit_message_and_commit(commit_message):
"""
展示提交信息并提交。
Args:
commit_message: 提交信息。
Returns:
None。
"""
new_commit_message = ui_text_edit("Edit commit meesage", commit_message)
if not new_commit_message:
return
subprocess.check_output(["git", "commit", "-m", new_commit_message])
def check_git_installed():
try:
subprocess.run(
["git", "--version"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
return True
except subprocess.CalledProcessError:
print("Git is not installed on your system.", file=sys.stderr, flush=True)
except FileNotFoundError:
print("Git is not installed on your system.", file=sys.stderr, flush=True)
except Exception:
print("Git is not installed on your system.", file=sys.stderr, flush=True)
return False
def main():
global language
try:
log_info("Start commit workflow ...")
# Ensure enough command line arguments are provided
if len(sys.argv) < 3:
print("Usage: python script.py <user_input> <language>", file=sys.stderr, flush=True)
sys.exit(-1)
user_input = sys.argv[1]
language = sys.argv[2]
if not check_git_installed():
sys.exit(-1)
modified_files, staged_files = get_modified_files()
if len(modified_files) == 0:
print("No files to commit.", file=sys.stderr, flush=True)
sys.exit(-1)
file_summaries = get_file_summaries(modified_files, staged_files, user_input)
selected_files = get_marked_files(modified_files, staged_files, file_summaries)
if not selected_files:
print("No files selected, commit aborted.")
return
rebuild_stage_list(selected_files)
summaries_for_select_files = {
file: file_summaries[file] for file in selected_files if file in file_summaries
}
if len(summaries_for_select_files.keys()) < len(selected_files):
diff = get_diff()
commit_message = generate_commit_message_base_diff(user_input, diff)
else:
commit_message = generate_commit_message_base_file_summaries(
user_input, summaries_for_select_files
)
display_commit_message_and_commit(commit_message["content"])
sys.exit(0)
except Exception as err:
print("Exception:", err, file=sys.stderr, flush=True)
sys.exit(-1)
if __name__ == "__main__":
main()