import os import sys from typing import Dict, List, Tuple import click sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs")) from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402 from find_context import ( Context, find_symbol_context_by_static_analysis, find_symbol_context_of_llm_recommendation, ) from find_reference_tests import find_reference_tests from i18n import TUILanguage, get_translation from ide_services import IDEService # noqa: E402 from model import ( FuncToTest, TokenBudgetExceededException, UserCancelledException, ) from propose_test import propose_test from tools.file_util import retrieve_file_content from write_tests import write_and_print_tests class UnitTestsWorkflow: def __init__( self, user_prompt: str, func_to_test: FuncToTest, repo_root: str, tui_lang: TUILanguage, ): self.user_prompt = user_prompt self.func_to_test = func_to_test self.repo_root = repo_root self.tui_lang = tui_lang def run(self): """ Run the workflow to generate unit tests. """ symbol_context = self.step1_find_symbol_context() contexts = set() for _, v in symbol_context.items(): contexts.update(v) cases, files = self.step2_propose_cases_and_reference_files(list(contexts)) res = self.step3_edit_cases_and_reference_files(cases, files) cases = res[0] files = res[1] self.step4_write_and_print_tests(cases, files, list(contexts)) def step2_propose_cases_and_reference_files( self, contexts: List[Context], ) -> Tuple[List[str], List[str]]: """ Propose test cases and reference files for a specified function. Return: (test_cases, reference_files) """ test_cases: List[str] = [] reference_files: List[str] = [] _i = get_translation(self.tui_lang) msg = _i("Analyzing the function and current unit tests...") with Step(msg): test_cases = propose_test( user_prompt=self.user_prompt, func_to_test=self.func_to_test, contexts=contexts, chat_language=self.tui_lang.chat_language, ) ref_files = find_reference_tests( self.repo_root, self.func_to_test.func_name, self.func_to_test.file_path, ) if ref_files: # Only use the most relevant reference file currently reference_files.append(ref_files[0]) return test_cases, reference_files def step3_edit_cases_and_reference_files( self, test_cases: List[str], reference_files: List[str] ) -> Tuple[List[str], List[str]]: """ Edit test cases and reference files by user. Return the updated cases and valid reference files. """ _i = get_translation(self.tui_lang) checkbox = Checkbox( options=test_cases, title=_i("Select test cases to generate"), ) case_editor = TextEditor( text="", title=_i( "You can add more test cases here\n" "(Multiple cases can be separated by line breaks)" ), ) ref_editor = TextEditor( text=reference_files[0] if reference_files else "", title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"), ) form = Form(components=[checkbox, case_editor, ref_editor]) form.render() # Check test cases cases = [checkbox.options[idx] for idx in checkbox.selections] user_cases = [] if case_editor.new_text: user_cases = [c.strip() for c in case_editor.new_text.split("\n")] user_cases = [c for c in user_cases if c] cases.extend(user_cases) # Check if any test case is selected if not cases: raise UserCancelledException(_i("No test case is selected. Quit generating tests.")) # Validate reference files ref_files = [f.strip() for f in ref_editor.new_text.split("\n")] valid_files = [] invalid_files = [] for ref_file in ref_files: if not ref_file: continue try: retrieve_file_content(file_path=ref_file, root_path=self.repo_root) valid_files.append(ref_file) except Exception: invalid_files.append(ref_file) # Print summary title = _i("Will generate tests for the following cases.") lines = [] lines.append(_i("\nTest cases:")) width = len(str(len(cases))) lines.extend([f"{(i+1):>{width}}. {c}" for i, c in enumerate(cases)]) if not valid_files: lines.append( _i( "\nNo valid reference file is provided. " "Will not use reference to generate tests." ) ) else: lines.append(_i("\nWill use the following reference files to generate tests.")) lines.append(_i("\nValid reference files:")) width = len(str(len(valid_files))) lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(valid_files)]) if invalid_files: lines.append(_i("\nInvalid files:")) width = len(str(len(invalid_files))) lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)]) with Step(title): print("\n".join(lines), flush=True) return cases, valid_files def step1_find_symbol_context(self) -> Dict[str, List[Context]]: symbol_context = find_symbol_context_by_static_analysis( self.func_to_test, self.tui_lang.chat_language ) # with Step("Symbol context"): # for k, v in symbol_context.items(): # print(f"\n- {k}: ") # for item in v: # print(f"{item.file_path}\n{item.content}") known_context_for_llm: List[Context] = [] if self.func_to_test.container_content is not None: known_context_for_llm.append( Context( file_path=self.func_to_test.file_path, content=self.func_to_test.container_content, ) ) known_context_for_llm += list( {item for sublist in list(symbol_context.values()) for item in sublist} ) recommended_context = find_symbol_context_of_llm_recommendation( self.func_to_test, known_context_for_llm ) # with Step("Recommended context"): # for k, v in recommended_context.items(): # print(f"\n- {k}: ") # for item in v: # print(f"{item.file_path}\n{item.content}") symbol_context.update(recommended_context) return symbol_context def step4_write_and_print_tests( self, cases: List[str], ref_files: List[str], symbol_contexts: List[Context], ): """ Write and print tests. """ write_and_print_tests( root_path=self.repo_root, func_to_test=self.func_to_test, test_cases=cases, reference_files=ref_files, symbol_contexts=symbol_contexts, chat_language=self.tui_lang.chat_language, ) @click.command() @click.argument("input", required=True) def main(input: str): """ The main entry point for the unit tests generation workflow. "/unit_tests {a}:::{b}:::{c}:::{d}:::{e}:::{f}" """ # Parse input params = input.strip().split(":::") assert len(params) == 6, f"Invalid input: {input}, number of params: {len(params)}" ( file_path, func_name, func_start_line, # 0-based, inclusive func_end_line, # 0-based, inclusive container_start_line, # 0-based, inclusive container_end_line, # 0-based, inclusive ) = params try: func_start_line = int(func_start_line) func_end_line = int(func_end_line) container_start_line = int(container_start_line) container_end_line = int(container_end_line) except Exception as e: raise Exception(f"Invalid input: {input}, error: {e}") user_prompt = f"Help me write unit tests for the `{func_name}` function" repo_root = os.getcwd() ide_lang = IDEService().ide_language() tui_lang = TUILanguage.from_str(ide_lang) _i = get_translation(tui_lang) # Use relative path in inner logic if os.path.isabs(file_path): file_path = os.path.relpath(file_path, repo_root) func_to_test = FuncToTest( func_name=func_name, repo_root=repo_root, file_path=file_path, func_start_line=func_start_line, func_end_line=func_end_line, container_start_line=container_start_line, container_end_line=container_end_line, ) try: workflow = UnitTestsWorkflow(user_prompt, func_to_test, repo_root, tui_lang) workflow.run() except TokenBudgetExceededException as e: msg = _i("The function's size surpasses AI's context capacity.") with Step(msg): print(f"\n{e}\n", flush=True) except UserCancelledException as e: with Step(f"{e}"): pass except Exception as e: raise e if __name__ == "__main__": main()