Compare commits

...

22 Commits

Author SHA1 Message Date
boob.yang
accf557fe6
Merge pull request #48 from devchat-ai/update_workflows
Update workflows
2024-01-21 17:31:42 +08:00
bobo
4d80089598 add rewrite workflows 2024-01-21 17:29:46 +08:00
bobo
b82ebd8af1 remove ask-code 2024-01-21 17:20:27 +08:00
boob.yang
09daeef589
Merge pull request #47 from devchat-ai/fix_minimax_llm_api
update minimax llm api
2024-01-21 17:18:31 +08:00
bobo
e9a28f6d9a update minimax llm api 2024-01-21 17:11:56 +08:00
kagami
ae12062d34
Merge pull request #45 from devchat-ai/chinese-prompt
Chinese prompt for minimax
2024-01-18 22:27:20 +08:00
kagami
7be108b874 Force TUI language to chinese 2024-01-18 22:26:27 +08:00
kagami
d49b2f4732 Use chinese prompt 2024-01-18 22:26:27 +08:00
boob.yang
5ced3dcfa8
Merge pull request #44 from devchat-ai/fix_path_in_unit_tests
Fix path in unit tests
2024-01-18 19:42:44 +08:00
bobo
9138e1569e fix lint error 2024-01-18 19:40:33 +08:00
bobo
17f4aba1b4 fix: Add cwd to sys.path in unit tests
- Append the current working directory to sys.path for import resolution
- Ensure unit tests function correctly with module imports
- Improve reliability of running unit tests in different environments
2024-01-18 19:37:54 +08:00
boob.yang
b7ca8127d1
Merge pull request #43 from devchat-ai/minimax-issue
Fix minimax request issues
2024-01-18 15:22:00 +08:00
kagami
d3008d5237 Fix minimax request issues 2024-01-18 14:44:46 +08:00
kagami
7c1252aee2
Merge pull request #42 from devchat-ai/unit-tests-on-minimax
Adapt /unit_tests to use minimax
2024-01-18 14:43:54 +08:00
kagami
da51cbfc5f Adapt /unit_tests to use minimax 2024-01-18 14:38:18 +08:00
boob.yang
8dc6c5ffe9
Merge pull request #41 from devchat-ai/update_minimax_api
update minimax llm api
2024-01-18 10:34:26 +08:00
bobo
7652f73537 update minimax llm api 2024-01-18 10:08:14 +08:00
kagami
aff4142d24
Merge pull request #40 from devchat-ai/support_minimax_llm
support minimax llm
2024-01-17 22:53:47 +08:00
bobo
12e5ce6a57 fix lint error 2024-01-17 22:33:28 +08:00
bobo
2daf4ed765 support minimax llm 2024-01-17 22:33:28 +08:00
boob.yang
94249cd6d1
Merge pull request #39 from devchat-ai/revert-38-revert-37-switch-env
Revert "Revert "/unit_tests switch to devchat env""
2024-01-17 22:32:37 +08:00
kagami
bc5f4e23a9
Revert "Revert "/unit_tests switch to devchat env"" 2024-01-17 17:56:29 +08:00
20 changed files with 615 additions and 62 deletions

View File

