2024-04-01 16:41:25 +08:00
|
|
|
|
import os
|
2024-04-01 16:55:10 +08:00
|
|
|
|
import re
|
2024-02-09 11:09:14 +08:00
|
|
|
|
import sys
|
2025-03-07 13:00:43 +08:00
|
|
|
|
from typing import Any, Dict, List, Tuple
|
2025-03-07 12:00:49 +08:00
|
|
|
|
|
|
|
|
|
from devchat.llm import chat, chat_json
|
2024-02-09 11:04:29 +08:00
|
|
|
|
|
2025-03-07 13:00:43 +08:00
|
|
|
|
from lib.chatmark import Step
|
2024-05-09 21:55:23 +08:00
|
|
|
|
from lib.ide_service import IDEService
|
|
|
|
|
|
2024-02-09 11:04:29 +08:00
|
|
|
|
|
|
|
|
|
def get_selected_code():
|
|
|
|
|
"""
|
|
|
|
|
Retrieves the selected lines of code from the user's selection.
|
|
|
|
|
|
|
|
|
|
This function extracts the text selected by the user in their IDE or text editor.
|
|
|
|
|
If no text has been selected, it prints an error message to stderr and exits the
|
|
|
|
|
program with a non-zero status indicating failure.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
dict: A dictionary containing the key 'selectedText' with the selected text
|
|
|
|
|
as its value. If no text is selected, the program exits.
|
|
|
|
|
"""
|
|
|
|
|
selected_data = IDEService().get_selected_range().dict()
|
|
|
|
|
|
|
|
|
|
miss_selected_error = "Please select some text."
|
2024-04-01 16:46:47 +08:00
|
|
|
|
if selected_data["range"]["start"] == selected_data["range"]["end"]:
|
2024-04-01 16:41:25 +08:00
|
|
|
|
readme_path = os.path.join(os.path.dirname(__file__), "README.md")
|
|
|
|
|
if os.path.exists(readme_path):
|
|
|
|
|
with open(readme_path, "r", encoding="utf-8") as f:
|
|
|
|
|
readme_text = f.read()
|
|
|
|
|
print(readme_text)
|
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
2024-02-09 11:04:29 +08:00
|
|
|
|
print(miss_selected_error, file=sys.stderr, flush=True)
|
|
|
|
|
sys.exit(-1)
|
2024-02-09 11:09:14 +08:00
|
|
|
|
|
2024-02-09 11:04:29 +08:00
|
|
|
|
return selected_data
|
|
|
|
|
|
2024-02-09 11:09:14 +08:00
|
|
|
|
|
2024-02-09 11:04:29 +08:00
|
|
|
|
REWRITE_PROMPT = prompt = """
|
2025-03-07 12:00:49 +08:00
|
|
|
|
你是一个代码重构专家,你的任务是根据用户的需求重写代码。你需要根据用户的需求,重写代码,并保证代码的语法正确性和逻辑正确性。
|
|
|
|
|
你的重构目标:
|
2024-02-09 11:04:29 +08:00
|
|
|
|
{question}
|
2025-03-07 12:00:49 +08:00
|
|
|
|
|
|
|
|
|
待重构的代码:
|
|
|
|
|
{selected_text}
|
|
|
|
|
|
|
|
|
|
项目中其他相关上下文代码:
|
|
|
|
|
{context_code}
|
|
|
|
|
|
|
|
|
|
围绕重构目标对代码进行重构,不要进行重构目标之外的其他代码优化修改。
|
|
|
|
|
输出重构后的代码,代码需要用markdown代码块包裹,代码需要与原重构代码保持同样的缩进格式,并且代码块的语言类型需要与原重构代码保持一致,例如:
|
|
|
|
|
```python
|
|
|
|
|
...
|
|
|
|
|
```
|
|
|
|
|
以及
|
|
|
|
|
```java
|
|
|
|
|
...
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
只输出最终重构后代码,不要输出其他任何内容。
|
2024-02-09 11:04:29 +08:00
|
|
|
|
"""
|
2024-02-09 11:09:14 +08:00
|
|
|
|
|
|
|
|
|
|
2025-03-11 13:43:17 +08:00
|
|
|
|
@chat(prompt=REWRITE_PROMPT, stream_out=False)
|
2024-02-09 11:04:29 +08:00
|
|
|
|
# pylint: disable=unused-argument
|
2025-03-07 12:00:49 +08:00
|
|
|
|
def ai_rewrite(question, selected_text, context_code):
|
2024-02-09 11:04:29 +08:00
|
|
|
|
"""
|
|
|
|
|
call ai to rewrite selected code
|
|
|
|
|
"""
|
2024-02-09 11:09:14 +08:00
|
|
|
|
pass # pylint: disable=unnecessary-pass
|
|
|
|
|
|
2024-02-09 11:04:29 +08:00
|
|
|
|
|
|
|
|
|
def extract_markdown_block(text):
|
|
|
|
|
"""
|
|
|
|
|
Extracts the first Markdown code block from the given text without the language specifier.
|
|
|
|
|
|
|
|
|
|
:param text: A string containing Markdown text
|
|
|
|
|
:return: The content of the first Markdown code block, or None if not found
|
|
|
|
|
"""
|
|
|
|
|
pattern = r"```(?:\w+)?\s*\n(.*?)\n```"
|
|
|
|
|
match = re.search(pattern, text, re.DOTALL)
|
|
|
|
|
|
|
|
|
|
if match:
|
|
|
|
|
block_content = match.group(1)
|
|
|
|
|
return block_content
|
|
|
|
|
else:
|
2024-05-11 16:45:49 +08:00
|
|
|
|
# whether exist ```language?
|
|
|
|
|
if text.find("```"):
|
|
|
|
|
return None
|
2024-02-09 11:04:29 +08:00
|
|
|
|
return text
|
|
|
|
|
|
2024-02-09 11:09:14 +08:00
|
|
|
|
|
2024-02-09 11:04:29 +08:00
|
|
|
|
def replace_selected(new_code):
|
|
|
|
|
selected_data = IDEService().get_selected_code().dict()
|
|
|
|
|
select_file = selected_data["abspath"]
|
|
|
|
|
select_range = selected_data["range"] # [start_line, start_col, end_line, end_col]
|
|
|
|
|
|
|
|
|
|
# Read the file
|
|
|
|
|
with open(select_file, "r", encoding="utf-8") as file:
|
|
|
|
|
lines = file.readlines()
|
|
|
|
|
lines.append("\n")
|
|
|
|
|
|
|
|
|
|
# Modify the selected lines
|
|
|
|
|
start_line = select_range["start"]["line"]
|
|
|
|
|
start_col = select_range["start"]["character"]
|
|
|
|
|
end_line = select_range["end"]["line"]
|
|
|
|
|
end_col = select_range["end"]["character"]
|
|
|
|
|
|
|
|
|
|
# If the selection spans multiple lines, handle the last line and delete the lines in between
|
|
|
|
|
if start_line != end_line:
|
|
|
|
|
lines[start_line] = lines[start_line][:start_col] + new_code
|
|
|
|
|
# Append the text after the selection on the last line
|
|
|
|
|
lines[start_line] += lines[end_line][end_col:]
|
|
|
|
|
# Delete the lines between start_line and end_line
|
|
|
|
|
del lines[start_line + 1 : end_line + 1]
|
|
|
|
|
else:
|
|
|
|
|
# If the selection is within a single line, remove the selected text
|
|
|
|
|
lines[start_line] = lines[start_line][:start_col] + new_code + lines[end_line][end_col:]
|
|
|
|
|
|
|
|
|
|
# Combine everything back together
|
|
|
|
|
modified_text = "".join(lines)
|
|
|
|
|
|
|
|
|
|
# Write the changes back to the file
|
|
|
|
|
with open(select_file, "w", encoding="utf-8") as file:
|
|
|
|
|
file.write(modified_text)
|
|
|
|
|
|
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
# 定义用于分析代码中缺少定义的符号的提示
|
|
|
|
|
SYMBOL_ANALYSIS_PROMPT = """
|
|
|
|
|
角色:你是一个资深的程序工程师,擅长代码分析与代码重构。
|
|
|
|
|
|
|
|
|
|
重构目标:
|
|
|
|
|
{task}
|
|
|
|
|
|
|
|
|
|
当前用户选择代码片段:
|
|
|
|
|
{code}
|
|
|
|
|
|
|
|
|
|
你的任务:
|
|
|
|
|
针对当前已知代码,找出所有可能缺少定义的符号(变量、函数、类等)。这些符号在代码中被使用,但可能没有在当前上下文中定义。
|
|
|
|
|
|
|
|
|
|
输出要求,输出markdown代码块形式的JSON块,形式为:
|
|
|
|
|
```json
|
|
|
|
|
[
|
|
|
|
|
{{
|
|
|
|
|
"symbol": "<identifier>",
|
|
|
|
|
"line": "<identifier所在行完整代码,包含缩进空格>",
|
|
|
|
|
"value": "<影响重要度0-1之间>",
|
|
|
|
|
"reason": "<影响描述>"
|
|
|
|
|
}},
|
|
|
|
|
......
|
|
|
|
|
]
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
示例:
|
|
|
|
|
针对以下代码片段:
|
|
|
|
|
```python
|
|
|
|
|
def fun1():
|
|
|
|
|
fun2("hello")
|
|
|
|
|
fun3(
|
|
|
|
|
"hello"
|
|
|
|
|
)
|
|
|
|
|
```
|
|
|
|
|
进行重构,希望fun1执行打印输出hello DevChat。
|
|
|
|
|
|
|
|
|
|
此时由于缺少fun2,fun3的函数定义,所以并不清楚fun1的完整执行逻辑。所以输出:
|
|
|
|
|
```json
|
|
|
|
|
[
|
|
|
|
|
{{
|
|
|
|
|
"symbol": "fun2",
|
|
|
|
|
"line": " fun2(\"hello\")",
|
|
|
|
|
"value": 0.8,
|
2025-03-07 13:00:43 +08:00
|
|
|
|
"reason": "不能确定fun2中具体代码逻辑,是否会打印输出参数'hello',\
|
|
|
|
|
如果是,那么重构只需要修改fun1中参数信息"
|
2025-03-07 12:00:49 +08:00
|
|
|
|
}},
|
|
|
|
|
{{
|
|
|
|
|
"symbol": "fun3",
|
|
|
|
|
"line": " fun3(",
|
|
|
|
|
"value": 0.8,
|
2025-03-07 13:00:43 +08:00
|
|
|
|
"reason": "不能确定fun3中具体代码逻辑,是否会打印输出参数'hello',\
|
|
|
|
|
如果是,那么重构只需要修改fun1中参数信息"
|
2025-03-07 12:00:49 +08:00
|
|
|
|
}}
|
|
|
|
|
]
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
请确保返回的JSON格式正确,只包含实际缺少定义的符号。只输出最终结果JSON,不要输出其他解释描述。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 定义用于生成符号使用建议的提示
|
|
|
|
|
SYMBOL_USAGE_PROMPT = """
|
|
|
|
|
基于以下符号的定义,生成正确使用这些符号的建议。
|
|
|
|
|
|
|
|
|
|
符号定义:
|
|
|
|
|
{symbol_definitions}
|
|
|
|
|
|
|
|
|
|
原始代码片段:
|
|
|
|
|
```
|
|
|
|
|
{original_code}
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
请提供以下内容:
|
|
|
|
|
1. 对每个符号的正确使用方法的解释
|
|
|
|
|
2. 在原始代码中可能存在的符号使用错误
|
|
|
|
|
3. 修复建议
|
|
|
|
|
4. 示例代码,展示如何正确使用这些符号
|
|
|
|
|
|
|
|
|
|
请以JSON格式返回,包含以下字段:
|
|
|
|
|
1. "suggestions": 包含所有符号使用建议的数组,每个建议包含:
|
|
|
|
|
- "symbol": 符号名称
|
|
|
|
|
- "explanation": 符号的正确使用方法解释
|
|
|
|
|
- "errors": 在原始代码中可能存在的使用错误
|
|
|
|
|
- "fix": 修复建议
|
|
|
|
|
- "example": 示例代码
|
|
|
|
|
|
|
|
|
|
请确保返回的JSON格式正确。
|
|
|
|
|
"""
|
|
|
|
|
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
@chat_json(prompt=SYMBOL_ANALYSIS_PROMPT)
|
|
|
|
|
def analyze_missing_symbols(code: str, task: str) -> Dict[str, List[Dict[str, Any]]]:
|
|
|
|
|
"""
|
|
|
|
|
使用大模型分析代码片段中缺少定义的符号
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
Args:
|
|
|
|
|
code: 需要分析的代码片段
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
Returns:
|
|
|
|
|
包含缺少定义的符号列表的字典
|
|
|
|
|
"""
|
|
|
|
|
pass
|
|
|
|
|
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
@chat_json(prompt=SYMBOL_USAGE_PROMPT)
|
2025-03-07 13:00:43 +08:00
|
|
|
|
def generate_symbol_usage_suggestions(
|
|
|
|
|
symbol_definitions: str, original_code: str
|
|
|
|
|
) -> Dict[str, List[Dict[str, Any]]]:
|
2025-03-07 12:00:49 +08:00
|
|
|
|
"""
|
|
|
|
|
基于符号定义生成符号使用建议
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
Args:
|
|
|
|
|
symbol_definitions: 符号定义的文本表示
|
|
|
|
|
original_code: 原始代码片段
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
Returns:
|
|
|
|
|
包含符号使用建议的字典
|
|
|
|
|
"""
|
|
|
|
|
pass
|
|
|
|
|
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
|
|
|
|
def get_symbol_definition(
|
|
|
|
|
abspath: str,
|
|
|
|
|
line: int,
|
|
|
|
|
character: int,
|
|
|
|
|
symbol_name: str,
|
|
|
|
|
symbol_type: str,
|
|
|
|
|
project_root_path: str,
|
|
|
|
|
) -> List[Tuple]:
|
2025-03-07 12:00:49 +08:00
|
|
|
|
"""
|
|
|
|
|
获取符号定义的代码
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
Args:
|
|
|
|
|
abspath: 文件的绝对路径
|
|
|
|
|
line: 符号所在行
|
|
|
|
|
character: 符号所在列
|
|
|
|
|
symbol_type: 符号类型
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
Returns:
|
|
|
|
|
符号定义的代码,如果找不到则返回None
|
|
|
|
|
"""
|
|
|
|
|
ide_service = IDEService()
|
|
|
|
|
locations = []
|
|
|
|
|
has_visited = set()
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
# 根据符号类型选择合适的查找方法
|
|
|
|
|
locations1 = ide_service.find_type_def_locations(abspath, line, character)
|
|
|
|
|
locations2 = ide_service.find_def_locations(abspath, line, character)
|
2025-03-07 13:00:43 +08:00
|
|
|
|
locations3 = ide_service.find_type_def_locations(
|
|
|
|
|
abspath, line, character + len(symbol_name) - 1
|
|
|
|
|
)
|
2025-03-07 12:00:49 +08:00
|
|
|
|
locations4 = ide_service.find_def_locations(abspath, line, character + len(symbol_name) - 1)
|
|
|
|
|
for location in locations1 + locations2 + locations3 + locations4:
|
|
|
|
|
if not location.abspath.startswith(project_root_path):
|
|
|
|
|
continue
|
|
|
|
|
key = (location.abspath, location.range.start.line, location.range.end.line)
|
|
|
|
|
key_str = f"{location.abspath}:{location.range.start.line}:{location.range.end.line}"
|
|
|
|
|
if key_str not in has_visited:
|
|
|
|
|
has_visited.add(key_str)
|
|
|
|
|
locations.append(key)
|
|
|
|
|
|
|
|
|
|
return locations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_symbol_results(symbols: List[Dict[str, Any]], definitions: Dict[str, str]) -> str:
|
|
|
|
|
"""
|
|
|
|
|
格式化符号分析结果为Markdown格式
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
Args:
|
|
|
|
|
symbols: 符号列表
|
|
|
|
|
definitions: 符号定义字典
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
Returns:
|
|
|
|
|
格式化后的Markdown文本
|
|
|
|
|
"""
|
|
|
|
|
result = "## 符号分析结果\n\n"
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
if not symbols:
|
|
|
|
|
return result + "没有找到缺少定义的符号。"
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
result += f"找到 {len(symbols)} 个可能缺少定义的符号:\n\n"
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
for i, symbol in enumerate(symbols):
|
2025-03-11 14:24:28 +08:00
|
|
|
|
result += f"### {i + 1}. {symbol['name']} ({symbol['type']})\n\n"
|
2025-03-07 12:00:49 +08:00
|
|
|
|
result += f"- 位置: 第{symbol['line'] + 1}行,第{symbol['character'] + 1}列\n"
|
|
|
|
|
result += f"- 原因: {symbol['reason']}\n\n"
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
|
|
|
|
if symbol["name"] in definitions and definitions[symbol["name"]]:
|
2025-03-07 12:00:49 +08:00
|
|
|
|
result += "#### 找到的定义:\n\n"
|
|
|
|
|
result += f"{definitions[symbol['name']]}\n\n"
|
|
|
|
|
else:
|
|
|
|
|
result += "#### 未找到定义\n\n"
|
|
|
|
|
result += "无法在当前项目中找到此符号的定义。可能是外部库、内置函数或拼写错误。\n\n"
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
return result
|
|
|
|
|
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
def format_usage_suggestions(suggestions: List[Dict[str, Any]]) -> str:
|
|
|
|
|
"""
|
|
|
|
|
格式化符号使用建议为Markdown格式
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
Args:
|
|
|
|
|
suggestions: 符号使用建议列表
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
Returns:
|
|
|
|
|
格式化后的Markdown文本
|
|
|
|
|
"""
|
|
|
|
|
if not suggestions:
|
|
|
|
|
return ""
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
result = "## 符号使用建议\n\n"
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
for i, suggestion in enumerate(suggestions):
|
2025-03-11 14:24:28 +08:00
|
|
|
|
result += f"### {i + 1}. {suggestion['symbol']}\n\n"
|
2025-03-07 12:00:49 +08:00
|
|
|
|
result += f"**正确使用方法**:\n{suggestion['explanation']}\n\n"
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
|
|
|
|
if suggestion.get("errors"):
|
2025-03-07 12:00:49 +08:00
|
|
|
|
result += f"**可能存在的错误**:\n{suggestion['errors']}\n\n"
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
|
|
|
|
if suggestion.get("fix"):
|
2025-03-07 12:00:49 +08:00
|
|
|
|
result += f"**修复建议**:\n{suggestion['fix']}\n\n"
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
|
|
|
|
if suggestion.get("example"):
|
|
|
|
|
result += "**示例代码**:\n```python\n" + suggestion["example"] + "\n```\n\n"
|
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_project_root(file_path: str) -> str:
|
|
|
|
|
"""
|
|
|
|
|
根据文件路径查找项目根目录
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
通过向上遍历目录,查找 .git 或 .svn 目录来确定项目根目录
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
Args:
|
|
|
|
|
file_path: 文件的绝对路径
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
Returns:
|
|
|
|
|
项目根目录的绝对路径,如果找不到则返回原始文件所在目录
|
|
|
|
|
"""
|
|
|
|
|
if not os.path.isabs(file_path):
|
|
|
|
|
file_path = os.path.abspath(file_path)
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
current_dir = os.path.dirname(file_path)
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
# 向上遍历目录,直到找到包含 .git 或 .svn 的目录,或者到达根目录
|
2025-03-07 13:00:43 +08:00
|
|
|
|
while current_dir and current_dir != "/":
|
2025-03-07 12:00:49 +08:00
|
|
|
|
# 检查当前目录是否包含 .git 或 .svn
|
2025-03-07 13:00:43 +08:00
|
|
|
|
if os.path.exists(os.path.join(current_dir, ".git")) or os.path.exists(
|
|
|
|
|
os.path.join(current_dir, ".svn")
|
|
|
|
|
):
|
2025-03-07 12:00:49 +08:00
|
|
|
|
return current_dir
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
# 向上移动一级目录
|
|
|
|
|
parent_dir = os.path.dirname(current_dir)
|
|
|
|
|
if parent_dir == current_dir: # 防止在Windows根目录下无限循环
|
|
|
|
|
break
|
|
|
|
|
current_dir = parent_dir
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
# 如果没有找到版本控制目录,返回文件所在目录
|
|
|
|
|
return os.path.dirname(file_path)
|
|
|
|
|
|
|
|
|
|
|
2024-02-09 11:04:29 +08:00
|
|
|
|
def main():
|
2025-03-07 12:00:49 +08:00
|
|
|
|
ide_service = IDEService()
|
2024-02-09 11:04:29 +08:00
|
|
|
|
question = sys.argv[1]
|
2025-03-07 12:00:49 +08:00
|
|
|
|
rafact_task = sys.argv[1]
|
2024-02-09 11:04:29 +08:00
|
|
|
|
# prepare code
|
2025-03-07 12:00:49 +08:00
|
|
|
|
|
|
|
|
|
# 步骤1: 获取用户选中的代码片段
|
|
|
|
|
with Step("获取选中的代码片段..."):
|
|
|
|
|
selected_code = ide_service.get_selected_range()
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
if not selected_code or not selected_code.text.strip():
|
|
|
|
|
print("请先选择一段代码片段再执行此命令。")
|
|
|
|
|
return
|
|
|
|
|
# print(selected_code)
|
|
|
|
|
selected_text = selected_code.text
|
|
|
|
|
project_root_path = find_project_root(selected_code.abspath)
|
|
|
|
|
print(f"项目根目录: {project_root_path}\n\n")
|
|
|
|
|
|
|
|
|
|
# 步骤2: 分析代码片段中缺少定义的符号
|
|
|
|
|
with Step("分析代码中缺少定义的符号..."):
|
|
|
|
|
try:
|
|
|
|
|
analysis_result = analyze_missing_symbols(code=selected_code.text, task=rafact_task)
|
|
|
|
|
missing_symbols = analysis_result # 直接获取返回的列表
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
if not missing_symbols:
|
|
|
|
|
print("没有找到缺少定义的符号。")
|
|
|
|
|
return
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
ide_service.ide_logging("info", f"找到 {len(missing_symbols)} 个可能缺少定义的符号")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
ide_service.ide_logging("error", f"分析符号时出错: {str(e)}")
|
|
|
|
|
print(f"分析代码时出错: {str(e)}")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
base_line = selected_code.range.start.line
|
|
|
|
|
|
|
|
|
|
# 步骤3: 将分析结果转换为可处理的结构
|
|
|
|
|
with Step("处理符号信息..."):
|
|
|
|
|
symbols = []
|
|
|
|
|
code_lines = selected_code.text.splitlines()
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
for symbol_info in missing_symbols:
|
|
|
|
|
symbol_name = symbol_info["symbol"]
|
|
|
|
|
symbol_line_text = symbol_info["line"]
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
# 在代码中查找匹配的行
|
|
|
|
|
line_index = -1
|
|
|
|
|
for i, line in enumerate(code_lines):
|
|
|
|
|
if line == symbol_line_text:
|
|
|
|
|
line_index = i
|
|
|
|
|
break
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
if line_index == -1:
|
|
|
|
|
ide_service.ide_logging("warning", f"找不到符号 {symbol_name} 所在行")
|
|
|
|
|
continue
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
# 在行中查找符号位置
|
|
|
|
|
char_index = symbol_line_text.find(symbol_name)
|
|
|
|
|
if char_index == -1:
|
|
|
|
|
ide_service.ide_logging("warning", f"在行中找不到符号 {symbol_name}")
|
|
|
|
|
continue
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
# 构建符号信息
|
|
|
|
|
symbol = {
|
|
|
|
|
"name": symbol_name,
|
|
|
|
|
"line": base_line + line_index,
|
|
|
|
|
"character": char_index,
|
|
|
|
|
"type": "unknown", # 默认类型
|
2025-03-07 13:00:43 +08:00
|
|
|
|
"reason": symbol_info.get("reason", "未知原因"),
|
2025-03-07 12:00:49 +08:00
|
|
|
|
}
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
symbols.append(symbol)
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
# 步骤3: 查找符号的实际定义
|
|
|
|
|
with Step("查找符号的实际定义..."):
|
|
|
|
|
symbol_definitions = {}
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
for symbol in symbols:
|
|
|
|
|
symbol_name = symbol["name"]
|
|
|
|
|
symbol_line = symbol.get("line", 0)
|
|
|
|
|
symbol_char = symbol.get("character", 0)
|
|
|
|
|
symbol_type = symbol.get("type", "unknown")
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
definitions = get_symbol_definition(
|
2025-03-07 13:00:43 +08:00
|
|
|
|
selected_code.abspath,
|
|
|
|
|
symbol_line,
|
|
|
|
|
symbol_char,
|
2025-03-07 12:00:49 +08:00
|
|
|
|
symbol_name,
|
|
|
|
|
symbol_type,
|
2025-03-07 13:00:43 +08:00
|
|
|
|
project_root_path,
|
2025-03-07 12:00:49 +08:00
|
|
|
|
)
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
symbol_definitions[symbol_name] = definitions
|
2025-03-07 13:00:43 +08:00
|
|
|
|
|
2025-03-07 12:00:49 +08:00
|
|
|
|
# 计算每个文件被引用次数
|
|
|
|
|
files_ref_counts = {}
|
|
|
|
|
# 当前选中代码文件,默认计算100
|
|
|
|
|
files_ref_counts[selected_code.abspath] = 100
|
|
|
|
|
for symbol in symbol_definitions:
|
|
|
|
|
for definition in symbol_definitions[symbol]:
|
|
|
|
|
if definition[0] not in files_ref_counts:
|
|
|
|
|
files_ref_counts[definition[0]] = 0
|
|
|
|
|
files_ref_counts[definition[0]] += 1
|
|
|
|
|
|
|
|
|
|
context_code = ""
|
|
|
|
|
for filepath, ref_count in files_ref_counts.items():
|
|
|
|
|
if ref_count > 0:
|
|
|
|
|
with open(filepath, "r") as f:
|
|
|
|
|
context_code += f"文件名:{filepath}\n\n"
|
|
|
|
|
context_code += "文件内容:\n"
|
|
|
|
|
context_code += f.read()
|
|
|
|
|
context_code += "\n\n"
|
2024-02-09 11:04:29 +08:00
|
|
|
|
|
|
|
|
|
# rewrite
|
2025-03-07 12:00:49 +08:00
|
|
|
|
response = ai_rewrite(question=question, selected_text=selected_text, context_code=context_code)
|
2024-06-18 18:21:25 +08:00
|
|
|
|
if not response:
|
|
|
|
|
sys.exit(1)
|
2024-02-09 11:09:14 +08:00
|
|
|
|
|
2024-02-09 11:04:29 +08:00
|
|
|
|
# apply new code to editor
|
|
|
|
|
new_code = extract_markdown_block(response)
|
2024-05-11 16:45:49 +08:00
|
|
|
|
if not new_code:
|
|
|
|
|
if IDEService().ide_language() == "zh":
|
|
|
|
|
print("\n\n大模型输出不完整,不能进行代码操作。\n\n")
|
|
|
|
|
else:
|
|
|
|
|
print("\n\nThe output of the LLM is incomplete and cannot perform code operations.\n\n")
|
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
2024-02-09 11:04:29 +08:00
|
|
|
|
IDEService().diff_apply("", new_code)
|
|
|
|
|
|
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|