Merge pull request #73 from devchat-ai/user-requirements

Allow user to add customized requirements for writing tests
This commit is contained in:
boob.yang 2024-03-12 09:52:54 +08:00 committed by GitHub
commit 079faae3a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 186 additions and 50 deletions

View File

@ -1,3 +1,5 @@
.PHONY: setup-dev check fix
div = $(shell printf '=%.0s' {1..120}) div = $(shell printf '=%.0s' {1..120})
setup-dev: setup-dev:

44
unit_tests/cache.py Normal file
View File

@ -0,0 +1,44 @@
import json
import os
class LocalCache:
"""
A file-based cache for storing and retrieving simple data in JSON format.
"""
DIR = "local_cache"
def __init__(self, name: str, base_path: str):
self._name = name
self._cache = {}
self._base_path = base_path
cache_dir = os.path.join(base_path, self.DIR)
os.makedirs(cache_dir, exist_ok=True)
self._cache_file = os.path.join(cache_dir, f"{name}.json")
self.load()
@property
def name(self):
return self._name
def load(self):
try:
with open(self._cache_file, "r") as f:
self._cache = json.load(f)
except FileNotFoundError:
pass
def save(self):
with open(self._cache_file, "w") as f:
json.dump(self._cache, f)
def get(self, key: str):
return self._cache.get(key)
def set(self, key: str, value):
if self._cache.get(key) != value:
self._cache[key] = value
self.save()

View File

@ -8,7 +8,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: PACKAGE VERSION\n" "Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2024-01-18 15:29+0800\n" "POT-Creation-Date: 2024-03-11 22:42+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n" "Language-Team: LANGUAGE <LL@li.org>\n"
@ -17,65 +17,77 @@ msgstr ""
"Content-Type: text/plain; charset=CHARSET\n" "Content-Type: text/plain; charset=CHARSET\n"
"Content-Transfer-Encoding: 8bit\n" "Content-Transfer-Encoding: 8bit\n"
#: main.py:53 #: main.py:79
msgid "Analyzing the function and current unit tests..." msgid "Analyzing the function and current unit tests..."
msgstr "" msgstr ""
#: main.py:84 #: main.py:116
msgid "Select test cases to generate" msgid "Select test cases to generate"
msgstr "" msgstr ""
#: main.py:89 #: main.py:121
msgid "" msgid ""
"You can add more test cases here\n" "You can add more test cases here\n"
"(Multiple cases can be separated by line breaks)" "(Multiple cases can be separated by line breaks)"
msgstr "" msgstr ""
#: main.py:95 #: main.py:127
msgid "" msgid ""
"Edit reference test file\n" "Edit reference test file\n"
"(Multiple files can be separated by line breaks)" "(Multiple files can be separated by line breaks)"
msgstr "" msgstr ""
#: main.py:112 #: main.py:134
msgid ""
"Write your customized requirements(prompts) for tests here.\n"
"(For example, what testing framework to use.)"
msgstr ""
#: main.py:153
msgid "No test case is selected. Quit generating tests." msgid "No test case is selected. Quit generating tests."
msgstr "" msgstr ""
#: main.py:129 #: main.py:176
msgid "Will generate tests for the following cases." msgid "Will generate tests for the following cases."
msgstr "" msgstr ""
#: main.py:132 #: main.py:179
msgid "" msgid ""
"\n" "\n"
"Test cases:" "Test cases:"
msgstr "" msgstr ""
#: main.py:139 #: main.py:186
msgid "" msgid ""
"\n" "\n"
"No valid reference file is provided. Will not use reference to generate " "No valid reference file is provided. Will not use reference to generate "
"tests." "tests."
msgstr "" msgstr ""
#: main.py:144 #: main.py:191
msgid "" msgid ""
"\n" "\n"
"Will use the following reference files to generate tests." "Will use the following reference files to generate tests."
msgstr "" msgstr ""
#: main.py:145 #: main.py:192
msgid "" msgid ""
"\n" "\n"
"Valid reference files:" "Valid reference files:"
msgstr "" msgstr ""
#: main.py:150 #: main.py:197
msgid "" msgid ""
"\n" "\n"
"Invalid files:" "Invalid files:"
msgstr "" msgstr ""
#: main.py:232 #: main.py:201
msgid ""
"\n"
"Customized requirements(prompts):"
msgstr ""
#: main.py:330
msgid "The function's size surpasses AI's context capacity." msgid "The function's size surpasses AI's context capacity."
msgstr "" msgstr ""

