workflows/merico/pr/providers/devchat_provider.py
2024-05-19 15:34:26 +08:00

262 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import sys
from typing import Optional, Tuple
import pr_agent.git_providers as git_providers
from pr_agent.git_providers.git_provider import GitProvider, IncrementalPR
from pr_agent.git_providers.github_provider import GithubProvider
from lib.chatmark import Button, Form, TextEditor
class DevChatProvider(GitProvider):
def __init__(self, pr_url: Optional[str] = None, incremental=IncrementalPR(False)):
# 根据某个状态创建正确的GitProvider
provider_type = os.environ.get("CONFIG.GIT_PROVIDER_TYPE")
self.provider: GitProvider = git_providers._GIT_PROVIDERS[provider_type](
pr_url, incremental
)
@property
def pr(self):
return self.provider.pr
@property
def diff_files(self):
return self.provider.diff_files
@property
def github_client(self):
return self.provider.github_client
def is_supported(self, capability: str) -> bool:
return self.provider.is_supported(capability)
def get_diff_files(self):
return self.provider.get_diff_files()
def need_edit(self):
button = Button(
["Commit", "Edit"],
)
button.render()
return 1 == button.clicked
def publish_description(self, pr_title: str, pr_body: str):
# Preview pr title and body
print(f"\n\nPR Title: {pr_title}", end="\n\n", flush=True)
print("PR Body:", end="\n\n", flush=True)
print(pr_body, end="\n\n", flush=True)
# Need Edit?
if self.need_edit():
# Edit pr title and body
title_editor = TextEditor(pr_title)
body_editor = TextEditor(pr_body)
form = Form(["Edit pr title:", title_editor, "Edit pr body:", body_editor])
form.render()
pr_title = title_editor.new_text
pr_body = body_editor.new_text
if not pr_title or not pr_body:
print("Title or body is empty, please fill in the title and body.")
sys.exit(0)
return self.provider.publish_description(pr_title, pr_body)
def publish_code_suggestions(self, code_suggestions: list) -> bool:
code_suggestions_json_str = json.dumps(code_suggestions, indent=4)
code_suggestions_editor = TextEditor(
code_suggestions_json_str, "Edit code suggestions in JSON format:"
)
code_suggestions_editor.render()
code_suggestions_json_new = code_suggestions_editor.new_text
if not code_suggestions_json_new:
print("Code suggestions are empty, please fill in the code suggestions.")
sys.exit(0)
code_suggestions = json.loads(code_suggestions_json_new)
return self.provider.publish_code_suggestions(code_suggestions)
def get_languages(self):
return self.provider.get_languages()
def get_pr_branch(self):
return self.provider.get_pr_branch()
def get_files(self):
return self.provider.get_files()
def get_user_id(self):
return self.provider.get_user_id()
def get_pr_description_full(self) -> str:
return self.provider.get_pr_description_full()
def edit_comment(self, comment, body: str):
if body.find("## PR Code Suggestions") == -1:
return self.provider.edit_comment(comment, body)
print(f"\n\n{body}", end="\n\n", flush=True)
if self.need_edit():
comment_editor = TextEditor(body, "Edit Comment:")
comment_editor.render()
body = comment_editor.new_text
if not body:
print("Comment is empty, please fill in the comment.")
sys.exit(0)
return self.provider.edit_comment(comment, body)
def reply_to_comment_from_comment_id(self, comment_id: int, body: str):
return self.provider.reply_to_comment_from_comment_id(comment_id, body)
def get_pr_description(self, *, full: bool = True) -> str:
return self.provider.get_pr_description(full=full)
def get_user_description(self) -> str:
return self.provider.get_user_description()
def _possible_headers(self):
return self.provider._possible_headers()
def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool:
return self.provider._is_generated_by_pr_agent(description_lowercase)
def get_repo_settings(self):
return self.provider.get_repo_settings()
def get_pr_id(self):
return self.provider.get_pr_id()
def get_line_link(
self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None
) -> str:
return self.provider.get_line_link(relevant_file, relevant_line_start, relevant_line_end)
#### comments operations ####
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary:
return None
if pr_comment.find("## Generating PR code suggestions") != -1:
return None
if (
not is_temporary
and pr_comment.find("## Generating PR code suggestions") == -1
and pr_comment.find("**[PR Description]") == -1
):
print(f"\n\n{pr_comment}", end="\n\n", flush=True)
if self.need_edit():
pr_comment_editor = TextEditor(pr_comment)
form = Form(["Edit pr comment:", pr_comment_editor])
form.render()
pr_comment = pr_comment_editor.new_text
if not pr_comment:
print("Comment is empty, please fill in the comment.")
sys.exit(0)
return self.provider.publish_comment(pr_comment, is_temporary=is_temporary)
def publish_persistent_comment(
self,
pr_comment: str,
initial_header: str,
update_header: bool = True,
name="review",
final_update_message=True,
):
print(f"\n\n{initial_header}", end="\n\n", flush=True)
print(pr_comment, end="\n\n", flush=True)
if self.need_edit():
pr_comment_editor = TextEditor(pr_comment)
form = Form(["Edit pr comment:", pr_comment_editor])
form.render()
pr_comment = pr_comment_editor.new_text
if not pr_comment:
print("Comment is empty, please fill in the comment.")
sys.exit(0)
return self.provider.publish_persistent_comment(
pr_comment, initial_header, update_header, name, final_update_message
)
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
return self.provider.publish_inline_comment(body, relevant_file, relevant_line_in_file)
def create_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
return self.provider.create_inline_comment(
body, relevant_file, relevant_line_in_file, absolute_position
)
def publish_inline_comments(self, comments: list[dict]):
return self.provider.publish_inline_comments(comments)
def remove_initial_comment(self):
return self.provider.remove_initial_comment()
def remove_comment(self, comment):
return self.provider.remove_comment(comment)
def get_issue_comments(self):
return self.provider.get_issue_comments()
def get_comment_url(self, comment) -> str:
return self.provider.get_comment_url(comment)
#### labels operations ####
def publish_labels(self, labels):
if not os.environ.get("ENABLE_PUBLISH_LABELS", None):
return None
return self.provider.publish_labels(labels)
def get_pr_labels(self, update=False):
return self.provider.get_pr_labels(update=update)
def get_repo_labels(self):
return self.provider.get_repo_labels()
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
return self.provider.add_eyes_reaction(issue_comment_id, disable_eyes=disable_eyes)
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
return self.provider.remove_reaction(issue_comment_id, reaction_id)
#### commits operations ####
def get_commit_messages(self):
return self.provider.get_commit_messages()
def get_pr_url(self) -> str:
return self.provider.get_pr_url()
def get_latest_commit_url(self) -> str:
return self.provider.get_latest_commit_url()
def auto_approve(self) -> bool:
return self.provider.auto_approve()
def calc_pr_statistics(self, pull_request_data: dict):
return self.provider.calc_pr_statistics(pull_request_data)
def get_num_of_files(self):
return self.provider.get_num_of_files()
@staticmethod
def _parse_issue_url(issue_url: str) -> Tuple[str, int]:
return GithubProvider._parse_issue_url(issue_url)