Merge pull request #67 from devchat-ai/symbol-context

Add more context for writing tests by analyzing symbols
This commit is contained in:
boob.yang 2024-03-12 07:56:57 +08:00 committed by GitHub
commit db6b2f4b4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 535 additions and 15 deletions

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
.vscode
__pycache__/
.DS_Store
tmp/

View File

@ -1,4 +1,4 @@
from contextlib import AbstractContextManager, contextmanager
from contextlib import AbstractContextManager
class Step(AbstractContextManager):
@ -25,4 +25,4 @@ class Step(AbstractContextManager):
def __exit__(self, exc_type, exc_val, exc_tb):
# close the step
print(f"\n```", flush=True)
print("\n```", flush=True)

View File

@ -1,5 +1,6 @@
from .service import IDEService
from .types import *
__all__ = [
__all__ = types.__all__ + [
"IDEService",
]

View File

@ -91,3 +91,7 @@ class IDEService:
@rpc_method
def find_type_def_locations(self, abspath: str, line: int, character: int) -> List[Location]:
return [Location.parse_obj(loc) for loc in self._result]
@rpc_method
def find_def_locations(self, abspath: str, line: int, character: int) -> List[Location]:
return [Location.parse_obj(loc) for loc in self._result]

View File

@ -2,21 +2,46 @@ from typing import List
from pydantic import BaseModel
__all__ = [
"Position",
"Range",
"Location",
"SymbolNode",
]
class Position(BaseModel):
line: int # 0-based
character: int # 0-based
def __repr__(self):
return f"Ln{self.line}:Col{self.character}"
def __hash__(self):
return hash(self.__repr__())
class Range(BaseModel):
start: Position
end: Position
def __repr__(self):
return f"{self.start} - {self.end}"
def __hash__(self):
return hash(self.__repr__())
class Location(BaseModel):
abspath: str
range: Range
def __repr__(self):
return f"{self.abspath}::{self.range}"
def __hash__(self):
return hash(self.__repr__())
class SymbolNode(BaseModel):
name: str

View File

@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Optional
from openai import OpenAI
from tools.directory_viewer import mk_repo_file_criteria

View File

@ -0,0 +1,70 @@
import json
from typing import List, Optional
from model import FuncToTest
from openai_util import create_chat_completion_content
MODEL = "gpt-4-1106-preview"
ENCODING = "cl100k_base"
# TODO: handle token budget
TOKEN_BUDGET = int(128000 * 0.9)
# ruff: noqa: E501
recommend_symbol_context_prompt = """
You're an advanced AI test generator.
You're about to write test cases for the function `{function_name}` in the file `{file_path}`.
Before you start, you need to check if you have enough context information to write the test cases.
Here is the source code of the function:
```
{function_content}
```
And here are some context information that might help you write the test cases:
{context_content}
Do you think the context information is enough?
If the information is insufficient, recommend which symbols or types you need to know more about.
Return a JSON object with a single key "key_symbols" whose value is a list of strings.
- If the context information is enough, return an empty list.
- Each string is the name of a symbol or type appearing in the function that lacks context information for writing test.
- The list should contain the most important symbols and should not exceed 10 items.
JSON Format Example:
{{
"key_symbols": ["<symbol 1>", "<symbol 2>", "<symbol 3>",...]
}}
"""
def get_recommended_symbols(
func_to_test: FuncToTest, known_context: Optional[List] = None
) -> List[str]:
known_context = known_context or []
context_content = "\n\n".join([str(c) for c in known_context])
msg = recommend_symbol_context_prompt.format(
function_content=func_to_test.func_content,
context_content=context_content,
function_name=func_to_test.func_name,
file_path=func_to_test.file_path,
)
response = create_chat_completion_content(
model=MODEL,
messages=[{"role": "user", "content": msg}],
response_format={"type": "json_object"},
temperature=0.1,
)
key_symbols = json.loads(response).get("key_symbols", [])
return key_symbols

221
unit_tests/find_context.py Normal file
View File

@ -0,0 +1,221 @@
import os
import sys
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set
from assistants.recommend_test_context import get_recommended_symbols
from model import FuncToTest
from tools.symbol_util import (
find_symbol_nodes,
get_symbol_content,
locate_symbol_by_name,
split_tokens,
)
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from libs.ide_services import IDEService, Location, SymbolNode
@dataclass
class Context:
file_path: str # relative path to repo root
content: str
def __hash__(self) -> int:
return hash((self.file_path, self.content))
def __str__(self) -> str:
return f"file path:`{self.file_path}`\n```\n{self.content}\n```"
def _extract_referenced_symbols_context(
func_to_test: FuncToTest, symbols: List[SymbolNode], depth: int = 0
) -> Dict[str, List[Context]]:
"""
Extract context of the document symbols referenced in the function.
Exclude the function itself and symbols whose depth is greater than the specified depth.
"""
referenced_symbols_context = defaultdict(list)
func_content = func_to_test.func_content
referenced_symbols = []
stack = [(s, 0) for s in symbols]
while stack:
s, d = stack.pop()
if d > depth:
continue
if s.name == func_to_test.func_name:
# Skip the function itself and its children
continue
if s.name in func_content:
# Use simple string matching for now
referenced_symbols.append(s)
stack.extend((c, depth + 1) for c in reversed(s.children))
# Get the content of the symbols
for s in referenced_symbols:
content = get_symbol_content(s, file_content=func_to_test.file_content)
context = Context(file_path=func_to_test.file_path, content=content)
referenced_symbols_context[s.name].append(context)
return referenced_symbols_context
def _find_children_symbols_type_def_context(
func_to_test: FuncToTest, func_symbol: SymbolNode
) -> Dict[str, List[Context]]:
"""
Find the type definitions of the symbols in the function.
"""
type_defs: Dict[str, List[Context]] = defaultdict(list)
client = IDEService()
abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path)
type_def_locations: Dict[str, Set[Location]] = defaultdict(set)
# find type definitions for symbols in the function
stack = func_symbol.children[:]
while stack:
s = stack.pop()
locations = client.find_type_def_locations(
abs_path, s.range.start.line, s.range.start.character
)
for loc in locations:
# check if loc.abspath is in func_to_test.repo_root
if not loc.abspath.startswith(func_to_test.repo_root):
# skip, not in the repo
continue
if loc.abspath == abs_path:
# skip, the symbol is in the same file
continue
type_def_locations[s.name].add(loc)
stack.extend(s.children)
# Get the content of the type definitions
for symbol_name, locations in type_def_locations.items():
for loc in locations:
symbols = client.get_document_symbols(loc.abspath)
targets = find_symbol_nodes(symbols, line=loc.range.start.line)
for t, _ in targets:
content = get_symbol_content(t, abspath=loc.abspath)
relpath = os.path.relpath(loc.abspath, func_to_test.repo_root)
context = Context(file_path=relpath, content=content)
type_defs[symbol_name].append(context)
return type_defs
def _extract_recommended_symbols_context(
func_to_test: FuncToTest, symbol_names: List[str]
) -> Dict[str, List[Context]]:
"""
Extract context of the given symbol names.
"""
abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path)
client = IDEService()
# symbol name -> a list of context
recommended_symbols: Dict[str, List[Context]] = defaultdict(list)
# Will try to find both Definition and Type Definition for a symbol
symbol_def_locations: Dict[str, Set[Location]] = {}
for symbol_name in symbol_names:
def_locs = set()
symbol_tokens = split_tokens(symbol_name)
# Use the last token to locate the symbol as an approximation
last_token = None
last_char = -1
for s, chars in symbol_tokens.items():
for c in chars:
if c > last_char:
last_char = c
last_token = s
# locate the symbol in the file
positions = locate_symbol_by_name(last_token, abs_path)
for pos in positions:
type_locations = client.find_type_def_locations(abs_path, pos.line, pos.character)
def_locations = client.find_def_locations(abs_path, pos.line, pos.character)
locations = type_locations + def_locations
for loc in locations:
# check if loc.abspath is in func_to_test.repo_root
if not loc.abspath.startswith(func_to_test.repo_root):
# skip, not in the repo
continue
def_locs.add(loc)
symbol_def_locations[symbol_name] = def_locs
# Get the content of the found definitions
for symbol_name, locations in symbol_def_locations.items():
for loc in locations:
# NOTE: further improvement is needed to
# get the symbol node of function with decorator in Python
symbols = client.get_document_symbols(loc.abspath)
# targets = find_symbol_nodes(symbols, name=symbol_name, line=loc.range.start.line)
targets = find_symbol_nodes(symbols, line=loc.range.start.line)
for t, _ in targets:
content = get_symbol_content(t, abspath=loc.abspath)
relpath = os.path.relpath(loc.abspath, func_to_test.repo_root)
context = Context(file_path=relpath, content=content)
recommended_symbols[symbol_name].append(context)
return recommended_symbols
def find_symbol_context_by_static_analysis(
func_to_test: FuncToTest, chat_language: str
) -> Dict[str, List[Context]]:
"""
Find the context of symbols in the function to test by static analysis.
"""
abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path)
client = IDEService()
# symbol name -> a list of context
symbol_context: Dict[str, List[Context]] = defaultdict(list)
# Get all symbols in the file
doc_symbols = client.get_document_symbols(abs_path)
# Find the symbol of the function to test
func_symbols = find_symbol_nodes(
doc_symbols, name=func_to_test.func_name, line=func_to_test.func_start_line
)
if not func_symbols:
return symbol_context
func_symbol, func_depth = func_symbols[0]
context_by_reference = _extract_referenced_symbols_context(
func_to_test, doc_symbols, depth=func_depth
)
context_by_type_def = _find_children_symbols_type_def_context(func_to_test, func_symbol)
symbol_context.update(context_by_reference)
symbol_context.update(context_by_type_def)
return symbol_context
def find_symbol_context_of_llm_recommendation(
func_to_test: FuncToTest, known_context: Optional[List[Context]] = None
) -> Dict[str, List[Context]]:
"""
Find the context of symbols recommended by LLM.
"""
recommended_symbols = get_recommended_symbols(func_to_test, known_context)
recommended_context = _extract_recommended_symbols_context(func_to_test, recommended_symbols)
return recommended_context

