226 lines
7.7 KiB
Python
226 lines
7.7 KiB
Python
import os
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Set
|
|
|
|
from assistants.recommend_test_context import get_recommended_symbols
|
|
from model import FuncToTest
|
|
from tools.symbol_util import (
|
|
find_symbol_nodes,
|
|
get_symbol_content,
|
|
locate_symbol_by_name,
|
|
split_tokens,
|
|
)
|
|
|
|
from lib.ide_service import (
|
|
IDEService,
|
|
Location,
|
|
Position,
|
|
Range,
|
|
SymbolNode,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Context:
|
|
file_path: str # relative path to repo root
|
|
content: str
|
|
range: Range
|
|
|
|
def __hash__(self) -> int:
|
|
return hash((self.file_path, self.content))
|
|
|
|
def __str__(self) -> str:
|
|
return f"file path:`{self.file_path}`\n```\n{self.content}\n```"
|
|
|
|
|
|
def _extract_referenced_symbols_context(
|
|
func_to_test: FuncToTest, symbols: List[SymbolNode], depth: int = 0
|
|
) -> Dict[str, List[Context]]:
|
|
"""
|
|
Extract context of the document symbols referenced in the function.
|
|
Exclude the function itself and symbols whose depth is greater than the specified depth.
|
|
"""
|
|
referenced_symbols_context = defaultdict(list)
|
|
func_content = func_to_test.func_content
|
|
referenced_symbols: List[SymbolNode] = []
|
|
stack = [(s, 0) for s in symbols]
|
|
|
|
while stack:
|
|
s, d = stack.pop()
|
|
if d > depth:
|
|
continue
|
|
|
|
if s.name == func_to_test.func_name:
|
|
# Skip the function itself and its children
|
|
continue
|
|
|
|
if s.name in func_content:
|
|
# Use simple string matching for now
|
|
referenced_symbols.append(s)
|
|
|
|
stack.extend((c, depth + 1) for c in reversed(s.children))
|
|
|
|
# Get the content of the symbols
|
|
for s in referenced_symbols:
|
|
content = get_symbol_content(s, file_content=func_to_test.file_content)
|
|
context = Context(file_path=func_to_test.file_path, content=content, range=s.range)
|
|
referenced_symbols_context[s.name].append(context)
|
|
return referenced_symbols_context
|
|
|
|
|
|
def _find_children_symbols_type_def_context(
|
|
func_to_test: FuncToTest, func_symbol: SymbolNode
|
|
) -> Dict[str, List[Context]]:
|
|
"""
|
|
Find the type definitions of the symbols in the function.
|
|
"""
|
|
type_defs: Dict[str, List[Context]] = defaultdict(list)
|
|
|
|
client = IDEService()
|
|
abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path)
|
|
|
|
type_def_locations: Dict[str, Set[Location]] = defaultdict(set)
|
|
# find type definitions for symbols in the function
|
|
stack = func_symbol.children[:]
|
|
while stack:
|
|
s = stack.pop()
|
|
|
|
locations = client.find_type_def_locations(
|
|
abs_path, s.range.start.line, s.range.start.character
|
|
)
|
|
for loc in locations:
|
|
# check if loc.abspath is in func_to_test.repo_root
|
|
if not loc.abspath.startswith(func_to_test.repo_root):
|
|
# skip, not in the repo
|
|
continue
|
|
if loc.abspath == abs_path:
|
|
# skip, the symbol is in the same file
|
|
continue
|
|
|
|
type_def_locations[s.name].add(loc)
|
|
|
|
stack.extend(s.children)
|
|
|
|
# Get the content of the type definitions
|
|
for symbol_name, locations in type_def_locations.items():
|
|
for loc in locations:
|
|
symbols = client.get_document_symbols(loc.abspath)
|
|
targets = find_symbol_nodes(symbols, line=loc.range.start.line)
|
|
for t, _ in targets:
|
|
content = get_symbol_content(t, abspath=loc.abspath)
|
|
relpath = os.path.relpath(loc.abspath, func_to_test.repo_root)
|
|
context = Context(file_path=relpath, content=content, range=t.range)
|
|
type_defs[symbol_name].append(context)
|
|
|
|
return type_defs
|
|
|
|
|
|
def _extract_recommended_symbols_context(
|
|
func_to_test: FuncToTest, symbol_names: List[str]
|
|
) -> Dict[str, List[Context]]:
|
|
"""
|
|
Extract context of the given symbol names.
|
|
"""
|
|
abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path)
|
|
client = IDEService()
|
|
|
|
# symbol name -> a list of context
|
|
recommended_symbols: Dict[str, List[Context]] = defaultdict(list)
|
|
|
|
# Will try to find both Definition and Type Definition for a symbol
|
|
symbol_def_locations: Dict[str, Set[Location]] = {}
|
|
|
|
for symbol_name in symbol_names:
|
|
def_locs = set()
|
|
|
|
symbol_tokens = split_tokens(symbol_name)
|
|
# Use the last token to locate the symbol as an approximation
|
|
last_token = None
|
|
last_char = -1
|
|
for s, chars in symbol_tokens.items():
|
|
for c in chars:
|
|
if c > last_char:
|
|
last_char = c
|
|
last_token = s
|
|
|
|
# locate the symbol in the file
|
|
positions: List[Position] = locate_symbol_by_name(last_token, abs_path)
|
|
|
|
for pos in positions:
|
|
type_locations = client.find_type_def_locations(abs_path, pos.line, pos.character)
|
|
def_locations = client.find_def_locations(abs_path, pos.line, pos.character)
|
|
locations = type_locations + def_locations
|
|
for loc in locations:
|
|
# check if loc.abspath is in func_to_test.repo_root
|
|
if not loc.abspath.startswith(func_to_test.repo_root):
|
|
# skip, not in the repo
|
|
continue
|
|
|
|
def_locs.add(loc)
|
|
symbol_def_locations[symbol_name] = def_locs
|
|
|
|
# Get the content of the found definitions
|
|
for symbol_name, locations in symbol_def_locations.items():
|
|
for loc in locations:
|
|
# NOTE: further improvement is needed to
|
|
# get the symbol node of function with decorator in Python
|
|
symbols = client.get_document_symbols(loc.abspath)
|
|
# targets = find_symbol_nodes(symbols, name=symbol_name, line=loc.range.start.line)
|
|
targets = find_symbol_nodes(symbols, line=loc.range.start.line)
|
|
|
|
for t, _ in targets:
|
|
content = get_symbol_content(t, abspath=loc.abspath)
|
|
relpath = os.path.relpath(loc.abspath, func_to_test.repo_root)
|
|
context = Context(file_path=relpath, content=content, range=t.range)
|
|
recommended_symbols[symbol_name].append(context)
|
|
|
|
return recommended_symbols
|
|
|
|
|
|
def find_symbol_context_by_static_analysis(
|
|
func_to_test: FuncToTest, chat_language: str
|
|
) -> Dict[str, List[Context]]:
|
|
"""
|
|
Find the context of symbols in the function to test by static analysis.
|
|
"""
|
|
abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path)
|
|
client = IDEService()
|
|
|
|
# symbol name -> a list of context
|
|
symbol_context: Dict[str, List[Context]] = defaultdict(list)
|
|
|
|
# Get all symbols in the file
|
|
doc_symbols = client.get_document_symbols(abs_path)
|
|
# Find the symbol of the function to test
|
|
func_symbols = find_symbol_nodes(
|
|
doc_symbols, name=func_to_test.func_name, line=func_to_test.func_start_line
|
|
)
|
|
if not func_symbols:
|
|
return symbol_context
|
|
|
|
func_symbol, func_depth = func_symbols[0]
|
|
|
|
context_by_reference = _extract_referenced_symbols_context(
|
|
func_to_test, doc_symbols, depth=func_depth
|
|
)
|
|
context_by_type_def = _find_children_symbols_type_def_context(func_to_test, func_symbol)
|
|
|
|
symbol_context.update(context_by_reference)
|
|
symbol_context.update(context_by_type_def)
|
|
|
|
return symbol_context
|
|
|
|
|
|
def find_symbol_context_of_llm_recommendation(
|
|
func_to_test: FuncToTest, known_context: Optional[List[Context]] = None
|
|
) -> Dict[str, List[Context]]:
|
|
"""
|
|
Find the context of symbols recommended by LLM.
|
|
"""
|
|
recommended_symbols = get_recommended_symbols(func_to_test, known_context)
|
|
|
|
recommended_context = _extract_recommended_symbols_context(func_to_test, recommended_symbols)
|
|
|
|
return recommended_context
|