refactor: Update rewrite.py with Chinese prompts and symbol analysis

- Translate prompts to Chinese for better localization
- Add symbol analysis functionality to detect missing definitions
- Implement context code gathering for better refactoring results
- Add helper functions for project root detection and symbol definition lookup
This commit is contained in:
bobo.yang 2025-03-07 12:00:49 +08:00
parent f4a4dbc529
commit 021efac1d9

View File

@ -2,9 +2,12 @@ import os
import re
import sys
from devchat.llm import chat
from typing import List, Dict, Any, Optional, Tuple
from devchat.llm import chat, chat_json
from lib.ide_service import IDEService
from lib.chatmark import Step, Form, TextEditor
def get_selected_code():
@ -37,21 +40,33 @@ def get_selected_code():
REWRITE_PROMPT = prompt = """
Your task is:
你是一个代码重构专家你的任务是根据用户的需求重写代码你需要根据用户的需求重写代码并保证代码的语法正确性和逻辑正确性
你的重构目标
{question}
Following the task requirements, modify only the selected portion of the code. \
Please ensure that the revised code segment maintains the same indentation as the \
selected code to seamlessly integrate with the existing code structure and maintain \
correct syntax. Just refactor the selected code. Keep all other information as it is. \
Here is the relevant context \
information for your reference:
1. selected code info: {selected_text}
待重构的代码
{selected_text}
项目中其他相关上下文代码
{context_code}
围绕重构目标对代码进行重构不要进行重构目标之外的其他代码优化修改
输出重构后的代码代码需要用markdown代码块包裹,代码需要与原重构代码保持同样的缩进格式并且代码块的语言类型需要与原重构代码保持一致例如
```python
...
```
以及
```java
...
```
只输出最终重构后代码不要输出其他任何内容
"""
@chat(prompt=REWRITE_PROMPT, stream_out=True)
# pylint: disable=unused-argument
def ai_rewrite(question, selected_text):
def ai_rewrite(question, selected_text, context_code):
"""
call ai to rewrite selected code
"""
@ -113,13 +128,373 @@ def replace_selected(new_code):
file.write(modified_text)
# 定义用于分析代码中缺少定义的符号的提示
SYMBOL_ANALYSIS_PROMPT = """
角色你是一个资深的程序工程师擅长代码分析与代码重构
重构目标
{task}
当前用户选择代码片段
{code}
你的任务
针对当前已知代码找出所有可能缺少定义的符号变量函数类等这些符号在代码中被使用但可能没有在当前上下文中定义
输出要求输出markdown代码块形式的JSON块形式为
```json
[
{{
"symbol": "<identifier>",
"line": "<identifier所在行完整代码包含缩进空格>",
"value": "<影响重要度0-1之间>",
"reason": "<影响描述>"
}},
......
]
```
示例
针对以下代码片段
```python
def fun1():
fun2("hello")
fun3(
"hello"
)
```
进行重构希望fun1执行打印输出hello DevChat
此时由于缺少fun2,fun3的函数定义所以并不清楚fun1的完整执行逻辑所以输出
```json
[
{{
"symbol": "fun2",
"line": " fun2(\"hello\")",
"value": 0.8,
"reason": "不能确定fun2中具体代码逻辑是否会打印输出参数'hello'如果是那么重构只需要修改fun1中参数信息"
}},
{{
"symbol": "fun3",
"line": " fun3(",
"value": 0.8,
"reason": "不能确定fun3中具体代码逻辑是否会打印输出参数'hello'如果是那么重构只需要修改fun1中参数信息"
}}
]
```
请确保返回的JSON格式正确只包含实际缺少定义的符号只输出最终结果JSON不要输出其他解释描述
"""
# 定义用于生成符号使用建议的提示
SYMBOL_USAGE_PROMPT = """
基于以下符号的定义生成正确使用这些符号的建议
符号定义:
{symbol_definitions}
原始代码片段:
```
{original_code}
```
请提供以下内容
1. 对每个符号的正确使用方法的解释
2. 在原始代码中可能存在的符号使用错误
3. 修复建议
4. 示例代码展示如何正确使用这些符号
请以JSON格式返回包含以下字段
1. "suggestions": 包含所有符号使用建议的数组每个建议包含
- "symbol": 符号名称
- "explanation": 符号的正确使用方法解释
- "errors": 在原始代码中可能存在的使用错误
- "fix": 修复建议
- "example": 示例代码
请确保返回的JSON格式正确
"""
@chat_json(prompt=SYMBOL_ANALYSIS_PROMPT)
def analyze_missing_symbols(code: str, task: str) -> Dict[str, List[Dict[str, Any]]]:
"""
使用大模型分析代码片段中缺少定义的符号
Args:
code: 需要分析的代码片段
Returns:
包含缺少定义的符号列表的字典
"""
pass
@chat_json(prompt=SYMBOL_USAGE_PROMPT)
def generate_symbol_usage_suggestions(symbol_definitions: str, original_code: str) -> Dict[str, List[Dict[str, Any]]]:
"""
基于符号定义生成符号使用建议
Args:
symbol_definitions: 符号定义的文本表示
original_code: 原始代码片段
Returns:
包含符号使用建议的字典
"""
pass
def get_symbol_definition(abspath: str, line: int, character: int, symbol_name: str, symbol_type: str, project_root_path: str) -> List[Tuple]:
"""
获取符号定义的代码
Args:
abspath: 文件的绝对路径
line: 符号所在行
character: 符号所在列
symbol_type: 符号类型
Returns:
符号定义的代码如果找不到则返回None
"""
ide_service = IDEService()
locations = []
has_visited = set()
# 根据符号类型选择合适的查找方法
locations1 = ide_service.find_type_def_locations(abspath, line, character)
locations2 = ide_service.find_def_locations(abspath, line, character)
locations3 = ide_service.find_type_def_locations(abspath, line, character + len(symbol_name) - 1)
locations4 = ide_service.find_def_locations(abspath, line, character + len(symbol_name) - 1)
for location in locations1 + locations2 + locations3 + locations4:
if not location.abspath.startswith(project_root_path):
continue
key = (location.abspath, location.range.start.line, location.range.end.line)
key_str = f"{location.abspath}:{location.range.start.line}:{location.range.end.line}"
if key_str not in has_visited:
has_visited.add(key_str)
locations.append(key)
return locations
def format_symbol_results(symbols: List[Dict[str, Any]], definitions: Dict[str, str]) -> str:
"""
格式化符号分析结果为Markdown格式
Args:
symbols: 符号列表
definitions: 符号定义字典
Returns:
格式化后的Markdown文本
"""
result = "## 符号分析结果\n\n"
if not symbols:
return result + "没有找到缺少定义的符号。"
result += f"找到 {len(symbols)} 个可能缺少定义的符号:\n\n"
for i, symbol in enumerate(symbols):
result += f"### {i+1}. {symbol['name']} ({symbol['type']})\n\n"
result += f"- 位置: 第{symbol['line'] + 1}行,第{symbol['character'] + 1}\n"
result += f"- 原因: {symbol['reason']}\n\n"
if symbol['name'] in definitions and definitions[symbol['name']]:
result += "#### 找到的定义:\n\n"
result += f"{definitions[symbol['name']]}\n\n"
else:
result += "#### 未找到定义\n\n"
result += "无法在当前项目中找到此符号的定义。可能是外部库、内置函数或拼写错误。\n\n"
return result
def format_usage_suggestions(suggestions: List[Dict[str, Any]]) -> str:
"""
格式化符号使用建议为Markdown格式
Args:
suggestions: 符号使用建议列表
Returns:
格式化后的Markdown文本
"""
if not suggestions:
return ""
result = "## 符号使用建议\n\n"
for i, suggestion in enumerate(suggestions):
result += f"### {i+1}. {suggestion['symbol']}\n\n"
result += f"**正确使用方法**:\n{suggestion['explanation']}\n\n"
if suggestion.get('errors'):
result += f"**可能存在的错误**:\n{suggestion['errors']}\n\n"
if suggestion.get('fix'):
result += f"**修复建议**:\n{suggestion['fix']}\n\n"
if suggestion.get('example'):
result += "**示例代码**:\n```python\n" + suggestion['example'] + "\n```\n\n"
return result
def find_project_root(file_path: str) -> str:
"""
根据文件路径查找项目根目录
通过向上遍历目录查找 .git .svn 目录来确定项目根目录
Args:
file_path: 文件的绝对路径
Returns:
项目根目录的绝对路径如果找不到则返回原始文件所在目录
"""
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
current_dir = os.path.dirname(file_path)
# 向上遍历目录,直到找到包含 .git 或 .svn 的目录,或者到达根目录
while current_dir and current_dir != '/':
# 检查当前目录是否包含 .git 或 .svn
if os.path.exists(os.path.join(current_dir, '.git')) or \
os.path.exists(os.path.join(current_dir, '.svn')):
return current_dir
# 向上移动一级目录
parent_dir = os.path.dirname(current_dir)
if parent_dir == current_dir: # 防止在Windows根目录下无限循环
break
current_dir = parent_dir
# 如果没有找到版本控制目录,返回文件所在目录
return os.path.dirname(file_path)
def main():
ide_service = IDEService()
question = sys.argv[1]
rafact_task = sys.argv[1]
# prepare code
selected_text = get_selected_code()
# 步骤1: 获取用户选中的代码片段
with Step("获取选中的代码片段..."):
selected_code = ide_service.get_selected_range()
if not selected_code or not selected_code.text.strip():
print("请先选择一段代码片段再执行此命令。")
return
# print(selected_code)
selected_text = selected_code.text
project_root_path = find_project_root(selected_code.abspath)
print(f"项目根目录: {project_root_path}\n\n")
# 步骤2: 分析代码片段中缺少定义的符号
with Step("分析代码中缺少定义的符号..."):
try:
analysis_result = analyze_missing_symbols(code=selected_code.text, task=rafact_task)
missing_symbols = analysis_result # 直接获取返回的列表
if not missing_symbols:
print("没有找到缺少定义的符号。")
return
ide_service.ide_logging("info", f"找到 {len(missing_symbols)} 个可能缺少定义的符号")
except Exception as e:
ide_service.ide_logging("error", f"分析符号时出错: {str(e)}")
print(f"分析代码时出错: {str(e)}")
return
current_filepath = selected_code.abspath
base_line = selected_code.range.start.line
# range = "abspath='/Users/boboyang/.chat/scripts/merico/symbol_resolver/command.py' range=line=248 character=0 - line=248 character=24 text=' print(selected_code)'"
# 步骤3: 将分析结果转换为可处理的结构
with Step("处理符号信息..."):
symbols = []
code_lines = selected_code.text.splitlines()
for symbol_info in missing_symbols:
symbol_name = symbol_info["symbol"]
symbol_line_text = symbol_info["line"]
# 在代码中查找匹配的行
line_index = -1
for i, line in enumerate(code_lines):
if line == symbol_line_text:
line_index = i
break
if line_index == -1:
ide_service.ide_logging("warning", f"找不到符号 {symbol_name} 所在行")
continue
# 在行中查找符号位置
char_index = symbol_line_text.find(symbol_name)
if char_index == -1:
ide_service.ide_logging("warning", f"在行中找不到符号 {symbol_name}")
continue
# 构建符号信息
symbol = {
"name": symbol_name,
"line": base_line + line_index,
"character": char_index,
"type": "unknown", # 默认类型
"reason": symbol_info.get("reason", "未知原因")
}
symbols.append(symbol)
# 步骤3: 查找符号的实际定义
with Step("查找符号的实际定义..."):
symbol_definitions = {}
for symbol in symbols:
symbol_name = symbol["name"]
symbol_line = symbol.get("line", 0)
symbol_char = symbol.get("character", 0)
symbol_type = symbol.get("type", "unknown")
definitions = get_symbol_definition(
selected_code.abspath,
symbol_line,
symbol_char,
symbol_name,
symbol_type,
project_root_path
)
symbol_definitions[symbol_name] = definitions
# 计算每个文件被引用次数
files_ref_counts = {}
# 当前选中代码文件默认计算100
files_ref_counts[selected_code.abspath] = 100
for symbol in symbol_definitions:
for definition in symbol_definitions[symbol]:
if definition[0] not in files_ref_counts:
files_ref_counts[definition[0]] = 0
files_ref_counts[definition[0]] += 1
context_code = ""
for filepath, ref_count in files_ref_counts.items():
if ref_count > 0:
with open(filepath, "r") as f:
context_code += f"文件名:{filepath}\n\n"
context_code += "文件内容:\n"
context_code += f.read()
context_code += "\n\n"
# rewrite
response = ai_rewrite(question=question, selected_text=selected_text)
response = ai_rewrite(question=question, selected_text=selected_text, context_code=context_code)
if not response:
sys.exit(1)