@ -0,0 +1,234 @@
import json
import os
import sys
import time
import requests
class StreamIterWrapper:
def __init__(self, response, is_private=True):
self.response = response
self.create_time = int(time.time())
self.line_iterator = response.iter_lines()
self.is_private = is_private
self.stop = False
def __iter__(self):
return self
def __next__(self):
try:
if self.stop:
raise StopIteration
response_line = next(self.line_iterator)
if response_line == b"":
return self.__next__()
if response_line == b"\n":
return self.__next__()
response_line = response_line.replace(b"data: ", b"")
response_result = json.loads(response_line.decode("utf-8"))
if self.is_private:
if "finish" in response_result and response_result["finish"] == True:
self.stop = True
if "err" in response_result and response_result["err"]:
raise ValueError(f"minimax api response error: {response_result['err']}")
if not self.is_private:
if response_result["choices"][0].get("finish_reason", None):
raise StopIteration
data = {}
if self.is_private:
data = json.loads(response_result["data"])
stream_response = {
"id": f"minimax_{self.create_time}",
"created": self.create_time,
"object": "chat.completion.chunk",
"model": response_result.get("model", "abab5.5-chat"),
"choices": [
{
"index": 0,
"finish_reason": "stop",
"delta": {
"role": "assistant",
"content": response_result["choices"][0]["messages"][0]["text"] if not self.is_private else data.get("text", ""),
},
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 100},
}
return stream_response
except StopIteration as exc: # If there is no more event
raise StopIteration from exc
except Exception as err:
print("Exception:", err.__class__.__name__, err, file=sys.stderr, end="\n\n")
raise StopIteration from err
def chat_completion(messages, llm_config):
url = _make_api_url()
headers = _make_header()
if _is_private_llm():
payload = _make_private_payload(messages, llm_config)
else:
payload = _make_public_payload(messages, llm_config)
response = requests.post(url, headers=headers, json=payload)
response_json = json.loads(response.text)
if not response_json.get("texts", []):
raise ValueError(f"minimax api response error: {response_json}")
return {"content": response_json["texts"][0]}
def stream_chat_completion(messages, llm_config):
url = _make_api_url()
headers = _make_header()
if _is_private_llm():
payload = _make_private_payload(messages, llm_config, True)
else:
payload = _make_public_payload(messages, llm_config, True)
response = requests.post(url, headers=headers, json=payload)
streamIters = StreamIterWrapper(response, _is_private_llm())
return streamIters
def _is_private_llm():
api_base_url = os.environ.get("OPENAI_API_BASE", "")
return not api_base_url.startswith("https://api.minimax.chat")
def _make_api_url():
api_base_url = os.environ.get("OPENAI_API_BASE", None)
if not api_base_url:
raise ValueError("minimax api url is not set")
if api_base_url.startswith("https://api.minimax.chat"):
if api_base_url.endswith("/"):
api_base_url = api_base_url[:-1]
if not api_base_url.endswith("/v1"):
api_base_url = api_base_url + "/v1"
api_base_url += "/text/chatcompletion_pro"
api_key = os.environ.get("OPENAI_API_KEY", None)
if not api_key:
raise ValueError("minimax api key is not set")
group_id = api_key.split("##")[0]
api_base_url += f"?GroupId={group_id}"
return api_base_url
else:
if api_base_url.endswith("/"):
api_base_url = api_base_url[:-1]
if not api_base_url.endswith("/interact"):
api_base_url = api_base_url + "/interact"
return api_base_url
def _make_api_key():
if _is_private_llm():
return ""
api_key = os.environ.get("OPENAI_API_KEY", None)
return api_key.split("##")[1]
def _make_header():
api_key = _make_api_key()
return {
**({"Authorization": f"Bearer {api_key}"} if not _is_private_llm() else {}),
"Content-Type": "application/json",
}
def _to_private_messages(messages):
new_messages = []
for message in messages:
if message["role"] == "user":
new_messages.append({"role": "user", "name": "user", "text": message["content"]})
else:
new_messages.append({"role": "ai", "name": "ai", "text": message["content"]})
new_messages.append({"role": "ai", "name": "ai", "text": ""})
return new_messages
def _make_private_payload(messages, llm_config, stream=False):
return {
"data": _to_private_messages(messages),
"model_control": {
"system_data": [
{
"role": "system",
"ai_setting": "ai",
"text": "你是minimax编码助理擅长编写代码编写注释编写测试用例并且很注重编码的规范性。",
},
],
# "alpha_frequency": 128,
# "alpha_frequency_src": 1,
# "alpha_presence": 0,
# "alpha_presence_src": 0,
# "block_ngram": 0,
# "clean_init_no_penalty_list": True,
# "context_block_ngram": 0,
# "factual_topp": False,
# "lamda_decay": 1,
# "length_penalty": 1,
# "no_penalty_list": ",",
# "omega_bound": 0,
# "repeat_filter": False,
# "repeat_sampling": 1,
# "skip_text_mask": True,
"tokens_to_generate": llm_config.get("max_tokens", 2048),
# "sampler_type": "nucleus",
"beam_width": 1,
# "delimiter": "\n",
# "min_length": 0,
# "skip_info_mask": True,
"stop_sequence": [],
# "top_p": 0.95,
"temperature": llm_config.get("temperature", 0.95),
},
"stream": stream,
}
def _to_public_messages(messages):
new_messages = []
for message in messages:
if message["role"] == "user":
new_messages.append(
{"sender_type": "USER", "sender_name": "USER", "text": message["content"]}
)
else:
new_messages.append(
{"sender_type": "BOT", "sender_name": "ai", "text": message["content"]}
)
return new_messages
def _make_public_payload(messages, llm_config, stream=False):
response = {
"model": "abab5.5-chat",
"tokens_to_generate": llm_config.get("max_tokens", 2048),
"temperature": llm_config.get("temperature", 0.1),
# "top_p": 0.9,
"reply_constraints": {"sender_type": "BOT", "sender_name": "ai"},
"sample_messages": [],
"plugins": [],
"messages": _to_public_messages(messages),
"bot_setting": [
{
"bot_name": "ai",
"content": (
"MM智能助理是一款由MiniMax自研的"
"没有调用其他产品的接口的大型语言模型。"
"MiniMax是一家中国科技公司一直致力于进行大模型相关的研究。"
),
}
],
"stream": stream,
}
return response

View File

@ -7,8 +7,10 @@ import sys
import openai import openai
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(os.path.dirname(__file__))
from ide_services.services import log_warn from ide_services.services import log_warn
from minimax_chat import stream_chat_completion
def _try_remove_markdown_block_flag(content): def _try_remove_markdown_block_flag(content):
@ -46,18 +48,23 @@ def chat_completion_stream(messages, llm_config, error_out: bool = True, stream_
""" """
for try_times in range(3): for try_times in range(3):
try: try:
client = openai.OpenAI( if llm_config.get("model", "").startswith("abab"):
api_key=os.environ.get("OPENAI_API_KEY", None), response = stream_chat_completion(messages, llm_config)
base_url=os.environ.get("OPENAI_API_BASE", None), else:
) client = openai.OpenAI(
api_key=os.environ.get("OPENAI_API_KEY", None),
base_url=os.environ.get("OPENAI_API_BASE", None),
)
llm_config["stream"] = True llm_config["stream"] = True
llm_config["timeout"] = 8 llm_config["timeout"] = 8
response = client.chat.completions.create(messages=messages, **llm_config) response = client.chat.completions.create(messages=messages, **llm_config)
response_result = {"content": None, "function_name": None, "parameters": ""} response_result = {"content": None, "function_name": None, "parameters": ""}
for chunk in response: # pylint: disable=E1133 for chunk in response: # pylint: disable=E1133
chunk = chunk.dict() if not isinstance(chunk, dict):
chunk = chunk.dict()
delta = chunk["choices"][0]["delta"] delta = chunk["choices"][0]["delta"]
if "tool_calls" in delta and delta["tool_calls"]: if "tool_calls" in delta and delta["tool_calls"]:
tool_call = delta["tool_calls"][0]["function"] tool_call = delta["tool_calls"][0]["function"]

5
rewrite/command.yml Normal file
View File

@ -0,0 +1,5 @@
description: rewrite selected code.
hint: question
input: required
steps:
- run: $devchat_python $command_path/rewrite.py "$input"

View File

@ -0,0 +1,3 @@
description: add doc comment for selected code
steps:
- run: $devchat_python $command_path/../rewrite.py "add doc comment"

View File

@ -0,0 +1,4 @@
description: Optimizing variable and function names in code.
steps:
- run: $devchat_python $command_path/../rewrite.py "Refine internal variable and function names within the code to achieve concise and meaningful identifiers that comply with English naming conventions."

View File

@ -0,0 +1,3 @@
description: improve the readability and conformity of a given string to English language conventions.
steps:
- run: $devchat_python $command_path/../rewrite.py "enhance the given string's readability and ensure it adheres to English linguistic standards"

122
rewrite/rewrite.py Normal file
View File

@ -0,0 +1,122 @@
import os
import sys
import re
import json
home = os.path.expanduser("~")
org_libs_path = os.path.join(home, ".chat", "workflows", "org", "libs")
sys_libs_path = os.path.join(home, ".chat", "workflows", "sys", "libs")
sys.path.append(org_libs_path)
sys.path.append(sys_libs_path)
from llm_api import chat_completion_stream # noqa: E402
from ide_services.services import visible_lines, selected_lines, diff_apply # noqa: E402
def create_prompt():
question = sys.argv[1]
visible_data = visible_lines()
selected_data = selected_lines()
file_path = visible_data["filePath"]
if not os.path.exists(file_path):
print("Current file is not valid filename:", file_path, file=sys.stderr, flush=True)
sys.exit(-1)
if selected_data["selectedText"] == "":
print("Please select some text.", file=sys.stderr, flush=True)
sys.exit(-1)
prompt = f"""
你的任务是:
{question}
根据任务要求仅修改选中的代码部分请确保修改后的代码段与选中的代码保持相同的缩进\
以便与现有的代码结构无缝集成并保持正确的语法只重构选中的代码保留所有其他信息\
以下是您参考的相关上下文信息
1. 选中代码信息: {selected_data}
2. 可视窗口代码信息: {visible_data}
"""
return prompt
def extract_markdown_block(text):
"""
Extracts the first Markdown code block from the given text without the language specifier.
:param text: A string containing Markdown text
:return: The content of the first Markdown code block, or None if not found
"""
# 正则表达式匹配Markdown代码块忽略可选的语言类型标记
pattern = r"```(?:\w+)?\s*\n(.*?)\n```"
match = re.search(pattern, text, re.DOTALL)
if match:
# 返回第一个匹配的代码块内容,去除首尾的反引号和语言类型标记
# 去除块结束标记前的一个换行符,但保留其他内容
block_content = match.group(1)
return block_content
else:
# 如果没有找到匹配项返回None
return text
def replace_selected(new_code):
selected_data = selected_lines()
select_file = selected_data["filePath"]
select_range = selected_data["selectedRange"] # [start_line, start_col, end_line, end_col]
# Read the file
with open(select_file, "r") as file:
lines = file.readlines()
lines.append("\n")
# Modify the selected lines
start_line, start_col, end_line, end_col = select_range
# If the selection spans multiple lines, handle the last line and delete the lines in between
if start_line != end_line:
lines[start_line] = lines[start_line][:start_col] + new_code
# Append the text after the selection on the last line
lines[start_line] += lines[end_line][end_col:]
# Delete the lines between start_line and end_line
del lines[start_line + 1 : end_line + 1]
else:
# If the selection is within a single line, remove the selected text
lines[start_line] = lines[start_line][:start_col] + new_code + lines[end_line][end_col:]
# Combine everything back together
modified_text = "".join(lines)
# Write the changes back to the file
with open(select_file, "w") as file:
file.write(modified_text)
def main():
# messages = json.loads(
# os.environ.get("CONTEXT_CONTENTS", json.dumps([{"role": "user", "content": ""}]))
# )
messages = [{"role": "user", "content": create_prompt()}]
response = chat_completion_stream(messages, {"model": os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")}, stream_out=True)
if not response:
sys.exit(-1)
print("\n")
new_code = extract_markdown_block(response["content"])
# Check if new_code is empty and handle the case appropriately
if not new_code:
print("Parsing result failed. Exiting with error code -1", file=sys.stderr)
sys.exit(-1)
# replace_selected(new_code)
selected_data = selected_lines()
select_file = selected_data["filePath"]
diff_apply("", new_code)
# print(response["content"], flush=True)
sys.exit(0)
if __name__ == "__main__":
main()
# print("hello")

View File

@ -14,7 +14,6 @@ class DirectoryStructureBase(ABC):
) -> None: ) -> None:
self._root_path = root_path self._root_path = root_path
self._client = OpenAI()
self._chat_language = chat_language self._chat_language = chat_language
@property @property

View File

@ -1,20 +1,20 @@
import json import json
import os
from pathlib import Path from pathlib import Path
from typing import Callable, List from typing import Callable, List
import tiktoken
from assistants.directory_structure.base import DirectoryStructureBase from assistants.directory_structure.base import DirectoryStructureBase
from assistants.rerank_files import rerank_files from assistants.rerank_files import rerank_files
from minimax_util import chat_completion_no_stream_return_json
from openai_util import create_chat_completion_content from openai_util import create_chat_completion_content
from tools.directory_viewer import ( from tools.directory_viewer import ListViewer
ListViewer, from tools.tiktoken_util import get_encoding
)
class RelevantFileFinder(DirectoryStructureBase): class RelevantFileFinder(DirectoryStructureBase):
model_name = "gpt-3.5-turbo-1106" model_name = "gpt-3.5-turbo-1106"
dir_token_budget = 16000 * 0.95 dir_token_budget = 16000 * 0.95
encoding: tiktoken.Encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-1106") encoding = get_encoding("cl100k_base")
def _paginate_dir_structure( def _paginate_dir_structure(
self, criteria: Callable[[Path], bool], style: str = "list" self, criteria: Callable[[Path], bool], style: str = "list"
@ -79,23 +79,48 @@ class RelevantFileFinder(DirectoryStructureBase):
return message return message
def _mk_message_cn(self, objective: str, dir_structure: str) -> str:
message = f"""
你是一位智能编程助手你的任务是理解代码库的目录结构推测各目录的作用
并根据用户输入的问题或者目标找到与之最相关的10个文件
请注意你的目的并不是分析各个文件里的代码而是通过分析项目的结构来判断
哪个文件最有可能包含用户需要的信息
以下是代码库的目录结构
{dir_structure}
请根据你的理解找出10个和以下问题最相关的文件
{objective}
请按以下JSON格式回复
{{
"files": ["<文件1的路径>" , "<文件2的路径>", "<文件3的路径>", ... ]
}}
"""
return message
def _find_relevant_files(self, objective: str, dir_structure_pages: List[str]) -> List[str]: def _find_relevant_files(self, objective: str, dir_structure_pages: List[str]) -> List[str]:
files: List[str] = [] files: List[str] = []
for dir_structure in dir_structure_pages: for dir_structure in dir_structure_pages:
user_msg = self._mk_message(objective, dir_structure) # user_msg = self._mk_message(objective, dir_structure)
user_msg = self._mk_message_cn(objective, dir_structure)
response = create_chat_completion_content( model = os.environ.get("LLM_MODEL", self.model_name)
client=self._client,
model=self.model_name, json_res = chat_completion_no_stream_return_json(
messages=[ messages=[{"role": "user", "content": user_msg}],
{"role": "user", "content": user_msg}, llm_config={
], "model": model,
response_format={"type": "json_object"}, "temperature": 0.1,
temperature=0.1, },
) )
json_res = json.loads(response)
files.extend(json_res.get("files", [])) files.extend(json_res.get("files", []))
reranked = rerank_files( reranked = rerank_files(

View File

@ -1,6 +1,9 @@
import json import json
import os
import sys
from typing import List, Tuple from typing import List, Tuple
from minimax_util import chat_completion_no_stream_return_json
from openai_util import create_chat_completion_content from openai_util import create_chat_completion_content
# ruff: noqa: E501 # ruff: noqa: E501
@ -28,6 +31,37 @@ Accumulated Knowledge: {accumulated_knowledge}
Answer: Answer:
""" """
rerank_file_prompt_cn = """
你是一位智能编程助手
用户将给你一个文件列表并提出一个和编程相关的问题
根据你对问题和每个文件可能包含的内容的理解
找到文件列表中有助于回答该问题的文件
并判断文件与问题的相关程度从1到10为文件评分10分表示相关度非常高1分表示相关度很低
最后将文件按相关度从高到低排序
请注意返回的所有文件路径都应在用户给出的文件列表里不能额外添加文件也不能修改文件路径
以下是用户给出的文件列表
{files}
用户的问题是: {question}
目前积累的相关背景知识: {accumulated_knowledge}
请按以下JSON格式回复
{{
"result": [
{{"item": "<最相关的文件路径>", "relevance": 7}},
{{"item": "<第二相关的文件路径>", "relevance": 4}},
{{"item": "<第三相关的文件路径>", "relevance": 3}}
]
}}
"""
RERANK_MODEL = "gpt-3.5-turbo-1106" RERANK_MODEL = "gpt-3.5-turbo-1106"
@ -50,25 +84,28 @@ def rerank_files(
assert isinstance(file, str), "items must be a list of str when item_type is 'file'" assert isinstance(file, str), "items must be a list of str when item_type is 'file'"
files_str += f"- {file}\n" files_str += f"- {file}\n"
user_msg = rerank_file_prompt.format( user_msg = rerank_file_prompt_cn.format(
files=files_str, files=files_str,
question=question, question=question,
accumulated_knowledge=knowledge, accumulated_knowledge=knowledge,
) )
response = create_chat_completion_content( model = os.environ.get("LLM_MODEL", RERANK_MODEL)
model=RERANK_MODEL, result = chat_completion_no_stream_return_json(
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": user_msg, "content": user_msg,
}, },
], ],
response_format={"type": "json_object"}, llm_config={
temperature=0.1, "model": model,
"temperature": 0.1,
},
) )
if not result:
return []
result = json.loads(response)
reranked = [(i["item"], i["relevance"]) for i in result["result"]] reranked = [(i["item"], i["relevance"]) for i in result["result"]]
return reranked return reranked

View File

@ -1,3 +1,3 @@
description: Generate unit tests. description: Generate unit tests.
steps: steps:
- run: $command_python $command_path/main.py "$input" - run: $devchat_python $command_path/main.py "$input"

View File

@ -1,7 +1,9 @@
from typing import List from typing import List
from assistants.directory_structure.relevant_file_finder import RelevantFileFinder from assistants.directory_structure.relevant_file_finder import RelevantFileFinder
from prompts import FIND_REFERENCE_PROMPT
# from prompts import FIND_REFERENCE_PROMPT
from prompts_cn import FIND_REFERENCE_PROMPT
from tools.file_util import verify_file_list from tools.file_util import verify_file_list

View File

@ -4,9 +4,12 @@ from enum import Enum
class TUILanguage(Enum): class TUILanguage(Enum):
EN = ("en", "English") EN = ("en", "英文")
ZH = ("zh", "Chinese") ZH = ("zh", "中文")
Other = ("en", "English") # default to show English Other = ("zh", "中文")
# EN = ("en", "English")
# ZH = ("zh", "Chinese")
# Other = ("en", "English") # default to show English
@classmethod @classmethod
def from_str(cls, language: str) -> "TUILanguage": def from_str(cls, language: str) -> "TUILanguage":

View File

@ -2,15 +2,17 @@ import os
import sys import sys
import click import click
from find_reference_tests import find_reference_tests
from i18n import TUILanguage, get_translation
from model import FuncToTest, TokenBudgetExceededException, UserCancelledException
from propose_test import propose_test
from tools.file_util import retrieve_file_content
from write_tests import write_and_print_tests
sys.path.append(os.path.dirname(__file__))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
from find_reference_tests import find_reference_tests # noqa: E402
from i18n import TUILanguage, get_translation # noqa: E402
from model import FuncToTest, TokenBudgetExceededException, UserCancelledException # noqa: E402
from propose_test import propose_test # noqa: E402
from tools.file_util import retrieve_file_content # noqa: E402
from write_tests import write_and_print_tests # noqa: E402
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402 from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
from ide_services import ide_language # noqa: E402 from ide_services import ide_language # noqa: E402
@ -131,7 +133,8 @@ def main(input: str):
user_prompt = f"Help me write unit tests for the `{func_name}` function" user_prompt = f"Help me write unit tests for the `{func_name}` function"
repo_root = os.getcwd() repo_root = os.getcwd()
ide_lang = ide_language() # ide_lang = ide_language()
ide_lang = "zh"
tui_lang = TUILanguage.from_str(ide_lang) tui_lang = TUILanguage.from_str(ide_lang)
_i = get_translation(tui_lang) _i = get_translation(tui_lang)

View File

@ -0,0 +1,9 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
from llm_api import (
chat_completion_no_stream_return_json,
chat_completion_stream,
)

68
unit_tests/prompts_cn.py Normal file
View File

@ -0,0 +1,68 @@
# ruff: noqa: E501
# Don not limit the length of each line of the prompts.
PROPOSE_TEST_PROMPT = """
你是一位智能测试用例生成助手
给定一个用户提示和一个目标函数请根据提示为目标函数生成测试用例
用户提示如下
{user_prompt}
目标函数是 `{function_name}` 该函数所在的文件是 {file_path}
以下是与该函数相关的源代码
{relevant_content}
请为每个测试用例提供一句话的描述描述该测试用例所测试的行为
你不需要用代码编写测试用例只需用普通的自然语言描述即可
最多生成 6 个测试用例
请按照以下JSON格式回复
{{
"test_cases": [
{{"description": "<测试用例1的自然语言描述>"}},
{{"description": "<测试用例2的自然语言描述>"}},
]
}}
"""
FIND_REFERENCE_PROMPT = """
请找到一个合适的参考测试文件该文件可用于为文件 {file_path} 中的 `{function_name}` 函数编写测试提供指导
参考文件应该提供一个清晰的示例演示测试类似性质的函数的最佳实践
"""
WRITE_TESTS_PROMPT = """
你是一位智能单元测试生成助手
给定一个目标函数一些参考代码和一系列具体的测试用例描述请完成单元测试的代码编写
每个测试函数都应该是独立且可执行的
请确保与参考代码使用相同的测试框架模拟对象库断言库等并采取相似的模拟策略判断策略等
目标函数是 `{function_name}`该函数位于文件 {file_path}
以下是与该函数相关的源代码
{relevant_content}
{reference_content}
以下是测试用例列表
{test_cases_str}
请按照以下格式回复
测试用例 1. <测试用例1的原始描述>
<测试用例1对应的测试函数代码>
测试用例 2. <测试用例2的原始描述>
<测试用例2对应的测试函数代码>
"""

View File

@ -1,14 +1,20 @@
import json import json
import os
import sys
from functools import partial from functools import partial
from typing import List from typing import List
import tiktoken from minimax_util import chat_completion_no_stream_return_json
from model import FuncToTest, TokenBudgetExceededException from model import FuncToTest, TokenBudgetExceededException
from openai_util import create_chat_completion_content from openai_util import create_chat_completion_content
from prompts import PROPOSE_TEST_PROMPT
# from prompts import PROPOSE_TEST_PROMPT
from prompts_cn import PROPOSE_TEST_PROMPT
from tools.tiktoken_util import get_encoding
MODEL = "gpt-3.5-turbo-1106" MODEL = "gpt-3.5-turbo-1106"
# MODEL = "gpt-4-1106-preview" # MODEL = "gpt-4-1106-preview"
ENCODING = "cl100k_base"
TOKEN_BUDGET = int(16000 * 0.9) TOKEN_BUDGET = int(16000 * 0.9)
@ -20,7 +26,7 @@ def _mk_user_msg(
""" """
Create a user message to be sent to the model within the token budget. Create a user message to be sent to the model within the token budget.
""" """
encoding: tiktoken.Encoding = tiktoken.encoding_for_model(MODEL) encoding = get_encoding(ENCODING)
func_content = f"function code\n```\n{func_to_test.func_content}\n```\n" func_content = f"function code\n```\n{func_to_test.func_content}\n```\n"
class_content = "" class_content = ""
@ -81,14 +87,16 @@ def propose_test(
chat_language=chat_language, chat_language=chat_language,
) )
content = create_chat_completion_content( model = os.environ.get("LLM_MODEL", MODEL)
model=MODEL, content = chat_completion_no_stream_return_json(
messages=[{"role": "user", "content": user_msg}], messages=[{"role": "user", "content": user_msg}],
response_format={"type": "json_object"}, llm_config={
temperature=0.1, "model": model,
"temperature": 0.1,
},
) )
cases = json.loads(content).get("test_cases", []) cases = content.get("test_cases", [])
descriptions = [] descriptions = []
for case in cases: for case in cases:

View File

@ -0,0 +1,20 @@
import tiktoken
def get_encoding(encoding_name: str):
"""
Get a tiktoken encoding by name.
"""
try:
return tiktoken.get_encoding(encoding_name)
except Exception:
from tiktoken import registry
from tiktoken.core import Encoding
from tiktoken.registry import _find_constructors
def _get_encoding(name: str):
_find_constructors()
constructor = registry.ENCODING_CONSTRUCTORS[name]
return Encoding(**constructor(), use_pure_python=True)
return _get_encoding(encoding_name)

View File

@ -1,13 +1,18 @@
import os
from functools import partial from functools import partial
from typing import List, Optional from typing import List, Optional
import tiktoken from minimax_util import chat_completion_stream
from model import FuncToTest, TokenBudgetExceededException from model import FuncToTest, TokenBudgetExceededException
from openai_util import create_chat_completion_chunks from openai_util import create_chat_completion_chunks
from prompts import WRITE_TESTS_PROMPT
# from prompts import WRITE_TESTS_PROMPT
from prompts_cn import WRITE_TESTS_PROMPT
from tools.file_util import retrieve_file_content from tools.file_util import retrieve_file_content
from tools.tiktoken_util import get_encoding
MODEL = "gpt-4-1106-preview" MODEL = "gpt-4-1106-preview"
ENCODING = "cl100k_base"
TOKEN_BUDGET = int(128000 * 0.9) TOKEN_BUDGET = int(128000 * 0.9)
@ -18,7 +23,7 @@ def _mk_write_tests_msg(
chat_language: str, chat_language: str,
reference_files: Optional[List[str]] = None, reference_files: Optional[List[str]] = None,
) -> Optional[str]: ) -> Optional[str]:
encoding: tiktoken.Encoding = tiktoken.encoding_for_model(MODEL) encoding = get_encoding(ENCODING)
test_cases_str = "" test_cases_str = ""
for i, test_case in enumerate(test_cases, 1): for i, test_case in enumerate(test_cases, 1):
@ -92,13 +97,9 @@ def write_and_print_tests(
chat_language=chat_language, chat_language=chat_language,
) )
chunks = create_chat_completion_chunks( model = os.environ.get("LLM_MODEL", MODEL)
model=MODEL, chat_completion_stream(
messages=[{"role": "user", "content": user_msg}], messages=[{"role": "user", "content": user_msg}],
temperature=0.1, llm_config={"model": model, "temperature": 0.1},
stream_out=True,
) )
for chunk in chunks:
if chunk.choices[0].finish_reason == "stop":
break
print(chunk.choices[0].delta.content, flush=True, end="")