2024-01-05 12:27:57 +08:00
|
|
|
import json
|
2024-01-02 20:36:13 +08:00
|
|
|
from functools import partial
|
2024-02-27 16:17:23 +08:00
|
|
|
from typing import List, Optional
|
2024-01-02 20:36:13 +08:00
|
|
|
|
2024-03-12 20:41:39 +08:00
|
|
|
from devchat.llm.openai import chat_completion_no_stream_return_json
|
2024-03-06 15:35:21 +08:00
|
|
|
from find_context import Context
|
2024-03-12 20:41:39 +08:00
|
|
|
from llm_conf import (
|
|
|
|
CONTEXT_SIZE,
|
|
|
|
DEFAULT_CONTEXT_SIZE,
|
|
|
|
DEFAULT_ENCODING,
|
|
|
|
USE_USER_MODEL,
|
|
|
|
USER_LLM_MODEL,
|
|
|
|
)
|
2024-01-05 12:27:57 +08:00
|
|
|
from model import FuncToTest, TokenBudgetExceededException
|
2023-12-24 19:15:43 +08:00
|
|
|
from openai_util import create_chat_completion_content
|
2023-12-24 19:43:29 +08:00
|
|
|
from prompts import PROPOSE_TEST_PROMPT
|
2024-01-17 17:56:29 +08:00
|
|
|
from tools.tiktoken_util import get_encoding
|
2023-12-18 20:46:16 +08:00
|
|
|
|
2024-03-19 10:56:34 +08:00
|
|
|
MODEL = USER_LLM_MODEL if USE_USER_MODEL else "gpt-4-turbo-preview" # "gpt-3.5-turbo"
|
2024-03-12 20:41:39 +08:00
|
|
|
ENCODING = (
|
|
|
|
get_encoding(DEFAULT_ENCODING) # Use default encoding as an approximation
|
|
|
|
if USE_USER_MODEL
|
|
|
|
else get_encoding("cl100k_base")
|
|
|
|
)
|
2024-03-15 15:21:51 +08:00
|
|
|
TOKEN_BUDGET = int(CONTEXT_SIZE.get(MODEL, DEFAULT_CONTEXT_SIZE) * 0.95)
|
2023-12-28 15:56:15 +08:00
|
|
|
|
|
|
|
|
|
|
|
def _mk_user_msg(
|
|
|
|
user_prompt: str,
|
|
|
|
func_to_test: FuncToTest,
|
2024-03-06 15:35:21 +08:00
|
|
|
contexts: List[Context],
|
2023-12-28 15:56:15 +08:00
|
|
|
chat_language: str,
|
|
|
|
) -> str:
|
|
|
|
"""
|
|
|
|
Create a user message to be sent to the model within the token budget.
|
|
|
|
"""
|
|
|
|
|
|
|
|
func_content = f"function code\n```\n{func_to_test.func_content}\n```\n"
|
|
|
|
class_content = ""
|
|
|
|
if func_to_test.container_content is not None:
|
|
|
|
class_content = f"class code\n```\n{func_to_test.container_content}\n```\n"
|
|
|
|
|
2024-02-27 16:17:23 +08:00
|
|
|
context_content = ""
|
2024-03-06 15:35:21 +08:00
|
|
|
if contexts:
|
|
|
|
context_content = "\n\nrelevant context\n\n"
|
|
|
|
context_content += "\n\n".join([str(c) for c in contexts])
|
|
|
|
context_content += "\n\n"
|
2024-02-27 16:17:23 +08:00
|
|
|
|
2024-01-02 20:36:13 +08:00
|
|
|
# Prepare a list of user messages to fit the token budget
|
|
|
|
# by adjusting the relevant content
|
|
|
|
relevant_content_fmt = partial(
|
|
|
|
PROPOSE_TEST_PROMPT.format,
|
2023-12-28 15:56:15 +08:00
|
|
|
user_prompt=user_prompt,
|
|
|
|
function_name=func_to_test.func_name,
|
|
|
|
file_path=func_to_test.file_path,
|
|
|
|
chat_language=chat_language,
|
|
|
|
)
|
2024-02-27 16:17:23 +08:00
|
|
|
# 0. func content & class content & context content
|
|
|
|
msg_0 = relevant_content_fmt(
|
|
|
|
relevant_content="\n".join([func_content, class_content, context_content]),
|
|
|
|
)
|
2024-01-02 20:36:13 +08:00
|
|
|
# 1. func content & class content
|
|
|
|
msg_1 = relevant_content_fmt(
|
|
|
|
relevant_content="\n".join([func_content, class_content]),
|
|
|
|
)
|
|
|
|
# 2. func content only
|
|
|
|
msg_2 = relevant_content_fmt(
|
|
|
|
relevant_content=func_content,
|
2023-12-28 15:56:15 +08:00
|
|
|
)
|
|
|
|
|
2024-02-27 16:17:23 +08:00
|
|
|
prioritized_msgs = [msg_0, msg_1, msg_2]
|
2024-01-02 20:36:13 +08:00
|
|
|
|
|
|
|
for msg in prioritized_msgs:
|
2024-03-12 20:41:39 +08:00
|
|
|
token_count = len(ENCODING.encode(msg, disallowed_special=()))
|
2024-01-02 20:36:13 +08:00
|
|
|
if token_count <= TOKEN_BUDGET:
|
|
|
|
return msg
|
|
|
|
|
|
|
|
# Even func content exceeds the token budget
|
2023-12-28 15:56:15 +08:00
|
|
|
raise TokenBudgetExceededException(
|
2023-12-28 16:40:42 +08:00
|
|
|
f"Token budget exceeded while proposing test cases for <{func_to_test}>. "
|
2023-12-28 15:56:15 +08:00
|
|
|
f"({token_count}/{TOKEN_BUDGET})"
|
|
|
|
)
|
2023-12-18 20:46:16 +08:00
|
|
|
|
|
|
|
|
|
|
|
def propose_test(
|
|
|
|
user_prompt: str,
|
2023-12-28 15:56:15 +08:00
|
|
|
func_to_test: FuncToTest,
|
2024-03-06 15:35:21 +08:00
|
|
|
contexts: Optional[List[Context]] = None,
|
2023-12-24 17:31:35 +08:00
|
|
|
chat_language: str = "English",
|
2023-12-18 20:46:16 +08:00
|
|
|
) -> List[str]:
|
|
|
|
"""Propose test cases for a specified function based on a user prompt
|
|
|
|
|
|
|
|
Args:
|
|
|
|
user_prompt (str): The prompt or description for which test cases need to be generated.
|
|
|
|
function_name (str): The name of the function to generate test cases for.
|
|
|
|
file_path (str): The absolute path to the file containing the target function for which
|
|
|
|
test cases will be generated.
|
2023-12-18 22:21:12 +08:00
|
|
|
|
2023-12-18 20:46:16 +08:00
|
|
|
Returns:
|
|
|
|
List[str]: A list of test case descriptions.
|
|
|
|
"""
|
2024-03-06 15:35:21 +08:00
|
|
|
contexts = contexts or []
|
2023-12-28 15:56:15 +08:00
|
|
|
user_msg = _mk_user_msg(
|
2023-12-18 20:46:16 +08:00
|
|
|
user_prompt=user_prompt,
|
2023-12-28 15:56:15 +08:00
|
|
|
func_to_test=func_to_test,
|
2024-03-06 15:35:21 +08:00
|
|
|
contexts=contexts,
|
2023-12-24 17:31:35 +08:00
|
|
|
chat_language=chat_language,
|
2023-12-18 20:46:16 +08:00
|
|
|
)
|
|
|
|
|
2024-03-12 20:41:39 +08:00
|
|
|
json_res = {}
|
|
|
|
if USE_USER_MODEL:
|
|
|
|
# Use the wrapped api parameters
|
2024-03-13 10:58:07 +08:00
|
|
|
json_res = (
|
|
|
|
chat_completion_no_stream_return_json(
|
|
|
|
messages=[{"role": "user", "content": user_msg}],
|
|
|
|
llm_config={
|
|
|
|
"model": MODEL,
|
|
|
|
"temperature": 0.1,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
or {}
|
2024-03-12 20:41:39 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
# Use the openai api parameters
|
|
|
|
content = create_chat_completion_content(
|
|
|
|
model=MODEL,
|
|
|
|
messages=[{"role": "user", "content": user_msg}],
|
|
|
|
response_format={"type": "json_object"},
|
|
|
|
temperature=0.1,
|
|
|
|
)
|
|
|
|
json_res = json.loads(content)
|
|
|
|
|
|
|
|
cases = json_res.get("test_cases", [])
|
2023-12-18 20:46:16 +08:00
|
|
|
|
|
|
|
descriptions = []
|
|
|
|
for case in cases:
|
|
|
|
description = case.get("description", None)
|
2024-01-23 16:40:04 -08:00
|
|
|
category = case.get("category", None)
|
2023-12-18 20:46:16 +08:00
|
|
|
if description:
|
2024-01-23 16:40:04 -08:00
|
|
|
if category:
|
|
|
|
descriptions.append(category + ": " + description)
|
|
|
|
else:
|
|
|
|
descriptions.append(description)
|
2023-12-18 22:21:12 +08:00
|
|
|
|
2023-12-18 20:46:16 +08:00
|
|
|
return descriptions
|