View File

@ -1,6 +1,6 @@
import os
import sys
from typing import List, Tuple
from typing import Dict, List, Tuple
import click
@ -8,6 +8,11 @@ sys.path.append(os.path.dirname(__file__))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
from find_context import (
Context,
find_symbol_context_by_static_analysis,
find_symbol_context_of_llm_recommendation,
)
from find_reference_tests import find_reference_tests
from i18n import TUILanguage, get_translation
from ide_services import IDEService # noqa: E402
@ -38,14 +43,22 @@ class UnitTestsWorkflow:
"""
Run the workflow to generate unit tests.
"""
cases, files = self.step1_propose_cases_and_reference_files()
symbol_context = self.step1_find_symbol_context()
contexts = set()
for _, v in symbol_context.items():
contexts.update(v)
cases, files = self.step2_edit_cases_and_reference_files(cases, files)
cases, files = self.step2_propose_cases_and_reference_files(list(contexts))
self.step3_write_and_print_tests(cases, files)
res = self.step3_edit_cases_and_reference_files(cases, files)
cases = res[0]
files = res[1]
def step1_propose_cases_and_reference_files(
self.step4_write_and_print_tests(cases, files, list(contexts))
def step2_propose_cases_and_reference_files(
self,
contexts: List[Context],
) -> Tuple[List[str], List[str]]:
"""
Propose test cases and reference files for a specified function.
@ -63,6 +76,7 @@ class UnitTestsWorkflow:
test_cases = propose_test(
user_prompt=self.user_prompt,
func_to_test=self.func_to_test,
contexts=contexts,
chat_language=self.tui_lang.chat_language,
)
@ -78,7 +92,7 @@ class UnitTestsWorkflow:
return test_cases, reference_files
def step2_edit_cases_and_reference_files(
def step3_edit_cases_and_reference_files(
self, test_cases: List[str], reference_files: List[str]
) -> Tuple[List[str], List[str]]:
"""
@ -165,10 +179,48 @@ class UnitTestsWorkflow:
return cases, valid_files
def step3_write_and_print_tests(
def step1_find_symbol_context(self) -> Dict[str, List[Context]]:
symbol_context = find_symbol_context_by_static_analysis(
self.func_to_test, self.tui_lang.chat_language
)
# with Step("Symbol context"):
# for k, v in symbol_context.items():
# print(f"\n- {k}: ")
# for item in v:
# print(f"{item.file_path}\n{item.content}")
known_context_for_llm: List[Context] = []
if self.func_to_test.container_content is not None:
known_context_for_llm.append(
Context(
file_path=self.func_to_test.file_path,
content=self.func_to_test.container_content,
)
)
known_context_for_llm += list(
{item for sublist in list(symbol_context.values()) for item in sublist}
)
recommended_context = find_symbol_context_of_llm_recommendation(
self.func_to_test, known_context_for_llm
)
# with Step("Recommended context"):
# for k, v in recommended_context.items():
# print(f"\n- {k}: ")
# for item in v:
# print(f"{item.file_path}\n{item.content}")
symbol_context.update(recommended_context)
return symbol_context
def step4_write_and_print_tests(
self,
cases: List[str],
ref_files: List[str],
symbol_contexts: List[Context],
):
"""
Write and print tests.
@ -179,6 +231,7 @@ class UnitTestsWorkflow:
func_to_test=self.func_to_test,
test_cases=cases,
reference_files=ref_files,
symbol_contexts=symbol_contexts,
chat_language=self.tui_lang.chat_language,
)

View File

@ -1,7 +1,8 @@
import json
from functools import partial
from typing import List
from typing import List, Optional
from find_context import Context
from model import FuncToTest, TokenBudgetExceededException
from openai_util import create_chat_completion_content
from prompts import PROPOSE_TEST_PROMPT
@ -16,6 +17,7 @@ TOKEN_BUDGET = int(16000 * 0.9)
def _mk_user_msg(
user_prompt: str,
func_to_test: FuncToTest,
contexts: List[Context],
chat_language: str,
) -> str:
"""
@ -28,6 +30,12 @@ def _mk_user_msg(
if func_to_test.container_content is not None:
class_content = f"class code\n```\n{func_to_test.container_content}\n```\n"
context_content = ""
if contexts:
context_content = "\n\nrelevant context\n\n"
context_content += "\n\n".join([str(c) for c in contexts])
context_content += "\n\n"
# Prepare a list of user messages to fit the token budget
# by adjusting the relevant content
relevant_content_fmt = partial(
@ -37,6 +45,10 @@ def _mk_user_msg(
file_path=func_to_test.file_path,
chat_language=chat_language,
)
# 0. func content & class content & context content
msg_0 = relevant_content_fmt(
relevant_content="\n".join([func_content, class_content, context_content]),
)
# 1. func content & class content
msg_1 = relevant_content_fmt(
relevant_content="\n".join([func_content, class_content]),
@ -46,7 +58,7 @@ def _mk_user_msg(
relevant_content=func_content,
)
prioritized_msgs = [msg_1, msg_2]
prioritized_msgs = [msg_0, msg_1, msg_2]
for msg in prioritized_msgs:
token_count = len(encoding.encode(msg, disallowed_special=()))
@ -63,6 +75,7 @@ def _mk_user_msg(
def propose_test(
user_prompt: str,
func_to_test: FuncToTest,
contexts: Optional[List[Context]] = None,
chat_language: str = "English",
) -> List[str]:
"""Propose test cases for a specified function based on a user prompt
@ -76,9 +89,11 @@ def propose_test(
Returns:
List[str]: A list of test case descriptions.
"""
contexts = contexts or []
user_msg = _mk_user_msg(
user_prompt=user_prompt,
func_to_test=func_to_test,
contexts=contexts,
chat_language=chat_language,
)

View File

@ -0,0 +1,107 @@
import os
import re
import sys
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from tools.file_util import retrieve_file_content
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from libs.ide_services import Position, SymbolNode
def split_tokens(text: str) -> Dict[str, List[int]]:
"""
Split a line of text into tokens.
Return a dictionary of token -> character numbers.
Not a perfect implementation, but may be enough for now.
"""
matches = re.finditer(r"\b\w+\b", text)
result = defaultdict(list)
for match in matches:
token = match.group()
result[token].append(match.start())
return result
def locate_symbol_by_name(
symbol_name: str, abspath: str, line_numbers: Optional[List[int]] = None
) -> List[Position]:
"""
Find the locations of the specified symbol in the specified file.
Line and column numbers are 0-based.
symbol_name: The name of the symbol to find.
abspath: The absolute path to the file to search.
line_numbers: The line numbers to search for the symbol.
If None, search the entire file.
return: a list of Position
"""
line_set = set(line_numbers) if line_numbers else None
positions: List[Position] = []
with open(abspath, "r") as file:
for i, line in enumerate(file):
if line_set and i not in line_set:
continue
tokens = split_tokens(line)
chars = tokens.get(symbol_name, [])
for char in chars:
positions.append(Position(line=i, character=char))
return positions
def find_symbol_nodes(
symbols: List[SymbolNode], name: Optional[str] = None, line: Optional[int] = None
) -> List[Tuple[SymbolNode, int]]:
"""
Find the symbols with the specified name and line number.
return: a list of tuples (symbol, depth)
"""
assert name is not None or line is not None
res = []
stack = [(s, 0) for s in symbols]
while stack:
symbol, depth = stack.pop()
flag = True
if name and symbol.name != name:
flag = False
if line and symbol.range.start.line != line:
flag = False
if flag:
res.append((symbol, depth))
else:
stack.extend((c, depth + 1) for c in reversed(symbol.children))
return res
def get_symbol_content(
symbol: SymbolNode,
file_content: Optional[str] = None,
abspath: Optional[str] = None,
) -> str:
"""
Get the content of the symbol in the file.
"""
if file_content is None and abspath is None:
raise ValueError("Either file_content or abspath should be provided")
if file_content is None:
file_content = retrieve_file_content(abspath, None)
lines = file_content.split("\n")
content = lines[symbol.range.start.line : symbol.range.end.line]
content.append(lines[symbol.range.end.line][: symbol.range.end.character])
return "\n".join(content)

View File

@ -1,6 +1,7 @@
from functools import partial
from typing import List, Optional
from find_context import Context
from model import FuncToTest, TokenBudgetExceededException
from openai_util import create_chat_completion_chunks
from prompts import WRITE_TESTS_PROMPT
@ -18,6 +19,8 @@ def _mk_write_tests_msg(
test_cases: List[str],
chat_language: str,
reference_files: Optional[List[str]] = None,
# context_files: Optional[List[str]] = None,
symbol_contexts: Optional[List[Context]] = None,
) -> Optional[str]:
encoding = get_encoding(ENCODING)
@ -39,6 +42,19 @@ def _mk_write_tests_msg(
if func_to_test.container_content is not None:
class_content = f"\nclass code\n```\n{func_to_test.container_content}\n```\n"
context_content = ""
if symbol_contexts:
context_content += "\n\nrelevant context\n\n"
context_content += "\n\n".join([str(c) for c in symbol_contexts])
context_content += "\n\n"
# if context_files:
# context_content += "\n\nrelevant context files\n\n"
# for i, fp in enumerate(context_files, 1):
# context_file_content = retrieve_file_content(fp, root_path)
# context_content += f"{i}. {fp}\n\n"
# context_content += f"```{context_file_content}```\n\n"
# Prepare a list of user messages to fit the token budget
# by adjusting the relevant content and reference content
content_fmt = partial(
@ -48,6 +64,13 @@ def _mk_write_tests_msg(
test_cases_str=test_cases_str,
chat_language=chat_language,
)
# NOTE: adjust symbol_context content more flexibly if needed
msg_0 = content_fmt(
relevant_content="\n".join([func_content, class_content, context_content]),
reference_content=reference_content,
)
# 1. func content & class content & reference file content
msg_1 = content_fmt(
relevant_content="\n".join([func_content, class_content]),
@ -64,7 +87,7 @@ def _mk_write_tests_msg(
reference_content="",
)
prioritized_msgs = [msg_1, msg_2, msg_3]
prioritized_msgs = [msg_0, msg_1, msg_2, msg_3]
for msg in prioritized_msgs:
tokens = len(encoding.encode(msg, disallowed_special=()))
@ -83,6 +106,7 @@ def write_and_print_tests(
func_to_test: FuncToTest,
test_cases: List[str],
reference_files: Optional[List[str]] = None,
symbol_contexts: Optional[List[Context]] = None,
chat_language: str = "English",
) -> None:
user_msg = _mk_write_tests_msg(
@ -90,6 +114,7 @@ def write_and_print_tests(
func_to_test=func_to_test,
test_cases=test_cases,
reference_files=reference_files,
symbol_contexts=symbol_contexts,
chat_language=chat_language,
)