2023-12-18 20:46:16 +08:00
|
|
|
import os
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
import tiktoken
|
|
|
|
|
|
|
|
from chat.ask_codebase.tools.retrieve_file_content import retrieve_file_content
|
|
|
|
from chat.util.openai_util import create_chat_completion
|
2023-12-24 17:21:31 +08:00
|
|
|
import openai
|
2023-12-18 20:46:16 +08:00
|
|
|
|
2023-12-24 17:21:31 +08:00
|
|
|
from datetime import datetime
|
2023-12-18 20:46:16 +08:00
|
|
|
|
|
|
|
MODEL = "gpt-4-1106-preview"
|
|
|
|
WRITE_TESTS_PROMPT = """
|
|
|
|
You're an advanced AI test case generator.
|
|
|
|
Given a target function, some reference test code, and a list of specific test case descriptions, write the test cases in code.
|
|
|
|
Each test case should be self-contained and executable.
|
|
|
|
Use the content of the reference test cases as a model, ensuring you use the same test framework and mock library,
|
|
|
|
and apply comparable mocking strategies and best practices.
|
|
|
|
|
|
|
|
|
|
|
|
The target function is {function_name}, located in the file {file_path}.
|
|
|
|
Here's the source code of the function:
|
|
|
|
```
|
|
|
|
{function_str}
|
|
|
|
```
|
|
|
|
Content of reference test code:
|
|
|
|
|
|
|
|
{reference_tests_str}
|
|
|
|
|
|
|
|
Here's the list of test case descriptions:
|
|
|
|
|
|
|
|
{test_cases_str}
|
|
|
|
|
2023-12-24 18:38:59 +08:00
|
|
|
Answer in the following format in {chat_language}:
|
2023-12-18 20:46:16 +08:00
|
|
|
|
|
|
|
Test Case 1. <original test case 1 description>
|
|
|
|
|
|
|
|
<test case 1 code>
|
|
|
|
|
|
|
|
Test Case 2. <original test case 2 description>
|
|
|
|
|
|
|
|
<test case 2 code>
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
2023-12-24 17:21:31 +08:00
|
|
|
def _mk_write_tests_msg(
|
2023-12-18 20:46:16 +08:00
|
|
|
root_path: str,
|
|
|
|
function_name: str,
|
2023-12-18 22:21:12 +08:00
|
|
|
function_content: str,
|
2023-12-18 20:46:16 +08:00
|
|
|
file_path: str,
|
|
|
|
test_cases: List[str],
|
2023-12-24 18:38:59 +08:00
|
|
|
chat_language: str,
|
2023-12-18 20:46:16 +08:00
|
|
|
reference_files: Optional[List[str]] = None,
|
2023-12-24 17:21:31 +08:00
|
|
|
) -> Optional[str]:
|
2023-12-18 20:46:16 +08:00
|
|
|
encoding: tiktoken.Encoding = tiktoken.encoding_for_model(MODEL)
|
|
|
|
|
|
|
|
# cost saving
|
|
|
|
token_budget = 16000 * 0.9
|
|
|
|
|
|
|
|
test_cases_str = ""
|
|
|
|
for i, test_case in enumerate(test_cases, 1):
|
|
|
|
test_cases_str += f"{i}. {test_case}\n"
|
|
|
|
|
|
|
|
if reference_files:
|
|
|
|
reference_tests_str = ""
|
|
|
|
for i, fp in enumerate(reference_files, 1):
|
|
|
|
reference_test_content = retrieve_file_content(
|
|
|
|
str(Path(root_path) / fp), root_path
|
|
|
|
)
|
|
|
|
reference_tests_str += f"{i}. {fp}\n\n"
|
|
|
|
reference_tests_str += f"```{reference_test_content}```\n"
|
|
|
|
else:
|
|
|
|
reference_tests_str = "No reference test cases provided."
|
|
|
|
|
|
|
|
user_msg = WRITE_TESTS_PROMPT.format(
|
|
|
|
function_name=function_name,
|
|
|
|
file_path=file_path,
|
2023-12-18 22:21:12 +08:00
|
|
|
function_str=function_content,
|
2023-12-18 20:46:16 +08:00
|
|
|
test_cases_str=test_cases_str,
|
2023-12-24 18:38:59 +08:00
|
|
|
chat_language=chat_language,
|
2023-12-18 20:46:16 +08:00
|
|
|
reference_tests_str=reference_tests_str,
|
|
|
|
)
|
|
|
|
|
|
|
|
tokens = len(encoding.encode(user_msg))
|
|
|
|
if tokens > token_budget:
|
2023-12-24 17:21:31 +08:00
|
|
|
# "Token budget exceeded while generating test cases."
|
|
|
|
# TODO: how ot handle token budget exceeded
|
|
|
|
return None
|
2023-12-18 20:46:16 +08:00
|
|
|
|
2023-12-24 17:21:31 +08:00
|
|
|
return user_msg
|
|
|
|
# response = create_chat_completion(
|
|
|
|
# model=MODEL,
|
|
|
|
# messages=[{"role": "user", "content": user_msg}],
|
|
|
|
# temperature=0.1,
|
|
|
|
# )
|
|
|
|
|
|
|
|
# content = response.choices[0].message.content
|
|
|
|
|
|
|
|
# return content
|
|
|
|
|
|
|
|
|
|
|
|
def write_and_print_tests(
|
|
|
|
root_path: str,
|
|
|
|
function_name: str,
|
|
|
|
function_content: str,
|
|
|
|
file_path: str,
|
|
|
|
test_cases: List[str],
|
|
|
|
reference_files: Optional[List[str]] = None,
|
2023-12-24 18:38:59 +08:00
|
|
|
chat_language: str = "English",
|
2023-12-24 17:21:31 +08:00
|
|
|
stream: Optional[bool] = False,
|
|
|
|
) -> str | None:
|
|
|
|
user_msg = _mk_write_tests_msg(
|
|
|
|
root_path=root_path,
|
|
|
|
function_name=function_name,
|
|
|
|
function_content=function_content,
|
|
|
|
file_path=file_path,
|
|
|
|
test_cases=test_cases,
|
|
|
|
reference_files=reference_files,
|
2023-12-24 18:38:59 +08:00
|
|
|
chat_language=chat_language,
|
2023-12-18 20:46:16 +08:00
|
|
|
)
|
2023-12-24 17:21:31 +08:00
|
|
|
if not user_msg:
|
|
|
|
# TODO: how ot handle token budget exceeded
|
|
|
|
print("Token budget exceeded while generating test cases.", flush=True)
|
|
|
|
|
|
|
|
if not stream:
|
|
|
|
print(
|
|
|
|
"\n\n```Step\n# Generating tests...\n",
|
|
|
|
flush=True,
|
|
|
|
)
|
|
|
|
response = create_chat_completion(
|
|
|
|
model=MODEL,
|
|
|
|
messages=[{"role": "user", "content": user_msg}],
|
|
|
|
temperature=0.1,
|
|
|
|
)
|
|
|
|
print("Complete Generating.\n```", flush=True)
|
|
|
|
|
|
|
|
content = response.choices[0].message.content
|
|
|
|
# return content
|
|
|
|
print(content, flush=True)
|
|
|
|
|
|
|
|
else:
|
|
|
|
client = openai.OpenAI()
|
2023-12-18 20:46:16 +08:00
|
|
|
|
2023-12-24 17:21:31 +08:00
|
|
|
chunks = client.chat.completions.create(
|
|
|
|
model=MODEL,
|
|
|
|
messages=[{"role": "user", "content": user_msg}],
|
|
|
|
temperature=0.1,
|
|
|
|
stream=True,
|
|
|
|
)
|
2023-12-18 20:46:16 +08:00
|
|
|
|
2023-12-24 17:21:31 +08:00
|
|
|
for chunk in chunks:
|
|
|
|
print(chunk.choices[0].delta.content, flush=True, end="")
|