diff --git a/merico/refactor/rewrite.py b/merico/refactor/rewrite.py index d83f7ea..6849ef6 100644 --- a/merico/refactor/rewrite.py +++ b/merico/refactor/rewrite.py @@ -2,9 +2,12 @@ import os import re import sys -from devchat.llm import chat +from typing import List, Dict, Any, Optional, Tuple + +from devchat.llm import chat, chat_json from lib.ide_service import IDEService +from lib.chatmark import Step, Form, TextEditor def get_selected_code(): @@ -37,21 +40,33 @@ def get_selected_code(): REWRITE_PROMPT = prompt = """ -Your task is: +你是一个代码重构专家,你的任务是根据用户的需求重写代码。你需要根据用户的需求,重写代码,并保证代码的语法正确性和逻辑正确性。 +你的重构目标: {question} -Following the task requirements, modify only the selected portion of the code. \ -Please ensure that the revised code segment maintains the same indentation as the \ -selected code to seamlessly integrate with the existing code structure and maintain \ -correct syntax. Just refactor the selected code. Keep all other information as it is. \ -Here is the relevant context \ -information for your reference: -1. selected code info: {selected_text} + +待重构的代码: +{selected_text} + +项目中其他相关上下文代码: +{context_code} + +围绕重构目标对代码进行重构,不要进行重构目标之外的其他代码优化修改。 +输出重构后的代码,代码需要用markdown代码块包裹,代码需要与原重构代码保持同样的缩进格式,并且代码块的语言类型需要与原重构代码保持一致,例如: +```python +... +``` +以及 +```java +... +``` + +只输出最终重构后代码,不要输出其他任何内容。 """ @chat(prompt=REWRITE_PROMPT, stream_out=True) # pylint: disable=unused-argument -def ai_rewrite(question, selected_text): +def ai_rewrite(question, selected_text, context_code): """ call ai to rewrite selected code """ @@ -113,13 +128,373 @@ def replace_selected(new_code): file.write(modified_text) + + + +# 定义用于分析代码中缺少定义的符号的提示 +SYMBOL_ANALYSIS_PROMPT = """ +角色:你是一个资深的程序工程师,擅长代码分析与代码重构。 + +重构目标: +{task} + +当前用户选择代码片段: +{code} + +你的任务: +针对当前已知代码,找出所有可能缺少定义的符号(变量、函数、类等)。这些符号在代码中被使用,但可能没有在当前上下文中定义。 + +输出要求,输出markdown代码块形式的JSON块,形式为: +```json +[ + {{ + "symbol": "", + "line": "", + "value": "<影响重要度0-1之间>", + "reason": "<影响描述>" + }}, + ...... +] +``` + +示例: +针对以下代码片段: +```python +def fun1(): + fun2("hello") + fun3( + "hello" + ) +``` +进行重构,希望fun1执行打印输出hello DevChat。 + +此时由于缺少fun2,fun3的函数定义,所以并不清楚fun1的完整执行逻辑。所以输出: +```json +[ + {{ + "symbol": "fun2", + "line": " fun2(\"hello\")", + "value": 0.8, + "reason": "不能确定fun2中具体代码逻辑,是否会打印输出参数'hello',如果是,那么重构只需要修改fun1中参数信息" + }}, + {{ + "symbol": "fun3", + "line": " fun3(", + "value": 0.8, + "reason": "不能确定fun3中具体代码逻辑,是否会打印输出参数'hello',如果是,那么重构只需要修改fun1中参数信息" + }} +] +``` + +请确保返回的JSON格式正确,只包含实际缺少定义的符号。只输出最终结果JSON,不要输出其他解释描述。 +""" + +# 定义用于生成符号使用建议的提示 +SYMBOL_USAGE_PROMPT = """ +基于以下符号的定义,生成正确使用这些符号的建议。 + +符号定义: +{symbol_definitions} + +原始代码片段: +``` +{original_code} +``` + +请提供以下内容: +1. 对每个符号的正确使用方法的解释 +2. 在原始代码中可能存在的符号使用错误 +3. 修复建议 +4. 示例代码,展示如何正确使用这些符号 + +请以JSON格式返回,包含以下字段: +1. "suggestions": 包含所有符号使用建议的数组,每个建议包含: + - "symbol": 符号名称 + - "explanation": 符号的正确使用方法解释 + - "errors": 在原始代码中可能存在的使用错误 + - "fix": 修复建议 + - "example": 示例代码 + +请确保返回的JSON格式正确。 +""" + +@chat_json(prompt=SYMBOL_ANALYSIS_PROMPT) +def analyze_missing_symbols(code: str, task: str) -> Dict[str, List[Dict[str, Any]]]: + """ + 使用大模型分析代码片段中缺少定义的符号 + + Args: + code: 需要分析的代码片段 + + Returns: + 包含缺少定义的符号列表的字典 + """ + pass + +@chat_json(prompt=SYMBOL_USAGE_PROMPT) +def generate_symbol_usage_suggestions(symbol_definitions: str, original_code: str) -> Dict[str, List[Dict[str, Any]]]: + """ + 基于符号定义生成符号使用建议 + + Args: + symbol_definitions: 符号定义的文本表示 + original_code: 原始代码片段 + + Returns: + 包含符号使用建议的字典 + """ + pass + +def get_symbol_definition(abspath: str, line: int, character: int, symbol_name: str, symbol_type: str, project_root_path: str) -> List[Tuple]: + """ + 获取符号定义的代码 + + Args: + abspath: 文件的绝对路径 + line: 符号所在行 + character: 符号所在列 + symbol_type: 符号类型 + + Returns: + 符号定义的代码,如果找不到则返回None + """ + ide_service = IDEService() + locations = [] + has_visited = set() + + # 根据符号类型选择合适的查找方法 + locations1 = ide_service.find_type_def_locations(abspath, line, character) + locations2 = ide_service.find_def_locations(abspath, line, character) + locations3 = ide_service.find_type_def_locations(abspath, line, character + len(symbol_name) - 1) + locations4 = ide_service.find_def_locations(abspath, line, character + len(symbol_name) - 1) + for location in locations1 + locations2 + locations3 + locations4: + if not location.abspath.startswith(project_root_path): + continue + key = (location.abspath, location.range.start.line, location.range.end.line) + key_str = f"{location.abspath}:{location.range.start.line}:{location.range.end.line}" + if key_str not in has_visited: + has_visited.add(key_str) + locations.append(key) + + return locations + + +def format_symbol_results(symbols: List[Dict[str, Any]], definitions: Dict[str, str]) -> str: + """ + 格式化符号分析结果为Markdown格式 + + Args: + symbols: 符号列表 + definitions: 符号定义字典 + + Returns: + 格式化后的Markdown文本 + """ + result = "## 符号分析结果\n\n" + + if not symbols: + return result + "没有找到缺少定义的符号。" + + result += f"找到 {len(symbols)} 个可能缺少定义的符号:\n\n" + + for i, symbol in enumerate(symbols): + result += f"### {i+1}. {symbol['name']} ({symbol['type']})\n\n" + result += f"- 位置: 第{symbol['line'] + 1}行,第{symbol['character'] + 1}列\n" + result += f"- 原因: {symbol['reason']}\n\n" + + if symbol['name'] in definitions and definitions[symbol['name']]: + result += "#### 找到的定义:\n\n" + result += f"{definitions[symbol['name']]}\n\n" + else: + result += "#### 未找到定义\n\n" + result += "无法在当前项目中找到此符号的定义。可能是外部库、内置函数或拼写错误。\n\n" + + return result + +def format_usage_suggestions(suggestions: List[Dict[str, Any]]) -> str: + """ + 格式化符号使用建议为Markdown格式 + + Args: + suggestions: 符号使用建议列表 + + Returns: + 格式化后的Markdown文本 + """ + if not suggestions: + return "" + + result = "## 符号使用建议\n\n" + + for i, suggestion in enumerate(suggestions): + result += f"### {i+1}. {suggestion['symbol']}\n\n" + result += f"**正确使用方法**:\n{suggestion['explanation']}\n\n" + + if suggestion.get('errors'): + result += f"**可能存在的错误**:\n{suggestion['errors']}\n\n" + + if suggestion.get('fix'): + result += f"**修复建议**:\n{suggestion['fix']}\n\n" + + if suggestion.get('example'): + result += "**示例代码**:\n```python\n" + suggestion['example'] + "\n```\n\n" + + return result + + +def find_project_root(file_path: str) -> str: + """ + 根据文件路径查找项目根目录 + + 通过向上遍历目录,查找 .git 或 .svn 目录来确定项目根目录 + + Args: + file_path: 文件的绝对路径 + + Returns: + 项目根目录的绝对路径,如果找不到则返回原始文件所在目录 + """ + if not os.path.isabs(file_path): + file_path = os.path.abspath(file_path) + + current_dir = os.path.dirname(file_path) + + # 向上遍历目录,直到找到包含 .git 或 .svn 的目录,或者到达根目录 + while current_dir and current_dir != '/': + # 检查当前目录是否包含 .git 或 .svn + if os.path.exists(os.path.join(current_dir, '.git')) or \ + os.path.exists(os.path.join(current_dir, '.svn')): + return current_dir + + # 向上移动一级目录 + parent_dir = os.path.dirname(current_dir) + if parent_dir == current_dir: # 防止在Windows根目录下无限循环 + break + current_dir = parent_dir + + # 如果没有找到版本控制目录,返回文件所在目录 + return os.path.dirname(file_path) + + def main(): + ide_service = IDEService() question = sys.argv[1] + rafact_task = sys.argv[1] # prepare code - selected_text = get_selected_code() + + + # 步骤1: 获取用户选中的代码片段 + with Step("获取选中的代码片段..."): + selected_code = ide_service.get_selected_range() + + if not selected_code or not selected_code.text.strip(): + print("请先选择一段代码片段再执行此命令。") + return + # print(selected_code) + selected_text = selected_code.text + project_root_path = find_project_root(selected_code.abspath) + print(f"项目根目录: {project_root_path}\n\n") + + + # 步骤2: 分析代码片段中缺少定义的符号 + with Step("分析代码中缺少定义的符号..."): + try: + analysis_result = analyze_missing_symbols(code=selected_code.text, task=rafact_task) + missing_symbols = analysis_result # 直接获取返回的列表 + + if not missing_symbols: + print("没有找到缺少定义的符号。") + return + + ide_service.ide_logging("info", f"找到 {len(missing_symbols)} 个可能缺少定义的符号") + except Exception as e: + ide_service.ide_logging("error", f"分析符号时出错: {str(e)}") + print(f"分析代码时出错: {str(e)}") + return + + current_filepath = selected_code.abspath + base_line = selected_code.range.start.line + # range = "abspath='/Users/boboyang/.chat/scripts/merico/symbol_resolver/command.py' range=line=248 character=0 - line=248 character=24 text=' print(selected_code)'" + + # 步骤3: 将分析结果转换为可处理的结构 + with Step("处理符号信息..."): + symbols = [] + code_lines = selected_code.text.splitlines() + + for symbol_info in missing_symbols: + symbol_name = symbol_info["symbol"] + symbol_line_text = symbol_info["line"] + + # 在代码中查找匹配的行 + line_index = -1 + for i, line in enumerate(code_lines): + if line == symbol_line_text: + line_index = i + break + + if line_index == -1: + ide_service.ide_logging("warning", f"找不到符号 {symbol_name} 所在行") + continue + + # 在行中查找符号位置 + char_index = symbol_line_text.find(symbol_name) + if char_index == -1: + ide_service.ide_logging("warning", f"在行中找不到符号 {symbol_name}") + continue + + # 构建符号信息 + symbol = { + "name": symbol_name, + "line": base_line + line_index, + "character": char_index, + "type": "unknown", # 默认类型 + "reason": symbol_info.get("reason", "未知原因") + } + + symbols.append(symbol) + + # 步骤3: 查找符号的实际定义 + with Step("查找符号的实际定义..."): + symbol_definitions = {} + + for symbol in symbols: + symbol_name = symbol["name"] + symbol_line = symbol.get("line", 0) + symbol_char = symbol.get("character", 0) + symbol_type = symbol.get("type", "unknown") + + definitions = get_symbol_definition( + selected_code.abspath, + symbol_line, + symbol_char, + symbol_name, + symbol_type, + project_root_path + ) + + symbol_definitions[symbol_name] = definitions + + # 计算每个文件被引用次数 + files_ref_counts = {} + # 当前选中代码文件,默认计算100 + files_ref_counts[selected_code.abspath] = 100 + for symbol in symbol_definitions: + for definition in symbol_definitions[symbol]: + if definition[0] not in files_ref_counts: + files_ref_counts[definition[0]] = 0 + files_ref_counts[definition[0]] += 1 + + context_code = "" + for filepath, ref_count in files_ref_counts.items(): + if ref_count > 0: + with open(filepath, "r") as f: + context_code += f"文件名:{filepath}\n\n" + context_code += "文件内容:\n" + context_code += f.read() + context_code += "\n\n" # rewrite - response = ai_rewrite(question=question, selected_text=selected_text) + response = ai_rewrite(question=question, selected_text=selected_text, context_code=context_code) if not response: sys.exit(1)