Refactor code and improve comments in main.py

This commit is contained in:
bobo 2024-05-11 15:17:18 +08:00
parent ddd14faffa
commit dd716e4475
3 changed files with 77 additions and 56 deletions

View File

@ -62,7 +62,6 @@ here is an error example:
// const posOffset = document.offsetAt(position);
// await outputAst(filepath, fileContent, posOffset);
// // await testTreesitterQuery(filepath, fileContent);
// const result2 = await findSimilarCodeBlock(filepath, fileContent, position.line, position.character);
// logger.channel()?.info("Result:", result2);
// if (1) {
// return [];
@ -89,7 +88,7 @@ output is :
In this example, string in code is changed, so this output is bad.
Output should format as code block.
"""
""",
},
{
"role": "user",
@ -98,7 +97,7 @@ file: a1.py
```
print("hello")
```
"""
""",
},
{
"role": "assistant",
@ -107,7 +106,7 @@ file: a1.py
# print hello
print("hello")
```
"""
""",
},
{
"role": "user",
@ -116,7 +115,7 @@ file: a1.py
```
print("hell\\nworld")
```
"""
""",
},
{
"role": "assistant",
@ -125,7 +124,7 @@ file: a1.py
# print hello world
print("hell\\nworld")
```
"""
""",
},
{
"role": "user",
@ -140,7 +139,7 @@ file: t2.ts
length: response!.code.length
});
```
"""
""",
},
{
"role": "assistant",
@ -159,7 +158,7 @@ file: t2.ts
length: response!.code.length
});
```
"""
""",
},
{
"role": "user",
@ -174,7 +173,7 @@ file: t2.ts
console.log("a is 2");
}
```
"""
""",
},
{
"role": "assistant",
@ -190,8 +189,8 @@ file: t2.ts
console.log("a is 2");
}
```
"""
}
""",
},
]
@ -209,18 +208,22 @@ def get_selected_code():
sys.exit(-1)
return selected_data
memory = FixSizeChatMemory(max_size=20, messages=MESSAGES_FEW_SHOT)
def get_prompt():
ide_language = IDEService().ide_language()
return PROMPT_ZH if ide_language == "zh" else PROMPT
@chat(prompt=get_prompt(), stream_out=True, memory=memory)
# pylint: disable=unused-argument
def add_comments(selected_text, file_path):
"""Call AI to rewrite selected code"""
pass # pylint: disable=unnecessary-pass
def extract_markdown_block(text):
"""Extracts the first Markdown code block from the given text without the language specifier."""
pattern = r"```(?:\w+)?\s*\n(.*?)\n```"
@ -233,21 +236,23 @@ def extract_markdown_block(text):
return None
return text
def remove_unnecessary_escapes(code_a, code_b):
code_copy = code_b # Create a copy of the original code
escape_chars = re.finditer(r"\\(.)", code_b)
remove_char_index = []
for match in escape_chars:
before = code_b[max(0, match.start()-4):match.start()]
after = code_b[match.start()+1:match.start()+5]
before = code_b[max(0, match.start() - 4) : match.start()]
after = code_b[match.start() + 1 : match.start() + 5]
substr = before + after
if substr in code_a:
remove_char_index.append(match.start())
remove_char_index.reverse()
for index in remove_char_index:
code_copy = code_copy[:index] + code_copy[index+1:]
code_copy = code_copy[:index] + code_copy[index + 1 :]
return code_copy
def main():
selected_text = get_selected_code()
file_path = selected_text.get("abspath", "")
@ -260,7 +265,9 @@ def main():
ide_lang = IDEService().ide_language()
error_msg = (
"\n\nThe output of the LLM is incomplete and cannot perform code operations.\n\n"
if ide_lang != "zh" else "\n\n大模型输出不完整,不能进行代码操作。\n\n")
if ide_lang != "zh"
else "\n\n大模型输出不完整,不能进行代码操作。\n\n"
)
print(error_msg)
sys.exit(0)
@ -268,5 +275,6 @@ def main():
IDEService().diff_apply("", new_code)
sys.exit(0)
if __name__ == "__main__":
main()

View File

