Merge pull request #120 from devchat-ai/ref-file-for-test
Improve finding reference file for writing test
This commit is contained in:
commit
619ac6532a
@ -1,15 +1,112 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from assistants.directory_structure.relevant_file_finder import RelevantFileFinder
|
||||
from prompts import FIND_REFERENCE_PROMPT
|
||||
from tools.file_util import verify_file_list
|
||||
from devchat.llm.openai import chat_completion_no_stream_return_json
|
||||
from llm_conf import (
|
||||
CONTEXT_SIZE,
|
||||
DEFAULT_CONTEXT_SIZE,
|
||||
DEFAULT_ENCODING,
|
||||
USE_USER_MODEL,
|
||||
USER_LLM_MODEL,
|
||||
)
|
||||
from openai_util import create_chat_completion_content
|
||||
from tools.file_util import (
|
||||
is_not_hidden,
|
||||
is_source_code,
|
||||
is_test_file,
|
||||
verify_file_list,
|
||||
)
|
||||
from tools.git_util import git_file_of_interest_filter
|
||||
from tools.tiktoken_util import get_encoding
|
||||
from tools.time_util import print_exec_time
|
||||
|
||||
MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo"
|
||||
ENCODING = (
|
||||
get_encoding(DEFAULT_ENCODING) # Use default encoding as an approximation
|
||||
if USE_USER_MODEL
|
||||
else get_encoding("cl100k_base")
|
||||
)
|
||||
TOKEN_BUDGET = int(CONTEXT_SIZE.get(MODEL, DEFAULT_CONTEXT_SIZE) * 0.95)
|
||||
|
||||
|
||||
def find_reference_tests(root_path: str, function_name: str, file_path: str) -> List[str]:
|
||||
FIND_REF_TEST_PROMPT = """
|
||||
As an advanced AI coding assistant,
|
||||
you're given the task to identify suitable reference test files that can be used as a guide
|
||||
for writing test cases for a specific function in the codebase.
|
||||
|
||||
You're provided with a list of test files in the repository.
|
||||
Infer the purpose of each test file and identify the top 3 key files
|
||||
that may be relevant to the target function and can serve as a reference for writing test cases.
|
||||
The reference could provide a clear example of best practices
|
||||
in testing functions of a similar nature.
|
||||
|
||||
The target function is {function_name}, located in the file {file_path}.
|
||||
The list of test files in the repository is as follows:
|
||||
|
||||
{test_files_str}
|
||||
|
||||
|
||||
Answer in JSON format with a list of the top 3 key file paths under the key `files`.
|
||||
Make sure each file path is from the list of test files provided above.
|
||||
|
||||
Example:
|
||||
{{
|
||||
"files": ["<file path 1>", "<file path 2>", "<file path 3>"]
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def get_test_files(repo_root: str) -> List[str]:
|
||||
"""
|
||||
Get all test files in the repository.
|
||||
"""
|
||||
root = Path(repo_root)
|
||||
is_git_interest = git_file_of_interest_filter(repo_root)
|
||||
|
||||
files = []
|
||||
for filepath in root.rglob("*"):
|
||||
relpath = filepath.relative_to(root)
|
||||
|
||||
is_candidate = (
|
||||
filepath.is_file()
|
||||
and is_not_hidden(relpath)
|
||||
and is_git_interest(relpath)
|
||||
and is_source_code(str(filepath), only_code=True)
|
||||
and is_test_file(str(relpath))
|
||||
)
|
||||
|
||||
if not is_candidate:
|
||||
continue
|
||||
|
||||
files.append(str(relpath))
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def _mk_user_msg(function_name: str, file_path: str, test_files: List[str]) -> str:
|
||||
"""
|
||||
Create a user message to be sent to the model within the token budget.
|
||||
"""
|
||||
test_files_str = "\n".join([f"- {f}" for f in test_files])
|
||||
msg = FIND_REF_TEST_PROMPT.format(
|
||||
function_name=function_name,
|
||||
file_path=file_path,
|
||||
test_files_str=test_files_str,
|
||||
)
|
||||
|
||||
# TODO: check if the message fits within the token budget
|
||||
# and adjust the content accordingly
|
||||
return msg
|
||||
|
||||
|
||||
@print_exec_time("Model response time")
|
||||
def find_reference_tests(repo_root: str, function_name: str, file_path: str) -> List[str]:
|
||||
"""Find reference tests for a specified function
|
||||
|
||||
Args:
|
||||
root_path (str): The path to the root directory of the codebase.
|
||||
repo_root (str): The path to the root directory of the codebase.
|
||||
function_name (str): The name of the function to generate test cases for.
|
||||
file_path (str): The path to the file containing the target function
|
||||
for which test cases will be generated.
|
||||
@ -18,11 +115,39 @@ def find_reference_tests(root_path: str, function_name: str, file_path: str) ->
|
||||
List[str]: A list of paths to files that may contain a reference test
|
||||
for the specified function.
|
||||
"""
|
||||
finder = RelevantFileFinder(root_path=root_path)
|
||||
objective = FIND_REFERENCE_PROMPT.format(
|
||||
test_files = get_test_files(repo_root)
|
||||
|
||||
user_msg = _mk_user_msg(
|
||||
function_name=function_name,
|
||||
file_path=file_path,
|
||||
test_files=test_files,
|
||||
)
|
||||
|
||||
test_paths = finder.analyze(objective)
|
||||
return verify_file_list(test_paths, root_path)
|
||||
json_res = {}
|
||||
if USE_USER_MODEL:
|
||||
# Use the wrapped api parameters
|
||||
json_res = (
|
||||
chat_completion_no_stream_return_json(
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
llm_config={
|
||||
"model": MODEL,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
else:
|
||||
# Use the openai api parameters
|
||||
content = create_chat_completion_content(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.1,
|
||||
)
|
||||
json_res = json.loads(content)
|
||||
|
||||
files = json_res.get("files", [])
|
||||
ref_files = verify_file_list(files, repo_root)
|
||||
|
||||
return ref_files
|
||||
|
@ -76,6 +76,7 @@ Basic requirements
|
||||
5. If two or more test cases share the same setup, you should reuse the setup code.
|
||||
6. If two or more test cases share the same test logic, you should reuse the test logic.
|
||||
7. Use TODO comments or FIXME comments to indicate any missing parts of the code or any parts that need to be improved.
|
||||
8. Write as much production-ready code as possible, leaving less TODO comments or FIXME comments.
|
||||
|
||||
{additional_requirements}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def retrieve_file_content(file_path: str, root_path: str) -> str:
|
||||
@ -154,3 +155,47 @@ def is_source_code(file_name: str, only_code=False) -> bool:
|
||||
_, extension = os.path.splitext(file_name)
|
||||
|
||||
return extension in source_code_extensions
|
||||
|
||||
|
||||
DEFAULT_TEST_REGS = [r"^(.+/)*[Tt]ests?/"] # C, C++, OBJC
|
||||
TEST_PATH_PATTERNS: Dict[str, List[str]] = {
|
||||
"C": DEFAULT_TEST_REGS,
|
||||
"C++": DEFAULT_TEST_REGS,
|
||||
"Objective-C": DEFAULT_TEST_REGS,
|
||||
# Gradle https://docs.gradle.org/current/userguide/java_testing.html#sec:test_detection
|
||||
"Java": [r"^(.+/)*src/test/.*Tests?.java$"],
|
||||
# jest
|
||||
"JavaScript": [r"(.+/)*(__[Tt]ests__/.*|((.*\.)?(test|spec)))\.[jt]sx?$"],
|
||||
# pytest https://docs.pytest.org/en/stable/goodpractices.html#conventions-for-python-test-discovery
|
||||
"Python": [r"(.*_test|.*/?test_[^/]*)\.py$"],
|
||||
"Ruby": [r"^(.+/)*(spec/.*_spec.rb|test/.*_test.rb)$"],
|
||||
# golang, from `go help test`
|
||||
"Go": [r"^(.+/)*[^_\.][^/]*_test.go$"],
|
||||
"PHP": [r"^(.+/)*[Tt]ests?/(.+/)*([^/]*[Tt]ests?\.php|[Ff]ixtures?/(.+/)*.+\.php)"],
|
||||
"Kotlin": [r"^(.+/)*src/test/.*Tests?.kt$"],
|
||||
"C#": [r"^(.+/)[^/]+[Tt]ests?.cs$"],
|
||||
"Swift": [r"^(.+/)*[^/]*Tests?.swift"],
|
||||
"Scala": [r"^(.+/)*src/test/.*(scala|sc)"],
|
||||
"Dart": [r"^(.+/)*[Tt]ests?/(.+/)*[^/]*[Tt]ests?.dart"],
|
||||
"Lua": [r"^(.+/)*(specs?/.*_spec|tests?/(.*_test|test_[^/]*))\.lua$"],
|
||||
}
|
||||
LANG_TEST_REGS: Dict[str, List] = {
|
||||
k: [re.compile(r) for r in v] for k, v in TEST_PATH_PATTERNS.items()
|
||||
}
|
||||
|
||||
|
||||
def is_test_file(file_path: str) -> bool:
|
||||
"""
|
||||
Check if a given file is a test file based on its path.
|
||||
|
||||
Args:
|
||||
file_path (str): The path to the file to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the file is a test file, False otherwise.
|
||||
"""
|
||||
for _, regs in LANG_TEST_REGS.items():
|
||||
for reg in regs:
|
||||
if reg.match(file_path):
|
||||
return True
|
||||
return False
|
||||
|
Loading…
x
Reference in New Issue
Block a user