View File

@ -7,7 +7,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: PACKAGE VERSION\n" "Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2024-01-18 15:29+0800\n" "POT-Creation-Date: 2024-03-11 22:42+0800\n"
"PO-Revision-Date: 2023-12-24 16:51+0800\n" "PO-Revision-Date: 2023-12-24 16:51+0800\n"
"Last-Translator: kagami <mingjing@merico.dev>\n" "Last-Translator: kagami <mingjing@merico.dev>\n"
"Language-Team: English\n" "Language-Team: English\n"
@ -17,66 +17,78 @@ msgstr ""
"Content-Transfer-Encoding: 8bit\n" "Content-Transfer-Encoding: 8bit\n"
"Plural-Forms: nplurals=2; plural=(n != 1);\n" "Plural-Forms: nplurals=2; plural=(n != 1);\n"
#: main.py:53 #: main.py:79
msgid "Analyzing the function and current unit tests..." msgid "Analyzing the function and current unit tests..."
msgstr "" msgstr ""
#: main.py:84 #: main.py:116
msgid "Select test cases to generate" msgid "Select test cases to generate"
msgstr "" msgstr ""
#: main.py:89 #: main.py:121
msgid "" msgid ""
"You can add more test cases here\n" "You can add more test cases here\n"
"(Multiple cases can be separated by line breaks)" "(Multiple cases can be separated by line breaks)"
msgstr "" msgstr ""
#: main.py:95 #: main.py:127
msgid "" msgid ""
"Edit reference test file\n" "Edit reference test file\n"
"(Multiple files can be separated by line breaks)" "(Multiple files can be separated by line breaks)"
msgstr "" msgstr ""
#: main.py:112 #: main.py:134
msgid ""
"Write your customized requirements(prompts) for tests here.\n"
"(For example, what testing framework to use.)"
msgstr ""
#: main.py:153
msgid "No test case is selected. Quit generating tests." msgid "No test case is selected. Quit generating tests."
msgstr "" msgstr ""
#: main.py:129 #: main.py:176
msgid "Will generate tests for the following cases." msgid "Will generate tests for the following cases."
msgstr "" msgstr ""
#: main.py:132 #: main.py:179
msgid "" msgid ""
"\n" "\n"
"Test cases:" "Test cases:"
msgstr "" msgstr ""
#: main.py:139 #: main.py:186
msgid "" msgid ""
"\n" "\n"
"No valid reference file is provided. Will not use reference to generate " "No valid reference file is provided. Will not use reference to generate "
"tests." "tests."
msgstr "" msgstr ""
#: main.py:144 #: main.py:191
msgid "" msgid ""
"\n" "\n"
"Will use the following reference files to generate tests." "Will use the following reference files to generate tests."
msgstr "" msgstr ""
#: main.py:145 #: main.py:192
msgid "" msgid ""
"\n" "\n"
"Valid reference files:" "Valid reference files:"
msgstr "" msgstr ""
#: main.py:150 #: main.py:197
msgid "" msgid ""
"\n" "\n"
"Invalid files:" "Invalid files:"
msgstr "" msgstr ""
#: main.py:232 #: main.py:201
msgid ""
"\n"
"Customized requirements(prompts):"
msgstr ""
#: main.py:330
msgid "The function's size surpasses AI's context capacity." msgid "The function's size surpasses AI's context capacity."
msgstr "" msgstr ""

View File

