Include file path info in Context
This commit is contained in:
parent
f13ae6530b
commit
1d9cc2f85a
@ -25,9 +25,9 @@ Here is the source code of the function:
|
||||
|
||||
And here are some context information that might help you write the test cases:
|
||||
|
||||
```
|
||||
|
||||
{context_content}
|
||||
```
|
||||
|
||||
|
||||
Do you think the context information is enough?
|
||||
If the information is insufficient, recommend which symbols or types you need to know more about.
|
||||
@ -46,10 +46,10 @@ JSON Format Example:
|
||||
|
||||
|
||||
def get_recommended_symbols(
|
||||
func_to_test: FuncToTest, known_context: Optional[List[str]] = None
|
||||
func_to_test: FuncToTest, known_context: Optional[List] = None
|
||||
) -> List[str]:
|
||||
known_context = known_context or []
|
||||
context_content = "\n\n".join(known_context)
|
||||
context_content = "\n\n".join([str(c) for c in known_context])
|
||||
|
||||
msg = recommend_symbol_context_prompt.format(
|
||||
function_content=func_to_test.func_content,
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from assistants.recommend_test_context import get_recommended_symbols
|
||||
@ -17,9 +18,21 @@ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from libs.ide_services import IDEService, Location, SymbolNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
file_path: str # relative path to repo root
|
||||
content: str
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.file_path, self.content))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"file path:`{self.file_path}`\n```\n{self.content}\n```"
|
||||
|
||||
|
||||
def _extract_referenced_symbols_context(
|
||||
func_to_test: FuncToTest, symbols: List[SymbolNode], depth: int = 0
|
||||
) -> Dict[str, List[str]]:
|
||||
) -> Dict[str, List[Context]]:
|
||||
"""
|
||||
Extract context of the document symbols referenced in the function.
|
||||
Exclude the function itself and symbols whose depth is greater than the specified depth.
|
||||
@ -47,15 +60,18 @@ def _extract_referenced_symbols_context(
|
||||
# Get the content of the symbols
|
||||
for s in referenced_symbols:
|
||||
content = get_symbol_content(s, file_content=func_to_test.file_content)
|
||||
referenced_symbols_context[s.name].append(content)
|
||||
context = Context(file_path=func_to_test.file_path, content=content)
|
||||
referenced_symbols_context[s.name].append(context)
|
||||
return referenced_symbols_context
|
||||
|
||||
|
||||
def _find_children_symbols_type_def_context(func_to_test: FuncToTest, func_symbol: SymbolNode):
|
||||
def _find_children_symbols_type_def_context(
|
||||
func_to_test: FuncToTest, func_symbol: SymbolNode
|
||||
) -> Dict[str, List[Context]]:
|
||||
"""
|
||||
Find the type definitions of the symbols in the function.
|
||||
"""
|
||||
type_defs: Dict[str, List[str]] = defaultdict(list)
|
||||
type_defs: Dict[str, List[Context]] = defaultdict(list)
|
||||
|
||||
client = IDEService()
|
||||
abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path)
|
||||
@ -89,22 +105,24 @@ def _find_children_symbols_type_def_context(func_to_test: FuncToTest, func_symbo
|
||||
targets = find_symbol_nodes(symbols, line=loc.range.start.line)
|
||||
for t, _ in targets:
|
||||
content = get_symbol_content(t, abspath=loc.abspath)
|
||||
type_defs[symbol_name].append(content)
|
||||
relpath = os.path.relpath(loc.abspath, func_to_test.repo_root)
|
||||
context = Context(file_path=relpath, content=content)
|
||||
type_defs[symbol_name].append(context)
|
||||
|
||||
return type_defs
|
||||
|
||||
|
||||
def _extract_recommended_symbols_context(
|
||||
func_to_test: FuncToTest, symbol_names: List[str]
|
||||
) -> Dict[str, List[str]]:
|
||||
) -> Dict[str, List[Context]]:
|
||||
"""
|
||||
Extract context of the given symbol names.
|
||||
"""
|
||||
abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path)
|
||||
client = IDEService()
|
||||
|
||||
# symbol name -> a list of context content (source code)
|
||||
recommended_symbols: Dict[str, List[str]] = defaultdict(list)
|
||||
# symbol name -> a list of context
|
||||
recommended_symbols: Dict[str, List[Context]] = defaultdict(list)
|
||||
|
||||
# Will try to find both Definition and Type Definition for a symbol
|
||||
symbol_def_locations: Dict[str, Set[Location]] = {}
|
||||
@ -149,22 +167,24 @@ def _extract_recommended_symbols_context(
|
||||
|
||||
for t, _ in targets:
|
||||
content = get_symbol_content(t, abspath=loc.abspath)
|
||||
recommended_symbols[symbol_name].append(content)
|
||||
relpath = os.path.relpath(loc.abspath, func_to_test.repo_root)
|
||||
context = Context(file_path=relpath, content=content)
|
||||
recommended_symbols[symbol_name].append(context)
|
||||
|
||||
return recommended_symbols
|
||||
|
||||
|
||||
def find_symbol_context_by_static_analysis(
|
||||
func_to_test: FuncToTest, chat_language: str
|
||||
) -> Dict[str, List[str]]:
|
||||
) -> Dict[str, List[Context]]:
|
||||
"""
|
||||
Find the context of symbols in the function to test by static analysis.
|
||||
"""
|
||||
abs_path = os.path.join(func_to_test.repo_root, func_to_test.file_path)
|
||||
client = IDEService()
|
||||
|
||||
# symbol name -> a list of context content (code)
|
||||
symbol_context: Dict[str, List[str]] = defaultdict(list)
|
||||
# symbol name -> a list of context
|
||||
symbol_context: Dict[str, List[Context]] = defaultdict(list)
|
||||
|
||||
# Get all symbols in the file
|
||||
doc_symbols = client.get_document_symbols(abs_path)
|
||||
@ -189,8 +209,8 @@ def find_symbol_context_by_static_analysis(
|
||||
|
||||
|
||||
def find_symbol_context_of_llm_recommendation(
|
||||
func_to_test: FuncToTest, known_context: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
func_to_test: FuncToTest, known_context: Optional[List[Context]] = None
|
||||
) -> Dict[str, List[Context]]:
|
||||
"""
|
||||
Find the context of symbols recommended by LLM.
|
||||
"""
|
||||
|
@ -9,6 +9,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
||||
|
||||
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
|
||||
from find_context import (
|
||||
Context,
|
||||
find_symbol_context_by_static_analysis,
|
||||
find_symbol_context_of_llm_recommendation,
|
||||
)
|
||||
@ -43,21 +44,21 @@ class UnitTestsWorkflow:
|
||||
Run the workflow to generate unit tests.
|
||||
"""
|
||||
symbol_context = self.step1_find_symbol_context()
|
||||
context = set()
|
||||
contexts = set()
|
||||
for _, v in symbol_context.items():
|
||||
context.update(v)
|
||||
contexts.update(v)
|
||||
|
||||
cases, files = self.step2_propose_cases_and_reference_files(list(context))
|
||||
cases, files = self.step2_propose_cases_and_reference_files(list(contexts))
|
||||
|
||||
res = self.step3_edit_cases_and_reference_files(cases, files)
|
||||
cases = res[0]
|
||||
files = res[1]
|
||||
|
||||
self.step4_write_and_print_tests(cases, files, list(context))
|
||||
self.step4_write_and_print_tests(cases, files, list(contexts))
|
||||
|
||||
def step2_propose_cases_and_reference_files(
|
||||
self,
|
||||
context: List[str],
|
||||
contexts: List[Context],
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Propose test cases and reference files for a specified function.
|
||||
@ -75,7 +76,7 @@ class UnitTestsWorkflow:
|
||||
test_cases = propose_test(
|
||||
user_prompt=self.user_prompt,
|
||||
func_to_test=self.func_to_test,
|
||||
context=context,
|
||||
contexts=contexts,
|
||||
chat_language=self.tui_lang.chat_language,
|
||||
)
|
||||
|
||||
@ -178,14 +179,25 @@ class UnitTestsWorkflow:
|
||||
|
||||
return cases, valid_files
|
||||
|
||||
def step1_find_symbol_context(self) -> Dict[str, List[str]]:
|
||||
def step1_find_symbol_context(self) -> Dict[str, List[Context]]:
|
||||
symbol_context = find_symbol_context_by_static_analysis(
|
||||
self.func_to_test, self.tui_lang.chat_language
|
||||
)
|
||||
|
||||
known_context_for_llm = []
|
||||
# with Step("Symbol context"):
|
||||
# for k, v in symbol_context.items():
|
||||
# print(f"\n- {k}: ")
|
||||
# for item in v:
|
||||
# print(f"{item.file_path}\n{item.content}")
|
||||
|
||||
known_context_for_llm: List[Context] = []
|
||||
if self.func_to_test.container_content is not None:
|
||||
known_context_for_llm.append(self.func_to_test.container_content)
|
||||
known_context_for_llm.append(
|
||||
Context(
|
||||
file_path=self.func_to_test.file_path,
|
||||
content=self.func_to_test.container_content,
|
||||
)
|
||||
)
|
||||
known_context_for_llm += list(
|
||||
{item for sublist in list(symbol_context.values()) for item in sublist}
|
||||
)
|
||||
@ -194,6 +206,12 @@ class UnitTestsWorkflow:
|
||||
self.func_to_test, known_context_for_llm
|
||||
)
|
||||
|
||||
# with Step("Recommended context"):
|
||||
# for k, v in recommended_context.items():
|
||||
# print(f"\n- {k}: ")
|
||||
# for item in v:
|
||||
# print(f"{item.file_path}\n{item.content}")
|
||||
|
||||
symbol_context.update(recommended_context)
|
||||
|
||||
return symbol_context
|
||||
@ -202,7 +220,7 @@ class UnitTestsWorkflow:
|
||||
self,
|
||||
cases: List[str],
|
||||
ref_files: List[str],
|
||||
symbol_context: List[str],
|
||||
symbol_contexts: List[Context],
|
||||
):
|
||||
"""
|
||||
Write and print tests.
|
||||
@ -213,7 +231,7 @@ class UnitTestsWorkflow:
|
||||
func_to_test=self.func_to_test,
|
||||
test_cases=cases,
|
||||
reference_files=ref_files,
|
||||
symbol_context=symbol_context,
|
||||
symbol_contexts=symbol_contexts,
|
||||
chat_language=self.tui_lang.chat_language,
|
||||
)
|
||||
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
from find_context import Context
|
||||
from model import FuncToTest, TokenBudgetExceededException
|
||||
from openai_util import create_chat_completion_content
|
||||
from prompts import PROPOSE_TEST_PROMPT
|
||||
@ -16,7 +17,7 @@ TOKEN_BUDGET = int(16000 * 0.9)
|
||||
def _mk_user_msg(
|
||||
user_prompt: str,
|
||||
func_to_test: FuncToTest,
|
||||
context: List[str],
|
||||
contexts: List[Context],
|
||||
chat_language: str,
|
||||
) -> str:
|
||||
"""
|
||||
@ -30,10 +31,10 @@ def _mk_user_msg(
|
||||
class_content = f"class code\n```\n{func_to_test.container_content}\n```\n"
|
||||
|
||||
context_content = ""
|
||||
if context:
|
||||
context_content = "\n\nrelevant context\n```\n"
|
||||
context_content += "\n\n".join(context)
|
||||
context_content += "\n```\n"
|
||||
if contexts:
|
||||
context_content = "\n\nrelevant context\n\n"
|
||||
context_content += "\n\n".join([str(c) for c in contexts])
|
||||
context_content += "\n\n"
|
||||
|
||||
# Prepare a list of user messages to fit the token budget
|
||||
# by adjusting the relevant content
|
||||
@ -74,7 +75,7 @@ def _mk_user_msg(
|
||||
def propose_test(
|
||||
user_prompt: str,
|
||||
func_to_test: FuncToTest,
|
||||
context: Optional[List[str]] = None,
|
||||
contexts: Optional[List[Context]] = None,
|
||||
chat_language: str = "English",
|
||||
) -> List[str]:
|
||||
"""Propose test cases for a specified function based on a user prompt
|
||||
@ -88,11 +89,11 @@ def propose_test(
|
||||
Returns:
|
||||
List[str]: A list of test case descriptions.
|
||||
"""
|
||||
context = context or []
|
||||
contexts = contexts or []
|
||||
user_msg = _mk_user_msg(
|
||||
user_prompt=user_prompt,
|
||||
func_to_test=func_to_test,
|
||||
context=context,
|
||||
contexts=contexts,
|
||||
chat_language=chat_language,
|
||||
)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
from find_context import Context
|
||||
from model import FuncToTest, TokenBudgetExceededException
|
||||
from openai_util import create_chat_completion_chunks
|
||||
from prompts import WRITE_TESTS_PROMPT
|
||||
@ -19,7 +20,7 @@ def _mk_write_tests_msg(
|
||||
chat_language: str,
|
||||
reference_files: Optional[List[str]] = None,
|
||||
# context_files: Optional[List[str]] = None,
|
||||
symbol_context: Optional[List[str]] = None,
|
||||
symbol_contexts: Optional[List[Context]] = None,
|
||||
) -> Optional[str]:
|
||||
encoding = get_encoding(ENCODING)
|
||||
|
||||
@ -42,10 +43,10 @@ def _mk_write_tests_msg(
|
||||
class_content = f"\nclass code\n```\n{func_to_test.container_content}\n```\n"
|
||||
|
||||
context_content = ""
|
||||
if symbol_context:
|
||||
context_content += "\n\nrelevant context\n```\n"
|
||||
context_content += "\n\n".join(symbol_context)
|
||||
context_content += "\n```\n"
|
||||
if symbol_contexts:
|
||||
context_content += "\n\nrelevant context\n\n"
|
||||
context_content += "\n\n".join([str(c) for c in symbol_contexts])
|
||||
context_content += "\n\n"
|
||||
|
||||
# if context_files:
|
||||
# context_content += "\n\nrelevant context files\n\n"
|
||||
@ -105,7 +106,7 @@ def write_and_print_tests(
|
||||
func_to_test: FuncToTest,
|
||||
test_cases: List[str],
|
||||
reference_files: Optional[List[str]] = None,
|
||||
symbol_context: Optional[List[str]] = None,
|
||||
symbol_contexts: Optional[List[Context]] = None,
|
||||
chat_language: str = "English",
|
||||
) -> None:
|
||||
user_msg = _mk_write_tests_msg(
|
||||
@ -113,7 +114,7 @@ def write_and_print_tests(
|
||||
func_to_test=func_to_test,
|
||||
test_cases=test_cases,
|
||||
reference_files=reference_files,
|
||||
symbol_context=symbol_context,
|
||||
symbol_contexts=symbol_contexts,
|
||||
chat_language=chat_language,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user