Merge pull request #67 from devchat-ai/symbol-context
Add more context for writing tests by analyzing symbols
This commit is contained in:
commit
db6b2f4b4c
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,4 +1,4 @@
|
||||
.vscode
|
||||
__pycache__/
|
||||
.DS_Store
|
||||
|
||||
tmp/
|
||||
|
@ -1,4 +1,4 @@
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from contextlib import AbstractContextManager
|
||||
|
||||
|
||||
class Step(AbstractContextManager):
|
||||
@ -25,4 +25,4 @@ class Step(AbstractContextManager):
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# close the step
|
||||
print(f"\n```", flush=True)
|
||||
print("\n```", flush=True)
|
||||
|
@ -1,5 +1,6 @@
|
||||
from .service import IDEService
|
||||
from .types import *
|
||||
|
||||
__all__ = [
|
||||
__all__ = types.__all__ + [
|
||||
"IDEService",
|
||||
]
|
||||
|
@ -91,3 +91,7 @@ class IDEService:
|
||||
@rpc_method
|
||||
def find_type_def_locations(self, abspath: str, line: int, character: int) -> List[Location]:
|
||||
return [Location.parse_obj(loc) for loc in self._result]
|
||||
|
||||
@rpc_method
|
||||
def find_def_locations(self, abspath: str, line: int, character: int) -> List[Location]:
|
||||
return [Location.parse_obj(loc) for loc in self._result]
|
||||
|
@ -2,21 +2,46 @@ from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = [
|
||||
"Position",
|
||||
"Range",
|
||||
"Location",
|
||||
"SymbolNode",
|
||||
]
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
line: int # 0-based
|
||||
character: int # 0-based
|
||||
|
||||
def __repr__(self):
|
||||
return f"Ln{self.line}:Col{self.character}"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.__repr__())
|
||||
|
||||
|
||||
class Range(BaseModel):
|
||||
start: Position
|
||||
end: Position
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.start} - {self.end}"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.__repr__())
|
||||
|
||||
|
||||
class Location(BaseModel):
|
||||
abspath: str
|
||||
range: Range
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.abspath}::{self.range}"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.__repr__())
|
||||
|
||||
|
||||
class SymbolNode(BaseModel):
|
||||
name: str
|
||||
|
@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
from tools.directory_viewer import mk_repo_file_criteria
|
||||
|
||||
|
||||
|
70
unit_tests/assistants/recommend_test_context.py
Normal file
70
unit_tests/assistants/recommend_test_context.py
Normal file
@ -0,0 +1,70 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from model import FuncToTest
|
||||
from openai_util import create_chat_completion_content
|
||||
|
||||
MODEL = "gpt-4-1106-preview"
|
||||
ENCODING = "cl100k_base"
|
||||
# TODO: handle token budget
|
||||
TOKEN_BUDGET = int(128000 * 0.9)
|
||||
|
||||
|
||||
# ruff: noqa: E501
|
||||
recommend_symbol_context_prompt = """
|
||||
You're an advanced AI test generator.
|
||||
|
||||
You're about to write test cases for the function `{function_name}` in the file `{file_path}`.
|
||||
Before you start, you need to check if you have enough context information to write the test cases.
|
||||
|
||||
Here is the source code of the function:
|
||||
|
||||
```
|
||||
{function_content}
|
||||
```
|
||||
|
||||
And here are some context information that might help you write the test cases:
|
||||
|
||||
|
||||
{context_content}
|
||||
|
||||
|
||||
Do you think the context information is enough?
|
||||
If the information is insufficient, recommend which symbols or types you need to know more about.
|
||||
|
||||
Return a JSON object with a single key "key_symbols" whose value is a list of strings.
|
||||
- If the context information is enough, return an empty list.
|
||||
- Each string is the name of a symbol or type appearing in the function that lacks context information for writing test.
|
||||
- The list should contain the most important symbols and should not exceed 10 items.
|
||||
|
||||
JSON Format Example:
|
||||
{{
|
||||
"key_symbols": ["<symbol 1>", "<symbol 2>", "<symbol 3>",...]
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def get_recommended_symbols(
|
||||
func_to_test: FuncToTest, known_context: Optional[List] = None
|
||||
) -> List[str]:
|
||||
known_context = known_context or []
|
||||
context_content = "\n\n".join([str(c) for c in known_context])
|
||||
|
||||
msg = recommend_symbol_context_prompt.format(
|
||||
function_content=func_to_test.func_content,
|
||||
context_content=context_content,
|
||||
function_name=func_to_test.func_name,
|
||||
file_path=func_to_test.file_path,
|
||||
)
|
||||
|
||||
response = create_chat_completion_content(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": msg}],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
key_symbols = json.loads(response).get("key_symbols", [])
|
||||
|
||||
return key_symbols
|
221
unit_tests/find_context.py
Normal file
221
unit_tests/find_context.py
Normal file
@ -0,0 +1,221 @@
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from assistants.recommend_test_context import get_recommended_symbols
|
||||
from model import FuncToTest
|
||||
from tools.symbol_util import (
|
||||
find_symbol_nodes,
|
||||
get_symbol_content,
|
||||
locate_symbol_by_name,
|
||||
split_tokens,
|
||||
)
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from libs.ide_services import IDEService, Location, SymbolNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
file_path: str # relative path to repo root
|
||||
content: str
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.file_path, self.content))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"file path:`{self.file_path}`\n```\n{self.content}\n```"
|
||||
|
||||
|
||||
def _extract_referenced_symbols_context(
|
||||
func_to_test: FuncToTest, symbols: List[SymbolNode], depth: int = 0
|
||||
) -> Dict[str, List[Context]]:
|
||||
"""
|
||||
Extract context of the document symbols referenced in the function.
|
||||
Exclude the function itself and symbols whose depth is greater than the specified depth.
|
||||
"""
|
||||
referenced_symbols_context = defaultdict(list)
|
||||
func_content = func_to_test.func_content
|
||||
referenced_symbols = []
|
||||
stack = [(s, 0) for s in symbols]
|
||||
|
||||
while stack:
|
||||
s, d = stack.pop()
|
||||
if d > depth:
|
||||
continue
|
||||
|
||||
if s.name == func_to_test.func_name:
|
||||
# Skip the function itself and its children
|
||||
continue
|
||||
|
||||
if s.name in func_content:
|
||||
# Use simple string matching for now
|
||||
referenced_symbols.append(s)
|
||||
|
||||
stack.extend((c, depth + 1) for c in reversed(s.children))
|
||||
|
||||
# Get the content of the symbols
|
||||
for s in referenced_symbols:
|
||||
content = get_symbol_content(s, file_content=func_to_test.file_content)
|
||||
context = Context(file_path=func_to_test.file_path, content=content)
|
||||
referenced_symbols_context[s.name].append(context)
|
||||
return referenced_symbols_context
|
||||
|
||||
|
||||
def _find_children_symbols_type_def_context(
|
||||
func_to_test: FuncToTest, func_symbol: SymbolNode
|
||||
) -> Dict[str, List[Context]]:
|
||||
"""
|
||||
Find the type definitions of the symbols in the function.
|
||||
"""
|
||||
type_defs: Dict[str, List[Context]] = defaultdict(list)
|
||||
|
||||
client = IDEService()
|
||||
abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path)
|
||||
|
||||
type_def_locations: Dict[str, Set[Location]] = defaultdict(set)
|
||||
# find type definitions for symbols in the function
|
||||
stack = func_symbol.children[:]
|
||||
while stack:
|
||||
s = stack.pop()
|
||||
|
||||
locations = client.find_type_def_locations(
|
||||
abs_path, s.range.start.line, s.range.start.character
|
||||
)
|
||||
for loc in locations:
|
||||
# check if loc.abspath is in func_to_test.repo_root
|
||||
if not loc.abspath.startswith(func_to_test.repo_root):
|
||||
# skip, not in the repo
|
||||
continue
|
||||
if loc.abspath == abs_path:
|
||||
# skip, the symbol is in the same file
|
||||
continue
|
||||
|
||||
type_def_locations[s.name].add(loc)
|
||||
|
||||
stack.extend(s.children)
|
||||
|
||||
# Get the content of the type definitions
|
||||
for symbol_name, locations in type_def_locations.items():
|
||||
for loc in locations:
|
||||
symbols = client.get_document_symbols(loc.abspath)
|
||||
targets = find_symbol_nodes(symbols, line=loc.range.start.line)
|
||||
for t, _ in targets:
|
||||
content = get_symbol_content(t, abspath=loc.abspath)
|
||||
relpath = os.path.relpath(loc.abspath, func_to_test.repo_root)
|
||||
context = Context(file_path=relpath, content=content)
|
||||
type_defs[symbol_name].append(context)
|
||||
|
||||
return type_defs
|
||||
|
||||
|
||||
def _extract_recommended_symbols_context(
|
||||
func_to_test: FuncToTest, symbol_names: List[str]
|
||||
) -> Dict[str, List[Context]]:
|
||||
"""
|
||||
Extract context of the given symbol names.
|
||||
"""
|
||||
abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path)
|
||||
client = IDEService()
|
||||
|
||||
# symbol name -> a list of context
|
||||
recommended_symbols: Dict[str, List[Context]] = defaultdict(list)
|
||||
|
||||
# Will try to find both Definition and Type Definition for a symbol
|
||||
symbol_def_locations: Dict[str, Set[Location]] = {}
|
||||
|
||||
for symbol_name in symbol_names:
|
||||
def_locs = set()
|
||||
|
||||
symbol_tokens = split_tokens(symbol_name)
|
||||
# Use the last token to locate the symbol as an approximation
|
||||
last_token = None
|
||||
last_char = -1
|
||||
for s, chars in symbol_tokens.items():
|
||||
for c in chars:
|
||||
if c > last_char:
|
||||
last_char = c
|
||||
last_token = s
|
||||
|
||||
# locate the symbol in the file
|
||||
positions = locate_symbol_by_name(last_token, abs_path)
|
||||
|
||||
for pos in positions:
|
||||
type_locations = client.find_type_def_locations(abs_path, pos.line, pos.character)
|
||||
def_locations = client.find_def_locations(abs_path, pos.line, pos.character)
|
||||
locations = type_locations + def_locations
|
||||
for loc in locations:
|
||||
# check if loc.abspath is in func_to_test.repo_root
|
||||
if not loc.abspath.startswith(func_to_test.repo_root):
|
||||
# skip, not in the repo
|
||||
continue
|
||||
|
||||
def_locs.add(loc)
|
||||
symbol_def_locations[symbol_name] = def_locs
|
||||
|
||||
# Get the content of the found definitions
|
||||
for symbol_name, locations in symbol_def_locations.items():
|
||||
for loc in locations:
|
||||
# NOTE: further improvement is needed to
|
||||
# get the symbol node of function with decorator in Python
|
||||
symbols = client.get_document_symbols(loc.abspath)
|
||||
# targets = find_symbol_nodes(symbols, name=symbol_name, line=loc.range.start.line)
|
||||
targets = find_symbol_nodes(symbols, line=loc.range.start.line)
|
||||
|
||||
for t, _ in targets:
|
||||
content = get_symbol_content(t, abspath=loc.abspath)
|
||||
relpath = os.path.relpath(loc.abspath, func_to_test.repo_root)
|
||||
context = Context(file_path=relpath, content=content)
|
||||
recommended_symbols[symbol_name].append(context)
|
||||
|
||||
return recommended_symbols
|
||||
|
||||
|
||||
def find_symbol_context_by_static_analysis(
|
||||
func_to_test: FuncToTest, chat_language: str
|
||||
) -> Dict[str, List[Context]]:
|
||||
"""
|
||||
Find the context of symbols in the function to test by static analysis.
|
||||
"""
|
||||
abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path)
|
||||
client = IDEService()
|
||||
|
||||
# symbol name -> a list of context
|
||||
symbol_context: Dict[str, List[Context]] = defaultdict(list)
|
||||
|
||||
# Get all symbols in the file
|
||||
doc_symbols = client.get_document_symbols(abs_path)
|
||||
# Find the symbol of the function to test
|
||||
func_symbols = find_symbol_nodes(
|
||||
doc_symbols, name=func_to_test.func_name, line=func_to_test.func_start_line
|
||||
)
|
||||
if not func_symbols:
|
||||
return symbol_context
|
||||
|
||||
func_symbol, func_depth = func_symbols[0]
|
||||
|
||||
context_by_reference = _extract_referenced_symbols_context(
|
||||
func_to_test, doc_symbols, depth=func_depth
|
||||
)
|
||||
context_by_type_def = _find_children_symbols_type_def_context(func_to_test, func_symbol)
|
||||
|
||||
symbol_context.update(context_by_reference)
|
||||
symbol_context.update(context_by_type_def)
|
||||
|
||||
return symbol_context
|
||||
|
||||
|
||||
def find_symbol_context_of_llm_recommendation(
|
||||
func_to_test: FuncToTest, known_context: Optional[List[Context]] = None
|
||||
) -> Dict[str, List[Context]]:
|
||||
"""
|
||||
Find the context of symbols recommended by LLM.
|
||||
"""
|
||||
recommended_symbols = get_recommended_symbols(func_to_test, known_context)
|
||||
|
||||
recommended_context = _extract_recommended_symbols_context(func_to_test, recommended_symbols)
|
||||
|
||||
return recommended_context
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import click
|
||||
|
||||
@ -8,6 +8,11 @@ sys.path.append(os.path.dirname(__file__))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
||||
|
||||
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
|
||||
from find_context import (
|
||||
Context,
|
||||
find_symbol_context_by_static_analysis,
|
||||
find_symbol_context_of_llm_recommendation,
|
||||
)
|
||||
from find_reference_tests import find_reference_tests
|
||||
from i18n import TUILanguage, get_translation
|
||||
from ide_services import IDEService # noqa: E402
|
||||
@ -38,14 +43,22 @@ class UnitTestsWorkflow:
|
||||
"""
|
||||
Run the workflow to generate unit tests.
|
||||
"""
|
||||
cases, files = self.step1_propose_cases_and_reference_files()
|
||||
symbol_context = self.step1_find_symbol_context()
|
||||
contexts = set()
|
||||
for _, v in symbol_context.items():
|
||||
contexts.update(v)
|
||||
|
||||
cases, files = self.step2_edit_cases_and_reference_files(cases, files)
|
||||
cases, files = self.step2_propose_cases_and_reference_files(list(contexts))
|
||||
|
||||
self.step3_write_and_print_tests(cases, files)
|
||||
res = self.step3_edit_cases_and_reference_files(cases, files)
|
||||
cases = res[0]
|
||||
files = res[1]
|
||||
|
||||
def step1_propose_cases_and_reference_files(
|
||||
self.step4_write_and_print_tests(cases, files, list(contexts))
|
||||
|
||||
def step2_propose_cases_and_reference_files(
|
||||
self,
|
||||
contexts: List[Context],
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Propose test cases and reference files for a specified function.
|
||||
@ -63,6 +76,7 @@ class UnitTestsWorkflow:
|
||||
test_cases = propose_test(
|
||||
user_prompt=self.user_prompt,
|
||||
func_to_test=self.func_to_test,
|
||||
contexts=contexts,
|
||||
chat_language=self.tui_lang.chat_language,
|
||||
)
|
||||
|
||||
@ -78,7 +92,7 @@ class UnitTestsWorkflow:
|
||||
|
||||
return test_cases, reference_files
|
||||
|
||||
def step2_edit_cases_and_reference_files(
|
||||
def step3_edit_cases_and_reference_files(
|
||||
self, test_cases: List[str], reference_files: List[str]
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
@ -165,10 +179,48 @@ class UnitTestsWorkflow:
|
||||
|
||||
return cases, valid_files
|
||||
|
||||
def step3_write_and_print_tests(
|
||||
def step1_find_symbol_context(self) -> Dict[str, List[Context]]:
|
||||
symbol_context = find_symbol_context_by_static_analysis(
|
||||
self.func_to_test, self.tui_lang.chat_language
|
||||
)
|
||||
|
||||
# with Step("Symbol context"):
|
||||
# for k, v in symbol_context.items():
|
||||
# print(f"\n- {k}: ")
|
||||
# for item in v:
|
||||
# print(f"{item.file_path}\n{item.content}")
|
||||
|
||||
known_context_for_llm: List[Context] = []
|
||||
if self.func_to_test.container_content is not None:
|
||||
known_context_for_llm.append(
|
||||
Context(
|
||||
file_path=self.func_to_test.file_path,
|
||||
content=self.func_to_test.container_content,
|
||||
)
|
||||
)
|
||||
known_context_for_llm += list(
|
||||
{item for sublist in list(symbol_context.values()) for item in sublist}
|
||||
)
|
||||
|
||||
recommended_context = find_symbol_context_of_llm_recommendation(
|
||||
self.func_to_test, known_context_for_llm
|
||||
)
|
||||
|
||||
# with Step("Recommended context"):
|
||||
# for k, v in recommended_context.items():
|
||||
# print(f"\n- {k}: ")
|
||||
# for item in v:
|
||||
# print(f"{item.file_path}\n{item.content}")
|
||||
|
||||
symbol_context.update(recommended_context)
|
||||
|
||||
return symbol_context
|
||||
|
||||
def step4_write_and_print_tests(
|
||||
self,
|
||||
cases: List[str],
|
||||
ref_files: List[str],
|
||||
symbol_contexts: List[Context],
|
||||
):
|
||||
"""
|
||||
Write and print tests.
|
||||
@ -179,6 +231,7 @@ class UnitTestsWorkflow:
|
||||
func_to_test=self.func_to_test,
|
||||
test_cases=cases,
|
||||
reference_files=ref_files,
|
||||
symbol_contexts=symbol_contexts,
|
||||
chat_language=self.tui_lang.chat_language,
|
||||
)
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from find_context import Context
|
||||
from model import FuncToTest, TokenBudgetExceededException
|
||||
from openai_util import create_chat_completion_content
|
||||
from prompts import PROPOSE_TEST_PROMPT
|
||||
@ -16,6 +17,7 @@ TOKEN_BUDGET = int(16000 * 0.9)
|
||||
def _mk_user_msg(
|
||||
user_prompt: str,
|
||||
func_to_test: FuncToTest,
|
||||
contexts: List[Context],
|
||||
chat_language: str,
|
||||
) -> str:
|
||||
"""
|
||||
@ -28,6 +30,12 @@ def _mk_user_msg(
|
||||
if func_to_test.container_content is not None:
|
||||
class_content = f"class code\n```\n{func_to_test.container_content}\n```\n"
|
||||
|
||||
context_content = ""
|
||||
if contexts:
|
||||
context_content = "\n\nrelevant context\n\n"
|
||||
context_content += "\n\n".join([str(c) for c in contexts])
|
||||
context_content += "\n\n"
|
||||
|
||||
# Prepare a list of user messages to fit the token budget
|
||||
# by adjusting the relevant content
|
||||
relevant_content_fmt = partial(
|
||||
@ -37,6 +45,10 @@ def _mk_user_msg(
|
||||
file_path=func_to_test.file_path,
|
||||
chat_language=chat_language,
|
||||
)
|
||||
# 0. func content & class content & context content
|
||||
msg_0 = relevant_content_fmt(
|
||||
relevant_content="\n".join([func_content, class_content, context_content]),
|
||||
)
|
||||
# 1. func content & class content
|
||||
msg_1 = relevant_content_fmt(
|
||||
relevant_content="\n".join([func_content, class_content]),
|
||||
@ -46,7 +58,7 @@ def _mk_user_msg(
|
||||
relevant_content=func_content,
|
||||
)
|
||||
|
||||
prioritized_msgs = [msg_1, msg_2]
|
||||
prioritized_msgs = [msg_0, msg_1, msg_2]
|
||||
|
||||
for msg in prioritized_msgs:
|
||||
token_count = len(encoding.encode(msg, disallowed_special=()))
|
||||
@ -63,6 +75,7 @@ def _mk_user_msg(
|
||||
def propose_test(
|
||||
user_prompt: str,
|
||||
func_to_test: FuncToTest,
|
||||
contexts: Optional[List[Context]] = None,
|
||||
chat_language: str = "English",
|
||||
) -> List[str]:
|
||||
"""Propose test cases for a specified function based on a user prompt
|
||||
@ -76,9 +89,11 @@ def propose_test(
|
||||
Returns:
|
||||
List[str]: A list of test case descriptions.
|
||||
"""
|
||||
contexts = contexts or []
|
||||
user_msg = _mk_user_msg(
|
||||
user_prompt=user_prompt,
|
||||
func_to_test=func_to_test,
|
||||
contexts=contexts,
|
||||
chat_language=chat_language,
|
||||
)
|
||||
|
||||
|
107
unit_tests/tools/symbol_util.py
Normal file
107
unit_tests/tools/symbol_util.py
Normal file
@ -0,0 +1,107 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from tools.file_util import retrieve_file_content
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from libs.ide_services import Position, SymbolNode
|
||||
|
||||
|
||||
def split_tokens(text: str) -> Dict[str, List[int]]:
|
||||
"""
|
||||
Split a line of text into tokens.
|
||||
Return a dictionary of token -> character numbers.
|
||||
|
||||
Not a perfect implementation, but may be enough for now.
|
||||
"""
|
||||
matches = re.finditer(r"\b\w+\b", text)
|
||||
result = defaultdict(list)
|
||||
for match in matches:
|
||||
token = match.group()
|
||||
result[token].append(match.start())
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def locate_symbol_by_name(
|
||||
symbol_name: str, abspath: str, line_numbers: Optional[List[int]] = None
|
||||
) -> List[Position]:
|
||||
"""
|
||||
Find the locations of the specified symbol in the specified file.
|
||||
Line and column numbers are 0-based.
|
||||
|
||||
symbol_name: The name of the symbol to find.
|
||||
abspath: The absolute path to the file to search.
|
||||
line_numbers: The line numbers to search for the symbol.
|
||||
If None, search the entire file.
|
||||
|
||||
return: a list of Position
|
||||
"""
|
||||
line_set = set(line_numbers) if line_numbers else None
|
||||
|
||||
positions: List[Position] = []
|
||||
with open(abspath, "r") as file:
|
||||
for i, line in enumerate(file):
|
||||
if line_set and i not in line_set:
|
||||
continue
|
||||
|
||||
tokens = split_tokens(line)
|
||||
chars = tokens.get(symbol_name, [])
|
||||
for char in chars:
|
||||
positions.append(Position(line=i, character=char))
|
||||
|
||||
return positions
|
||||
|
||||
|
||||
def find_symbol_nodes(
|
||||
symbols: List[SymbolNode], name: Optional[str] = None, line: Optional[int] = None
|
||||
) -> List[Tuple[SymbolNode, int]]:
|
||||
"""
|
||||
Find the symbols with the specified name and line number.
|
||||
|
||||
return: a list of tuples (symbol, depth)
|
||||
"""
|
||||
assert name is not None or line is not None
|
||||
|
||||
res = []
|
||||
stack = [(s, 0) for s in symbols]
|
||||
while stack:
|
||||
symbol, depth = stack.pop()
|
||||
flag = True
|
||||
if name and symbol.name != name:
|
||||
flag = False
|
||||
if line and symbol.range.start.line != line:
|
||||
flag = False
|
||||
|
||||
if flag:
|
||||
res.append((symbol, depth))
|
||||
else:
|
||||
stack.extend((c, depth + 1) for c in reversed(symbol.children))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_symbol_content(
|
||||
symbol: SymbolNode,
|
||||
file_content: Optional[str] = None,
|
||||
abspath: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the content of the symbol in the file.
|
||||
"""
|
||||
if file_content is None and abspath is None:
|
||||
raise ValueError("Either file_content or abspath should be provided")
|
||||
|
||||
if file_content is None:
|
||||
file_content = retrieve_file_content(abspath, None)
|
||||
|
||||
lines = file_content.split("\n")
|
||||
|
||||
content = lines[symbol.range.start.line : symbol.range.end.line]
|
||||
content.append(lines[symbol.range.end.line][: symbol.range.end.character])
|
||||
|
||||
return "\n".join(content)
|
@ -1,6 +1,7 @@
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
from find_context import Context
|
||||
from model import FuncToTest, TokenBudgetExceededException
|
||||
from openai_util import create_chat_completion_chunks
|
||||
from prompts import WRITE_TESTS_PROMPT
|
||||
@ -18,6 +19,8 @@ def _mk_write_tests_msg(
|
||||
test_cases: List[str],
|
||||
chat_language: str,
|
||||
reference_files: Optional[List[str]] = None,
|
||||
# context_files: Optional[List[str]] = None,
|
||||
symbol_contexts: Optional[List[Context]] = None,
|
||||
) -> Optional[str]:
|
||||
encoding = get_encoding(ENCODING)
|
||||
|
||||
@ -39,6 +42,19 @@ def _mk_write_tests_msg(
|
||||
if func_to_test.container_content is not None:
|
||||
class_content = f"\nclass code\n```\n{func_to_test.container_content}\n```\n"
|
||||
|
||||
context_content = ""
|
||||
if symbol_contexts:
|
||||
context_content += "\n\nrelevant context\n\n"
|
||||
context_content += "\n\n".join([str(c) for c in symbol_contexts])
|
||||
context_content += "\n\n"
|
||||
|
||||
# if context_files:
|
||||
# context_content += "\n\nrelevant context files\n\n"
|
||||
# for i, fp in enumerate(context_files, 1):
|
||||
# context_file_content = retrieve_file_content(fp, root_path)
|
||||
# context_content += f"{i}. {fp}\n\n"
|
||||
# context_content += f"```{context_file_content}```\n\n"
|
||||
|
||||
# Prepare a list of user messages to fit the token budget
|
||||
# by adjusting the relevant content and reference content
|
||||
content_fmt = partial(
|
||||
@ -48,6 +64,13 @@ def _mk_write_tests_msg(
|
||||
test_cases_str=test_cases_str,
|
||||
chat_language=chat_language,
|
||||
)
|
||||
|
||||
# NOTE: adjust symbol_context content more flexibly if needed
|
||||
msg_0 = content_fmt(
|
||||
relevant_content="\n".join([func_content, class_content, context_content]),
|
||||
reference_content=reference_content,
|
||||
)
|
||||
|
||||
# 1. func content & class content & reference file content
|
||||
msg_1 = content_fmt(
|
||||
relevant_content="\n".join([func_content, class_content]),
|
||||
@ -64,7 +87,7 @@ def _mk_write_tests_msg(
|
||||
reference_content="",
|
||||
)
|
||||
|
||||
prioritized_msgs = [msg_1, msg_2, msg_3]
|
||||
prioritized_msgs = [msg_0, msg_1, msg_2, msg_3]
|
||||
|
||||
for msg in prioritized_msgs:
|
||||
tokens = len(encoding.encode(msg, disallowed_special=()))
|
||||
@ -83,6 +106,7 @@ def write_and_print_tests(
|
||||
func_to_test: FuncToTest,
|
||||
test_cases: List[str],
|
||||
reference_files: Optional[List[str]] = None,
|
||||
symbol_contexts: Optional[List[Context]] = None,
|
||||
chat_language: str = "English",
|
||||
) -> None:
|
||||
user_msg = _mk_write_tests_msg(
|
||||
@ -90,6 +114,7 @@ def write_and_print_tests(
|
||||
func_to_test=func_to_test,
|
||||
test_cases=test_cases,
|
||||
reference_files=reference_files,
|
||||
symbol_contexts=symbol_contexts,
|
||||
chat_language=chat_language,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user