@ -7,7 +7,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: PACKAGE VERSION\n" "Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2024-01-18 15:29+0800\n" "POT-Creation-Date: 2024-03-11 22:42+0800\n"
"PO-Revision-Date: 2023-12-24 16:51+0800\n" "PO-Revision-Date: 2023-12-24 16:51+0800\n"
"Last-Translator: kagami <mingjing@merico.dev>\n" "Last-Translator: kagami <mingjing@merico.dev>\n"
"Language-Team: Chinese\n" "Language-Team: Chinese\n"
@ -16,15 +16,15 @@ msgstr ""
"Content-Type: text/plain; charset=UTF-8\n" "Content-Type: text/plain; charset=UTF-8\n"
"Content-Transfer-Encoding: 8bit\n" "Content-Transfer-Encoding: 8bit\n"
#: main.py:53 #: main.py:79
msgid "Analyzing the function and current unit tests..." msgid "Analyzing the function and current unit tests..."
msgstr "正在分析目标函数和项目中现有的单元测试……" msgstr "正在分析目标函数和项目中现有的单元测试……"
#: main.py:84 #: main.py:116
msgid "Select test cases to generate" msgid "Select test cases to generate"
msgstr "选择测试用例以生成单元测试" msgstr "选择测试用例以生成单元测试"
#: main.py:89 #: main.py:121
msgid "" msgid ""
"You can add more test cases here\n" "You can add more test cases here\n"
"(Multiple cases can be separated by line breaks)" "(Multiple cases can be separated by line breaks)"
@ -32,7 +32,7 @@ msgstr ""
"可在输入框中添加更多测试用例\n" "可在输入框中添加更多测试用例\n"
"(多个测试用例用换行分隔)" "(多个测试用例用换行分隔)"
#: main.py:95 #: main.py:127
msgid "" msgid ""
"Edit reference test file\n" "Edit reference test file\n"
"(Multiple files can be separated by line breaks)" "(Multiple files can be separated by line breaks)"
@ -40,34 +40,48 @@ msgstr ""
"参考的测试文件\n" "参考的测试文件\n"
"(多个文件路径用换行分隔)" "(多个文件路径用换行分隔)"
#: main.py:112 #: main.py:134
msgid ""
"Write your customized requirements(prompts) for tests here.\n"
"(For example, what testing framework to use.)"
msgstr ""
"可在以下输入框中填写自定义的测试需求(提示词)\n"
"(例如指定特定的测试框架)"
#: main.py:153
msgid "No test case is selected. Quit generating tests." msgid "No test case is selected. Quit generating tests."
msgstr "未选择测试用例,退出生成单元测试。" msgstr "未选择测试用例,退出生成单元测试。"
#: main.py:129 #: main.py:176
msgid "Will generate tests for the following cases." msgid "Will generate tests for the following cases."
msgstr "将为以下用例生成测试。" msgstr "将为以下用例生成测试。"
#: main.py:132 #: main.py:179
msgid "" msgid ""
"\n" "\n"
"Test cases:" "Test cases:"
msgstr "\n测试用例" msgstr ""
"\n"
"测试用例:"
#: main.py:139 #: main.py:186
msgid "" msgid ""
"\n" "\n"
"No valid reference file is provided. Will not use reference to generate " "No valid reference file is provided. Will not use reference to generate "
"tests." "tests."
msgstr "\n没有提供合法的文件生成单元测试时将不使用参考。" msgstr ""
"\n"
"没有提供合法的文件,生成单元测试时将不使用参考。"
#: main.py:144 #: main.py:191
msgid "" msgid ""
"\n" "\n"
"Will use the following reference files to generate tests." "Will use the following reference files to generate tests."
msgstr "\n将参考以下文件生成单元测试。" msgstr ""
"\n"
"将参考以下文件生成单元测试。"
#: main.py:145 #: main.py:192
msgid "" msgid ""
"\n" "\n"
"Valid reference files:" "Valid reference files:"
@ -75,7 +89,7 @@ msgstr ""
"\n" "\n"
"参考文件:" "参考文件:"
#: main.py:150 #: main.py:197
msgid "" msgid ""
"\n" "\n"
"Invalid files:" "Invalid files:"
@ -83,7 +97,15 @@ msgstr ""
"\n" "\n"
"不合法的文件:" "不合法的文件:"
#: main.py:232 #: main.py:201
msgid ""
"\n"
"Customized requirements(prompts):"
msgstr ""
"\n"
"自定义测试需求(提示词):"
#: main.py:330
msgid "The function's size surpasses AI's context capacity." msgid "The function's size surpasses AI's context capacity."
msgstr "由于当前函数过大AI暂时无法处理。" msgstr "由于当前函数过大AI暂时无法处理。"

View File

