2024-01-02 19:37:37 +08:00
|
|
|
from typing import Optional
|
2023-12-18 20:46:16 +08:00
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
import click
|
|
|
|
|
|
|
|
from propose_test import propose_test
|
|
|
|
from find_reference_tests import find_reference_tests
|
2023-12-24 17:21:31 +08:00
|
|
|
from write_tests import write_and_print_tests
|
2023-12-24 16:58:09 +08:00
|
|
|
from i18n import TUILanguage, get_translation
|
2023-12-18 20:46:16 +08:00
|
|
|
|
2023-12-28 15:56:15 +08:00
|
|
|
from model import FuncToTest, TokenBudgetExceededException
|
2023-12-18 20:46:16 +08:00
|
|
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
|
|
|
|
2024-01-02 19:37:37 +08:00
|
|
|
from chatmark import Checkbox, TextEditor, Form # noqa: E402
|
2023-12-24 18:38:59 +08:00
|
|
|
from ide_services import ide_language # noqa: E402
|
|
|
|
|
2023-12-18 20:46:16 +08:00
|
|
|
|
2023-12-28 15:56:15 +08:00
|
|
|
def generate_unit_tests_workflow(
|
2023-12-18 20:46:16 +08:00
|
|
|
user_prompt: str,
|
2023-12-28 15:56:15 +08:00
|
|
|
func_to_test: FuncToTest,
|
|
|
|
tui_lang: TUILanguage,
|
2023-12-18 20:46:16 +08:00
|
|
|
):
|
2023-12-28 15:56:15 +08:00
|
|
|
"""
|
|
|
|
The main workflow for generating unit tests.
|
|
|
|
"""
|
2023-12-18 20:46:16 +08:00
|
|
|
repo_root = os.getcwd()
|
2023-12-24 16:58:09 +08:00
|
|
|
|
2023-12-28 15:56:15 +08:00
|
|
|
_i = get_translation(tui_lang)
|
2023-12-18 20:46:16 +08:00
|
|
|
|
2023-12-28 16:40:42 +08:00
|
|
|
msg = _i("Analyzing the function and current unit tests...")
|
2023-12-18 20:46:16 +08:00
|
|
|
print(
|
2023-12-28 16:40:42 +08:00
|
|
|
f"\n\n```Step\n# {msg}\n```\n",
|
2023-12-18 20:46:16 +08:00
|
|
|
flush=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
test_cases = propose_test(
|
|
|
|
user_prompt=user_prompt,
|
2023-12-28 15:56:15 +08:00
|
|
|
func_to_test=func_to_test,
|
2023-12-24 17:31:35 +08:00
|
|
|
chat_language=tui_lang.chat_language,
|
2023-12-18 20:46:16 +08:00
|
|
|
)
|
2023-12-28 15:56:15 +08:00
|
|
|
|
|
|
|
ref_files = find_reference_tests(repo_root, func_to_test.func_name, func_to_test.file_path)
|
2024-01-02 19:37:37 +08:00
|
|
|
ref_file = ref_files[0] if ref_files else ""
|
2023-12-18 20:46:16 +08:00
|
|
|
|
2024-01-02 19:37:37 +08:00
|
|
|
cases_checkbox = Checkbox(
|
|
|
|
options=test_cases,
|
|
|
|
title=_i("Select test cases to generate"),
|
2023-12-18 20:46:16 +08:00
|
|
|
)
|
2024-01-02 19:37:37 +08:00
|
|
|
ref_file_editor = TextEditor(text=ref_file, title=_i("Edit reference test file"))
|
2023-12-24 17:21:31 +08:00
|
|
|
|
2024-01-02 19:37:37 +08:00
|
|
|
form = Form(components=[cases_checkbox, ref_file_editor])
|
|
|
|
form.render()
|
2023-12-18 20:46:16 +08:00
|
|
|
|
2024-01-02 19:37:37 +08:00
|
|
|
selected_cases = [cases_checkbox.options[idx] for idx in cases_checkbox.selections]
|
|
|
|
new_ref_file = ref_file_editor.new_text
|
2023-12-18 20:46:16 +08:00
|
|
|
|
2023-12-24 17:21:31 +08:00
|
|
|
write_and_print_tests(
|
2023-12-18 20:46:16 +08:00
|
|
|
root_path=repo_root,
|
2023-12-28 15:56:15 +08:00
|
|
|
func_to_test=func_to_test,
|
2023-12-18 20:46:16 +08:00
|
|
|
test_cases=selected_cases,
|
|
|
|
reference_files=[new_ref_file] if new_ref_file else None,
|
2023-12-24 18:38:59 +08:00
|
|
|
chat_language=tui_lang.chat_language,
|
2023-12-18 20:46:16 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-12-28 15:56:15 +08:00
|
|
|
@click.command()
|
|
|
|
@click.argument("user_prompt", required=True)
|
|
|
|
@click.option("-fn", "--func_name", required=True, type=str)
|
|
|
|
@click.option("-fp", "--file_path", required=True, type=str)
|
|
|
|
@click.option("-fsl", "--func_start_line", required=True, type=int)
|
|
|
|
@click.option("-fel", "--func_end_line", required=True, type=int)
|
|
|
|
# Optional container_name is not well supported in Shortcut button's variable
|
|
|
|
# @click.option("-cn", "--container_name", required=False, type=str)
|
|
|
|
@click.option("-csl", "--container_start_line", required=False, type=int)
|
|
|
|
@click.option("-cel", "--container_end_line", required=False, type=int)
|
|
|
|
def main(
|
|
|
|
user_prompt: str,
|
|
|
|
func_name: str,
|
|
|
|
file_path: str,
|
|
|
|
func_start_line: Optional[int], # 0-based, inclusive
|
|
|
|
func_end_line: Optional[int], # 0-based, inclusive
|
|
|
|
container_start_line: Optional[int], # 0-based, inclusive
|
|
|
|
container_end_line: Optional[int], # 0-based, inclusive
|
|
|
|
):
|
|
|
|
repo_root = os.getcwd()
|
|
|
|
ide_lang = ide_language()
|
|
|
|
tui_lang = TUILanguage.from_str(ide_lang)
|
2024-01-03 12:30:17 +08:00
|
|
|
tui_lang = TUILanguage.from_str("zh")
|
2023-12-28 15:56:15 +08:00
|
|
|
_i = get_translation(tui_lang)
|
|
|
|
|
|
|
|
# Use relative path in inner logic
|
|
|
|
if os.path.isabs(file_path):
|
|
|
|
file_path = os.path.relpath(file_path, repo_root)
|
|
|
|
|
|
|
|
func_to_test = FuncToTest(
|
|
|
|
func_name=func_name,
|
|
|
|
repo_root=repo_root,
|
|
|
|
file_path=file_path,
|
|
|
|
func_start_line=func_start_line,
|
|
|
|
func_end_line=func_end_line,
|
|
|
|
container_start_line=container_start_line,
|
|
|
|
container_end_line=container_end_line,
|
|
|
|
)
|
|
|
|
|
2023-12-28 16:40:42 +08:00
|
|
|
# print("\n\n$$$$$$$$$$$\n\n")
|
|
|
|
# print(f"repo_root: {repo_root}\n\n")
|
|
|
|
# print(f"user_prompt: {user_prompt}\n\n")
|
|
|
|
# print(f"func_name: {func_name}\n\n")
|
|
|
|
# print(func_to_test, "\n\n")
|
|
|
|
# print("func_content: \n\n")
|
|
|
|
# print("```")
|
|
|
|
# print(func_to_test.func_content)
|
|
|
|
# print("```")
|
|
|
|
# print("container_content: \n\n")
|
|
|
|
# print("```")
|
|
|
|
# print(func_to_test.container_content)
|
|
|
|
# print("```")
|
|
|
|
# print(f"ide_lang: {ide_lang}\n\n")
|
|
|
|
# print(f"tui_lang: {tui_lang}, {tui_lang.language_code}, { tui_lang.chat_language}\n\n")
|
|
|
|
# print("\n\n$$$$$$$$$$$\n\n", flush=True)
|
2023-12-28 15:56:15 +08:00
|
|
|
|
|
|
|
try:
|
|
|
|
generate_unit_tests_workflow(user_prompt, func_to_test, tui_lang)
|
|
|
|
|
|
|
|
except TokenBudgetExceededException as e:
|
2023-12-28 16:40:42 +08:00
|
|
|
msg = _i("The funciton is too large for AI to handle.")
|
|
|
|
|
|
|
|
info = "\n\n```Step\n"
|
|
|
|
info += f"# {msg}\n"
|
|
|
|
info += f"\n{e}\n```\n"
|
|
|
|
print(info, flush=True)
|
|
|
|
|
2023-12-28 15:56:15 +08:00
|
|
|
except Exception as e:
|
|
|
|
print(e, file=sys.stderr, flush=True)
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
2023-12-18 20:46:16 +08:00
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|