Cache user requirement of tests for each repo
This commit is contained in:
parent
8d32101c80
commit
991267e9d6
44
unit_tests/cache.py
Normal file
44
unit_tests/cache.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class LocalCache:
|
||||||
|
"""
|
||||||
|
A file-based cache for storing and retrieving simple data in JSON format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DIR = "local_cache"
|
||||||
|
|
||||||
|
def __init__(self, name: str, base_path: str):
|
||||||
|
self._name = name
|
||||||
|
self._cache = {}
|
||||||
|
self._base_path = base_path
|
||||||
|
|
||||||
|
cache_dir = os.path.join(base_path, self.DIR)
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
self._cache_file = os.path.join(cache_dir, f"{name}.json")
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
try:
|
||||||
|
with open(self._cache_file, "r") as f:
|
||||||
|
self._cache = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
with open(self._cache_file, "w") as f:
|
||||||
|
json.dump(self._cache, f)
|
||||||
|
|
||||||
|
def get(self, key: str):
|
||||||
|
return self._cache.get(key)
|
||||||
|
|
||||||
|
def set(self, key: str, value):
|
||||||
|
if self._cache.get(key) != value:
|
||||||
|
self._cache[key] = value
|
||||||
|
self.save()
|
@ -7,6 +7,7 @@ import click
|
|||||||
sys.path.append(os.path.dirname(__file__))
|
sys.path.append(os.path.dirname(__file__))
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
||||||
|
|
||||||
|
from cache import LocalCache
|
||||||
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
|
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
|
||||||
from find_context import (
|
from find_context import (
|
||||||
Context,
|
Context,
|
||||||
@ -25,6 +26,8 @@ from propose_test import propose_test
|
|||||||
from tools.file_util import retrieve_file_content
|
from tools.file_util import retrieve_file_content
|
||||||
from write_tests import write_and_print_tests
|
from write_tests import write_and_print_tests
|
||||||
|
|
||||||
|
CHAT_WORKFLOW_DIR_PATH = [".chat", "workflows"]
|
||||||
|
|
||||||
|
|
||||||
class UnitTestsWorkflow:
|
class UnitTestsWorkflow:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -33,11 +36,13 @@ class UnitTestsWorkflow:
|
|||||||
func_to_test: FuncToTest,
|
func_to_test: FuncToTest,
|
||||||
repo_root: str,
|
repo_root: str,
|
||||||
tui_lang: TUILanguage,
|
tui_lang: TUILanguage,
|
||||||
|
local_cache: LocalCache,
|
||||||
):
|
):
|
||||||
self.user_prompt = user_prompt
|
self.user_prompt = user_prompt
|
||||||
self.func_to_test = func_to_test
|
self.func_to_test = func_to_test
|
||||||
self.repo_root = repo_root
|
self.repo_root = repo_root
|
||||||
self.tui_lang = tui_lang
|
self.tui_lang = tui_lang
|
||||||
|
self.local_cache = local_cache
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""
|
"""
|
||||||
@ -100,7 +105,7 @@ class UnitTestsWorkflow:
|
|||||||
Edit test cases and reference files by user.
|
Edit test cases and reference files by user.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
- the updated cases
|
- the updated cases
|
||||||
- valid reference files
|
- valid reference files
|
||||||
- customized requirements(prompts)
|
- customized requirements(prompts)
|
||||||
"""
|
"""
|
||||||
@ -122,9 +127,12 @@ class UnitTestsWorkflow:
|
|||||||
title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"),
|
title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cached_requirements = self.local_cache.get("user_requirements") or ""
|
||||||
requirements_editor = TextEditor(
|
requirements_editor = TextEditor(
|
||||||
text = "",
|
text=cached_requirements,
|
||||||
title = _i("Write your customized requirements(prompts) for tests here.\n(For example, what testing framework to use.)"),
|
title=_i(
|
||||||
|
"Write your customized requirements(prompts) for tests here.\n(For example, what testing framework to use.)"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
form = Form(components=[checkbox, case_editor, ref_editor, requirements_editor])
|
form = Form(components=[checkbox, case_editor, ref_editor, requirements_editor])
|
||||||
@ -158,7 +166,10 @@ class UnitTestsWorkflow:
|
|||||||
invalid_files.append(ref_file)
|
invalid_files.append(ref_file)
|
||||||
|
|
||||||
# Get customized requirements
|
# Get customized requirements
|
||||||
requirements: str = requirements_editor.new_text.strip() if requirements_editor.new_text else ""
|
requirements: str = (
|
||||||
|
requirements_editor.new_text.strip() if requirements_editor.new_text else ""
|
||||||
|
)
|
||||||
|
self.local_cache.set("user_requirements", requirements)
|
||||||
|
|
||||||
# Print summary
|
# Print summary
|
||||||
title = _i("Will generate tests for the following cases.")
|
title = _i("Will generate tests for the following cases.")
|
||||||
@ -285,6 +296,7 @@ def main(input: str):
|
|||||||
|
|
||||||
repo_root = os.getcwd()
|
repo_root = os.getcwd()
|
||||||
ide_lang = IDEService().ide_language()
|
ide_lang = IDEService().ide_language()
|
||||||
|
local_cache = LocalCache("unit_tests", os.path.join(repo_root, *CHAT_WORKFLOW_DIR_PATH))
|
||||||
|
|
||||||
tui_lang = TUILanguage.from_str(ide_lang)
|
tui_lang = TUILanguage.from_str(ide_lang)
|
||||||
_i = get_translation(tui_lang)
|
_i = get_translation(tui_lang)
|
||||||
@ -304,7 +316,13 @@ def main(input: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workflow = UnitTestsWorkflow(user_prompt, func_to_test, repo_root, tui_lang)
|
workflow = UnitTestsWorkflow(
|
||||||
|
user_prompt,
|
||||||
|
func_to_test,
|
||||||
|
repo_root,
|
||||||
|
tui_lang,
|
||||||
|
local_cache,
|
||||||
|
)
|
||||||
workflow.run()
|
workflow.run()
|
||||||
|
|
||||||
except TokenBudgetExceededException as e:
|
except TokenBudgetExceededException as e:
|
||||||
|
@ -111,7 +111,7 @@ def write_and_print_tests(
|
|||||||
test_cases: List[str],
|
test_cases: List[str],
|
||||||
reference_files: Optional[List[str]] = None,
|
reference_files: Optional[List[str]] = None,
|
||||||
symbol_contexts: Optional[List[Context]] = None,
|
symbol_contexts: Optional[List[Context]] = None,
|
||||||
user_requirements:str = "",
|
user_requirements: str = "",
|
||||||
chat_language: str = "English",
|
chat_language: str = "English",
|
||||||
) -> None:
|
) -> None:
|
||||||
user_msg = _mk_write_tests_msg(
|
user_msg = _mk_write_tests_msg(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user