refactor: Update rewrite.py with Chinese prompts and symbol analysis
- Translate prompts to Chinese for better localization - Add symbol analysis functionality to detect missing definitions - Implement context code gathering for better refactoring results - Add helper functions for project root detection and symbol definition lookup
This commit is contained in:
parent
f4a4dbc529
commit
021efac1d9
@ -2,9 +2,12 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from devchat.llm import chat
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
|
||||||
|
from devchat.llm import chat, chat_json
|
||||||
|
|
||||||
from lib.ide_service import IDEService
|
from lib.ide_service import IDEService
|
||||||
|
from lib.chatmark import Step, Form, TextEditor
|
||||||
|
|
||||||
|
|
||||||
def get_selected_code():
|
def get_selected_code():
|
||||||
@ -37,21 +40,33 @@ def get_selected_code():
|
|||||||
|
|
||||||
|
|
||||||
REWRITE_PROMPT = prompt = """
|
REWRITE_PROMPT = prompt = """
|
||||||
Your task is:
|
你是一个代码重构专家,你的任务是根据用户的需求重写代码。你需要根据用户的需求,重写代码,并保证代码的语法正确性和逻辑正确性。
|
||||||
|
你的重构目标:
|
||||||
{question}
|
{question}
|
||||||
Following the task requirements, modify only the selected portion of the code. \
|
|
||||||
Please ensure that the revised code segment maintains the same indentation as the \
|
待重构的代码:
|
||||||
selected code to seamlessly integrate with the existing code structure and maintain \
|
{selected_text}
|
||||||
correct syntax. Just refactor the selected code. Keep all other information as it is. \
|
|
||||||
Here is the relevant context \
|
项目中其他相关上下文代码:
|
||||||
information for your reference:
|
{context_code}
|
||||||
1. selected code info: {selected_text}
|
|
||||||
|
围绕重构目标对代码进行重构,不要进行重构目标之外的其他代码优化修改。
|
||||||
|
输出重构后的代码,代码需要用markdown代码块包裹,代码需要与原重构代码保持同样的缩进格式,并且代码块的语言类型需要与原重构代码保持一致,例如:
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
```
|
||||||
|
以及
|
||||||
|
```java
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
只输出最终重构后代码,不要输出其他任何内容。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@chat(prompt=REWRITE_PROMPT, stream_out=True)
|
@chat(prompt=REWRITE_PROMPT, stream_out=True)
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def ai_rewrite(question, selected_text):
|
def ai_rewrite(question, selected_text, context_code):
|
||||||
"""
|
"""
|
||||||
call ai to rewrite selected code
|
call ai to rewrite selected code
|
||||||
"""
|
"""
|
||||||
@ -113,13 +128,373 @@ def replace_selected(new_code):
|
|||||||
file.write(modified_text)
|
file.write(modified_text)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 定义用于分析代码中缺少定义的符号的提示
|
||||||
|
SYMBOL_ANALYSIS_PROMPT = """
|
||||||
|
角色:你是一个资深的程序工程师,擅长代码分析与代码重构。
|
||||||
|
|
||||||
|
重构目标:
|
||||||
|
{task}
|
||||||
|
|
||||||
|
当前用户选择代码片段:
|
||||||
|
{code}
|
||||||
|
|
||||||
|
你的任务:
|
||||||
|
针对当前已知代码,找出所有可能缺少定义的符号(变量、函数、类等)。这些符号在代码中被使用,但可能没有在当前上下文中定义。
|
||||||
|
|
||||||
|
输出要求,输出markdown代码块形式的JSON块,形式为:
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{{
|
||||||
|
"symbol": "<identifier>",
|
||||||
|
"line": "<identifier所在行完整代码,包含缩进空格>",
|
||||||
|
"value": "<影响重要度0-1之间>",
|
||||||
|
"reason": "<影响描述>"
|
||||||
|
}},
|
||||||
|
......
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
示例:
|
||||||
|
针对以下代码片段:
|
||||||
|
```python
|
||||||
|
def fun1():
|
||||||
|
fun2("hello")
|
||||||
|
fun3(
|
||||||
|
"hello"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
进行重构,希望fun1执行打印输出hello DevChat。
|
||||||
|
|
||||||
|
此时由于缺少fun2,fun3的函数定义,所以并不清楚fun1的完整执行逻辑。所以输出:
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{{
|
||||||
|
"symbol": "fun2",
|
||||||
|
"line": " fun2(\"hello\")",
|
||||||
|
"value": 0.8,
|
||||||
|
"reason": "不能确定fun2中具体代码逻辑,是否会打印输出参数'hello',如果是,那么重构只需要修改fun1中参数信息"
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"symbol": "fun3",
|
||||||
|
"line": " fun3(",
|
||||||
|
"value": 0.8,
|
||||||
|
"reason": "不能确定fun3中具体代码逻辑,是否会打印输出参数'hello',如果是,那么重构只需要修改fun1中参数信息"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
请确保返回的JSON格式正确,只包含实际缺少定义的符号。只输出最终结果JSON,不要输出其他解释描述。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 定义用于生成符号使用建议的提示
|
||||||
|
SYMBOL_USAGE_PROMPT = """
|
||||||
|
基于以下符号的定义,生成正确使用这些符号的建议。
|
||||||
|
|
||||||
|
符号定义:
|
||||||
|
{symbol_definitions}
|
||||||
|
|
||||||
|
原始代码片段:
|
||||||
|
```
|
||||||
|
{original_code}
|
||||||
|
```
|
||||||
|
|
||||||
|
请提供以下内容:
|
||||||
|
1. 对每个符号的正确使用方法的解释
|
||||||
|
2. 在原始代码中可能存在的符号使用错误
|
||||||
|
3. 修复建议
|
||||||
|
4. 示例代码,展示如何正确使用这些符号
|
||||||
|
|
||||||
|
请以JSON格式返回,包含以下字段:
|
||||||
|
1. "suggestions": 包含所有符号使用建议的数组,每个建议包含:
|
||||||
|
- "symbol": 符号名称
|
||||||
|
- "explanation": 符号的正确使用方法解释
|
||||||
|
- "errors": 在原始代码中可能存在的使用错误
|
||||||
|
- "fix": 修复建议
|
||||||
|
- "example": 示例代码
|
||||||
|
|
||||||
|
请确保返回的JSON格式正确。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@chat_json(prompt=SYMBOL_ANALYSIS_PROMPT)
|
||||||
|
def analyze_missing_symbols(code: str, task: str) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
使用大模型分析代码片段中缺少定义的符号
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: 需要分析的代码片段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含缺少定义的符号列表的字典
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@chat_json(prompt=SYMBOL_USAGE_PROMPT)
|
||||||
|
def generate_symbol_usage_suggestions(symbol_definitions: str, original_code: str) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
基于符号定义生成符号使用建议
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_definitions: 符号定义的文本表示
|
||||||
|
original_code: 原始代码片段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含符号使用建议的字典
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_symbol_definition(abspath: str, line: int, character: int, symbol_name: str, symbol_type: str, project_root_path: str) -> List[Tuple]:
|
||||||
|
"""
|
||||||
|
获取符号定义的代码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
abspath: 文件的绝对路径
|
||||||
|
line: 符号所在行
|
||||||
|
character: 符号所在列
|
||||||
|
symbol_type: 符号类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
符号定义的代码,如果找不到则返回None
|
||||||
|
"""
|
||||||
|
ide_service = IDEService()
|
||||||
|
locations = []
|
||||||
|
has_visited = set()
|
||||||
|
|
||||||
|
# 根据符号类型选择合适的查找方法
|
||||||
|
locations1 = ide_service.find_type_def_locations(abspath, line, character)
|
||||||
|
locations2 = ide_service.find_def_locations(abspath, line, character)
|
||||||
|
locations3 = ide_service.find_type_def_locations(abspath, line, character + len(symbol_name) - 1)
|
||||||
|
locations4 = ide_service.find_def_locations(abspath, line, character + len(symbol_name) - 1)
|
||||||
|
for location in locations1 + locations2 + locations3 + locations4:
|
||||||
|
if not location.abspath.startswith(project_root_path):
|
||||||
|
continue
|
||||||
|
key = (location.abspath, location.range.start.line, location.range.end.line)
|
||||||
|
key_str = f"{location.abspath}:{location.range.start.line}:{location.range.end.line}"
|
||||||
|
if key_str not in has_visited:
|
||||||
|
has_visited.add(key_str)
|
||||||
|
locations.append(key)
|
||||||
|
|
||||||
|
return locations
|
||||||
|
|
||||||
|
|
||||||
|
def format_symbol_results(symbols: List[Dict[str, Any]], definitions: Dict[str, str]) -> str:
|
||||||
|
"""
|
||||||
|
格式化符号分析结果为Markdown格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbols: 符号列表
|
||||||
|
definitions: 符号定义字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化后的Markdown文本
|
||||||
|
"""
|
||||||
|
result = "## 符号分析结果\n\n"
|
||||||
|
|
||||||
|
if not symbols:
|
||||||
|
return result + "没有找到缺少定义的符号。"
|
||||||
|
|
||||||
|
result += f"找到 {len(symbols)} 个可能缺少定义的符号:\n\n"
|
||||||
|
|
||||||
|
for i, symbol in enumerate(symbols):
|
||||||
|
result += f"### {i+1}. {symbol['name']} ({symbol['type']})\n\n"
|
||||||
|
result += f"- 位置: 第{symbol['line'] + 1}行,第{symbol['character'] + 1}列\n"
|
||||||
|
result += f"- 原因: {symbol['reason']}\n\n"
|
||||||
|
|
||||||
|
if symbol['name'] in definitions and definitions[symbol['name']]:
|
||||||
|
result += "#### 找到的定义:\n\n"
|
||||||
|
result += f"{definitions[symbol['name']]}\n\n"
|
||||||
|
else:
|
||||||
|
result += "#### 未找到定义\n\n"
|
||||||
|
result += "无法在当前项目中找到此符号的定义。可能是外部库、内置函数或拼写错误。\n\n"
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def format_usage_suggestions(suggestions: List[Dict[str, Any]]) -> str:
|
||||||
|
"""
|
||||||
|
格式化符号使用建议为Markdown格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
suggestions: 符号使用建议列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化后的Markdown文本
|
||||||
|
"""
|
||||||
|
if not suggestions:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
result = "## 符号使用建议\n\n"
|
||||||
|
|
||||||
|
for i, suggestion in enumerate(suggestions):
|
||||||
|
result += f"### {i+1}. {suggestion['symbol']}\n\n"
|
||||||
|
result += f"**正确使用方法**:\n{suggestion['explanation']}\n\n"
|
||||||
|
|
||||||
|
if suggestion.get('errors'):
|
||||||
|
result += f"**可能存在的错误**:\n{suggestion['errors']}\n\n"
|
||||||
|
|
||||||
|
if suggestion.get('fix'):
|
||||||
|
result += f"**修复建议**:\n{suggestion['fix']}\n\n"
|
||||||
|
|
||||||
|
if suggestion.get('example'):
|
||||||
|
result += "**示例代码**:\n```python\n" + suggestion['example'] + "\n```\n\n"
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def find_project_root(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
根据文件路径查找项目根目录
|
||||||
|
|
||||||
|
通过向上遍历目录,查找 .git 或 .svn 目录来确定项目根目录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文件的绝对路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
项目根目录的绝对路径,如果找不到则返回原始文件所在目录
|
||||||
|
"""
|
||||||
|
if not os.path.isabs(file_path):
|
||||||
|
file_path = os.path.abspath(file_path)
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(file_path)
|
||||||
|
|
||||||
|
# 向上遍历目录,直到找到包含 .git 或 .svn 的目录,或者到达根目录
|
||||||
|
while current_dir and current_dir != '/':
|
||||||
|
# 检查当前目录是否包含 .git 或 .svn
|
||||||
|
if os.path.exists(os.path.join(current_dir, '.git')) or \
|
||||||
|
os.path.exists(os.path.join(current_dir, '.svn')):
|
||||||
|
return current_dir
|
||||||
|
|
||||||
|
# 向上移动一级目录
|
||||||
|
parent_dir = os.path.dirname(current_dir)
|
||||||
|
if parent_dir == current_dir: # 防止在Windows根目录下无限循环
|
||||||
|
break
|
||||||
|
current_dir = parent_dir
|
||||||
|
|
||||||
|
# 如果没有找到版本控制目录,返回文件所在目录
|
||||||
|
return os.path.dirname(file_path)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
ide_service = IDEService()
|
||||||
question = sys.argv[1]
|
question = sys.argv[1]
|
||||||
|
rafact_task = sys.argv[1]
|
||||||
# prepare code
|
# prepare code
|
||||||
selected_text = get_selected_code()
|
|
||||||
|
|
||||||
|
# 步骤1: 获取用户选中的代码片段
|
||||||
|
with Step("获取选中的代码片段..."):
|
||||||
|
selected_code = ide_service.get_selected_range()
|
||||||
|
|
||||||
|
if not selected_code or not selected_code.text.strip():
|
||||||
|
print("请先选择一段代码片段再执行此命令。")
|
||||||
|
return
|
||||||
|
# print(selected_code)
|
||||||
|
selected_text = selected_code.text
|
||||||
|
project_root_path = find_project_root(selected_code.abspath)
|
||||||
|
print(f"项目根目录: {project_root_path}\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
# 步骤2: 分析代码片段中缺少定义的符号
|
||||||
|
with Step("分析代码中缺少定义的符号..."):
|
||||||
|
try:
|
||||||
|
analysis_result = analyze_missing_symbols(code=selected_code.text, task=rafact_task)
|
||||||
|
missing_symbols = analysis_result # 直接获取返回的列表
|
||||||
|
|
||||||
|
if not missing_symbols:
|
||||||
|
print("没有找到缺少定义的符号。")
|
||||||
|
return
|
||||||
|
|
||||||
|
ide_service.ide_logging("info", f"找到 {len(missing_symbols)} 个可能缺少定义的符号")
|
||||||
|
except Exception as e:
|
||||||
|
ide_service.ide_logging("error", f"分析符号时出错: {str(e)}")
|
||||||
|
print(f"分析代码时出错: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
current_filepath = selected_code.abspath
|
||||||
|
base_line = selected_code.range.start.line
|
||||||
|
# range = "abspath='/Users/boboyang/.chat/scripts/merico/symbol_resolver/command.py' range=line=248 character=0 - line=248 character=24 text=' print(selected_code)'"
|
||||||
|
|
||||||
|
# 步骤3: 将分析结果转换为可处理的结构
|
||||||
|
with Step("处理符号信息..."):
|
||||||
|
symbols = []
|
||||||
|
code_lines = selected_code.text.splitlines()
|
||||||
|
|
||||||
|
for symbol_info in missing_symbols:
|
||||||
|
symbol_name = symbol_info["symbol"]
|
||||||
|
symbol_line_text = symbol_info["line"]
|
||||||
|
|
||||||
|
# 在代码中查找匹配的行
|
||||||
|
line_index = -1
|
||||||
|
for i, line in enumerate(code_lines):
|
||||||
|
if line == symbol_line_text:
|
||||||
|
line_index = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if line_index == -1:
|
||||||
|
ide_service.ide_logging("warning", f"找不到符号 {symbol_name} 所在行")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 在行中查找符号位置
|
||||||
|
char_index = symbol_line_text.find(symbol_name)
|
||||||
|
if char_index == -1:
|
||||||
|
ide_service.ide_logging("warning", f"在行中找不到符号 {symbol_name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 构建符号信息
|
||||||
|
symbol = {
|
||||||
|
"name": symbol_name,
|
||||||
|
"line": base_line + line_index,
|
||||||
|
"character": char_index,
|
||||||
|
"type": "unknown", # 默认类型
|
||||||
|
"reason": symbol_info.get("reason", "未知原因")
|
||||||
|
}
|
||||||
|
|
||||||
|
symbols.append(symbol)
|
||||||
|
|
||||||
|
# 步骤3: 查找符号的实际定义
|
||||||
|
with Step("查找符号的实际定义..."):
|
||||||
|
symbol_definitions = {}
|
||||||
|
|
||||||
|
for symbol in symbols:
|
||||||
|
symbol_name = symbol["name"]
|
||||||
|
symbol_line = symbol.get("line", 0)
|
||||||
|
symbol_char = symbol.get("character", 0)
|
||||||
|
symbol_type = symbol.get("type", "unknown")
|
||||||
|
|
||||||
|
definitions = get_symbol_definition(
|
||||||
|
selected_code.abspath,
|
||||||
|
symbol_line,
|
||||||
|
symbol_char,
|
||||||
|
symbol_name,
|
||||||
|
symbol_type,
|
||||||
|
project_root_path
|
||||||
|
)
|
||||||
|
|
||||||
|
symbol_definitions[symbol_name] = definitions
|
||||||
|
|
||||||
|
# 计算每个文件被引用次数
|
||||||
|
files_ref_counts = {}
|
||||||
|
# 当前选中代码文件,默认计算100
|
||||||
|
files_ref_counts[selected_code.abspath] = 100
|
||||||
|
for symbol in symbol_definitions:
|
||||||
|
for definition in symbol_definitions[symbol]:
|
||||||
|
if definition[0] not in files_ref_counts:
|
||||||
|
files_ref_counts[definition[0]] = 0
|
||||||
|
files_ref_counts[definition[0]] += 1
|
||||||
|
|
||||||
|
context_code = ""
|
||||||
|
for filepath, ref_count in files_ref_counts.items():
|
||||||
|
if ref_count > 0:
|
||||||
|
with open(filepath, "r") as f:
|
||||||
|
context_code += f"文件名:{filepath}\n\n"
|
||||||
|
context_code += "文件内容:\n"
|
||||||
|
context_code += f.read()
|
||||||
|
context_code += "\n\n"
|
||||||
|
|
||||||
# rewrite
|
# rewrite
|
||||||
response = ai_rewrite(question=question, selected_text=selected_text)
|
response = ai_rewrite(question=question, selected_text=selected_text, context_code=context_code)
|
||||||
if not response:
|
if not response:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user