Improve finding context by both definition & type definition
This commit is contained in:
parent
0196eb019f
commit
fccfcafe98
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,4 +1,4 @@
|
||||
.vscode
|
||||
__pycache__/
|
||||
.DS_Store
|
||||
|
||||
tmp/
|
||||
|
@ -9,6 +9,7 @@ from tools.symbol_util import (
|
||||
find_symbol_nodes,
|
||||
get_symbol_content,
|
||||
locate_symbol_by_name,
|
||||
split_tokens,
|
||||
)
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
@ -50,7 +51,9 @@ def _extract_referenced_symbols_context(
|
||||
return referenced_symbols_context
|
||||
|
||||
|
||||
def _find_children_symbols_type_def_context(func_to_test: FuncToTest, func_symbol: SymbolNode):
|
||||
def _find_children_symbols_type_def_context(
|
||||
func_to_test: FuncToTest, func_symbol: SymbolNode
|
||||
):
|
||||
"""
|
||||
Find the type definitions of the symbols in the function.
|
||||
"""
|
||||
@ -105,30 +108,48 @@ def _extract_recommended_symbols_context(
|
||||
# symbol name -> a list of context content (source code)
|
||||
recommended_symbols: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
type_def_locations: Dict[str, Set[Location]] = {}
|
||||
# Will try to find both Definition and Type Definition for a symbol
|
||||
symbol_def_locations: Dict[str, Set[Location]] = {}
|
||||
|
||||
for symbol_name in symbol_names:
|
||||
type_def_locs = set()
|
||||
def_locs = set()
|
||||
|
||||
symbol_tokens = split_tokens(symbol_name)
|
||||
# Use the last token to locate the symbol as an approximation
|
||||
last_token = None
|
||||
last_char = -1
|
||||
for s, chars in symbol_tokens.items():
|
||||
for c in chars:
|
||||
if c > last_char:
|
||||
last_char = c
|
||||
last_token = s
|
||||
|
||||
# locate the symbol in the file
|
||||
positions = locate_symbol_by_name(symbol_name, abs_path)
|
||||
positions = locate_symbol_by_name(last_token, abs_path)
|
||||
|
||||
for pos in positions:
|
||||
locations = client.find_type_def_locations(abs_path, pos.line, pos.character)
|
||||
type_locations = client.find_type_def_locations(
|
||||
abs_path, pos.line, pos.character
|
||||
)
|
||||
def_locations = client.find_def_locations(abs_path, pos.line, pos.character)
|
||||
locations = type_locations + def_locations
|
||||
for loc in locations:
|
||||
# check if loc.abspath is in func_to_test.repo_root
|
||||
if not loc.abspath.startswith(func_to_test.repo_root):
|
||||
# skip, not in the repo
|
||||
continue
|
||||
|
||||
type_def_locs.add(loc)
|
||||
type_def_locations[symbol_name] = type_def_locs
|
||||
def_locs.add(loc)
|
||||
symbol_def_locations[symbol_name] = def_locs
|
||||
|
||||
# Get the content of the type definitions
|
||||
for symbol_name, locations in type_def_locations.items():
|
||||
# Get the content of the found definitions
|
||||
for symbol_name, locations in symbol_def_locations.items():
|
||||
for loc in locations:
|
||||
# NOTE: further improvement is needed to
|
||||
# get the symbol node of function with decorator in Python
|
||||
symbols = client.get_document_symbols(loc.abspath)
|
||||
targets = find_symbol_nodes(symbols, name=symbol_name, line=loc.range.start.line)
|
||||
# targets = find_symbol_nodes(symbols, name=symbol_name, line=loc.range.start.line)
|
||||
targets = find_symbol_nodes(symbols, line=loc.range.start.line)
|
||||
|
||||
for t, _ in targets:
|
||||
content = get_symbol_content(t, abspath=loc.abspath)
|
||||
@ -163,7 +184,9 @@ def find_symbol_context_by_static_analysis(
|
||||
context_by_reference = _extract_referenced_symbols_context(
|
||||
func_to_test, doc_symbols, depth=func_depth
|
||||
)
|
||||
context_by_type_def = _find_children_symbols_type_def_context(func_to_test, func_symbol)
|
||||
context_by_type_def = _find_children_symbols_type_def_context(
|
||||
func_to_test, func_symbol
|
||||
)
|
||||
|
||||
symbol_context.update(context_by_reference)
|
||||
symbol_context.update(context_by_type_def)
|
||||
@ -179,6 +202,8 @@ def find_symbol_context_of_llm_recommendation(
|
||||
"""
|
||||
recommended_symbols = get_recommended_symbols(func_to_test, known_context)
|
||||
|
||||
recommended_context = _extract_recommended_symbols_context(func_to_test, recommended_symbols)
|
||||
recommended_context = _extract_recommended_symbols_context(
|
||||
func_to_test, recommended_symbols
|
||||
)
|
||||
|
||||
return recommended_context
|
||||
|
@ -18,6 +18,7 @@ def _mk_write_tests_msg(
|
||||
test_cases: List[str],
|
||||
chat_language: str,
|
||||
reference_files: Optional[List[str]] = None,
|
||||
# context_files: Optional[List[str]] = None,
|
||||
symbol_context: Optional[List[str]] = None,
|
||||
) -> Optional[str]:
|
||||
encoding = get_encoding(ENCODING)
|
||||
@ -46,6 +47,13 @@ def _mk_write_tests_msg(
|
||||
context_content += "\n\n".join(symbol_context)
|
||||
context_content += "\n```\n"
|
||||
|
||||
# if context_files:
|
||||
# context_content += "\n\nrelevant context files\n\n"
|
||||
# for i, fp in enumerate(context_files, 1):
|
||||
# context_file_content = retrieve_file_content(fp, root_path)
|
||||
# context_content += f"{i}. {fp}\n\n"
|
||||
# context_content += f"```{context_file_content}```\n\n"
|
||||
|
||||
# Prepare a list of user messages to fit the token budget
|
||||
# by adjusting the relevant content and reference content
|
||||
content_fmt = partial(
|
||||
|
Loading…
x
Reference in New Issue
Block a user