Merge pull request #73 from devchat-ai/user-requirements
Allow user to add customized requirements for writing tests
This commit is contained in:
commit
079faae3a6
2
Makefile
2
Makefile
@ -1,3 +1,5 @@
|
|||||||
|
.PHONY: setup-dev check fix
|
||||||
|
|
||||||
div = $(shell printf '=%.0s' {1..120})
|
div = $(shell printf '=%.0s' {1..120})
|
||||||
|
|
||||||
setup-dev:
|
setup-dev:
|
||||||
|
44
unit_tests/cache.py
Normal file
44
unit_tests/cache.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class LocalCache:
|
||||||
|
"""
|
||||||
|
A file-based cache for storing and retrieving simple data in JSON format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DIR = "local_cache"
|
||||||
|
|
||||||
|
def __init__(self, name: str, base_path: str):
|
||||||
|
self._name = name
|
||||||
|
self._cache = {}
|
||||||
|
self._base_path = base_path
|
||||||
|
|
||||||
|
cache_dir = os.path.join(base_path, self.DIR)
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
self._cache_file = os.path.join(cache_dir, f"{name}.json")
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
try:
|
||||||
|
with open(self._cache_file, "r") as f:
|
||||||
|
self._cache = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
with open(self._cache_file, "w") as f:
|
||||||
|
json.dump(self._cache, f)
|
||||||
|
|
||||||
|
def get(self, key: str):
|
||||||
|
return self._cache.get(key)
|
||||||
|
|
||||||
|
def set(self, key: str, value):
|
||||||
|
if self._cache.get(key) != value:
|
||||||
|
self._cache[key] = value
|
||||||
|
self.save()
|
@ -8,7 +8,7 @@ msgid ""
|
|||||||
msgstr ""
|
msgstr ""
|
||||||
"Project-Id-Version: PACKAGE VERSION\n"
|
"Project-Id-Version: PACKAGE VERSION\n"
|
||||||
"Report-Msgid-Bugs-To: \n"
|
"Report-Msgid-Bugs-To: \n"
|
||||||
"POT-Creation-Date: 2024-01-18 15:29+0800\n"
|
"POT-Creation-Date: 2024-03-11 22:42+0800\n"
|
||||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||||
"Language-Team: LANGUAGE <LL@li.org>\n"
|
"Language-Team: LANGUAGE <LL@li.org>\n"
|
||||||
@ -17,65 +17,77 @@ msgstr ""
|
|||||||
"Content-Type: text/plain; charset=CHARSET\n"
|
"Content-Type: text/plain; charset=CHARSET\n"
|
||||||
"Content-Transfer-Encoding: 8bit\n"
|
"Content-Transfer-Encoding: 8bit\n"
|
||||||
|
|
||||||
#: main.py:53
|
#: main.py:79
|
||||||
msgid "Analyzing the function and current unit tests..."
|
msgid "Analyzing the function and current unit tests..."
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:84
|
#: main.py:116
|
||||||
msgid "Select test cases to generate"
|
msgid "Select test cases to generate"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:89
|
#: main.py:121
|
||||||
msgid ""
|
msgid ""
|
||||||
"You can add more test cases here\n"
|
"You can add more test cases here\n"
|
||||||
"(Multiple cases can be separated by line breaks)"
|
"(Multiple cases can be separated by line breaks)"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:95
|
#: main.py:127
|
||||||
msgid ""
|
msgid ""
|
||||||
"Edit reference test file\n"
|
"Edit reference test file\n"
|
||||||
"(Multiple files can be separated by line breaks)"
|
"(Multiple files can be separated by line breaks)"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:112
|
#: main.py:134
|
||||||
|
msgid ""
|
||||||
|
"Write your customized requirements(prompts) for tests here.\n"
|
||||||
|
"(For example, what testing framework to use.)"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
|
#: main.py:153
|
||||||
msgid "No test case is selected. Quit generating tests."
|
msgid "No test case is selected. Quit generating tests."
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:129
|
#: main.py:176
|
||||||
msgid "Will generate tests for the following cases."
|
msgid "Will generate tests for the following cases."
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:132
|
#: main.py:179
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"Test cases:"
|
"Test cases:"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:139
|
#: main.py:186
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"No valid reference file is provided. Will not use reference to generate "
|
"No valid reference file is provided. Will not use reference to generate "
|
||||||
"tests."
|
"tests."
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:144
|
#: main.py:191
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"Will use the following reference files to generate tests."
|
"Will use the following reference files to generate tests."
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:145
|
#: main.py:192
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"Valid reference files:"
|
"Valid reference files:"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:150
|
#: main.py:197
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"Invalid files:"
|
"Invalid files:"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:232
|
#: main.py:201
|
||||||
|
msgid ""
|
||||||
|
"\n"
|
||||||
|
"Customized requirements(prompts):"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
|
#: main.py:330
|
||||||
msgid "The function's size surpasses AI's context capacity."
|
msgid "The function's size surpasses AI's context capacity."
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
@ -7,7 +7,7 @@ msgid ""
|
|||||||
msgstr ""
|
msgstr ""
|
||||||
"Project-Id-Version: PACKAGE VERSION\n"
|
"Project-Id-Version: PACKAGE VERSION\n"
|
||||||
"Report-Msgid-Bugs-To: \n"
|
"Report-Msgid-Bugs-To: \n"
|
||||||
"POT-Creation-Date: 2024-01-18 15:29+0800\n"
|
"POT-Creation-Date: 2024-03-11 22:42+0800\n"
|
||||||
"PO-Revision-Date: 2023-12-24 16:51+0800\n"
|
"PO-Revision-Date: 2023-12-24 16:51+0800\n"
|
||||||
"Last-Translator: kagami <mingjing@merico.dev>\n"
|
"Last-Translator: kagami <mingjing@merico.dev>\n"
|
||||||
"Language-Team: English\n"
|
"Language-Team: English\n"
|
||||||
@ -17,66 +17,78 @@ msgstr ""
|
|||||||
"Content-Transfer-Encoding: 8bit\n"
|
"Content-Transfer-Encoding: 8bit\n"
|
||||||
"Plural-Forms: nplurals=2; plural=(n != 1);\n"
|
"Plural-Forms: nplurals=2; plural=(n != 1);\n"
|
||||||
|
|
||||||
#: main.py:53
|
#: main.py:79
|
||||||
msgid "Analyzing the function and current unit tests..."
|
msgid "Analyzing the function and current unit tests..."
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:84
|
#: main.py:116
|
||||||
msgid "Select test cases to generate"
|
msgid "Select test cases to generate"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:89
|
#: main.py:121
|
||||||
msgid ""
|
msgid ""
|
||||||
"You can add more test cases here\n"
|
"You can add more test cases here\n"
|
||||||
"(Multiple cases can be separated by line breaks)"
|
"(Multiple cases can be separated by line breaks)"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:95
|
#: main.py:127
|
||||||
msgid ""
|
msgid ""
|
||||||
"Edit reference test file\n"
|
"Edit reference test file\n"
|
||||||
"(Multiple files can be separated by line breaks)"
|
"(Multiple files can be separated by line breaks)"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:112
|
#: main.py:134
|
||||||
|
msgid ""
|
||||||
|
"Write your customized requirements(prompts) for tests here.\n"
|
||||||
|
"(For example, what testing framework to use.)"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
|
#: main.py:153
|
||||||
msgid "No test case is selected. Quit generating tests."
|
msgid "No test case is selected. Quit generating tests."
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:129
|
#: main.py:176
|
||||||
msgid "Will generate tests for the following cases."
|
msgid "Will generate tests for the following cases."
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:132
|
#: main.py:179
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"Test cases:"
|
"Test cases:"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:139
|
#: main.py:186
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"No valid reference file is provided. Will not use reference to generate "
|
"No valid reference file is provided. Will not use reference to generate "
|
||||||
"tests."
|
"tests."
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:144
|
#: main.py:191
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"Will use the following reference files to generate tests."
|
"Will use the following reference files to generate tests."
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:145
|
#: main.py:192
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"Valid reference files:"
|
"Valid reference files:"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:150
|
#: main.py:197
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"Invalid files:"
|
"Invalid files:"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
#: main.py:232
|
#: main.py:201
|
||||||
|
msgid ""
|
||||||
|
"\n"
|
||||||
|
"Customized requirements(prompts):"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
|
#: main.py:330
|
||||||
msgid "The function's size surpasses AI's context capacity."
|
msgid "The function's size surpasses AI's context capacity."
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
|
Binary file not shown.
@ -7,7 +7,7 @@ msgid ""
|
|||||||
msgstr ""
|
msgstr ""
|
||||||
"Project-Id-Version: PACKAGE VERSION\n"
|
"Project-Id-Version: PACKAGE VERSION\n"
|
||||||
"Report-Msgid-Bugs-To: \n"
|
"Report-Msgid-Bugs-To: \n"
|
||||||
"POT-Creation-Date: 2024-01-18 15:29+0800\n"
|
"POT-Creation-Date: 2024-03-11 22:42+0800\n"
|
||||||
"PO-Revision-Date: 2023-12-24 16:51+0800\n"
|
"PO-Revision-Date: 2023-12-24 16:51+0800\n"
|
||||||
"Last-Translator: kagami <mingjing@merico.dev>\n"
|
"Last-Translator: kagami <mingjing@merico.dev>\n"
|
||||||
"Language-Team: Chinese\n"
|
"Language-Team: Chinese\n"
|
||||||
@ -16,15 +16,15 @@ msgstr ""
|
|||||||
"Content-Type: text/plain; charset=UTF-8\n"
|
"Content-Type: text/plain; charset=UTF-8\n"
|
||||||
"Content-Transfer-Encoding: 8bit\n"
|
"Content-Transfer-Encoding: 8bit\n"
|
||||||
|
|
||||||
#: main.py:53
|
#: main.py:79
|
||||||
msgid "Analyzing the function and current unit tests..."
|
msgid "Analyzing the function and current unit tests..."
|
||||||
msgstr "正在分析目标函数和项目中现有的单元测试……"
|
msgstr "正在分析目标函数和项目中现有的单元测试……"
|
||||||
|
|
||||||
#: main.py:84
|
#: main.py:116
|
||||||
msgid "Select test cases to generate"
|
msgid "Select test cases to generate"
|
||||||
msgstr "选择测试用例以生成单元测试"
|
msgstr "选择测试用例以生成单元测试"
|
||||||
|
|
||||||
#: main.py:89
|
#: main.py:121
|
||||||
msgid ""
|
msgid ""
|
||||||
"You can add more test cases here\n"
|
"You can add more test cases here\n"
|
||||||
"(Multiple cases can be separated by line breaks)"
|
"(Multiple cases can be separated by line breaks)"
|
||||||
@ -32,7 +32,7 @@ msgstr ""
|
|||||||
"可在输入框中添加更多测试用例\n"
|
"可在输入框中添加更多测试用例\n"
|
||||||
"(多个测试用例用换行分隔)"
|
"(多个测试用例用换行分隔)"
|
||||||
|
|
||||||
#: main.py:95
|
#: main.py:127
|
||||||
msgid ""
|
msgid ""
|
||||||
"Edit reference test file\n"
|
"Edit reference test file\n"
|
||||||
"(Multiple files can be separated by line breaks)"
|
"(Multiple files can be separated by line breaks)"
|
||||||
@ -40,34 +40,48 @@ msgstr ""
|
|||||||
"参考的测试文件\n"
|
"参考的测试文件\n"
|
||||||
"(多个文件路径用换行分隔)"
|
"(多个文件路径用换行分隔)"
|
||||||
|
|
||||||
#: main.py:112
|
#: main.py:134
|
||||||
|
msgid ""
|
||||||
|
"Write your customized requirements(prompts) for tests here.\n"
|
||||||
|
"(For example, what testing framework to use.)"
|
||||||
|
msgstr ""
|
||||||
|
"可在以下输入框中填写自定义的测试需求(提示词)\n"
|
||||||
|
"(例如指定特定的测试框架)"
|
||||||
|
|
||||||
|
#: main.py:153
|
||||||
msgid "No test case is selected. Quit generating tests."
|
msgid "No test case is selected. Quit generating tests."
|
||||||
msgstr "未选择测试用例,退出生成单元测试。"
|
msgstr "未选择测试用例,退出生成单元测试。"
|
||||||
|
|
||||||
#: main.py:129
|
#: main.py:176
|
||||||
msgid "Will generate tests for the following cases."
|
msgid "Will generate tests for the following cases."
|
||||||
msgstr "将为以下用例生成测试。"
|
msgstr "将为以下用例生成测试。"
|
||||||
|
|
||||||
#: main.py:132
|
#: main.py:179
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"Test cases:"
|
"Test cases:"
|
||||||
msgstr "\n测试用例:"
|
msgstr ""
|
||||||
|
"\n"
|
||||||
|
"测试用例:"
|
||||||
|
|
||||||
#: main.py:139
|
#: main.py:186
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"No valid reference file is provided. Will not use reference to generate "
|
"No valid reference file is provided. Will not use reference to generate "
|
||||||
"tests."
|
"tests."
|
||||||
msgstr "\n没有提供合法的文件,生成单元测试时将不使用参考。"
|
msgstr ""
|
||||||
|
"\n"
|
||||||
|
"没有提供合法的文件,生成单元测试时将不使用参考。"
|
||||||
|
|
||||||
#: main.py:144
|
#: main.py:191
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"Will use the following reference files to generate tests."
|
"Will use the following reference files to generate tests."
|
||||||
msgstr "\n将参考以下文件生成单元测试。"
|
msgstr ""
|
||||||
|
"\n"
|
||||||
|
"将参考以下文件生成单元测试。"
|
||||||
|
|
||||||
#: main.py:145
|
#: main.py:192
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"Valid reference files:"
|
"Valid reference files:"
|
||||||
@ -75,7 +89,7 @@ msgstr ""
|
|||||||
"\n"
|
"\n"
|
||||||
"参考文件:"
|
"参考文件:"
|
||||||
|
|
||||||
#: main.py:150
|
#: main.py:197
|
||||||
msgid ""
|
msgid ""
|
||||||
"\n"
|
"\n"
|
||||||
"Invalid files:"
|
"Invalid files:"
|
||||||
@ -83,7 +97,15 @@ msgstr ""
|
|||||||
"\n"
|
"\n"
|
||||||
"不合法的文件:"
|
"不合法的文件:"
|
||||||
|
|
||||||
#: main.py:232
|
#: main.py:201
|
||||||
|
msgid ""
|
||||||
|
"\n"
|
||||||
|
"Customized requirements(prompts):"
|
||||||
|
msgstr ""
|
||||||
|
"\n"
|
||||||
|
"自定义测试需求(提示词):"
|
||||||
|
|
||||||
|
#: main.py:330
|
||||||
msgid "The function's size surpasses AI's context capacity."
|
msgid "The function's size surpasses AI's context capacity."
|
||||||
msgstr "由于当前函数过大,AI暂时无法处理。"
|
msgstr "由于当前函数过大,AI暂时无法处理。"
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ import click
|
|||||||
sys.path.append(os.path.dirname(__file__))
|
sys.path.append(os.path.dirname(__file__))
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
||||||
|
|
||||||
|
from cache import LocalCache
|
||||||
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
|
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
|
||||||
from find_context import (
|
from find_context import (
|
||||||
Context,
|
Context,
|
||||||
@ -25,6 +26,8 @@ from propose_test import propose_test
|
|||||||
from tools.file_util import retrieve_file_content
|
from tools.file_util import retrieve_file_content
|
||||||
from write_tests import write_and_print_tests
|
from write_tests import write_and_print_tests
|
||||||
|
|
||||||
|
CHAT_WORKFLOW_DIR_PATH = [".chat", "workflows"]
|
||||||
|
|
||||||
|
|
||||||
class UnitTestsWorkflow:
|
class UnitTestsWorkflow:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -33,11 +36,13 @@ class UnitTestsWorkflow:
|
|||||||
func_to_test: FuncToTest,
|
func_to_test: FuncToTest,
|
||||||
repo_root: str,
|
repo_root: str,
|
||||||
tui_lang: TUILanguage,
|
tui_lang: TUILanguage,
|
||||||
|
local_cache: LocalCache,
|
||||||
):
|
):
|
||||||
self.user_prompt = user_prompt
|
self.user_prompt = user_prompt
|
||||||
self.func_to_test = func_to_test
|
self.func_to_test = func_to_test
|
||||||
self.repo_root = repo_root
|
self.repo_root = repo_root
|
||||||
self.tui_lang = tui_lang
|
self.tui_lang = tui_lang
|
||||||
|
self.local_cache = local_cache
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""
|
"""
|
||||||
@ -50,11 +55,12 @@ class UnitTestsWorkflow:
|
|||||||
|
|
||||||
cases, files = self.step2_propose_cases_and_reference_files(list(contexts))
|
cases, files = self.step2_propose_cases_and_reference_files(list(contexts))
|
||||||
|
|
||||||
res = self.step3_edit_cases_and_reference_files(cases, files)
|
res = self.step3_user_interaction(cases, files)
|
||||||
cases = res[0]
|
cases = res[0]
|
||||||
files = res[1]
|
files = res[1]
|
||||||
|
requirements = res[2]
|
||||||
|
|
||||||
self.step4_write_and_print_tests(cases, files, list(contexts))
|
self.step4_write_and_print_tests(cases, files, list(contexts), requirements)
|
||||||
|
|
||||||
def step2_propose_cases_and_reference_files(
|
def step2_propose_cases_and_reference_files(
|
||||||
self,
|
self,
|
||||||
@ -92,13 +98,16 @@ class UnitTestsWorkflow:
|
|||||||
|
|
||||||
return test_cases, reference_files
|
return test_cases, reference_files
|
||||||
|
|
||||||
def step3_edit_cases_and_reference_files(
|
def step3_user_interaction(
|
||||||
self, test_cases: List[str], reference_files: List[str]
|
self, test_cases: List[str], reference_files: List[str]
|
||||||
) -> Tuple[List[str], List[str]]:
|
) -> Tuple[List[str], List[str], str]:
|
||||||
"""
|
"""
|
||||||
Edit test cases and reference files by user.
|
Edit test cases and reference files by user.
|
||||||
|
|
||||||
Return the updated cases and valid reference files.
|
Return:
|
||||||
|
- the updated cases
|
||||||
|
- valid reference files
|
||||||
|
- customized requirements(prompts)
|
||||||
"""
|
"""
|
||||||
_i = get_translation(self.tui_lang)
|
_i = get_translation(self.tui_lang)
|
||||||
|
|
||||||
@ -118,7 +127,16 @@ class UnitTestsWorkflow:
|
|||||||
title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"),
|
title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"),
|
||||||
)
|
)
|
||||||
|
|
||||||
form = Form(components=[checkbox, case_editor, ref_editor])
|
cached_requirements = self.local_cache.get("user_requirements") or ""
|
||||||
|
requirements_editor = TextEditor(
|
||||||
|
text=cached_requirements,
|
||||||
|
title=_i(
|
||||||
|
"Write your customized requirements(prompts) for tests here."
|
||||||
|
"\n(For example, what testing framework to use.)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
form = Form(components=[checkbox, case_editor, ref_editor, requirements_editor])
|
||||||
form.render()
|
form.render()
|
||||||
|
|
||||||
# Check test cases
|
# Check test cases
|
||||||
@ -148,6 +166,12 @@ class UnitTestsWorkflow:
|
|||||||
except Exception:
|
except Exception:
|
||||||
invalid_files.append(ref_file)
|
invalid_files.append(ref_file)
|
||||||
|
|
||||||
|
# Get customized requirements
|
||||||
|
requirements: str = (
|
||||||
|
requirements_editor.new_text.strip() if requirements_editor.new_text else ""
|
||||||
|
)
|
||||||
|
self.local_cache.set("user_requirements", requirements)
|
||||||
|
|
||||||
# Print summary
|
# Print summary
|
||||||
title = _i("Will generate tests for the following cases.")
|
title = _i("Will generate tests for the following cases.")
|
||||||
lines = []
|
lines = []
|
||||||
@ -174,10 +198,13 @@ class UnitTestsWorkflow:
|
|||||||
width = len(str(len(invalid_files)))
|
width = len(str(len(invalid_files)))
|
||||||
lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)])
|
lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)])
|
||||||
|
|
||||||
|
lines.append(_i("\nCustomized requirements(prompts):"))
|
||||||
|
lines.append(requirements)
|
||||||
|
|
||||||
with Step(title):
|
with Step(title):
|
||||||
print("\n".join(lines), flush=True)
|
print("\n".join(lines), flush=True)
|
||||||
|
|
||||||
return cases, valid_files
|
return cases, valid_files, requirements
|
||||||
|
|
||||||
def step1_find_symbol_context(self) -> Dict[str, List[Context]]:
|
def step1_find_symbol_context(self) -> Dict[str, List[Context]]:
|
||||||
symbol_context = find_symbol_context_by_static_analysis(
|
symbol_context = find_symbol_context_by_static_analysis(
|
||||||
@ -221,6 +248,7 @@ class UnitTestsWorkflow:
|
|||||||
cases: List[str],
|
cases: List[str],
|
||||||
ref_files: List[str],
|
ref_files: List[str],
|
||||||
symbol_contexts: List[Context],
|
symbol_contexts: List[Context],
|
||||||
|
user_requirements: str,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Write and print tests.
|
Write and print tests.
|
||||||
@ -232,6 +260,7 @@ class UnitTestsWorkflow:
|
|||||||
test_cases=cases,
|
test_cases=cases,
|
||||||
reference_files=ref_files,
|
reference_files=ref_files,
|
||||||
symbol_contexts=symbol_contexts,
|
symbol_contexts=symbol_contexts,
|
||||||
|
user_requirements=user_requirements,
|
||||||
chat_language=self.tui_lang.chat_language,
|
chat_language=self.tui_lang.chat_language,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -268,6 +297,7 @@ def main(input: str):
|
|||||||
|
|
||||||
repo_root = os.getcwd()
|
repo_root = os.getcwd()
|
||||||
ide_lang = IDEService().ide_language()
|
ide_lang = IDEService().ide_language()
|
||||||
|
local_cache = LocalCache("unit_tests", os.path.join(repo_root, *CHAT_WORKFLOW_DIR_PATH))
|
||||||
|
|
||||||
tui_lang = TUILanguage.from_str(ide_lang)
|
tui_lang = TUILanguage.from_str(ide_lang)
|
||||||
_i = get_translation(tui_lang)
|
_i = get_translation(tui_lang)
|
||||||
@ -287,7 +317,13 @@ def main(input: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workflow = UnitTestsWorkflow(user_prompt, func_to_test, repo_root, tui_lang)
|
workflow = UnitTestsWorkflow(
|
||||||
|
user_prompt,
|
||||||
|
func_to_test,
|
||||||
|
repo_root,
|
||||||
|
tui_lang,
|
||||||
|
local_cache,
|
||||||
|
)
|
||||||
workflow.run()
|
workflow.run()
|
||||||
|
|
||||||
except TokenBudgetExceededException as e:
|
except TokenBudgetExceededException as e:
|
||||||
|
@ -55,6 +55,8 @@ Each test case should be self-contained and executable.
|
|||||||
Use the content of the reference test cases as a model, ensuring you use the same test framework and mock library,
|
Use the content of the reference test cases as a model, ensuring you use the same test framework and mock library,
|
||||||
and apply comparable mocking strategies and best practices.
|
and apply comparable mocking strategies and best practices.
|
||||||
|
|
||||||
|
{additional_requirements}
|
||||||
|
|
||||||
|
|
||||||
The target function is {function_name}, located in the file {file_path}.
|
The target function is {function_name}, located in the file {file_path}.
|
||||||
Here's the relevant source code of the function:
|
Here's the relevant source code of the function:
|
||||||
|
@ -21,9 +21,12 @@ def _mk_write_tests_msg(
|
|||||||
reference_files: Optional[List[str]] = None,
|
reference_files: Optional[List[str]] = None,
|
||||||
# context_files: Optional[List[str]] = None,
|
# context_files: Optional[List[str]] = None,
|
||||||
symbol_contexts: Optional[List[Context]] = None,
|
symbol_contexts: Optional[List[Context]] = None,
|
||||||
|
user_requirements: str = "",
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
encoding = get_encoding(ENCODING)
|
encoding = get_encoding(ENCODING)
|
||||||
|
|
||||||
|
additional_requirements = user_requirements
|
||||||
|
|
||||||
test_cases_str = ""
|
test_cases_str = ""
|
||||||
for i, test_case in enumerate(test_cases, 1):
|
for i, test_case in enumerate(test_cases, 1):
|
||||||
test_cases_str += f"{i}. {test_case}\n"
|
test_cases_str += f"{i}. {test_case}\n"
|
||||||
@ -63,6 +66,7 @@ def _mk_write_tests_msg(
|
|||||||
file_path=func_to_test.file_path,
|
file_path=func_to_test.file_path,
|
||||||
test_cases_str=test_cases_str,
|
test_cases_str=test_cases_str,
|
||||||
chat_language=chat_language,
|
chat_language=chat_language,
|
||||||
|
additional_requirements=additional_requirements,
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: adjust symbol_context content more flexibly if needed
|
# NOTE: adjust symbol_context content more flexibly if needed
|
||||||
@ -107,6 +111,7 @@ def write_and_print_tests(
|
|||||||
test_cases: List[str],
|
test_cases: List[str],
|
||||||
reference_files: Optional[List[str]] = None,
|
reference_files: Optional[List[str]] = None,
|
||||||
symbol_contexts: Optional[List[Context]] = None,
|
symbol_contexts: Optional[List[Context]] = None,
|
||||||
|
user_requirements: str = "",
|
||||||
chat_language: str = "English",
|
chat_language: str = "English",
|
||||||
) -> None:
|
) -> None:
|
||||||
user_msg = _mk_write_tests_msg(
|
user_msg = _mk_write_tests_msg(
|
||||||
@ -115,6 +120,7 @@ def write_and_print_tests(
|
|||||||
test_cases=test_cases,
|
test_cases=test_cases,
|
||||||
reference_files=reference_files,
|
reference_files=reference_files,
|
||||||
symbol_contexts=symbol_contexts,
|
symbol_contexts=symbol_contexts,
|
||||||
|
user_requirements=user_requirements,
|
||||||
chat_language=chat_language,
|
chat_language=chat_language,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user