Merge pull request #110 from devchat-ai/improve-interim-output

Improve interim output to print model response time
This commit is contained in:
boob.yang 2024-05-19 08:06:00 +00:00 committed by GitHub
commit 2f3a423741
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 356 additions and 271 deletions

View File

@ -15,6 +15,7 @@ from llm_conf import (
from openai_util import create_chat_completion_content from openai_util import create_chat_completion_content
from tools.directory_viewer import ListViewer from tools.directory_viewer import ListViewer
from tools.tiktoken_util import get_encoding from tools.tiktoken_util import get_encoding
from tools.time_util import print_exec_time
MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo" MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo"
ENCODING = ( ENCODING = (
@ -89,6 +90,7 @@ class RelevantFileFinder(DirectoryStructureBase):
return message return message
@print_exec_time("Model response time")
def _find_relevant_files(self, objective: str, dir_structure_pages: List[str]) -> List[str]: def _find_relevant_files(self, objective: str, dir_structure_pages: List[str]) -> List[str]:
files: List[str] = [] files: List[str] = []
for dir_structure in dir_structure_pages: for dir_structure in dir_structure_pages:

View File

@ -12,6 +12,7 @@ from llm_conf import (
from model import FuncToTest from model import FuncToTest
from openai_util import create_chat_completion_content from openai_util import create_chat_completion_content
from tools.tiktoken_util import get_encoding from tools.tiktoken_util import get_encoding
from tools.time_util import print_exec_time
MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview"
ENCODING = ( ENCODING = (
@ -82,6 +83,7 @@ def _mk_user_msg(func_to_test: FuncToTest, contexts: List) -> str:
return msg return msg
@print_exec_time("Model response time")
def get_recommended_symbols( def get_recommended_symbols(
func_to_test: FuncToTest, known_context: Optional[List] = None func_to_test: FuncToTest, known_context: Optional[List] = None
) -> List[str]: ) -> List[str]:

View File

@ -1,295 +1,27 @@
# ruff: noqa: E402 # ruff: noqa: E402
import os import os
import sys import sys
from typing import Dict, List, Tuple
import click import click
import openai
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(__file__))
from cache import LocalCache from cache import LocalCache
from find_context import (
Context,
Position,
Range,
find_symbol_context_by_static_analysis,
find_symbol_context_of_llm_recommendation,
)
from find_reference_tests import find_reference_tests
from i18n import TUILanguage, get_translation from i18n import TUILanguage, get_translation
from model import ( from model import (
FuncToTest, FuncToTest,
TokenBudgetExceededException, TokenBudgetExceededException,
UserCancelledException, UserCancelledException,
) )
from propose_test import propose_test from ut_workflow import UnitTestsWorkflow
from tools.file_util import retrieve_file_content
from write_tests import write_and_print_tests
from lib.chatmark import Checkbox, Form, Step, TextEditor from lib.chatmark import Step
from lib.ide_service import IDEService from lib.ide_service import IDEService
CHAT_WORKFLOW_DIR_PATH = [".chat", "workflows"] CHAT_WORKFLOW_DIR_PATH = [".chat", "workflows"]
class UnitTestsWorkflow:
def __init__(
self,
user_prompt: str,
func_to_test: FuncToTest,
repo_root: str,
tui_lang: TUILanguage,
local_cache: LocalCache,
):
self.user_prompt = user_prompt
self.func_to_test = func_to_test
self.repo_root = repo_root
self.tui_lang = tui_lang
self.local_cache = local_cache
def run(self):
"""
Run the workflow to generate unit tests.
"""
symbol_context = self.step1_find_symbol_context()
contexts = set()
for _, v in symbol_context.items():
contexts.update(v)
contexts = list(contexts)
cases, files = self.step2_propose_cases_and_reference_files(contexts)
res = self.step3_user_interaction(cases, files)
cases = res[0]
files = res[1]
requirements = res[2]
self.step4_print_test_summary(cases, files, requirements, contexts)
self.step5_write_and_print_tests(cases, files, contexts, requirements)
def step1_find_symbol_context(self) -> Dict[str, List[Context]]:
symbol_context = find_symbol_context_by_static_analysis(
self.func_to_test, self.tui_lang.chat_language
)
known_context_for_llm: List[Context] = []
if self.func_to_test.container_content is not None:
known_context_for_llm.append(
Context(
file_path=self.func_to_test.file_path,
content=self.func_to_test.container_content,
range=Range(
start=Position(line=self.func_to_test.container_start_line, character=0),
end=Position(line=self.func_to_test.container_end_line, character=0),
),
)
)
known_context_for_llm += list(
{item for sublist in list(symbol_context.values()) for item in sublist}
)
recommended_context = find_symbol_context_of_llm_recommendation(
self.func_to_test, known_context_for_llm
)
symbol_context.update(recommended_context)
return symbol_context
def step2_propose_cases_and_reference_files(
self,
contexts: List[Context],
) -> Tuple[List[str], List[str]]:
"""
Propose test cases and reference files for a specified function.
Return: (test_cases, reference_files)
"""
test_cases: List[str] = []
reference_files: List[str] = []
_i = get_translation(self.tui_lang)
msg = _i("Analyzing the function and current unit tests...")
with Step(msg):
test_cases = propose_test(
user_prompt=self.user_prompt,
func_to_test=self.func_to_test,
contexts=contexts,
chat_language=self.tui_lang.chat_language,
)
ref_files = find_reference_tests(
self.repo_root,
self.func_to_test.func_name,
self.func_to_test.file_path,
)
if ref_files:
# Only use the most relevant reference file currently
reference_files.append(ref_files[0])
return test_cases, reference_files
def step3_user_interaction(
self,
test_cases: List[str],
reference_files: List[str],
) -> Tuple[List[str], List[str], str]:
"""
Edit test cases and reference files by user.
Return:
- the updated cases
- valid reference files
- customized requirements(prompts)
"""
_i = get_translation(self.tui_lang)
checkbox = Checkbox(
options=test_cases,
title=_i("Select test cases to generate"),
)
case_editor = TextEditor(
text="",
title=_i(
"You can add more test cases here\n"
"(Multiple cases can be separated by line breaks)"
),
)
ref_editor = TextEditor(
text=reference_files[0] if reference_files else "",
title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"),
)
cached_requirements = self.local_cache.get("user_requirements") or ""
requirements_editor = TextEditor(
text=cached_requirements,
title=_i(
"Write your customized requirements(prompts) for tests here."
"\n(For example, what testing framework to use.)"
),
)
form = Form(components=[checkbox, case_editor, ref_editor, requirements_editor])
form.render()
# Check test cases
cases = [checkbox.options[idx] for idx in checkbox.selections]
user_cases = []
if case_editor.new_text:
user_cases = [c.strip() for c in case_editor.new_text.split("\n")]
user_cases = [c for c in user_cases if c]
cases.extend(user_cases)
# Check if any test case is selected
if not cases:
raise UserCancelledException(_i("No test case is selected. Quit generating tests."))
# Validate reference files
ref_files = [f.strip() for f in ref_editor.new_text.split("\n")]
valid_files = []
invalid_files = []
for ref_file in ref_files:
if not ref_file:
continue
try:
retrieve_file_content(file_path=ref_file, root_path=self.repo_root)
valid_files.append(ref_file)
except Exception:
invalid_files.append(ref_file)
# Get customized requirements
requirements: str = (
requirements_editor.new_text.strip() if requirements_editor.new_text else ""
)
self.local_cache.set("user_requirements", requirements)
return cases, valid_files, requirements
# Tuple[List[str], List[str], str]:
def step4_print_test_summary(
self,
cases: List[str],
valid_files: List[str],
requirements: str,
contexts: List[Context],
):
"""
Print the summary message in Step
"""
_i = get_translation(self.tui_lang)
title = _i("Will generate tests for the following cases.")
lines = []
lines.append(_i("\nTest cases:"))
width = len(str(len(cases)))
lines.extend([f"{(i+1):>{width}}. {c}" for i, c in enumerate(cases)])
if not valid_files:
lines.append(
_i(
"\nNo valid reference file is provided. "
"Will not use reference to generate tests."
)
)
else:
lines.append(_i("\nWill use the following reference files to generate tests."))
# lines.append(_i("\nValid reference files:"))
width = len(str(len(valid_files)))
lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(valid_files)])
# if invalid_files:
# lines.append(_i("\nInvalid files:"))
# width = len(str(len(invalid_files)))
# lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)])
lines.append(_i("\nCustomized requirements(prompts):"))
if requirements.strip():
lines.append(requirements)
else:
lines.append(_i("No customized requirements."))
if contexts:
lines.append(_i("\nAdditional context:"))
width = len(str(len(contexts)))
lines.extend(
[
f"{(i+1):>{width}}. {c.file_path}:{c.range.start.line+1}-{c.range.end.line+1}"
for i, c in enumerate(contexts)
]
)
with Step(title):
print("\n".join(lines), flush=True)
def step5_write_and_print_tests(
self,
cases: List[str],
ref_files: List[str],
symbol_contexts: List[Context],
user_requirements: str,
):
"""
Write and print tests.
"""
write_and_print_tests(
root_path=self.repo_root,
func_to_test=self.func_to_test,
test_cases=cases,
reference_files=ref_files,
symbol_contexts=symbol_contexts,
user_requirements=user_requirements,
chat_language=self.tui_lang.chat_language,
)
@click.command() @click.command()
@click.argument("input", required=True) @click.argument("input", required=True)
def main(input: str): def main(input: str):
@ -369,6 +101,11 @@ def main(input: str):
with Step(f"{e}"): with Step(f"{e}"):
pass pass
except (openai.APIConnectionError, openai.APITimeoutError) as e:
msg = "Model API connection error. Please try again later."
with Step(msg):
print(f"\n{e}\n", flush=True)
except Exception as e: except Exception as e:
raise e raise e

View File

@ -15,6 +15,7 @@ from model import FuncToTest, TokenBudgetExceededException
from openai_util import create_chat_completion_content from openai_util import create_chat_completion_content
from prompts import PROPOSE_TEST_PROMPT from prompts import PROPOSE_TEST_PROMPT
from tools.tiktoken_util import get_encoding from tools.tiktoken_util import get_encoding
from tools.time_util import print_exec_time
MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo" MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo"
ENCODING = ( ENCODING = (
@ -82,6 +83,7 @@ def _mk_user_msg(
) )
@print_exec_time("Model response time")
def propose_test( def propose_test(
user_prompt: str, user_prompt: str,
func_to_test: FuncToTest, func_to_test: FuncToTest,

View File

@ -0,0 +1,18 @@
import time
from functools import wraps
def print_exec_time(message: str):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
print(f"{message} ({duration:.3f} s)", flush=True)
return result
return wrapper
return decorator

View File

@ -0,0 +1,324 @@
from typing import Dict, List, Tuple
from cache import LocalCache
from find_context import (
Context,
Position,
Range,
find_symbol_context_by_static_analysis,
find_symbol_context_of_llm_recommendation,
)
from find_reference_tests import find_reference_tests
from i18n import TUILanguage, get_translation
from model import (
FuncToTest,
UserCancelledException,
)
from propose_test import propose_test
from tools.file_util import retrieve_file_content
from write_tests import write_and_print_tests
from lib.chatmark import Checkbox, Form, Step, TextEditor
class UnitTestsWorkflow:
def __init__(
self,
user_prompt: str,
func_to_test: FuncToTest,
repo_root: str,
tui_lang: TUILanguage,
local_cache: LocalCache,
):
self.user_prompt = user_prompt
self.func_to_test = func_to_test
self.repo_root = repo_root
self.tui_lang = tui_lang
self.local_cache = local_cache
def run(self):
"""
Run the workflow to generate unit tests.
"""
_i = get_translation(self.tui_lang)
msg = _i("Analyzing the function and current unit tests...")
with Step(msg):
print("\n- Analyzing context for the function...", flush=True)
symbol_context = self.step_1_find_symbol_context()
contexts = set()
for _, v in symbol_context.items():
contexts.update(v)
contexts = list(contexts)
print("- Finding reference files...", flush=True)
reference_files = self.step_2_find_reference_files()
print("- Proposing test cases...", flush=True)
cases = self.step_3_propose_cases(contexts)
res = self.step_4_user_interaction(cases, reference_files)
cases = res[0]
files = res[1]
requirements = res[2]
self.step_5_print_test_summary(cases, files, requirements, contexts)
self.step_6_write_and_print_tests(cases, files, contexts, requirements)
def step_1_find_symbol_context(self) -> Dict[str, List[Context]]:
symbol_context = find_symbol_context_by_static_analysis(
self.func_to_test, self.tui_lang.chat_language
)
known_context_for_llm: List[Context] = []
if self.func_to_test.container_content is not None:
known_context_for_llm.append(
Context(
file_path=self.func_to_test.file_path,
content=self.func_to_test.container_content,
range=Range(
start=Position(line=self.func_to_test.container_start_line, character=0),
end=Position(line=self.func_to_test.container_end_line, character=0),
),
)
)
known_context_for_llm += list(
{item for sublist in list(symbol_context.values()) for item in sublist}
)
recommended_context = find_symbol_context_of_llm_recommendation(
self.func_to_test, known_context_for_llm
)
symbol_context.update(recommended_context)
return symbol_context
def step_2_find_reference_files(self) -> List[str]:
"""
Find reference files for the specified function.
"""
reference_files: List[str] = []
ref_tests = find_reference_tests(
self.repo_root,
self.func_to_test.func_name,
self.func_to_test.file_path,
)
if ref_tests:
# Only use the most relevant reference test file currently
reference_files.append(ref_tests[0])
return reference_files
def step_3_propose_cases(
self,
contexts: List[Context],
) -> List[str]:
"""
Propose test cases for the specified function.
"""
test_cases: List[str] = []
test_cases = propose_test(
user_prompt=self.user_prompt,
func_to_test=self.func_to_test,
contexts=contexts,
chat_language=self.tui_lang.chat_language,
)
return test_cases
def step2_propose_cases_and_reference_files(
self,
contexts: List[Context],
) -> Tuple[List[str], List[str]]:
"""
Propose test cases and reference files for a specified function.
Return: (test_cases, reference_files)
"""
test_cases: List[str] = []
reference_files: List[str] = []
_i = get_translation(self.tui_lang)
msg = _i("Analyzing the function and current unit tests...")
with Step(msg):
test_cases = propose_test(
user_prompt=self.user_prompt,
func_to_test=self.func_to_test,
contexts=contexts,
chat_language=self.tui_lang.chat_language,
)
ref_files = find_reference_tests(
self.repo_root,
self.func_to_test.func_name,
self.func_to_test.file_path,
)
if ref_files:
# Only use the most relevant reference file currently
reference_files.append(ref_files[0])
return test_cases, reference_files
def step_4_user_interaction(
self,
test_cases: List[str],
reference_files: List[str],
) -> Tuple[List[str], List[str], str]:
"""
Edit test cases and reference files by user.
Return:
- the updated cases
- valid reference files
- customized requirements(prompts)
"""
_i = get_translation(self.tui_lang)
checkbox = Checkbox(
options=test_cases,
title=_i("Select test cases to generate"),
)
case_editor = TextEditor(
text="",
title=_i(
"You can add more test cases here\n"
"(Multiple cases can be separated by line breaks)"
),
)
ref_editor = TextEditor(
text=reference_files[0] if reference_files else "",
title=_i("Edit reference test file\n(Multiple files can be separated by line breaks)"),
)
cached_requirements = self.local_cache.get("user_requirements") or ""
requirements_editor = TextEditor(
text=cached_requirements,
title=_i(
"Write your customized requirements(prompts) for tests here."
"\n(For example, what testing framework to use.)"
),
)
form = Form(components=[checkbox, case_editor, ref_editor, requirements_editor])
form.render()
# Check test cases
cases = [checkbox.options[idx] for idx in checkbox.selections]
user_cases = []
if case_editor.new_text:
user_cases = [c.strip() for c in case_editor.new_text.split("\n")]
user_cases = [c for c in user_cases if c]
cases.extend(user_cases)
# Check if any test case is selected
if not cases:
raise UserCancelledException(_i("No test case is selected. Quit generating tests."))
# Validate reference files
ref_files = [f.strip() for f in ref_editor.new_text.split("\n")]
valid_files = []
invalid_files = []
for ref_file in ref_files:
if not ref_file:
continue
try:
retrieve_file_content(file_path=ref_file, root_path=self.repo_root)
valid_files.append(ref_file)
except Exception:
invalid_files.append(ref_file)
# Get customized requirements
requirements: str = (
requirements_editor.new_text.strip() if requirements_editor.new_text else ""
)
self.local_cache.set("user_requirements", requirements)
return cases, valid_files, requirements
# Tuple[List[str], List[str], str]:
def step_5_print_test_summary(
self,
cases: List[str],
valid_files: List[str],
requirements: str,
contexts: List[Context],
):
"""
Print the summary message in Step
"""
_i = get_translation(self.tui_lang)
title = _i("Will generate tests for the following cases.")
lines = []
lines.append(_i("\nTest cases:"))
width = len(str(len(cases)))
lines.extend([f"{(i+1):>{width}}. {c}" for i, c in enumerate(cases)])
if not valid_files:
lines.append(
_i(
"\nNo valid reference file is provided. "
"Will not use reference to generate tests."
)
)
else:
lines.append(_i("\nWill use the following reference files to generate tests."))
# lines.append(_i("\nValid reference files:"))
width = len(str(len(valid_files)))
lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(valid_files)])
# if invalid_files:
# lines.append(_i("\nInvalid files:"))
# width = len(str(len(invalid_files)))
# lines.extend([f"{(i+1):>{width}}. {f}" for i, f in enumerate(invalid_files)])
lines.append(_i("\nCustomized requirements(prompts):"))
if requirements.strip():
lines.append(requirements)
else:
lines.append(_i("No customized requirements."))
if contexts:
lines.append(_i("\nAdditional context:"))
width = len(str(len(contexts)))
lines.extend(
[
f"{(i+1):>{width}}. {c.file_path}:{c.range.start.line+1}-{c.range.end.line+1}"
for i, c in enumerate(contexts)
]
)
with Step(title):
print("\n".join(lines), flush=True)
def step_6_write_and_print_tests(
self,
cases: List[str],
ref_files: List[str],
symbol_contexts: List[Context],
user_requirements: str,
):
"""
Write and print tests.
"""
write_and_print_tests(
root_path=self.repo_root,
func_to_test=self.func_to_test,
test_cases=cases,
reference_files=ref_files,
symbol_contexts=symbol_contexts,
user_requirements=user_requirements,
chat_language=self.tui_lang.chat_language,
)