diff --git a/unit_tests/assistants/recommend_test_context.py b/unit_tests/assistants/recommend_test_context.py index 9d072bf..45050e9 100644 --- a/unit_tests/assistants/recommend_test_context.py +++ b/unit_tests/assistants/recommend_test_context.py @@ -25,9 +25,9 @@ Here is the source code of the function: And here are some context information that might help you write the test cases: -``` + {context_content} -``` + Do you think the context information is enough? If the information is insufficient, recommend which symbols or types you need to know more about. @@ -46,10 +46,10 @@ JSON Format Example: def get_recommended_symbols( - func_to_test: FuncToTest, known_context: Optional[List[str]] = None + func_to_test: FuncToTest, known_context: Optional[List] = None ) -> List[str]: known_context = known_context or [] - context_content = "\n\n".join(known_context) + context_content = "\n\n".join([str(c) for c in known_context]) msg = recommend_symbol_context_prompt.format( function_content=func_to_test.func_content, diff --git a/unit_tests/find_context.py b/unit_tests/find_context.py index 4cb57f9..3bd6c73 100644 --- a/unit_tests/find_context.py +++ b/unit_tests/find_context.py @@ -1,6 +1,7 @@ import os import sys from collections import defaultdict +from dataclasses import dataclass from typing import Dict, List, Optional, Set from assistants.recommend_test_context import get_recommended_symbols @@ -17,9 +18,21 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from libs.ide_services import IDEService, Location, SymbolNode +@dataclass +class Context: + file_path: str # relative path to repo root + content: str + + def __hash__(self) -> int: + return hash((self.file_path, self.content)) + + def __str__(self) -> str: + return f"file path:`{self.file_path}`\n```\n{self.content}\n```" + + def _extract_referenced_symbols_context( func_to_test: FuncToTest, symbols: List[SymbolNode], depth: int = 0 -) -> Dict[str, List[str]]: +) -> Dict[str, List[Context]]: """ Extract context of the document symbols referenced in the function. Exclude the function itself and symbols whose depth is greater than the specified depth. @@ -47,15 +60,18 @@ def _extract_referenced_symbols_context( # Get the content of the symbols for s in referenced_symbols: content = get_symbol_content(s, file_content=func_to_test.file_content) - referenced_symbols_context[s.name].append(content) + context = Context(file_path=func_to_test.file_path, content=content) + referenced_symbols_context[s.name].append(context) return referenced_symbols_context -def _find_children_symbols_type_def_context(func_to_test: FuncToTest, func_symbol: SymbolNode): +def _find_children_symbols_type_def_context( + func_to_test: FuncToTest, func_symbol: SymbolNode +) -> Dict[str, List[Context]]: """ Find the type definitions of the symbols in the function. """ - type_defs: Dict[str, List[str]] = defaultdict(list) + type_defs: Dict[str, List[Context]] = defaultdict(list) client = IDEService() abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path) @@ -89,22 +105,24 @@ def _find_children_symbols_type_def_context(func_to_test: FuncToTest, func_symbo targets = find_symbol_nodes(symbols, line=loc.range.start.line) for t, _ in targets: content = get_symbol_content(t, abspath=loc.abspath) - type_defs[symbol_name].append(content) + relpath = os.path.relpath(loc.abspath, func_to_test.repo_root) + context = Context(file_path=relpath, content=content) + type_defs[symbol_name].append(context) return type_defs def _extract_recommended_symbols_context( func_to_test: FuncToTest, symbol_names: List[str] -) -> Dict[str, List[str]]: +) -> Dict[str, List[Context]]: """ Extract context of the given symbol names. """ abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path) client = IDEService() - # symbol name -> a list of context content (source code) - recommended_symbols: Dict[str, List[str]] = defaultdict(list) + # symbol name -> a list of context + recommended_symbols: Dict[str, List[Context]] = defaultdict(list) # Will try to find both Definition and Type Definition for a symbol symbol_def_locations: Dict[str, Set[Location]] = {} @@ -149,22 +167,24 @@ def _extract_recommended_symbols_context( for t, _ in targets: content = get_symbol_content(t, abspath=loc.abspath) - recommended_symbols[symbol_name].append(content) + relpath = os.path.relpath(loc.abspath, func_to_test.repo_root) + context = Context(file_path=relpath, content=content) + recommended_symbols[symbol_name].append(context) return recommended_symbols def find_symbol_context_by_static_analysis( func_to_test: FuncToTest, chat_language: str -) -> Dict[str, List[str]]: +) -> Dict[str, List[Context]]: """ Find the context of symbols in the function to test by static analysis. """ abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path) client = IDEService() - # symbol name -> a list of context content (code) - symbol_context: Dict[str, List[str]] = defaultdict(list) + # symbol name -> a list of context + symbol_context: Dict[str, List[Context]] = defaultdict(list) # Get all symbols in the file doc_symbols = client.get_document_symbols(abs_path) @@ -189,8 +209,8 @@ def find_symbol_context_by_static_analysis( def find_symbol_context_of_llm_recommendation( - func_to_test: FuncToTest, known_context: Optional[List[str]] = None -) -> List[str]: + func_to_test: FuncToTest, known_context: Optional[List[Context]] = None +) -> Dict[str, List[Context]]: """ Find the context of symbols recommended by LLM. """ diff --git a/unit_tests/main.py b/unit_tests/main.py index e5fe4a8..d663a69 100644 --- a/unit_tests/main.py +++ b/unit_tests/main.py @@ -9,6 +9,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs")) from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402 from find_context import ( + Context, find_symbol_context_by_static_analysis, find_symbol_context_of_llm_recommendation, ) @@ -43,21 +44,21 @@ class UnitTestsWorkflow: Run the workflow to generate unit tests. """ symbol_context = self.step1_find_symbol_context() - context = set() + contexts = set() for _, v in symbol_context.items(): - context.update(v) + contexts.update(v) - cases, files = self.step2_propose_cases_and_reference_files(list(context)) + cases, files = self.step2_propose_cases_and_reference_files(list(contexts)) res = self.step3_edit_cases_and_reference_files(cases, files) cases = res[0] files = res[1] - self.step4_write_and_print_tests(cases, files, list(context)) + self.step4_write_and_print_tests(cases, files, list(contexts)) def step2_propose_cases_and_reference_files( self, - context: List[str], + contexts: List[Context], ) -> Tuple[List[str], List[str]]: """ Propose test cases and reference files for a specified function. @@ -75,7 +76,7 @@ class UnitTestsWorkflow: test_cases = propose_test( user_prompt=self.user_prompt, func_to_test=self.func_to_test, - context=context, + contexts=contexts, chat_language=self.tui_lang.chat_language, ) @@ -178,14 +179,25 @@ class UnitTestsWorkflow: return cases, valid_files - def step1_find_symbol_context(self) -> Dict[str, List[str]]: + def step1_find_symbol_context(self) -> Dict[str, List[Context]]: symbol_context = find_symbol_context_by_static_analysis( self.func_to_test, self.tui_lang.chat_language ) - known_context_for_llm = [] + # with Step("Symbol context"): + # for k, v in symbol_context.items(): + # print(f"\n- {k}: ") + # for item in v: + # print(f"{item.file_path}\n{item.content}") + + known_context_for_llm: List[Context] = [] if self.func_to_test.container_content is not None: - known_context_for_llm.append(self.func_to_test.container_content) + known_context_for_llm.append( + Context( + file_path=self.func_to_test.file_path, + content=self.func_to_test.container_content, + ) + ) known_context_for_llm += list( {item for sublist in list(symbol_context.values()) for item in sublist} ) @@ -194,6 +206,12 @@ class UnitTestsWorkflow: self.func_to_test, known_context_for_llm ) + # with Step("Recommended context"): + # for k, v in recommended_context.items(): + # print(f"\n- {k}: ") + # for item in v: + # print(f"{item.file_path}\n{item.content}") + symbol_context.update(recommended_context) return symbol_context @@ -202,7 +220,7 @@ class UnitTestsWorkflow: self, cases: List[str], ref_files: List[str], - symbol_context: List[str], + symbol_contexts: List[Context], ): """ Write and print tests. @@ -213,7 +231,7 @@ class UnitTestsWorkflow: func_to_test=self.func_to_test, test_cases=cases, reference_files=ref_files, - symbol_context=symbol_context, + symbol_contexts=symbol_contexts, chat_language=self.tui_lang.chat_language, ) diff --git a/unit_tests/propose_test.py b/unit_tests/propose_test.py index dfab709..bf3e1da 100644 --- a/unit_tests/propose_test.py +++ b/unit_tests/propose_test.py @@ -2,6 +2,7 @@ import json from functools import partial from typing import List, Optional +from find_context import Context from model import FuncToTest, TokenBudgetExceededException from openai_util import create_chat_completion_content from prompts import PROPOSE_TEST_PROMPT @@ -16,7 +17,7 @@ TOKEN_BUDGET = int(16000 * 0.9) def _mk_user_msg( user_prompt: str, func_to_test: FuncToTest, - context: List[str], + contexts: List[Context], chat_language: str, ) -> str: """ @@ -30,10 +31,10 @@ def _mk_user_msg( class_content = f"class code\n```\n{func_to_test.container_content}\n```\n" context_content = "" - if context: - context_content = "\n\nrelevant context\n```\n" - context_content += "\n\n".join(context) - context_content += "\n```\n" + if contexts: + context_content = "\n\nrelevant context\n\n" + context_content += "\n\n".join([str(c) for c in contexts]) + context_content += "\n\n" # Prepare a list of user messages to fit the token budget # by adjusting the relevant content @@ -74,7 +75,7 @@ def _mk_user_msg( def propose_test( user_prompt: str, func_to_test: FuncToTest, - context: Optional[List[str]] = None, + contexts: Optional[List[Context]] = None, chat_language: str = "English", ) -> List[str]: """Propose test cases for a specified function based on a user prompt @@ -88,11 +89,11 @@ def propose_test( Returns: List[str]: A list of test case descriptions. """ - context = context or [] + contexts = contexts or [] user_msg = _mk_user_msg( user_prompt=user_prompt, func_to_test=func_to_test, - context=context, + contexts=contexts, chat_language=chat_language, ) diff --git a/unit_tests/write_tests.py b/unit_tests/write_tests.py index 58ac498..9037a8f 100644 --- a/unit_tests/write_tests.py +++ b/unit_tests/write_tests.py @@ -1,6 +1,7 @@ from functools import partial from typing import List, Optional +from find_context import Context from model import FuncToTest, TokenBudgetExceededException from openai_util import create_chat_completion_chunks from prompts import WRITE_TESTS_PROMPT @@ -19,7 +20,7 @@ def _mk_write_tests_msg( chat_language: str, reference_files: Optional[List[str]] = None, # context_files: Optional[List[str]] = None, - symbol_context: Optional[List[str]] = None, + symbol_contexts: Optional[List[Context]] = None, ) -> Optional[str]: encoding = get_encoding(ENCODING) @@ -42,10 +43,10 @@ def _mk_write_tests_msg( class_content = f"\nclass code\n```\n{func_to_test.container_content}\n```\n" context_content = "" - if symbol_context: - context_content += "\n\nrelevant context\n```\n" - context_content += "\n\n".join(symbol_context) - context_content += "\n```\n" + if symbol_contexts: + context_content += "\n\nrelevant context\n\n" + context_content += "\n\n".join([str(c) for c in symbol_contexts]) + context_content += "\n\n" # if context_files: # context_content += "\n\nrelevant context files\n\n" @@ -105,7 +106,7 @@ def write_and_print_tests( func_to_test: FuncToTest, test_cases: List[str], reference_files: Optional[List[str]] = None, - symbol_context: Optional[List[str]] = None, + symbol_contexts: Optional[List[Context]] = None, chat_language: str = "English", ) -> None: user_msg = _mk_write_tests_msg( @@ -113,7 +114,7 @@ def write_and_print_tests( func_to_test=func_to_test, test_cases=test_cases, reference_files=reference_files, - symbol_context=symbol_context, + symbol_contexts=symbol_contexts, chat_language=chat_language, )