From adc78c431b684dd9c8a7174046fe47ec8df4f6b1 Mon Sep 17 00:00:00 2001 From: kagami Date: Sun, 19 May 2024 15:05:26 +0800 Subject: [PATCH 1/2] Reorganize unit tests steps --- merico/unit_tests/main.py | 279 +------------------------- merico/unit_tests/ut_workflow.py | 324 +++++++++++++++++++++++++++++++ 2 files changed, 332 insertions(+), 271 deletions(-) create mode 100644 merico/unit_tests/ut_workflow.py diff --git a/merico/unit_tests/main.py b/merico/unit_tests/main.py index 01de96e..e3a402e 100644 --- a/merico/unit_tests/main.py +++ b/merico/unit_tests/main.py @@ -1,295 +1,27 @@ # ruff: noqa: E402 import os import sys -from typing import Dict, List, Tuple import click +import openai sys.path.append(os.path.dirname(__file__)) from cache import LocalCache -from find_context import ( - Context, - Position, - Range, - find_symbol_context_by_static_analysis, - find_symbol_context_of_llm_recommendation, -) -from find_reference_tests import find_reference_tests from i18n import TUILanguage, get_translation from model import ( FuncToTest, TokenBudgetExceededException, UserCancelledException, ) -from propose_test import propose_test -from tools.file_util import retrieve_file_content -from write_tests import write_and_print_tests +from ut_workflow import UnitTestsWorkflow -from lib.chatmark import Checkbox, Form, Step, TextEditor +from lib.chatmark import Step from lib.ide_service import IDEService CHAT_WORKFLOW_DIR_PATH = [".chat", "workflows"] -class UnitTestsWorkflow: - def __init__( - self, - user_prompt: str, - func_to_test: FuncToTest, - repo_root: str, - tui_lang: TUILanguage, - local_cache: LocalCache, - ): - self.user_prompt = user_prompt - self.func_to_test = func_to_test - self.repo_root = repo_root - self.tui_lang = tui_lang - self.local_cache = local_cache - - def run(self): - """ - Run the workflow to generate unit tests. - """ - symbol_context = self.step1_find_symbol_context() - contexts = set() - for _, v in symbol_context.items(): - contexts.update(v) - contexts = list(contexts) - - cases, files = self.step2_propose_cases_and_reference_files(contexts) - - res = self.step3_user_interaction(cases, files) - cases = res[0] - files = res[1] - requirements = res[2] - - self.step4_print_test_summary(cases, files, requirements, contexts) - - self.step5_write_and_print_tests(cases, files, contexts, requirements) - - def step1_find_symbol_context(self) -> Dict[str, List[Context]]: - symbol_context = find_symbol_context_by_static_analysis( - self.func_to_test, self.tui_lang.chat_language - ) - - known_context_for_llm: List[Context] = [] - if self.func_to_test.container_content is not None: - known_context_for_llm.append( - Context( - file_path=self.func_to_test.file_path, - content=self.func_to_test.container_content, - range=Range( - start=Position(line=self.func_to_test.container_start_line, character=0), - end=Position(line=self.func_to_test.container_end_line, character=0), - ), - ) - ) - known_context_for_llm += list( - {item for sublist in list(symbol_context.values()) for item in sublist} - ) - - recommended_context = find_symbol_context_of_llm_recommendation( - self.func_to_test, known_context_for_llm - ) - - symbol_context.update(recommended_context) - - return symbol_context - - def step2_propose_cases_and_reference_files( - self, - contexts: List[Context], - ) -> Tuple[List[str], List[str]]: - """ - Propose test cases and reference files for a specified function. - - Return: (test_cases, reference_files) - """ - test_cases: List[str] = [] - reference_files: List[str] = [] - - _i = get_translation(self.tui_lang) - - msg = _i("Analyzing the function and current unit tests...") - - with Step(msg): - test_cases = propose_test( - user_prompt=self.user_prompt, - func_to_test=self.func_to_test, - contexts=contexts, - chat_language=self.tui_lang.chat_language, - ) - - ref_files = find_reference_tests( - self.repo_root, - self.func_to_test.func_name, - self.func_to_test.file_path, - ) - - if ref_files: - # Only use the most relevant reference file currently - reference_files.append(ref_files[0]) - - return test_cases, reference_files - - def step3_user_interaction( - self, - test_cases: List[str], - reference_files: List[str], - ) -> Tuple[List[str], List[str], str]: - """ - Edit test cases and reference files by user. - - Return: - - the updated cases - - valid reference files - - customized requirements(prompts) - """ - _i = get_translation(self.tui_lang) - - checkbox = Checkbox( - options=test_cases, - title=_i("Select test cases to generate"), - ) - case_editor = TextEditor( - text="", - title=_i( - "You can add more test cases here\n" - "(Multiple cases can be separated by line breaks)" - ), - ) - ref_editor = TextEditor( - text=reference_files[0] if reference_files else "", - title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"), - ) - - cached_requirements = self.local_cache.get("user_requirements") or "" - requirements_editor = TextEditor( - text=cached_requirements, - title=_i( - "Write your customized requirements(prompts) for tests here." - "\n(For example, what testing framework to use.)" - ), - ) - - form = Form(components=[checkbox, case_editor, ref_editor, requirements_editor]) - form.render() - - # Check test cases - cases = [checkbox.options[idx] for idx in checkbox.selections] - user_cases = [] - if case_editor.new_text: - user_cases = [c.strip() for c in case_editor.new_text.split("\n")] - user_cases = [c for c in user_cases if c] - - cases.extend(user_cases) - - # Check if any test case is selected - if not cases: - raise UserCancelledException(_i("No test case is selected. Quit generating tests.")) - - # Validate reference files - ref_files = [f.strip() for f in ref_editor.new_text.split("\n")] - valid_files = [] - invalid_files = [] - - for ref_file in ref_files: - if not ref_file: - continue - try: - retrieve_file_content(file_path=ref_file, root_path=self.repo_root) - valid_files.append(ref_file) - except Exception: - invalid_files.append(ref_file) - - # Get customized requirements - requirements: str = ( - requirements_editor.new_text.strip() if requirements_editor.new_text else "" - ) - self.local_cache.set("user_requirements", requirements) - - return cases, valid_files, requirements - - # Tuple[List[str], List[str], str]: - def step4_print_test_summary( - self, - cases: List[str], - valid_files: List[str], - requirements: str, - contexts: List[Context], - ): - """ - Print the summary message in Step - """ - _i = get_translation(self.tui_lang) - - title = _i("Will generate tests for the following cases.") - lines = [] - - lines.append(_i("\nTest cases:")) - width = len(str(len(cases))) - lines.extend([f"{(i+1):>{width}}. {c}" for i, c in enumerate(cases)]) - - if not valid_files: - lines.append( - _i( - "\nNo valid reference file is provided. " - "Will not use reference to generate tests." - ) - ) - else: - lines.append(_i("\nWill use the following reference files to generate tests.")) - # lines.append(_i("\nValid reference files:")) - width = len(str(len(valid_files))) - lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(valid_files)]) - - # if invalid_files: - # lines.append(_i("\nInvalid files:")) - # width = len(str(len(invalid_files))) - # lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)]) - - lines.append(_i("\nCustomized requirements(prompts):")) - if requirements.strip(): - lines.append(requirements) - else: - lines.append(_i("No customized requirements.")) - - if contexts: - lines.append(_i("\nAdditional context:")) - width = len(str(len(contexts))) - lines.extend( - [ - f"{(i+1):>{width}}. {c.file_path}:{c.range.start.line+1}-{c.range.end.line+1}" - for i, c in enumerate(contexts) - ] - ) - - with Step(title): - print("\n".join(lines), flush=True) - - def step5_write_and_print_tests( - self, - cases: List[str], - ref_files: List[str], - symbol_contexts: List[Context], - user_requirements: str, - ): - """ - Write and print tests. - """ - - write_and_print_tests( - root_path=self.repo_root, - func_to_test=self.func_to_test, - test_cases=cases, - reference_files=ref_files, - symbol_contexts=symbol_contexts, - user_requirements=user_requirements, - chat_language=self.tui_lang.chat_language, - ) - - @click.command() @click.argument("input", required=True) def main(input: str): @@ -369,6 +101,11 @@ def main(input: str): with Step(f"{e}"): pass + except (openai.APIConnectionError, openai.APITimeoutError) as e: + msg = "Model API connection error. Please try again later." + with Step(msg): + print(f"\n{e}\n", flush=True) + except Exception as e: raise e diff --git a/merico/unit_tests/ut_workflow.py b/merico/unit_tests/ut_workflow.py new file mode 100644 index 0000000..b4c650b --- /dev/null +++ b/merico/unit_tests/ut_workflow.py @@ -0,0 +1,324 @@ +from typing import Dict, List, Tuple + +from cache import LocalCache +from find_context import ( + Context, + Position, + Range, + find_symbol_context_by_static_analysis, + find_symbol_context_of_llm_recommendation, +) +from find_reference_tests import find_reference_tests +from i18n import TUILanguage, get_translation +from model import ( + FuncToTest, + UserCancelledException, +) +from propose_test import propose_test +from tools.file_util import retrieve_file_content +from write_tests import write_and_print_tests + +from lib.chatmark import Checkbox, Form, Step, TextEditor + + +class UnitTestsWorkflow: + def __init__( + self, + user_prompt: str, + func_to_test: FuncToTest, + repo_root: str, + tui_lang: TUILanguage, + local_cache: LocalCache, + ): + self.user_prompt = user_prompt + self.func_to_test = func_to_test + self.repo_root = repo_root + self.tui_lang = tui_lang + self.local_cache = local_cache + + def run(self): + """ + Run the workflow to generate unit tests. + """ + _i = get_translation(self.tui_lang) + msg = _i("Analyzing the function and current unit tests...") + + with Step(msg): + print("\n- Analyzing context for the function...", flush=True) + symbol_context = self.step_1_find_symbol_context() + + contexts = set() + for _, v in symbol_context.items(): + contexts.update(v) + contexts = list(contexts) + + print("- Finding reference files...", flush=True) + reference_files = self.step_2_find_reference_files() + + print("- Proposing test cases...", flush=True) + cases = self.step_3_propose_cases(contexts) + + res = self.step_4_user_interaction(cases, reference_files) + cases = res[0] + files = res[1] + requirements = res[2] + + self.step_5_print_test_summary(cases, files, requirements, contexts) + + self.step_6_write_and_print_tests(cases, files, contexts, requirements) + + def step_1_find_symbol_context(self) -> Dict[str, List[Context]]: + symbol_context = find_symbol_context_by_static_analysis( + self.func_to_test, self.tui_lang.chat_language + ) + + known_context_for_llm: List[Context] = [] + if self.func_to_test.container_content is not None: + known_context_for_llm.append( + Context( + file_path=self.func_to_test.file_path, + content=self.func_to_test.container_content, + range=Range( + start=Position(line=self.func_to_test.container_start_line, character=0), + end=Position(line=self.func_to_test.container_end_line, character=0), + ), + ) + ) + known_context_for_llm += list( + {item for sublist in list(symbol_context.values()) for item in sublist} + ) + + recommended_context = find_symbol_context_of_llm_recommendation( + self.func_to_test, known_context_for_llm + ) + + symbol_context.update(recommended_context) + + return symbol_context + + def step_2_find_reference_files(self) -> List[str]: + """ + Find reference files for the specified function. + """ + reference_files: List[str] = [] + + ref_tests = find_reference_tests( + self.repo_root, + self.func_to_test.func_name, + self.func_to_test.file_path, + ) + if ref_tests: + # Only use the most relevant reference test file currently + reference_files.append(ref_tests[0]) + + return reference_files + + def step_3_propose_cases( + self, + contexts: List[Context], + ) -> List[str]: + """ + Propose test cases for the specified function. + """ + test_cases: List[str] = [] + + test_cases = propose_test( + user_prompt=self.user_prompt, + func_to_test=self.func_to_test, + contexts=contexts, + chat_language=self.tui_lang.chat_language, + ) + + return test_cases + + def step2_propose_cases_and_reference_files( + self, + contexts: List[Context], + ) -> Tuple[List[str], List[str]]: + """ + Propose test cases and reference files for a specified function. + + Return: (test_cases, reference_files) + """ + test_cases: List[str] = [] + reference_files: List[str] = [] + + _i = get_translation(self.tui_lang) + + msg = _i("Analyzing the function and current unit tests...") + + with Step(msg): + test_cases = propose_test( + user_prompt=self.user_prompt, + func_to_test=self.func_to_test, + contexts=contexts, + chat_language=self.tui_lang.chat_language, + ) + + ref_files = find_reference_tests( + self.repo_root, + self.func_to_test.func_name, + self.func_to_test.file_path, + ) + + if ref_files: + # Only use the most relevant reference file currently + reference_files.append(ref_files[0]) + + return test_cases, reference_files + + def step_4_user_interaction( + self, + test_cases: List[str], + reference_files: List[str], + ) -> Tuple[List[str], List[str], str]: + """ + Edit test cases and reference files by user. + + Return: + - the updated cases + - valid reference files + - customized requirements(prompts) + """ + _i = get_translation(self.tui_lang) + + checkbox = Checkbox( + options=test_cases, + title=_i("Select test cases to generate"), + ) + case_editor = TextEditor( + text="", + title=_i( + "You can add more test cases here\n" + "(Multiple cases can be separated by line breaks)" + ), + ) + ref_editor = TextEditor( + text=reference_files[0] if reference_files else "", + title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"), + ) + + cached_requirements = self.local_cache.get("user_requirements") or "" + requirements_editor = TextEditor( + text=cached_requirements, + title=_i( + "Write your customized requirements(prompts) for tests here." + "\n(For example, what testing framework to use.)" + ), + ) + + form = Form(components=[checkbox, case_editor, ref_editor, requirements_editor]) + form.render() + + # Check test cases + cases = [checkbox.options[idx] for idx in checkbox.selections] + user_cases = [] + if case_editor.new_text: + user_cases = [c.strip() for c in case_editor.new_text.split("\n")] + user_cases = [c for c in user_cases if c] + + cases.extend(user_cases) + + # Check if any test case is selected + if not cases: + raise UserCancelledException(_i("No test case is selected. Quit generating tests.")) + + # Validate reference files + ref_files = [f.strip() for f in ref_editor.new_text.split("\n")] + valid_files = [] + invalid_files = [] + + for ref_file in ref_files: + if not ref_file: + continue + try: + retrieve_file_content(file_path=ref_file, root_path=self.repo_root) + valid_files.append(ref_file) + except Exception: + invalid_files.append(ref_file) + + # Get customized requirements + requirements: str = ( + requirements_editor.new_text.strip() if requirements_editor.new_text else "" + ) + self.local_cache.set("user_requirements", requirements) + + return cases, valid_files, requirements + + # Tuple[List[str], List[str], str]: + def step_5_print_test_summary( + self, + cases: List[str], + valid_files: List[str], + requirements: str, + contexts: List[Context], + ): + """ + Print the summary message in Step + """ + _i = get_translation(self.tui_lang) + + title = _i("Will generate tests for the following cases.") + lines = [] + + lines.append(_i("\nTest cases:")) + width = len(str(len(cases))) + lines.extend([f"{(i+1):>{width}}. {c}" for i, c in enumerate(cases)]) + + if not valid_files: + lines.append( + _i( + "\nNo valid reference file is provided. " + "Will not use reference to generate tests." + ) + ) + else: + lines.append(_i("\nWill use the following reference files to generate tests.")) + # lines.append(_i("\nValid reference files:")) + width = len(str(len(valid_files))) + lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(valid_files)]) + + # if invalid_files: + # lines.append(_i("\nInvalid files:")) + # width = len(str(len(invalid_files))) + # lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)]) + + lines.append(_i("\nCustomized requirements(prompts):")) + if requirements.strip(): + lines.append(requirements) + else: + lines.append(_i("No customized requirements.")) + + if contexts: + lines.append(_i("\nAdditional context:")) + width = len(str(len(contexts))) + lines.extend( + [ + f"{(i+1):>{width}}. {c.file_path}:{c.range.start.line+1}-{c.range.end.line+1}" + for i, c in enumerate(contexts) + ] + ) + + with Step(title): + print("\n".join(lines), flush=True) + + def step_6_write_and_print_tests( + self, + cases: List[str], + ref_files: List[str], + symbol_contexts: List[Context], + user_requirements: str, + ): + """ + Write and print tests. + """ + + write_and_print_tests( + root_path=self.repo_root, + func_to_test=self.func_to_test, + test_cases=cases, + reference_files=ref_files, + symbol_contexts=symbol_contexts, + user_requirements=user_requirements, + chat_language=self.tui_lang.chat_language, + ) From 60c4c3d882b091c5ad3e0e36945109340b490f47 Mon Sep 17 00:00:00 2001 From: kagami Date: Sun, 19 May 2024 15:07:33 +0800 Subject: [PATCH 2/2] Print model response time for each step --- .../relevant_file_finder.py | 2 ++ .../assistants/recommend_test_context.py | 2 ++ merico/unit_tests/propose_test.py | 2 ++ merico/unit_tests/tools/time_util.py | 18 ++++++++++++++++++ 4 files changed, 24 insertions(+) create mode 100644 merico/unit_tests/tools/time_util.py diff --git a/merico/unit_tests/assistants/directory_structure/relevant_file_finder.py b/merico/unit_tests/assistants/directory_structure/relevant_file_finder.py index 1aa3acb..2db1da3 100644 --- a/merico/unit_tests/assistants/directory_structure/relevant_file_finder.py +++ b/merico/unit_tests/assistants/directory_structure/relevant_file_finder.py @@ -15,6 +15,7 @@ from llm_conf import ( from openai_util import create_chat_completion_content from tools.directory_viewer import ListViewer from tools.tiktoken_util import get_encoding +from tools.time_util import print_exec_time MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo" ENCODING = ( @@ -89,6 +90,7 @@ class RelevantFileFinder(DirectoryStructureBase): return message + @print_exec_time("Model response time") def _find_relevant_files(self, objective: str, dir_structure_pages: List[str]) -> List[str]: files: List[str] = [] for dir_structure in dir_structure_pages: diff --git a/merico/unit_tests/assistants/recommend_test_context.py b/merico/unit_tests/assistants/recommend_test_context.py index f7d849f..5f78cb6 100644 --- a/merico/unit_tests/assistants/recommend_test_context.py +++ b/merico/unit_tests/assistants/recommend_test_context.py @@ -12,6 +12,7 @@ from llm_conf import ( from model import FuncToTest from openai_util import create_chat_completion_content from tools.tiktoken_util import get_encoding +from tools.time_util import print_exec_time MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" ENCODING = ( @@ -82,6 +83,7 @@ def _mk_user_msg(func_to_test: FuncToTest, contexts: List) -> str: return msg +@print_exec_time("Model response time") def get_recommended_symbols( func_to_test: FuncToTest, known_context: Optional[List] = None ) -> List[str]: diff --git a/merico/unit_tests/propose_test.py b/merico/unit_tests/propose_test.py index f356ce4..1005642 100644 --- a/merico/unit_tests/propose_test.py +++ b/merico/unit_tests/propose_test.py @@ -15,6 +15,7 @@ from model import FuncToTest, TokenBudgetExceededException from openai_util import create_chat_completion_content from prompts import PROPOSE_TEST_PROMPT from tools.tiktoken_util import get_encoding +from tools.time_util import print_exec_time MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo" ENCODING = ( @@ -82,6 +83,7 @@ def _mk_user_msg( ) +@print_exec_time("Model response time") def propose_test( user_prompt: str, func_to_test: FuncToTest, diff --git a/merico/unit_tests/tools/time_util.py b/merico/unit_tests/tools/time_util.py new file mode 100644 index 0000000..f6d1b08 --- /dev/null +++ b/merico/unit_tests/tools/time_util.py @@ -0,0 +1,18 @@ +import time +from functools import wraps + + +def print_exec_time(message: str): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + duration = end_time - start_time + print(f"{message} ({duration:.3f} s)", flush=True) + return result + + return wrapper + + return decorator