@ -7,6 +7,7 @@ import click
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(__file__))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
from cache import LocalCache
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402 from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
from find_context import ( from find_context import (
Context, Context,
@ -25,6 +26,8 @@ from propose_test import propose_test
from tools.file_util import retrieve_file_content from tools.file_util import retrieve_file_content
from write_tests import write_and_print_tests from write_tests import write_and_print_tests
CHAT_WORKFLOW_DIR_PATH = [".chat", "workflows"]
class UnitTestsWorkflow: class UnitTestsWorkflow:
def __init__( def __init__(
@ -33,11 +36,13 @@ class UnitTestsWorkflow:
func_to_test: FuncToTest, func_to_test: FuncToTest,
repo_root: str, repo_root: str,
tui_lang: TUILanguage, tui_lang: TUILanguage,
local_cache: LocalCache,
): ):
self.user_prompt = user_prompt self.user_prompt = user_prompt
self.func_to_test = func_to_test self.func_to_test = func_to_test
self.repo_root = repo_root self.repo_root = repo_root
self.tui_lang = tui_lang self.tui_lang = tui_lang
self.local_cache = local_cache
def run(self): def run(self):
""" """
@ -50,11 +55,12 @@ class UnitTestsWorkflow:
cases, files = self.step2_propose_cases_and_reference_files(list(contexts)) cases, files = self.step2_propose_cases_and_reference_files(list(contexts))
res = self.step3_edit_cases_and_reference_files(cases, files) res = self.step3_user_interaction(cases, files)
cases = res[0] cases = res[0]
files = res[1] files = res[1]
requirements = res[2]
self.step4_write_and_print_tests(cases, files, list(contexts)) self.step4_write_and_print_tests(cases, files, list(contexts), requirements)
def step2_propose_cases_and_reference_files( def step2_propose_cases_and_reference_files(
self, self,
@ -92,13 +98,16 @@ class UnitTestsWorkflow:
return test_cases, reference_files return test_cases, reference_files
def step3_edit_cases_and_reference_files( def step3_user_interaction(
self, test_cases: List[str], reference_files: List[str] self, test_cases: List[str], reference_files: List[str]
) -> Tuple[List[str], List[str]]: ) -> Tuple[List[str], List[str], str]:
""" """
Edit test cases and reference files by user. Edit test cases and reference files by user.
Return the updated cases and valid reference files. Return:
- the updated cases
- valid reference files
- customized requirements(prompts)
""" """
_i = get_translation(self.tui_lang) _i = get_translation(self.tui_lang)
@ -118,7 +127,16 @@ class UnitTestsWorkflow:
title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"), title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"),
) )
form = Form(components=[checkbox, case_editor, ref_editor]) cached_requirements = self.local_cache.get("user_requirements") or ""
requirements_editor = TextEditor(
text=cached_requirements,
title=_i(
"Write your customized requirements(prompts) for tests here."
"\n(For example, what testing framework to use.)"
),
)
form = Form(components=[checkbox, case_editor, ref_editor, requirements_editor])
form.render() form.render()
# Check test cases # Check test cases
@ -148,6 +166,12 @@ class UnitTestsWorkflow:
except Exception: except Exception:
invalid_files.append(ref_file) invalid_files.append(ref_file)
# Get customized requirements
requirements: str = (
requirements_editor.new_text.strip() if requirements_editor.new_text else ""
)
self.local_cache.set("user_requirements", requirements)
# Print summary # Print summary
title = _i("Will generate tests for the following cases.") title = _i("Will generate tests for the following cases.")
lines = [] lines = []
@ -174,10 +198,13 @@ class UnitTestsWorkflow:
width = len(str(len(invalid_files))) width = len(str(len(invalid_files)))
lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)]) lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)])
lines.append(_i("\nCustomized requirements(prompts):"))
lines.append(requirements)
with Step(title): with Step(title):
print("\n".join(lines), flush=True) print("\n".join(lines), flush=True)
return cases, valid_files return cases, valid_files, requirements
def step1_find_symbol_context(self) -> Dict[str, List[Context]]: def step1_find_symbol_context(self) -> Dict[str, List[Context]]:
symbol_context = find_symbol_context_by_static_analysis( symbol_context = find_symbol_context_by_static_analysis(
@ -221,6 +248,7 @@ class UnitTestsWorkflow:
cases: List[str], cases: List[str],
ref_files: List[str], ref_files: List[str],
symbol_contexts: List[Context], symbol_contexts: List[Context],
user_requirements: str,
): ):
""" """
Write and print tests. Write and print tests.
@ -232,6 +260,7 @@ class UnitTestsWorkflow:
test_cases=cases, test_cases=cases,
reference_files=ref_files, reference_files=ref_files,
symbol_contexts=symbol_contexts, symbol_contexts=symbol_contexts,
user_requirements=user_requirements,
chat_language=self.tui_lang.chat_language, chat_language=self.tui_lang.chat_language,
) )
@ -268,6 +297,7 @@ def main(input: str):
repo_root = os.getcwd() repo_root = os.getcwd()
ide_lang = IDEService().ide_language() ide_lang = IDEService().ide_language()
local_cache = LocalCache("unit_tests", os.path.join(repo_root, *CHAT_WORKFLOW_DIR_PATH))
tui_lang = TUILanguage.from_str(ide_lang) tui_lang = TUILanguage.from_str(ide_lang)
_i = get_translation(tui_lang) _i = get_translation(tui_lang)
@ -287,7 +317,13 @@ def main(input: str):
) )
try: try:
workflow = UnitTestsWorkflow(user_prompt, func_to_test, repo_root, tui_lang) workflow = UnitTestsWorkflow(
user_prompt,
func_to_test,
repo_root,
tui_lang,
local_cache,
)
workflow.run() workflow.run()
except TokenBudgetExceededException as e: except TokenBudgetExceededException as e:

