Improve finding context by both definition & type definition

This commit is contained in:
kagami 2024-02-27 16:29:18 +08:00
parent 0196eb019f
commit fccfcafe98
3 changed files with 46 additions and 13 deletions

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
.vscode
__pycache__/
.DS_Store
tmp/

View File

@ -9,6 +9,7 @@ from tools.symbol_util import (
find_symbol_nodes,
get_symbol_content,
locate_symbol_by_name,
split_tokens,
)
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
@ -50,7 +51,9 @@ def _extract_referenced_symbols_context(
return referenced_symbols_context
def _find_children_symbols_type_def_context(func_to_test: FuncToTest, func_symbol: SymbolNode):
def _find_children_symbols_type_def_context(
func_to_test: FuncToTest, func_symbol: SymbolNode
):
"""
Find the type definitions of the symbols in the function.
"""
@ -105,30 +108,48 @@ def _extract_recommended_symbols_context(
# symbol name -> a list of context content (source code)
recommended_symbols: Dict[str, List[str]] = defaultdict(list)
type_def_locations: Dict[str, Set[Location]] = {}
# Will try to find both Definition and Type Definition for a symbol
symbol_def_locations: Dict[str, Set[Location]] = {}
for symbol_name in symbol_names:
type_def_locs = set()
def_locs = set()
symbol_tokens = split_tokens(symbol_name)
# Use the last token to locate the symbol as an approximation
last_token = None
last_char = -1
for s, chars in symbol_tokens.items():
for c in chars:
if c > last_char:
last_char = c
last_token = s
# locate the symbol in the file
positions = locate_symbol_by_name(symbol_name, abs_path)
positions = locate_symbol_by_name(last_token, abs_path)
for pos in positions:
locations = client.find_type_def_locations(abs_path, pos.line, pos.character)
type_locations = client.find_type_def_locations(
abs_path, pos.line, pos.character
)
def_locations = client.find_def_locations(abs_path, pos.line, pos.character)
locations = type_locations + def_locations
for loc in locations:
# check if loc.abspath is in func_to_test.repo_root
if not loc.abspath.startswith(func_to_test.repo_root):
# skip, not in the repo
continue
type_def_locs.add(loc)
type_def_locations[symbol_name] = type_def_locs
def_locs.add(loc)
symbol_def_locations[symbol_name] = def_locs
# Get the content of the type definitions
for symbol_name, locations in type_def_locations.items():
# Get the content of the found definitions
for symbol_name, locations in symbol_def_locations.items():
for loc in locations:
# NOTE: further improvement is needed to
# get the symbol node of function with decorator in Python
symbols = client.get_document_symbols(loc.abspath)
targets = find_symbol_nodes(symbols, name=symbol_name, line=loc.range.start.line)
# targets = find_symbol_nodes(symbols, name=symbol_name, line=loc.range.start.line)
targets = find_symbol_nodes(symbols, line=loc.range.start.line)
for t, _ in targets:
content = get_symbol_content(t, abspath=loc.abspath)
@ -163,7 +184,9 @@ def find_symbol_context_by_static_analysis(
context_by_reference = _extract_referenced_symbols_context(
func_to_test, doc_symbols, depth=func_depth
)
context_by_type_def = _find_children_symbols_type_def_context(func_to_test, func_symbol)
context_by_type_def = _find_children_symbols_type_def_context(
func_to_test, func_symbol
)
symbol_context.update(context_by_reference)
symbol_context.update(context_by_type_def)
@ -179,6 +202,8 @@ def find_symbol_context_of_llm_recommendation(
"""
recommended_symbols = get_recommended_symbols(func_to_test, known_context)
recommended_context = _extract_recommended_symbols_context(func_to_test, recommended_symbols)
recommended_context = _extract_recommended_symbols_context(
func_to_test, recommended_symbols
)
return recommended_context

View File

@ -18,6 +18,7 @@ def _mk_write_tests_msg(
test_cases: List[str],
chat_language: str,
reference_files: Optional[List[str]] = None,
# context_files: Optional[List[str]] = None,
symbol_context: Optional[List[str]] = None,
) -> Optional[str]:
encoding = get_encoding(ENCODING)
@ -46,6 +47,13 @@ def _mk_write_tests_msg(
context_content += "\n\n".join(symbol_context)
context_content += "\n```\n"
# if context_files:
# context_content += "\n\nrelevant context files\n\n"
# for i, fp in enumerate(context_files, 1):
# context_file_content = retrieve_file_content(fp, root_path)
# context_content += f"{i}. {fp}\n\n"
# context_content += f"```{context_file_content}```\n\n"
# Prepare a list of user messages to fit the token budget
# by adjusting the relevant content and reference content
content_fmt = partial(