@ -7,7 +7,6 @@ from devchat.ide.vscode_services import selected_lines, visible_lines
from devchat.llm import chat
from devchat.memory import FixSizeChatMemory
PROMPT = prompt = """
file: {file_path}
```
@ -24,6 +23,7 @@ PROMPT_ZH = prompt = """
输出内容使用中文我的母语为中文
"""
def get_prompt():
ide_language = IDEService().ide_language()
return PROMPT_ZH if ide_language == "zh" else PROMPT
@ -40,10 +40,12 @@ And just add the documents for the selected portion of the code.
Output documentation comment is format as code block.\
You must follow the following rules:
1. Output documentation comment in ```comment <documentation comments without code lines> ``` format.
2. Different languages have different comment symbols, please choose the correct comment symbol according to the file name.
1. Output documentation comment in ```comment <documentation comments without code lines> ``` \
format.
2. Different languages have different comment symbols, please choose the correct comment symbol \
according to the file name.
3. You must output ... to indicate the remaining code, output all code block can make more errors.
"""
""",
},
{
"role": "user",
@ -54,7 +56,7 @@ file: a1.py
print("hello")
print("world")
```
"""
""",
},
{
"role": "assistant",
@ -73,7 +75,7 @@ file: a1.py
print("hello")
...
```
"""
""",
},
{
"role": "user",
@ -84,7 +86,7 @@ file: t1.java
System.out.println("Hello, World!");
}
```
"""
""",
},
{
"role": "assistant",
@ -99,7 +101,7 @@ file: t1.java
public static void main(String[] args) {
...
```
"""
""",
},
{
"role": "user",
@ -116,7 +118,7 @@ def content_to_json(content):
except Exception as err:
raise err
```
"""
""",
},
{
"role": "assistant",
@ -138,8 +140,8 @@ def content_to_json(content):
try:
...
```
"""
}
""",
},
]
@ -174,6 +176,7 @@ def get_selected_code():
memory = FixSizeChatMemory(max_size=20, messages=MESSAGES_A)
@chat(prompt=get_prompt(), stream_out=True, memory=memory)
# pylint: disable=unused-argument
def add_docstring(selected_text, file_path):
@ -202,6 +205,7 @@ def extract_markdown_block(text):
return None
return text
def get_indent_level(text):
"""
Returns the indentation level of the given text.
@ -219,6 +223,7 @@ def get_indent_level(text):
break
return indent_level
def offset_indent_level(text, indent_level):
"""
Offsets the indentation level of the given text by the specified amount.
@ -237,6 +242,7 @@ def offset_indent_level(text, indent_level):
text = "\n".join(new_lines)
return text
def merge_code(selected_text, docstring):
user_selected_lines = selected_text.split("\n")
docstring_lines = docstring.split("\n")
@ -264,11 +270,14 @@ def merge_code(selected_text, docstring):
for index_doc, line_doc in enumerate(docstring_trim_lines):
if line_doc == "...":
break
if line.strip().find(line_doc.strip())!= -1 or line_doc.strip().find(line.strip())!= -1:
if (
line.strip().find(line_doc.strip()) != -1
or line_doc.strip().find(line.strip()) != -1
):
break
if ((line.strip().find(line_doc.strip()) == -1 and
line_doc.strip().find(line.strip()) == -1) or
index == index_doc):
if (
line.strip().find(line_doc.strip()) == -1 and line_doc.strip().find(line.strip()) == -1
) or index == index_doc:
continue
return "\n".join(docstring_lines[:index_doc] + user_selected_lines[index:])
return docstring + "\n" + selected_text
@ -280,12 +289,11 @@ def main():
# Rewrite
response = add_docstring(
selected_text=selected_text.get('text', ''),
file_path=selected_text.get('abspath', '')
selected_text=selected_text.get("text", ""), file_path=selected_text.get("abspath", "")
)
# Get indent level
indent = get_indent_level(selected_text.get('text', ''))
indent = get_indent_level(selected_text.get("text", ""))
# Apply new code to editor
new_code = extract_markdown_block(response)
@ -298,12 +306,13 @@ def main():
new_code = offset_indent_level(new_code, indent)
# Merge code
docstring_code = merge_code(selected_text.get('text', ''), new_code)
docstring_code = merge_code(selected_text.get("text", ""), new_code)
# Apply diff
IDEService().diff_apply("", docstring_code)
sys.exit(0)
def print_message(language):
if language == "zh":
print("\n\n大模型输出不完整,不能进行代码操作。\n\n")

View File

@ -7,7 +7,6 @@ from devchat.ide.vscode_services import selected_lines, visible_lines
from devchat.llm import chat
from devchat.memory import FixSizeChatMemory
PROMPT = prompt = """
file: {file_path}
```
@ -36,14 +35,16 @@ MESSAGES_A = [
"role": "system",
"content": """
Your task is:
Refine internal variable and function names within the code to achieve concise and meaningful identifiers that comply with English naming conventions.
Refine internal variable and function names within the code to achieve concise and \
meaningful identifiers that comply with English naming conventions.
Rules:
1. Don't rename a call or global variable. for example, xx() is function call, xx is a bad name, but you MUST not rename it .
1. Don't rename a call or global variable. for example, xx() is function call, xx \
is a bad name, but you MUST not rename it .
2. You can rename a local variable or parameter variable name.
3. Current function's name can be renamed. Always this is a new function.
"""
""",
},
{
"role": "user",
@ -54,7 +55,7 @@ file: a1.py
a = "hello world"
print(a)
```
"""
""",
},
{
"role": "assistant",
@ -64,7 +65,7 @@ file: a1.py
msg = "hello world"
print(msg)
```
"""
""",
},
{
"role": "user",
@ -74,7 +75,7 @@ file: t1.py
def print_hello(a: str):
print(a)
```
"""
""",
},
{
"role": "assistant",
@ -83,7 +84,7 @@ file: t1.py
def print_hello(msg: str):
print(msg)
```
"""
""",
},
{
"role": "user",
@ -93,7 +94,7 @@ file: t1.py
def some():
print("hello")
```
"""
""",
},
{
"role": "assistant",
@ -102,7 +103,7 @@ file: t1.py
def output_hello():
print("hello")
```
"""
""",
},
{
"role": "user",
@ -112,7 +113,7 @@ file: t1.py
def print_hello():
print("hello")
```
"""
""",
},
{
"role": "assistant",
@ -121,14 +122,14 @@ file: t1.py
def print_hello():
output("hello")
```
"""
""",
},
{
"role": "user",
"content": """
Your response is error, you changed call name.
print is a function call, if you rename it, this will make a compile error.
"""
""",
},
{
"role": "assistant",
@ -137,10 +138,11 @@ print is a function call, if you rename it, this will make a compile error.
def print_hello():
print("hello")
```
"""
}
""",
},
]
def get_selected_code():
"""
Retrieves the selected lines of code from the user's selection.
@ -172,6 +174,7 @@ def get_selected_code():
memory = FixSizeChatMemory(max_size=20, messages=MESSAGES_A)
@chat(prompt=get_prompt(), stream_out=True, memory=memory)
# pylint: disable=unused-argument
def reanme_variable(selected_text, file_path):
@ -203,12 +206,12 @@ def extract_markdown_block(text):
def remove_unnecessary_escapes(code_a, code_b):
code_copy = code_b # Create a copy of the original code
escape_chars = re.finditer(r'\\(.)', code_b)
escape_chars = re.finditer(r"\\(.)", code_b)
remove_char_index = []
for match in escape_chars:
before = code_b[max(0, match.start()-4):match.start()]
after = code_b[match.start()+1:match.start()+5]
before = code_b[max(0, match.start() - 4) : match.start()]
after = code_b[match.start() + 1 : match.start() + 5]
substr = before + after
if substr in code_a:
remove_char_index.append(match.start())
@ -216,14 +219,15 @@ def remove_unnecessary_escapes(code_a, code_b):
# visit remove_char_index in reverse order
remove_char_index.reverse()
for index in remove_char_index:
code_copy = code_copy[:index] + code_copy[index+1:]
code_copy = code_copy[:index] + code_copy[index + 1 :]
return code_copy
def main():
# prepare code
selected_text = get_selected_code()
selected_code = selected_text.get('text', '')
selected_file = selected_text.get('abspath', '')
selected_code = selected_text.get("text", "")
selected_file = selected_text.get("abspath", "")
# rewrite
response = reanme_variable(selected_text=selected_code, file_path=selected_file)
@ -237,7 +241,7 @@ def main():
print("\n\nThe output of the LLM is incomplete and cannot perform code operations.\n\n")
sys.exit(0)
new_code = remove_unnecessary_escapes(selected_text.get('text', ''), new_code)
new_code = remove_unnecessary_escapes(selected_text.get("text", ""), new_code)
IDEService().diff_apply("", new_code)
sys.exit(0)