View File

@ -55,6 +55,8 @@ Each test case should be self-contained and executable.
Use the content of the reference test cases as a model, ensuring you use the same test framework and mock library, Use the content of the reference test cases as a model, ensuring you use the same test framework and mock library,
and apply comparable mocking strategies and best practices. and apply comparable mocking strategies and best practices.
{additional_requirements}
The target function is {function_name}, located in the file {file_path}. The target function is {function_name}, located in the file {file_path}.
Here's the relevant source code of the function: Here's the relevant source code of the function:

View File

@ -21,9 +21,12 @@ def _mk_write_tests_msg(
reference_files: Optional[List[str]] = None, reference_files: Optional[List[str]] = None,
# context_files: Optional[List[str]] = None, # context_files: Optional[List[str]] = None,
symbol_contexts: Optional[List[Context]] = None, symbol_contexts: Optional[List[Context]] = None,
user_requirements: str = "",
) -> Optional[str]: ) -> Optional[str]:
encoding = get_encoding(ENCODING) encoding = get_encoding(ENCODING)
additional_requirements = user_requirements
test_cases_str = "" test_cases_str = ""
for i, test_case in enumerate(test_cases, 1): for i, test_case in enumerate(test_cases, 1):
test_cases_str += f"{i}. {test_case}\n" test_cases_str += f"{i}. {test_case}\n"
@ -63,6 +66,7 @@ def _mk_write_tests_msg(
file_path=func_to_test.file_path, file_path=func_to_test.file_path,
test_cases_str=test_cases_str, test_cases_str=test_cases_str,
chat_language=chat_language, chat_language=chat_language,
additional_requirements=additional_requirements,
) )
# NOTE: adjust symbol_context content more flexibly if needed # NOTE: adjust symbol_context content more flexibly if needed
@ -107,6 +111,7 @@ def write_and_print_tests(
test_cases: List[str], test_cases: List[str],
reference_files: Optional[List[str]] = None, reference_files: Optional[List[str]] = None,
symbol_contexts: Optional[List[Context]] = None, symbol_contexts: Optional[List[Context]] = None,
user_requirements: str = "",
chat_language: str = "English", chat_language: str = "English",
) -> None: ) -> None:
user_msg = _mk_write_tests_msg( user_msg = _mk_write_tests_msg(
@ -115,6 +120,7 @@ def write_and_print_tests(
test_cases=test_cases, test_cases=test_cases,
reference_files=reference_files, reference_files=reference_files,
symbol_contexts=symbol_contexts, symbol_contexts=symbol_contexts,
user_requirements=user_requirements,
chat_language=chat_language, chat_language=chat_language,
) )