Merge pull request #120 from devchat-ai/ref-file-for-test

Improve finding reference file for writing test
This commit is contained in:
boob.yang 2024-05-24 02:07:56 +00:00 committed by GitHub
commit 619ac6532a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 181 additions and 10 deletions

View File

@ -1,15 +1,112 @@
import json
from pathlib import Path
from typing import List
from assistants.directory_structure.relevant_file_finder import RelevantFileFinder
from prompts import FIND_REFERENCE_PROMPT
from tools.file_util import verify_file_list
from devchat.llm.openai import chat_completion_no_stream_return_json
from llm_conf import (
CONTEXT_SIZE,
DEFAULT_CONTEXT_SIZE,
DEFAULT_ENCODING,
USE_USER_MODEL,
USER_LLM_MODEL,
)
from openai_util import create_chat_completion_content
from tools.file_util import (
is_not_hidden,
is_source_code,
is_test_file,
verify_file_list,
)
from tools.git_util import git_file_of_interest_filter
from tools.tiktoken_util import get_encoding
from tools.time_util import print_exec_time
MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo"
ENCODING = (
get_encoding(DEFAULT_ENCODING) # Use default encoding as an approximation
if USE_USER_MODEL
else get_encoding("cl100k_base")
)
TOKEN_BUDGET = int(CONTEXT_SIZE.get(MODEL, DEFAULT_CONTEXT_SIZE) * 0.95)
def find_reference_tests(root_path: str, function_name: str, file_path: str) -> List[str]:
FIND_REF_TEST_PROMPT = """
As an advanced AI coding assistant,
you're given the task to identify suitable reference test files that can be used as a guide
for writing test cases for a specific function in the codebase.
You're provided with a list of test files in the repository.
Infer the purpose of each test file and identify the top 3 key files
that may be relevant to the target function and can serve as a reference for writing test cases.
The reference could provide a clear example of best practices
in testing functions of a similar nature.
The target function is {function_name}, located in the file {file_path}.
The list of test files in the repository is as follows:
{test_files_str}
Answer in JSON format with a list of the top 3 key file paths under the key `files`.
Make sure each file path is from the list of test files provided above.
Example:
{{
"files": ["<file path 1>", "<file path 2>", "<file path 3>"]
}}
"""
def get_test_files(repo_root: str) -> List[str]:
"""
Get all test files in the repository.
"""
root = Path(repo_root)
is_git_interest = git_file_of_interest_filter(repo_root)
files = []
for filepath in root.rglob("*"):
relpath = filepath.relative_to(root)
is_candidate = (
filepath.is_file()
and is_not_hidden(relpath)
and is_git_interest(relpath)
and is_source_code(str(filepath), only_code=True)
and is_test_file(str(relpath))
)
if not is_candidate:
continue
files.append(str(relpath))
return files
def _mk_user_msg(function_name: str, file_path: str, test_files: List[str]) -> str:
"""
Create a user message to be sent to the model within the token budget.
"""
test_files_str = "\n".join([f"- {f}" for f in test_files])
msg = FIND_REF_TEST_PROMPT.format(
function_name=function_name,
file_path=file_path,
test_files_str=test_files_str,
)
# TODO: check if the message fits within the token budget
# and adjust the content accordingly
return msg
@print_exec_time("Model response time")
def find_reference_tests(repo_root: str, function_name: str, file_path: str) -> List[str]:
"""Find reference tests for a specified function
Args:
root_path (str): The path to the root directory of the codebase.
repo_root (str): The path to the root directory of the codebase.
function_name (str): The name of the function to generate test cases for.
file_path (str): The path to the file containing the target function
for which test cases will be generated.
@ -18,11 +115,39 @@ def find_reference_tests(root_path: str, function_name: str, file_path: str) ->
List[str]: A list of paths to files that may contain a reference test
for the specified function.
"""
finder = RelevantFileFinder(root_path=root_path)
objective = FIND_REFERENCE_PROMPT.format(
test_files = get_test_files(repo_root)
user_msg = _mk_user_msg(
function_name=function_name,
file_path=file_path,
test_files=test_files,
)
test_paths = finder.analyze(objective)
return verify_file_list(test_paths, root_path)
json_res = {}
if USE_USER_MODEL:
# Use the wrapped api parameters
json_res = (
chat_completion_no_stream_return_json(
messages=[{"role": "user", "content": user_msg}],
llm_config={
"model": MODEL,
"temperature": 0.1,
},
)
or {}
)
else:
# Use the openai api parameters
content = create_chat_completion_content(
model=MODEL,
messages=[{"role": "user", "content": user_msg}],
response_format={"type": "json_object"},
temperature=0.1,
)
json_res = json.loads(content)
files = json_res.get("files", [])
ref_files = verify_file_list(files, repo_root)
return ref_files

View File

@ -76,6 +76,7 @@ Basic requirements
5. If two or more test cases share the same setup, you should reuse the setup code.
6. If two or more test cases share the same test logic, you should reuse the test logic.
7. Use TODO comments or FIXME comments to indicate any missing parts of the code or any parts that need to be improved.
8. Write as much production-ready code as possible, leaving less TODO comments or FIXME comments.
{additional_requirements}

View File

@ -1,6 +1,7 @@
import os
import re
from pathlib import Path
from typing import List
from typing import Dict, List
def retrieve_file_content(file_path: str, root_path: str) -> str:
@ -154,3 +155,47 @@ def is_source_code(file_name: str, only_code=False) -> bool:
_, extension = os.path.splitext(file_name)
return extension in source_code_extensions
DEFAULT_TEST_REGS = [r"^(.+/)*[Tt]ests?/"] # C, C++, OBJC
TEST_PATH_PATTERNS: Dict[str, List[str]] = {
"C": DEFAULT_TEST_REGS,
"C++": DEFAULT_TEST_REGS,
"Objective-C": DEFAULT_TEST_REGS,
# Gradle https://docs.gradle.org/current/userguide/java_testing.html#sec:test_detection
"Java": [r"^(.+/)*src/test/.*Tests?.java$"],
# jest
"JavaScript": [r"(.+/)*(__[Tt]ests__/.*|((.*\.)?(test|spec)))\.[jt]sx?$"],
# pytest https://docs.pytest.org/en/stable/goodpractices.html#conventions-for-python-test-discovery
"Python": [r"(.*_test|.*/?test_[^/]*)\.py$"],
"Ruby": [r"^(.+/)*(spec/.*_spec.rb|test/.*_test.rb)$"],
# golang, from `go help test`
"Go": [r"^(.+/)*[^_\.][^/]*_test.go$"],
"PHP": [r"^(.+/)*[Tt]ests?/(.+/)*([^/]*[Tt]ests?\.php|[Ff]ixtures?/(.+/)*.+\.php)"],
"Kotlin": [r"^(.+/)*src/test/.*Tests?.kt$"],
"C#": [r"^(.+/)[^/]+[Tt]ests?.cs$"],
"Swift": [r"^(.+/)*[^/]*Tests?.swift"],
"Scala": [r"^(.+/)*src/test/.*(scala|sc)"],
"Dart": [r"^(.+/)*[Tt]ests?/(.+/)*[^/]*[Tt]ests?.dart"],
"Lua": [r"^(.+/)*(specs?/.*_spec|tests?/(.*_test|test_[^/]*))\.lua$"],
}
LANG_TEST_REGS: Dict[str, List] = {
k: [re.compile(r) for r in v] for k, v in TEST_PATH_PATTERNS.items()
}
def is_test_file(file_path: str) -> bool:
"""
Check if a given file is a test file based on its path.
Args:
file_path (str): The path to the file to check.
Returns:
bool: True if the file is a test file, False otherwise.
"""
for _, regs in LANG_TEST_REGS.items():
for reg in regs:
if reg.match(file_path):
return True
return False