Merge pull request #110 from devchat-ai/improve-interim-output
Improve interim output to print model response time
This commit is contained in:
commit
2f3a423741
@ -15,6 +15,7 @@ from llm_conf import (
|
|||||||
from openai_util import create_chat_completion_content
|
from openai_util import create_chat_completion_content
|
||||||
from tools.directory_viewer import ListViewer
|
from tools.directory_viewer import ListViewer
|
||||||
from tools.tiktoken_util import get_encoding
|
from tools.tiktoken_util import get_encoding
|
||||||
|
from tools.time_util import print_exec_time
|
||||||
|
|
||||||
MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo"
|
MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo"
|
||||||
ENCODING = (
|
ENCODING = (
|
||||||
@ -89,6 +90,7 @@ class RelevantFileFinder(DirectoryStructureBase):
|
|||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
@print_exec_time("Model response time")
|
||||||
def _find_relevant_files(self, objective: str, dir_structure_pages: List[str]) -> List[str]:
|
def _find_relevant_files(self, objective: str, dir_structure_pages: List[str]) -> List[str]:
|
||||||
files: List[str] = []
|
files: List[str] = []
|
||||||
for dir_structure in dir_structure_pages:
|
for dir_structure in dir_structure_pages:
|
||||||
|
@ -12,6 +12,7 @@ from llm_conf import (
|
|||||||
from model import FuncToTest
|
from model import FuncToTest
|
||||||
from openai_util import create_chat_completion_content
|
from openai_util import create_chat_completion_content
|
||||||
from tools.tiktoken_util import get_encoding
|
from tools.tiktoken_util import get_encoding
|
||||||
|
from tools.time_util import print_exec_time
|
||||||
|
|
||||||
MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview"
|
MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview"
|
||||||
ENCODING = (
|
ENCODING = (
|
||||||
@ -82,6 +83,7 @@ def _mk_user_msg(func_to_test: FuncToTest, contexts: List) -> str:
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
@print_exec_time("Model response time")
|
||||||
def get_recommended_symbols(
|
def get_recommended_symbols(
|
||||||
func_to_test: FuncToTest, known_context: Optional[List] = None
|
func_to_test: FuncToTest, known_context: Optional[List] = None
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
@ -1,295 +1,27 @@
|
|||||||
# ruff: noqa: E402
|
# ruff: noqa: E402
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
import openai
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(__file__))
|
sys.path.append(os.path.dirname(__file__))
|
||||||
|
|
||||||
from cache import LocalCache
|
from cache import LocalCache
|
||||||
from find_context import (
|
|
||||||
Context,
|
|
||||||
Position,
|
|
||||||
Range,
|
|
||||||
find_symbol_context_by_static_analysis,
|
|
||||||
find_symbol_context_of_llm_recommendation,
|
|
||||||
)
|
|
||||||
from find_reference_tests import find_reference_tests
|
|
||||||
from i18n import TUILanguage, get_translation
|
from i18n import TUILanguage, get_translation
|
||||||
from model import (
|
from model import (
|
||||||
FuncToTest,
|
FuncToTest,
|
||||||
TokenBudgetExceededException,
|
TokenBudgetExceededException,
|
||||||
UserCancelledException,
|
UserCancelledException,
|
||||||
)
|
)
|
||||||
from propose_test import propose_test
|
from ut_workflow import UnitTestsWorkflow
|
||||||
from tools.file_util import retrieve_file_content
|
|
||||||
from write_tests import write_and_print_tests
|
|
||||||
|
|
||||||
from lib.chatmark import Checkbox, Form, Step, TextEditor
|
from lib.chatmark import Step
|
||||||
from lib.ide_service import IDEService
|
from lib.ide_service import IDEService
|
||||||
|
|
||||||
CHAT_WORKFLOW_DIR_PATH = [".chat", "workflows"]
|
CHAT_WORKFLOW_DIR_PATH = [".chat", "workflows"]
|
||||||
|
|
||||||
|
|
||||||
class UnitTestsWorkflow:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
user_prompt: str,
|
|
||||||
func_to_test: FuncToTest,
|
|
||||||
repo_root: str,
|
|
||||||
tui_lang: TUILanguage,
|
|
||||||
local_cache: LocalCache,
|
|
||||||
):
|
|
||||||
self.user_prompt = user_prompt
|
|
||||||
self.func_to_test = func_to_test
|
|
||||||
self.repo_root = repo_root
|
|
||||||
self.tui_lang = tui_lang
|
|
||||||
self.local_cache = local_cache
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
"""
|
|
||||||
Run the workflow to generate unit tests.
|
|
||||||
"""
|
|
||||||
symbol_context = self.step1_find_symbol_context()
|
|
||||||
contexts = set()
|
|
||||||
for _, v in symbol_context.items():
|
|
||||||
contexts.update(v)
|
|
||||||
contexts = list(contexts)
|
|
||||||
|
|
||||||
cases, files = self.step2_propose_cases_and_reference_files(contexts)
|
|
||||||
|
|
||||||
res = self.step3_user_interaction(cases, files)
|
|
||||||
cases = res[0]
|
|
||||||
files = res[1]
|
|
||||||
requirements = res[2]
|
|
||||||
|
|
||||||
self.step4_print_test_summary(cases, files, requirements, contexts)
|
|
||||||
|
|
||||||
self.step5_write_and_print_tests(cases, files, contexts, requirements)
|
|
||||||
|
|
||||||
def step1_find_symbol_context(self) -> Dict[str, List[Context]]:
|
|
||||||
symbol_context = find_symbol_context_by_static_analysis(
|
|
||||||
self.func_to_test, self.tui_lang.chat_language
|
|
||||||
)
|
|
||||||
|
|
||||||
known_context_for_llm: List[Context] = []
|
|
||||||
if self.func_to_test.container_content is not None:
|
|
||||||
known_context_for_llm.append(
|
|
||||||
Context(
|
|
||||||
file_path=self.func_to_test.file_path,
|
|
||||||
content=self.func_to_test.container_content,
|
|
||||||
range=Range(
|
|
||||||
start=Position(line=self.func_to_test.container_start_line, character=0),
|
|
||||||
end=Position(line=self.func_to_test.container_end_line, character=0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
known_context_for_llm += list(
|
|
||||||
{item for sublist in list(symbol_context.values()) for item in sublist}
|
|
||||||
)
|
|
||||||
|
|
||||||
recommended_context = find_symbol_context_of_llm_recommendation(
|
|
||||||
self.func_to_test, known_context_for_llm
|
|
||||||
)
|
|
||||||
|
|
||||||
symbol_context.update(recommended_context)
|
|
||||||
|
|
||||||
return symbol_context
|
|
||||||
|
|
||||||
def step2_propose_cases_and_reference_files(
|
|
||||||
self,
|
|
||||||
contexts: List[Context],
|
|
||||||
) -> Tuple[List[str], List[str]]:
|
|
||||||
"""
|
|
||||||
Propose test cases and reference files for a specified function.
|
|
||||||
|
|
||||||
Return: (test_cases, reference_files)
|
|
||||||
"""
|
|
||||||
test_cases: List[str] = []
|
|
||||||
reference_files: List[str] = []
|
|
||||||
|
|
||||||
_i = get_translation(self.tui_lang)
|
|
||||||
|
|
||||||
msg = _i("Analyzing the function and current unit tests...")
|
|
||||||
|
|
||||||
with Step(msg):
|
|
||||||
test_cases = propose_test(
|
|
||||||
user_prompt=self.user_prompt,
|
|
||||||
func_to_test=self.func_to_test,
|
|
||||||
contexts=contexts,
|
|
||||||
chat_language=self.tui_lang.chat_language,
|
|
||||||
)
|
|
||||||
|
|
||||||
ref_files = find_reference_tests(
|
|
||||||
self.repo_root,
|
|
||||||
self.func_to_test.func_name,
|
|
||||||
self.func_to_test.file_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
if ref_files:
|
|
||||||
# Only use the most relevant reference file currently
|
|
||||||
reference_files.append(ref_files[0])
|
|
||||||
|
|
||||||
return test_cases, reference_files
|
|
||||||
|
|
||||||
def step3_user_interaction(
|
|
||||||
self,
|
|
||||||
test_cases: List[str],
|
|
||||||
reference_files: List[str],
|
|
||||||
) -> Tuple[List[str], List[str], str]:
|
|
||||||
"""
|
|
||||||
Edit test cases and reference files by user.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
- the updated cases
|
|
||||||
- valid reference files
|
|
||||||
- customized requirements(prompts)
|
|
||||||
"""
|
|
||||||
_i = get_translation(self.tui_lang)
|
|
||||||
|
|
||||||
checkbox = Checkbox(
|
|
||||||
options=test_cases,
|
|
||||||
title=_i("Select test cases to generate"),
|
|
||||||
)
|
|
||||||
case_editor = TextEditor(
|
|
||||||
text="",
|
|
||||||
title=_i(
|
|
||||||
"You can add more test cases here\n"
|
|
||||||
"(Multiple cases can be separated by line breaks)"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
ref_editor = TextEditor(
|
|
||||||
text=reference_files[0] if reference_files else "",
|
|
||||||
title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"),
|
|
||||||
)
|
|
||||||
|
|
||||||
cached_requirements = self.local_cache.get("user_requirements") or ""
|
|
||||||
requirements_editor = TextEditor(
|
|
||||||
text=cached_requirements,
|
|
||||||
title=_i(
|
|
||||||
"Write your customized requirements(prompts) for tests here."
|
|
||||||
"\n(For example, what testing framework to use.)"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
form = Form(components=[checkbox, case_editor, ref_editor, requirements_editor])
|
|
||||||
form.render()
|
|
||||||
|
|
||||||
# Check test cases
|
|
||||||
cases = [checkbox.options[idx] for idx in checkbox.selections]
|
|
||||||
user_cases = []
|
|
||||||
if case_editor.new_text:
|
|
||||||
user_cases = [c.strip() for c in case_editor.new_text.split("\n")]
|
|
||||||
user_cases = [c for c in user_cases if c]
|
|
||||||
|
|
||||||
cases.extend(user_cases)
|
|
||||||
|
|
||||||
# Check if any test case is selected
|
|
||||||
if not cases:
|
|
||||||
raise UserCancelledException(_i("No test case is selected. Quit generating tests."))
|
|
||||||
|
|
||||||
# Validate reference files
|
|
||||||
ref_files = [f.strip() for f in ref_editor.new_text.split("\n")]
|
|
||||||
valid_files = []
|
|
||||||
invalid_files = []
|
|
||||||
|
|
||||||
for ref_file in ref_files:
|
|
||||||
if not ref_file:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
retrieve_file_content(file_path=ref_file, root_path=self.repo_root)
|
|
||||||
valid_files.append(ref_file)
|
|
||||||
except Exception:
|
|
||||||
invalid_files.append(ref_file)
|
|
||||||
|
|
||||||
# Get customized requirements
|
|
||||||
requirements: str = (
|
|
||||||
requirements_editor.new_text.strip() if requirements_editor.new_text else ""
|
|
||||||
)
|
|
||||||
self.local_cache.set("user_requirements", requirements)
|
|
||||||
|
|
||||||
return cases, valid_files, requirements
|
|
||||||
|
|
||||||
# Tuple[List[str], List[str], str]:
|
|
||||||
def step4_print_test_summary(
|
|
||||||
self,
|
|
||||||
cases: List[str],
|
|
||||||
valid_files: List[str],
|
|
||||||
requirements: str,
|
|
||||||
contexts: List[Context],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Print the summary message in Step
|
|
||||||
"""
|
|
||||||
_i = get_translation(self.tui_lang)
|
|
||||||
|
|
||||||
title = _i("Will generate tests for the following cases.")
|
|
||||||
lines = []
|
|
||||||
|
|
||||||
lines.append(_i("\nTest cases:"))
|
|
||||||
width = len(str(len(cases)))
|
|
||||||
lines.extend([f"{(i+1):>{width}}. {c}" for i, c in enumerate(cases)])
|
|
||||||
|
|
||||||
if not valid_files:
|
|
||||||
lines.append(
|
|
||||||
_i(
|
|
||||||
"\nNo valid reference file is provided. "
|
|
||||||
"Will not use reference to generate tests."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
lines.append(_i("\nWill use the following reference files to generate tests."))
|
|
||||||
# lines.append(_i("\nValid reference files:"))
|
|
||||||
width = len(str(len(valid_files)))
|
|
||||||
lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(valid_files)])
|
|
||||||
|
|
||||||
# if invalid_files:
|
|
||||||
# lines.append(_i("\nInvalid files:"))
|
|
||||||
# width = len(str(len(invalid_files)))
|
|
||||||
# lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)])
|
|
||||||
|
|
||||||
lines.append(_i("\nCustomized requirements(prompts):"))
|
|
||||||
if requirements.strip():
|
|
||||||
lines.append(requirements)
|
|
||||||
else:
|
|
||||||
lines.append(_i("No customized requirements."))
|
|
||||||
|
|
||||||
if contexts:
|
|
||||||
lines.append(_i("\nAdditional context:"))
|
|
||||||
width = len(str(len(contexts)))
|
|
||||||
lines.extend(
|
|
||||||
[
|
|
||||||
f"{(i+1):>{width}}. {c.file_path}:{c.range.start.line+1}-{c.range.end.line+1}"
|
|
||||||
for i, c in enumerate(contexts)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
with Step(title):
|
|
||||||
print("\n".join(lines), flush=True)
|
|
||||||
|
|
||||||
def step5_write_and_print_tests(
|
|
||||||
self,
|
|
||||||
cases: List[str],
|
|
||||||
ref_files: List[str],
|
|
||||||
symbol_contexts: List[Context],
|
|
||||||
user_requirements: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Write and print tests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
write_and_print_tests(
|
|
||||||
root_path=self.repo_root,
|
|
||||||
func_to_test=self.func_to_test,
|
|
||||||
test_cases=cases,
|
|
||||||
reference_files=ref_files,
|
|
||||||
symbol_contexts=symbol_contexts,
|
|
||||||
user_requirements=user_requirements,
|
|
||||||
chat_language=self.tui_lang.chat_language,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.argument("input", required=True)
|
@click.argument("input", required=True)
|
||||||
def main(input: str):
|
def main(input: str):
|
||||||
@ -369,6 +101,11 @@ def main(input: str):
|
|||||||
with Step(f"{e}"):
|
with Step(f"{e}"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
except (openai.APIConnectionError, openai.APITimeoutError) as e:
|
||||||
|
msg = "Model API connection error. Please try again later."
|
||||||
|
with Step(msg):
|
||||||
|
print(f"\n{e}\n", flush=True)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ from model import FuncToTest, TokenBudgetExceededException
|
|||||||
from openai_util import create_chat_completion_content
|
from openai_util import create_chat_completion_content
|
||||||
from prompts import PROPOSE_TEST_PROMPT
|
from prompts import PROPOSE_TEST_PROMPT
|
||||||
from tools.tiktoken_util import get_encoding
|
from tools.tiktoken_util import get_encoding
|
||||||
|
from tools.time_util import print_exec_time
|
||||||
|
|
||||||
MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo"
|
MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo"
|
||||||
ENCODING = (
|
ENCODING = (
|
||||||
@ -82,6 +83,7 @@ def _mk_user_msg(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@print_exec_time("Model response time")
|
||||||
def propose_test(
|
def propose_test(
|
||||||
user_prompt: str,
|
user_prompt: str,
|
||||||
func_to_test: FuncToTest,
|
func_to_test: FuncToTest,
|
||||||
|
18
merico/unit_tests/tools/time_util.py
Normal file
18
merico/unit_tests/tools/time_util.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import time
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
|
def print_exec_time(message: str):
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
start_time = time.time()
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
end_time = time.time()
|
||||||
|
duration = end_time - start_time
|
||||||
|
print(f"{message} ({duration:.3f} s)", flush=True)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
324
merico/unit_tests/ut_workflow.py
Normal file
324
merico/unit_tests/ut_workflow.py
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from cache import LocalCache
|
||||||
|
from find_context import (
|
||||||
|
Context,
|
||||||
|
Position,
|
||||||
|
Range,
|
||||||
|
find_symbol_context_by_static_analysis,
|
||||||
|
find_symbol_context_of_llm_recommendation,
|
||||||
|
)
|
||||||
|
from find_reference_tests import find_reference_tests
|
||||||
|
from i18n import TUILanguage, get_translation
|
||||||
|
from model import (
|
||||||
|
FuncToTest,
|
||||||
|
UserCancelledException,
|
||||||
|
)
|
||||||
|
from propose_test import propose_test
|
||||||
|
from tools.file_util import retrieve_file_content
|
||||||
|
from write_tests import write_and_print_tests
|
||||||
|
|
||||||
|
from lib.chatmark import Checkbox, Form, Step, TextEditor
|
||||||
|
|
||||||
|
|
||||||
|
class UnitTestsWorkflow:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_prompt: str,
|
||||||
|
func_to_test: FuncToTest,
|
||||||
|
repo_root: str,
|
||||||
|
tui_lang: TUILanguage,
|
||||||
|
local_cache: LocalCache,
|
||||||
|
):
|
||||||
|
self.user_prompt = user_prompt
|
||||||
|
self.func_to_test = func_to_test
|
||||||
|
self.repo_root = repo_root
|
||||||
|
self.tui_lang = tui_lang
|
||||||
|
self.local_cache = local_cache
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
Run the workflow to generate unit tests.
|
||||||
|
"""
|
||||||
|
_i = get_translation(self.tui_lang)
|
||||||
|
msg = _i("Analyzing the function and current unit tests...")
|
||||||
|
|
||||||
|
with Step(msg):
|
||||||
|
print("\n- Analyzing context for the function...", flush=True)
|
||||||
|
symbol_context = self.step_1_find_symbol_context()
|
||||||
|
|
||||||
|
contexts = set()
|
||||||
|
for _, v in symbol_context.items():
|
||||||
|
contexts.update(v)
|
||||||
|
contexts = list(contexts)
|
||||||
|
|
||||||
|
print("- Finding reference files...", flush=True)
|
||||||
|
reference_files = self.step_2_find_reference_files()
|
||||||
|
|
||||||
|
print("- Proposing test cases...", flush=True)
|
||||||
|
cases = self.step_3_propose_cases(contexts)
|
||||||
|
|
||||||
|
res = self.step_4_user_interaction(cases, reference_files)
|
||||||
|
cases = res[0]
|
||||||
|
files = res[1]
|
||||||
|
requirements = res[2]
|
||||||
|
|
||||||
|
self.step_5_print_test_summary(cases, files, requirements, contexts)
|
||||||
|
|
||||||
|
self.step_6_write_and_print_tests(cases, files, contexts, requirements)
|
||||||
|
|
||||||
|
def step_1_find_symbol_context(self) -> Dict[str, List[Context]]:
|
||||||
|
symbol_context = find_symbol_context_by_static_analysis(
|
||||||
|
self.func_to_test, self.tui_lang.chat_language
|
||||||
|
)
|
||||||
|
|
||||||
|
known_context_for_llm: List[Context] = []
|
||||||
|
if self.func_to_test.container_content is not None:
|
||||||
|
known_context_for_llm.append(
|
||||||
|
Context(
|
||||||
|
file_path=self.func_to_test.file_path,
|
||||||
|
content=self.func_to_test.container_content,
|
||||||
|
range=Range(
|
||||||
|
start=Position(line=self.func_to_test.container_start_line, character=0),
|
||||||
|
end=Position(line=self.func_to_test.container_end_line, character=0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
known_context_for_llm += list(
|
||||||
|
{item for sublist in list(symbol_context.values()) for item in sublist}
|
||||||
|
)
|
||||||
|
|
||||||
|
recommended_context = find_symbol_context_of_llm_recommendation(
|
||||||
|
self.func_to_test, known_context_for_llm
|
||||||
|
)
|
||||||
|
|
||||||
|
symbol_context.update(recommended_context)
|
||||||
|
|
||||||
|
return symbol_context
|
||||||
|
|
||||||
|
def step_2_find_reference_files(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Find reference files for the specified function.
|
||||||
|
"""
|
||||||
|
reference_files: List[str] = []
|
||||||
|
|
||||||
|
ref_tests = find_reference_tests(
|
||||||
|
self.repo_root,
|
||||||
|
self.func_to_test.func_name,
|
||||||
|
self.func_to_test.file_path,
|
||||||
|
)
|
||||||
|
if ref_tests:
|
||||||
|
# Only use the most relevant reference test file currently
|
||||||
|
reference_files.append(ref_tests[0])
|
||||||
|
|
||||||
|
return reference_files
|
||||||
|
|
||||||
|
def step_3_propose_cases(
|
||||||
|
self,
|
||||||
|
contexts: List[Context],
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Propose test cases for the specified function.
|
||||||
|
"""
|
||||||
|
test_cases: List[str] = []
|
||||||
|
|
||||||
|
test_cases = propose_test(
|
||||||
|
user_prompt=self.user_prompt,
|
||||||
|
func_to_test=self.func_to_test,
|
||||||
|
contexts=contexts,
|
||||||
|
chat_language=self.tui_lang.chat_language,
|
||||||
|
)
|
||||||
|
|
||||||
|
return test_cases
|
||||||
|
|
||||||
|
def step2_propose_cases_and_reference_files(
|
||||||
|
self,
|
||||||
|
contexts: List[Context],
|
||||||
|
) -> Tuple[List[str], List[str]]:
|
||||||
|
"""
|
||||||
|
Propose test cases and reference files for a specified function.
|
||||||
|
|
||||||
|
Return: (test_cases, reference_files)
|
||||||
|
"""
|
||||||
|
test_cases: List[str] = []
|
||||||
|
reference_files: List[str] = []
|
||||||
|
|
||||||
|
_i = get_translation(self.tui_lang)
|
||||||
|
|
||||||
|
msg = _i("Analyzing the function and current unit tests...")
|
||||||
|
|
||||||
|
with Step(msg):
|
||||||
|
test_cases = propose_test(
|
||||||
|
user_prompt=self.user_prompt,
|
||||||
|
func_to_test=self.func_to_test,
|
||||||
|
contexts=contexts,
|
||||||
|
chat_language=self.tui_lang.chat_language,
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_files = find_reference_tests(
|
||||||
|
self.repo_root,
|
||||||
|
self.func_to_test.func_name,
|
||||||
|
self.func_to_test.file_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ref_files:
|
||||||
|
# Only use the most relevant reference file currently
|
||||||
|
reference_files.append(ref_files[0])
|
||||||
|
|
||||||
|
return test_cases, reference_files
|
||||||
|
|
||||||
|
def step_4_user_interaction(
|
||||||
|
self,
|
||||||
|
test_cases: List[str],
|
||||||
|
reference_files: List[str],
|
||||||
|
) -> Tuple[List[str], List[str], str]:
|
||||||
|
"""
|
||||||
|
Edit test cases and reference files by user.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
- the updated cases
|
||||||
|
- valid reference files
|
||||||
|
- customized requirements(prompts)
|
||||||
|
"""
|
||||||
|
_i = get_translation(self.tui_lang)
|
||||||
|
|
||||||
|
checkbox = Checkbox(
|
||||||
|
options=test_cases,
|
||||||
|
title=_i("Select test cases to generate"),
|
||||||
|
)
|
||||||
|
case_editor = TextEditor(
|
||||||
|
text="",
|
||||||
|
title=_i(
|
||||||
|
"You can add more test cases here\n"
|
||||||
|
"(Multiple cases can be separated by line breaks)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
ref_editor = TextEditor(
|
||||||
|
text=reference_files[0] if reference_files else "",
|
||||||
|
title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"),
|
||||||
|
)
|
||||||
|
|
||||||
|
cached_requirements = self.local_cache.get("user_requirements") or ""
|
||||||
|
requirements_editor = TextEditor(
|
||||||
|
text=cached_requirements,
|
||||||
|
title=_i(
|
||||||
|
"Write your customized requirements(prompts) for tests here."
|
||||||
|
"\n(For example, what testing framework to use.)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
form = Form(components=[checkbox, case_editor, ref_editor, requirements_editor])
|
||||||
|
form.render()
|
||||||
|
|
||||||
|
# Check test cases
|
||||||
|
cases = [checkbox.options[idx] for idx in checkbox.selections]
|
||||||
|
user_cases = []
|
||||||
|
if case_editor.new_text:
|
||||||
|
user_cases = [c.strip() for c in case_editor.new_text.split("\n")]
|
||||||
|
user_cases = [c for c in user_cases if c]
|
||||||
|
|
||||||
|
cases.extend(user_cases)
|
||||||
|
|
||||||
|
# Check if any test case is selected
|
||||||
|
if not cases:
|
||||||
|
raise UserCancelledException(_i("No test case is selected. Quit generating tests."))
|
||||||
|
|
||||||
|
# Validate reference files
|
||||||
|
ref_files = [f.strip() for f in ref_editor.new_text.split("\n")]
|
||||||
|
valid_files = []
|
||||||
|
invalid_files = []
|
||||||
|
|
||||||
|
for ref_file in ref_files:
|
||||||
|
if not ref_file:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
retrieve_file_content(file_path=ref_file, root_path=self.repo_root)
|
||||||
|
valid_files.append(ref_file)
|
||||||
|
except Exception:
|
||||||
|
invalid_files.append(ref_file)
|
||||||
|
|
||||||
|
# Get customized requirements
|
||||||
|
requirements: str = (
|
||||||
|
requirements_editor.new_text.strip() if requirements_editor.new_text else ""
|
||||||
|
)
|
||||||
|
self.local_cache.set("user_requirements", requirements)
|
||||||
|
|
||||||
|
return cases, valid_files, requirements
|
||||||
|
|
||||||
|
# Tuple[List[str], List[str], str]:
|
||||||
|
def step_5_print_test_summary(
|
||||||
|
self,
|
||||||
|
cases: List[str],
|
||||||
|
valid_files: List[str],
|
||||||
|
requirements: str,
|
||||||
|
contexts: List[Context],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Print the summary message in Step
|
||||||
|
"""
|
||||||
|
_i = get_translation(self.tui_lang)
|
||||||
|
|
||||||
|
title = _i("Will generate tests for the following cases.")
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
lines.append(_i("\nTest cases:"))
|
||||||
|
width = len(str(len(cases)))
|
||||||
|
lines.extend([f"{(i+1):>{width}}. {c}" for i, c in enumerate(cases)])
|
||||||
|
|
||||||
|
if not valid_files:
|
||||||
|
lines.append(
|
||||||
|
_i(
|
||||||
|
"\nNo valid reference file is provided. "
|
||||||
|
"Will not use reference to generate tests."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lines.append(_i("\nWill use the following reference files to generate tests."))
|
||||||
|
# lines.append(_i("\nValid reference files:"))
|
||||||
|
width = len(str(len(valid_files)))
|
||||||
|
lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(valid_files)])
|
||||||
|
|
||||||
|
# if invalid_files:
|
||||||
|
# lines.append(_i("\nInvalid files:"))
|
||||||
|
# width = len(str(len(invalid_files)))
|
||||||
|
# lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)])
|
||||||
|
|
||||||
|
lines.append(_i("\nCustomized requirements(prompts):"))
|
||||||
|
if requirements.strip():
|
||||||
|
lines.append(requirements)
|
||||||
|
else:
|
||||||
|
lines.append(_i("No customized requirements."))
|
||||||
|
|
||||||
|
if contexts:
|
||||||
|
lines.append(_i("\nAdditional context:"))
|
||||||
|
width = len(str(len(contexts)))
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
f"{(i+1):>{width}}. {c.file_path}:{c.range.start.line+1}-{c.range.end.line+1}"
|
||||||
|
for i, c in enumerate(contexts)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with Step(title):
|
||||||
|
print("\n".join(lines), flush=True)
|
||||||
|
|
||||||
|
def step_6_write_and_print_tests(
|
||||||
|
self,
|
||||||
|
cases: List[str],
|
||||||
|
ref_files: List[str],
|
||||||
|
symbol_contexts: List[Context],
|
||||||
|
user_requirements: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Write and print tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
write_and_print_tests(
|
||||||
|
root_path=self.repo_root,
|
||||||
|
func_to_test=self.func_to_test,
|
||||||
|
test_cases=cases,
|
||||||
|
reference_files=ref_files,
|
||||||
|
symbol_contexts=symbol_contexts,
|
||||||
|
user_requirements=user_requirements,
|
||||||
|
chat_language=self.tui_lang.chat_language,
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user