Print additional context in summary

This commit is contained in:
kagami 2024-03-15 16:39:08 +08:00
parent e7b3a66ab6
commit 4fb3b79ebe
2 changed files with 70 additions and 52 deletions

View File

@ -15,14 +15,14 @@ from tools.symbol_util import (
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from libs.ide_services import IDEService, Location, SymbolNode, Range
from libs.ide_services import IDEService, Location, Range, SymbolNode
@dataclass
class Context:
file_path: str # relative path to repo root
content: str
range: Range
range: Range
def __hash__(self) -> int:
return hash((self.file_path, self.content))

View File

@ -52,15 +52,43 @@ class UnitTestsWorkflow:
contexts = set()
for _, v in symbol_context.items():
contexts.update(v)
contexts = list(contexts)
cases, files = self.step2_propose_cases_and_reference_files(list(contexts))
cases, files = self.step2_propose_cases_and_reference_files(contexts)
res = self.step3_user_interaction(cases, files)
cases = res[0]
files = res[1]
requirements = res[2]
self.step4_write_and_print_tests(cases, files, list(contexts), requirements)
self.step4_print_test_summary(cases, files, requirements, contexts)
self.step5_write_and_print_tests(cases, files, contexts, requirements)
def step1_find_symbol_context(self) -> Dict[str, List[Context]]:
symbol_context = find_symbol_context_by_static_analysis(
self.func_to_test, self.tui_lang.chat_language
)
known_context_for_llm: List[Context] = []
if self.func_to_test.container_content is not None:
known_context_for_llm.append(
Context(
file_path=self.func_to_test.file_path,
content=self.func_to_test.container_content,
)
)
known_context_for_llm += list(
{item for sublist in list(symbol_context.values()) for item in sublist}
)
recommended_context = find_symbol_context_of_llm_recommendation(
self.func_to_test, known_context_for_llm
)
symbol_context.update(recommended_context)
return symbol_context
def step2_propose_cases_and_reference_files(
self,
@ -99,7 +127,9 @@ class UnitTestsWorkflow:
return test_cases, reference_files
def step3_user_interaction(
self, test_cases: List[str], reference_files: List[str]
self,
test_cases: List[str],
reference_files: List[str],
) -> Tuple[List[str], List[str], str]:
"""
Edit test cases and reference files by user.
@ -172,7 +202,21 @@ class UnitTestsWorkflow:
)
self.local_cache.set("user_requirements", requirements)
# Print summary
return cases, valid_files, requirements
# Tuple[List[str], List[str], str]:
def step4_print_test_summary(
self,
cases: List[str],
valid_files: List[str],
requirements: str,
contexts: List[Context],
):
"""
Print the summary message in Step
"""
_i = get_translation(self.tui_lang)
title = _i("Will generate tests for the following cases.")
lines = []
@ -189,61 +233,35 @@ class UnitTestsWorkflow:
)
else:
lines.append(_i("\nWill use the following reference files to generate tests."))
lines.append(_i("\nValid reference files:"))
# lines.append(_i("\nValid reference files:"))
width = len(str(len(valid_files)))
lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(valid_files)])
if invalid_files:
lines.append(_i("\nInvalid files:"))
width = len(str(len(invalid_files)))
lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)])
# if invalid_files:
# lines.append(_i("\nInvalid files:"))
# width = len(str(len(invalid_files)))
# lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)])
lines.append(_i("\nCustomized requirements(prompts):"))
lines.append(requirements)
if requirements.strip():
lines.append(requirements)
else:
lines.append(_i("No customized requirments."))
if contexts:
lines.append(_i("\nAdditional context:"))
width = len(str(len(contexts)))
lines.extend(
[
f"{(i+1):>{width}}. {c.file_path}:{c.range.start.line+1}-{c.range.end.line+1}"
for i, c in enumerate(contexts)
]
)
with Step(title):
print("\n".join(lines), flush=True)
return cases, valid_files, requirements
def step1_find_symbol_context(self) -> Dict[str, List[Context]]:
symbol_context = find_symbol_context_by_static_analysis(
self.func_to_test, self.tui_lang.chat_language
)
# with Step("Symbol context"):
# for k, v in symbol_context.items():
# print(f"\n- {k}: ")
# for item in v:
# print(f"{item.file_path}\n{item.content}")
known_context_for_llm: List[Context] = []
if self.func_to_test.container_content is not None:
known_context_for_llm.append(
Context(
file_path=self.func_to_test.file_path,
content=self.func_to_test.container_content,
)
)
known_context_for_llm += list(
{item for sublist in list(symbol_context.values()) for item in sublist}
)
recommended_context = find_symbol_context_of_llm_recommendation(
self.func_to_test, known_context_for_llm
)
# with Step("Recommended context"):
# for k, v in recommended_context.items():
# print(f"\n- {k}: ")
# for item in v:
# print(f"{item.file_path}\n{item.content}")
symbol_context.update(recommended_context)
return symbol_context
def step4_write_and_print_tests(
def step5_write_and_print_tests(
self,
cases: List[str],
ref_files: List[str],