309 lines
9.5 KiB
Python
309 lines
9.5 KiB
Python
import os
|
|
import sys
|
|
from typing import Dict, List, Tuple
|
|
|
|
import click
|
|
|
|
sys.path.append(os.path.dirname(__file__))
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
|
|
|
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
|
|
from find_context import (
|
|
Context,
|
|
find_symbol_context_by_static_analysis,
|
|
find_symbol_context_of_llm_recommendation,
|
|
)
|
|
from find_reference_tests import find_reference_tests
|
|
from i18n import TUILanguage, get_translation
|
|
from ide_services import IDEService # noqa: E402
|
|
from model import (
|
|
FuncToTest,
|
|
TokenBudgetExceededException,
|
|
UserCancelledException,
|
|
)
|
|
from propose_test import propose_test
|
|
from tools.file_util import retrieve_file_content
|
|
from write_tests import write_and_print_tests
|
|
|
|
|
|
class UnitTestsWorkflow:
|
|
def __init__(
|
|
self,
|
|
user_prompt: str,
|
|
func_to_test: FuncToTest,
|
|
repo_root: str,
|
|
tui_lang: TUILanguage,
|
|
):
|
|
self.user_prompt = user_prompt
|
|
self.func_to_test = func_to_test
|
|
self.repo_root = repo_root
|
|
self.tui_lang = tui_lang
|
|
|
|
def run(self):
|
|
"""
|
|
Run the workflow to generate unit tests.
|
|
"""
|
|
symbol_context = self.step1_find_symbol_context()
|
|
contexts = set()
|
|
for _, v in symbol_context.items():
|
|
contexts.update(v)
|
|
|
|
cases, files = self.step2_propose_cases_and_reference_files(list(contexts))
|
|
|
|
res = self.step3_edit_cases_and_reference_files(cases, files)
|
|
cases = res[0]
|
|
files = res[1]
|
|
|
|
self.step4_write_and_print_tests(cases, files, list(contexts))
|
|
|
|
def step2_propose_cases_and_reference_files(
|
|
self,
|
|
contexts: List[Context],
|
|
) -> Tuple[List[str], List[str]]:
|
|
"""
|
|
Propose test cases and reference files for a specified function.
|
|
|
|
Return: (test_cases, reference_files)
|
|
"""
|
|
test_cases: List[str] = []
|
|
reference_files: List[str] = []
|
|
|
|
_i = get_translation(self.tui_lang)
|
|
|
|
msg = _i("Analyzing the function and current unit tests...")
|
|
|
|
with Step(msg):
|
|
test_cases = propose_test(
|
|
user_prompt=self.user_prompt,
|
|
func_to_test=self.func_to_test,
|
|
contexts=contexts,
|
|
chat_language=self.tui_lang.chat_language,
|
|
)
|
|
|
|
ref_files = find_reference_tests(
|
|
self.repo_root,
|
|
self.func_to_test.func_name,
|
|
self.func_to_test.file_path,
|
|
)
|
|
|
|
if ref_files:
|
|
# Only use the most relevant reference file currently
|
|
reference_files.append(ref_files[0])
|
|
|
|
return test_cases, reference_files
|
|
|
|
def step3_edit_cases_and_reference_files(
|
|
self, test_cases: List[str], reference_files: List[str]
|
|
) -> Tuple[List[str], List[str]]:
|
|
"""
|
|
Edit test cases and reference files by user.
|
|
|
|
Return the updated cases and valid reference files.
|
|
"""
|
|
_i = get_translation(self.tui_lang)
|
|
|
|
checkbox = Checkbox(
|
|
options=test_cases,
|
|
title=_i("Select test cases to generate"),
|
|
)
|
|
case_editor = TextEditor(
|
|
text="",
|
|
title=_i(
|
|
"You can add more test cases here\n"
|
|
"(Multiple cases can be separated by line breaks)"
|
|
),
|
|
)
|
|
ref_editor = TextEditor(
|
|
text=reference_files[0] if reference_files else "",
|
|
title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"),
|
|
)
|
|
|
|
form = Form(components=[checkbox, case_editor, ref_editor])
|
|
form.render()
|
|
|
|
# Check test cases
|
|
cases = [checkbox.options[idx] for idx in checkbox.selections]
|
|
user_cases = []
|
|
if case_editor.new_text:
|
|
user_cases = [c.strip() for c in case_editor.new_text.split("\n")]
|
|
user_cases = [c for c in user_cases if c]
|
|
|
|
cases.extend(user_cases)
|
|
|
|
# Check if any test case is selected
|
|
if not cases:
|
|
raise UserCancelledException(_i("No test case is selected. Quit generating tests."))
|
|
|
|
# Validate reference files
|
|
ref_files = [f.strip() for f in ref_editor.new_text.split("\n")]
|
|
valid_files = []
|
|
invalid_files = []
|
|
|
|
for ref_file in ref_files:
|
|
if not ref_file:
|
|
continue
|
|
try:
|
|
retrieve_file_content(file_path=ref_file, root_path=self.repo_root)
|
|
valid_files.append(ref_file)
|
|
except Exception:
|
|
invalid_files.append(ref_file)
|
|
|
|
# Print summary
|
|
title = _i("Will generate tests for the following cases.")
|
|
lines = []
|
|
|
|
lines.append(_i("\nTest cases:"))
|
|
width = len(str(len(cases)))
|
|
lines.extend([f"{(i+1):>{width}}. {c}" for i, c in enumerate(cases)])
|
|
|
|
if not valid_files:
|
|
lines.append(
|
|
_i(
|
|
"\nNo valid reference file is provided. "
|
|
"Will not use reference to generate tests."
|
|
)
|
|
)
|
|
else:
|
|
lines.append(_i("\nWill use the following reference files to generate tests."))
|
|
lines.append(_i("\nValid reference files:"))
|
|
width = len(str(len(valid_files)))
|
|
lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(valid_files)])
|
|
|
|
if invalid_files:
|
|
lines.append(_i("\nInvalid files:"))
|
|
width = len(str(len(invalid_files)))
|
|
lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)])
|
|
|
|
with Step(title):
|
|
print("\n".join(lines), flush=True)
|
|
|
|
return cases, valid_files
|
|
|
|
def step1_find_symbol_context(self) -> Dict[str, List[Context]]:
|
|
symbol_context = find_symbol_context_by_static_analysis(
|
|
self.func_to_test, self.tui_lang.chat_language
|
|
)
|
|
|
|
# with Step("Symbol context"):
|
|
# for k, v in symbol_context.items():
|
|
# print(f"\n- {k}: ")
|
|
# for item in v:
|
|
# print(f"{item.file_path}\n{item.content}")
|
|
|
|
known_context_for_llm: List[Context] = []
|
|
if self.func_to_test.container_content is not None:
|
|
known_context_for_llm.append(
|
|
Context(
|
|
file_path=self.func_to_test.file_path,
|
|
content=self.func_to_test.container_content,
|
|
)
|
|
)
|
|
known_context_for_llm += list(
|
|
{item for sublist in list(symbol_context.values()) for item in sublist}
|
|
)
|
|
|
|
recommended_context = find_symbol_context_of_llm_recommendation(
|
|
self.func_to_test, known_context_for_llm
|
|
)
|
|
|
|
# with Step("Recommended context"):
|
|
# for k, v in recommended_context.items():
|
|
# print(f"\n- {k}: ")
|
|
# for item in v:
|
|
# print(f"{item.file_path}\n{item.content}")
|
|
|
|
symbol_context.update(recommended_context)
|
|
|
|
return symbol_context
|
|
|
|
def step4_write_and_print_tests(
|
|
self,
|
|
cases: List[str],
|
|
ref_files: List[str],
|
|
symbol_contexts: List[Context],
|
|
):
|
|
"""
|
|
Write and print tests.
|
|
"""
|
|
|
|
write_and_print_tests(
|
|
root_path=self.repo_root,
|
|
func_to_test=self.func_to_test,
|
|
test_cases=cases,
|
|
reference_files=ref_files,
|
|
symbol_contexts=symbol_contexts,
|
|
chat_language=self.tui_lang.chat_language,
|
|
)
|
|
|
|
|
|
@click.command()
|
|
@click.argument("input", required=True)
|
|
def main(input: str):
|
|
"""
|
|
The main entry point for the unit tests generation workflow.
|
|
"/unit_tests {a}:::{b}:::{c}:::{d}:::{e}:::{f}"
|
|
"""
|
|
# Parse input
|
|
params = input.strip().split(":::")
|
|
assert len(params) == 6, f"Invalid input: {input}, number of params: {len(params)}"
|
|
|
|
(
|
|
file_path,
|
|
func_name,
|
|
func_start_line, # 0-based, inclusive
|
|
func_end_line, # 0-based, inclusive
|
|
container_start_line, # 0-based, inclusive
|
|
container_end_line, # 0-based, inclusive
|
|
) = params
|
|
|
|
try:
|
|
func_start_line = int(func_start_line)
|
|
func_end_line = int(func_end_line)
|
|
container_start_line = int(container_start_line)
|
|
container_end_line = int(container_end_line)
|
|
except Exception as e:
|
|
raise Exception(f"Invalid input: {input}, error: {e}")
|
|
|
|
user_prompt = f"Help me write unit tests for the `{func_name}` function"
|
|
|
|
repo_root = os.getcwd()
|
|
ide_lang = IDEService().ide_language()
|
|
|
|
tui_lang = TUILanguage.from_str(ide_lang)
|
|
_i = get_translation(tui_lang)
|
|
|
|
# Use relative path in inner logic
|
|
if os.path.isabs(file_path):
|
|
file_path = os.path.relpath(file_path, repo_root)
|
|
|
|
func_to_test = FuncToTest(
|
|
func_name=func_name,
|
|
repo_root=repo_root,
|
|
file_path=file_path,
|
|
func_start_line=func_start_line,
|
|
func_end_line=func_end_line,
|
|
container_start_line=container_start_line,
|
|
container_end_line=container_end_line,
|
|
)
|
|
|
|
try:
|
|
workflow = UnitTestsWorkflow(user_prompt, func_to_test, repo_root, tui_lang)
|
|
workflow.run()
|
|
|
|
except TokenBudgetExceededException as e:
|
|
msg = _i("The function's size surpasses AI's context capacity.")
|
|
|
|
with Step(msg):
|
|
print(f"\n{e}\n", flush=True)
|
|
|
|
except UserCancelledException as e:
|
|
with Step(f"{e}"):
|
|
pass
|
|
|
|
except Exception as e:
|
|
raise e
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|