Compare commits
22 Commits
scripts
...
minimax_ma
Author | SHA1 | Date | |
---|---|---|---|
![]() |
accf557fe6 | ||
![]() |
4d80089598 | ||
![]() |
b82ebd8af1 | ||
![]() |
09daeef589 | ||
![]() |
e9a28f6d9a | ||
![]() |
ae12062d34 | ||
![]() |
7be108b874 | ||
![]() |
d49b2f4732 | ||
![]() |
5ced3dcfa8 | ||
![]() |
9138e1569e | ||
![]() |
17f4aba1b4 | ||
![]() |
b7ca8127d1 | ||
![]() |
d3008d5237 | ||
![]() |
7c1252aee2 | ||
![]() |
da51cbfc5f | ||
![]() |
8dc6c5ffe9 | ||
![]() |
7652f73537 | ||
![]() |
aff4142d24 | ||
![]() |
12e5ce6a57 | ||
![]() |
2daf4ed765 | ||
![]() |
94249cd6d1 | ||
![]() |
bc5f4e23a9 |
234
libs/llm_api/minimax_chat.py
Normal file
234
libs/llm_api/minimax_chat.py
Normal file
@ -0,0 +1,234 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class StreamIterWrapper:
|
||||
def __init__(self, response, is_private=True):
|
||||
self.response = response
|
||||
self.create_time = int(time.time())
|
||||
self.line_iterator = response.iter_lines()
|
||||
self.is_private = is_private
|
||||
self.stop = False
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
if self.stop:
|
||||
raise StopIteration
|
||||
response_line = next(self.line_iterator)
|
||||
if response_line == b"":
|
||||
return self.__next__()
|
||||
if response_line == b"\n":
|
||||
return self.__next__()
|
||||
|
||||
response_line = response_line.replace(b"data: ", b"")
|
||||
response_result = json.loads(response_line.decode("utf-8"))
|
||||
if self.is_private:
|
||||
if "finish" in response_result and response_result["finish"] == True:
|
||||
self.stop = True
|
||||
if "err" in response_result and response_result["err"]:
|
||||
raise ValueError(f"minimax api response error: {response_result['err']}")
|
||||
if not self.is_private:
|
||||
if response_result["choices"][0].get("finish_reason", None):
|
||||
raise StopIteration
|
||||
|
||||
data = {}
|
||||
if self.is_private:
|
||||
data = json.loads(response_result["data"])
|
||||
|
||||
stream_response = {
|
||||
"id": f"minimax_{self.create_time}",
|
||||
"created": self.create_time,
|
||||
"object": "chat.completion.chunk",
|
||||
"model": response_result.get("model", "abab5.5-chat"),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"content": response_result["choices"][0]["messages"][0]["text"] if not self.is_private else data.get("text", ""),
|
||||
},
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 100},
|
||||
}
|
||||
|
||||
return stream_response
|
||||
except StopIteration as exc: # If there is no more event
|
||||
raise StopIteration from exc
|
||||
except Exception as err:
|
||||
print("Exception:", err.__class__.__name__, err, file=sys.stderr, end="\n\n")
|
||||
raise StopIteration from err
|
||||
|
||||
|
||||
def chat_completion(messages, llm_config):
|
||||
url = _make_api_url()
|
||||
headers = _make_header()
|
||||
if _is_private_llm():
|
||||
payload = _make_private_payload(messages, llm_config)
|
||||
else:
|
||||
payload = _make_public_payload(messages, llm_config)
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response_json = json.loads(response.text)
|
||||
if not response_json.get("texts", []):
|
||||
raise ValueError(f"minimax api response error: {response_json}")
|
||||
return {"content": response_json["texts"][0]}
|
||||
|
||||
|
||||
def stream_chat_completion(messages, llm_config):
|
||||
url = _make_api_url()
|
||||
headers = _make_header()
|
||||
if _is_private_llm():
|
||||
payload = _make_private_payload(messages, llm_config, True)
|
||||
else:
|
||||
payload = _make_public_payload(messages, llm_config, True)
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
streamIters = StreamIterWrapper(response, _is_private_llm())
|
||||
return streamIters
|
||||
|
||||
|
||||
def _is_private_llm():
|
||||
api_base_url = os.environ.get("OPENAI_API_BASE", "")
|
||||
return not api_base_url.startswith("https://api.minimax.chat")
|
||||
|
||||
|
||||
def _make_api_url():
|
||||
api_base_url = os.environ.get("OPENAI_API_BASE", None)
|
||||
if not api_base_url:
|
||||
raise ValueError("minimax api url is not set")
|
||||
|
||||
if api_base_url.startswith("https://api.minimax.chat"):
|
||||
if api_base_url.endswith("/"):
|
||||
api_base_url = api_base_url[:-1]
|
||||
if not api_base_url.endswith("/v1"):
|
||||
api_base_url = api_base_url + "/v1"
|
||||
api_base_url += "/text/chatcompletion_pro"
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY", None)
|
||||
if not api_key:
|
||||
raise ValueError("minimax api key is not set")
|
||||
|
||||
group_id = api_key.split("##")[0]
|
||||
api_base_url += f"?GroupId={group_id}"
|
||||
return api_base_url
|
||||
else:
|
||||
if api_base_url.endswith("/"):
|
||||
api_base_url = api_base_url[:-1]
|
||||
if not api_base_url.endswith("/interact"):
|
||||
api_base_url = api_base_url + "/interact"
|
||||
return api_base_url
|
||||
|
||||
|
||||
def _make_api_key():
|
||||
if _is_private_llm():
|
||||
return ""
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY", None)
|
||||
return api_key.split("##")[1]
|
||||
|
||||
|
||||
def _make_header():
|
||||
api_key = _make_api_key()
|
||||
return {
|
||||
**({"Authorization": f"Bearer {api_key}"} if not _is_private_llm() else {}),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def _to_private_messages(messages):
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
new_messages.append({"role": "user", "name": "user", "text": message["content"]})
|
||||
else:
|
||||
new_messages.append({"role": "ai", "name": "ai", "text": message["content"]})
|
||||
new_messages.append({"role": "ai", "name": "ai", "text": ""})
|
||||
return new_messages
|
||||
|
||||
|
||||
def _make_private_payload(messages, llm_config, stream=False):
|
||||
return {
|
||||
"data": _to_private_messages(messages),
|
||||
"model_control": {
|
||||
"system_data": [
|
||||
{
|
||||
"role": "system",
|
||||
"ai_setting": "ai",
|
||||
"text": "你是minimax编码助理,擅长编写代码,编写注释,编写测试用例,并且很注重编码的规范性。",
|
||||
},
|
||||
],
|
||||
# "alpha_frequency": 128,
|
||||
# "alpha_frequency_src": 1,
|
||||
# "alpha_presence": 0,
|
||||
# "alpha_presence_src": 0,
|
||||
# "block_ngram": 0,
|
||||
# "clean_init_no_penalty_list": True,
|
||||
# "context_block_ngram": 0,
|
||||
# "factual_topp": False,
|
||||
# "lamda_decay": 1,
|
||||
# "length_penalty": 1,
|
||||
# "no_penalty_list": ",",
|
||||
# "omega_bound": 0,
|
||||
# "repeat_filter": False,
|
||||
# "repeat_sampling": 1,
|
||||
# "skip_text_mask": True,
|
||||
"tokens_to_generate": llm_config.get("max_tokens", 2048),
|
||||
# "sampler_type": "nucleus",
|
||||
"beam_width": 1,
|
||||
# "delimiter": "\n",
|
||||
# "min_length": 0,
|
||||
# "skip_info_mask": True,
|
||||
"stop_sequence": [],
|
||||
# "top_p": 0.95,
|
||||
"temperature": llm_config.get("temperature", 0.95),
|
||||
},
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
|
||||
def _to_public_messages(messages):
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
new_messages.append(
|
||||
{"sender_type": "USER", "sender_name": "USER", "text": message["content"]}
|
||||
)
|
||||
else:
|
||||
new_messages.append(
|
||||
{"sender_type": "BOT", "sender_name": "ai", "text": message["content"]}
|
||||
)
|
||||
return new_messages
|
||||
|
||||
|
||||
def _make_public_payload(messages, llm_config, stream=False):
|
||||
response = {
|
||||
"model": "abab5.5-chat",
|
||||
"tokens_to_generate": llm_config.get("max_tokens", 2048),
|
||||
"temperature": llm_config.get("temperature", 0.1),
|
||||
# "top_p": 0.9,
|
||||
"reply_constraints": {"sender_type": "BOT", "sender_name": "ai"},
|
||||
"sample_messages": [],
|
||||
"plugins": [],
|
||||
"messages": _to_public_messages(messages),
|
||||
"bot_setting": [
|
||||
{
|
||||
"bot_name": "ai",
|
||||
"content": (
|
||||
"MM智能助理是一款由MiniMax自研的,"
|
||||
"没有调用其他产品的接口的大型语言模型。"
|
||||
"MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。"
|
||||
),
|
||||
}
|
||||
],
|
||||
"stream": stream,
|
||||
}
|
||||
return response
|
@ -7,8 +7,10 @@ import sys
|
||||
import openai
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
from ide_services.services import log_warn
|
||||
from minimax_chat import stream_chat_completion
|
||||
|
||||
|
||||
def _try_remove_markdown_block_flag(content):
|
||||
@ -46,6 +48,9 @@ def chat_completion_stream(messages, llm_config, error_out: bool = True, stream_
|
||||
"""
|
||||
for try_times in range(3):
|
||||
try:
|
||||
if llm_config.get("model", "").startswith("abab"):
|
||||
response = stream_chat_completion(messages, llm_config)
|
||||
else:
|
||||
client = openai.OpenAI(
|
||||
api_key=os.environ.get("OPENAI_API_KEY", None),
|
||||
base_url=os.environ.get("OPENAI_API_BASE", None),
|
||||
@ -57,7 +62,9 @@ def chat_completion_stream(messages, llm_config, error_out: bool = True, stream_
|
||||
|
||||
response_result = {"content": None, "function_name": None, "parameters": ""}
|
||||
for chunk in response: # pylint: disable=E1133
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if "tool_calls" in delta and delta["tool_calls"]:
|
||||
tool_call = delta["tool_calls"][0]["function"]
|
||||
|
5
rewrite/command.yml
Normal file
5
rewrite/command.yml
Normal file
@ -0,0 +1,5 @@
|
||||
description: rewrite selected code.
|
||||
hint: question
|
||||
input: required
|
||||
steps:
|
||||
- run: $devchat_python $command_path/rewrite.py "$input"
|
3
rewrite/doc_comment/command.yml
Normal file
3
rewrite/doc_comment/command.yml
Normal file
@ -0,0 +1,3 @@
|
||||
description: add doc comment for selected code
|
||||
steps:
|
||||
- run: $devchat_python $command_path/../rewrite.py "add doc comment"
|
4
rewrite/optimize_names/command.yml
Normal file
4
rewrite/optimize_names/command.yml
Normal file
@ -0,0 +1,4 @@
|
||||
description: Optimizing variable and function names in code.
|
||||
steps:
|
||||
- run: $devchat_python $command_path/../rewrite.py "Refine internal variable and function names within the code to achieve concise and meaningful identifiers that comply with English naming conventions."
|
||||
|
3
rewrite/optimize_string/command.yml
Normal file
3
rewrite/optimize_string/command.yml
Normal file
@ -0,0 +1,3 @@
|
||||
description: improve the readability and conformity of a given string to English language conventions.
|
||||
steps:
|
||||
- run: $devchat_python $command_path/../rewrite.py "enhance the given string's readability and ensure it adheres to English linguistic standards"
|
122
rewrite/rewrite.py
Normal file
122
rewrite/rewrite.py
Normal file
@ -0,0 +1,122 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import json
|
||||
|
||||
home = os.path.expanduser("~")
|
||||
org_libs_path = os.path.join(home, ".chat", "workflows", "org", "libs")
|
||||
sys_libs_path = os.path.join(home, ".chat", "workflows", "sys", "libs")
|
||||
sys.path.append(org_libs_path)
|
||||
sys.path.append(sys_libs_path)
|
||||
|
||||
from llm_api import chat_completion_stream # noqa: E402
|
||||
from ide_services.services import visible_lines, selected_lines, diff_apply # noqa: E402
|
||||
|
||||
|
||||
def create_prompt():
|
||||
question = sys.argv[1]
|
||||
|
||||
visible_data = visible_lines()
|
||||
selected_data = selected_lines()
|
||||
|
||||
file_path = visible_data["filePath"]
|
||||
if not os.path.exists(file_path):
|
||||
print("Current file is not valid filename:", file_path, file=sys.stderr, flush=True)
|
||||
sys.exit(-1)
|
||||
|
||||
if selected_data["selectedText"] == "":
|
||||
print("Please select some text.", file=sys.stderr, flush=True)
|
||||
sys.exit(-1)
|
||||
|
||||
prompt = f"""
|
||||
你的任务是:
|
||||
{question}
|
||||
根据任务要求,仅修改选中的代码部分。请确保修改后的代码段与选中的代码保持相同的缩进,\
|
||||
以便与现有的代码结构无缝集成并保持正确的语法。只重构选中的代码。保留所有其他信息。\
|
||||
以下是您参考的相关上下文信息:
|
||||
1. 选中代码信息: {selected_data}
|
||||
2. 可视窗口代码信息: {visible_data}
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
def extract_markdown_block(text):
|
||||
"""
|
||||
Extracts the first Markdown code block from the given text without the language specifier.
|
||||
|
||||
:param text: A string containing Markdown text
|
||||
:return: The content of the first Markdown code block, or None if not found
|
||||
"""
|
||||
# 正则表达式匹配Markdown代码块,忽略可选的语言类型标记
|
||||
pattern = r"```(?:\w+)?\s*\n(.*?)\n```"
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
|
||||
if match:
|
||||
# 返回第一个匹配的代码块内容,去除首尾的反引号和语言类型标记
|
||||
# 去除块结束标记前的一个换行符,但保留其他内容
|
||||
block_content = match.group(1)
|
||||
return block_content
|
||||
else:
|
||||
# 如果没有找到匹配项,返回None
|
||||
return text
|
||||
|
||||
|
||||
def replace_selected(new_code):
|
||||
selected_data = selected_lines()
|
||||
select_file = selected_data["filePath"]
|
||||
select_range = selected_data["selectedRange"] # [start_line, start_col, end_line, end_col]
|
||||
|
||||
# Read the file
|
||||
with open(select_file, "r") as file:
|
||||
lines = file.readlines()
|
||||
lines.append("\n")
|
||||
|
||||
# Modify the selected lines
|
||||
start_line, start_col, end_line, end_col = select_range
|
||||
|
||||
# If the selection spans multiple lines, handle the last line and delete the lines in between
|
||||
if start_line != end_line:
|
||||
lines[start_line] = lines[start_line][:start_col] + new_code
|
||||
# Append the text after the selection on the last line
|
||||
lines[start_line] += lines[end_line][end_col:]
|
||||
# Delete the lines between start_line and end_line
|
||||
del lines[start_line + 1 : end_line + 1]
|
||||
else:
|
||||
# If the selection is within a single line, remove the selected text
|
||||
lines[start_line] = lines[start_line][:start_col] + new_code + lines[end_line][end_col:]
|
||||
|
||||
# Combine everything back together
|
||||
modified_text = "".join(lines)
|
||||
|
||||
# Write the changes back to the file
|
||||
with open(select_file, "w") as file:
|
||||
file.write(modified_text)
|
||||
|
||||
|
||||
def main():
|
||||
# messages = json.loads(
|
||||
# os.environ.get("CONTEXT_CONTENTS", json.dumps([{"role": "user", "content": ""}]))
|
||||
# )
|
||||
messages = [{"role": "user", "content": create_prompt()}]
|
||||
|
||||
response = chat_completion_stream(messages, {"model": os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")}, stream_out=True)
|
||||
if not response:
|
||||
sys.exit(-1)
|
||||
print("\n")
|
||||
new_code = extract_markdown_block(response["content"])
|
||||
# Check if new_code is empty and handle the case appropriately
|
||||
if not new_code:
|
||||
print("Parsing result failed. Exiting with error code -1", file=sys.stderr)
|
||||
sys.exit(-1)
|
||||
|
||||
# replace_selected(new_code)
|
||||
selected_data = selected_lines()
|
||||
select_file = selected_data["filePath"]
|
||||
diff_apply("", new_code)
|
||||
# print(response["content"], flush=True)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# print("hello")
|
@ -14,7 +14,6 @@ class DirectoryStructureBase(ABC):
|
||||
) -> None:
|
||||
self._root_path = root_path
|
||||
|
||||
self._client = OpenAI()
|
||||
self._chat_language = chat_language
|
||||
|
||||
@property
|
||||
|
@ -1,20 +1,20 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, List
|
||||
|
||||
import tiktoken
|
||||
from assistants.directory_structure.base import DirectoryStructureBase
|
||||
from assistants.rerank_files import rerank_files
|
||||
from minimax_util import chat_completion_no_stream_return_json
|
||||
from openai_util import create_chat_completion_content
|
||||
from tools.directory_viewer import (
|
||||
ListViewer,
|
||||
)
|
||||
from tools.directory_viewer import ListViewer
|
||||
from tools.tiktoken_util import get_encoding
|
||||
|
||||
|
||||
class RelevantFileFinder(DirectoryStructureBase):
|
||||
model_name = "gpt-3.5-turbo-1106"
|
||||
dir_token_budget = 16000 * 0.95
|
||||
encoding: tiktoken.Encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-1106")
|
||||
encoding = get_encoding("cl100k_base")
|
||||
|
||||
def _paginate_dir_structure(
|
||||
self, criteria: Callable[[Path], bool], style: str = "list"
|
||||
@ -79,23 +79,48 @@ class RelevantFileFinder(DirectoryStructureBase):
|
||||
|
||||
return message
|
||||
|
||||
def _mk_message_cn(self, objective: str, dir_structure: str) -> str:
|
||||
message = f"""
|
||||
你是一位智能编程助手,你的任务是理解代码库的目录结构,推测各目录的作用,
|
||||
并根据用户输入的问题或者目标,找到与之最相关的10个文件。
|
||||
|
||||
请注意,你的目的并不是分析各个文件里的代码,而是通过分析项目的结构来判断
|
||||
哪个文件最有可能包含用户需要的信息。
|
||||
|
||||
以下是代码库的目录结构:
|
||||
|
||||
{dir_structure}
|
||||
|
||||
|
||||
请根据你的理解,找出10个和以下问题最相关的文件:
|
||||
|
||||
“{objective}”
|
||||
|
||||
|
||||
请按以下JSON格式回复:
|
||||
{{
|
||||
"files": ["<文件1的路径>" , "<文件2的路径>", "<文件3的路径>", ... ]
|
||||
}}
|
||||
"""
|
||||
|
||||
return message
|
||||
|
||||
def _find_relevant_files(self, objective: str, dir_structure_pages: List[str]) -> List[str]:
|
||||
files: List[str] = []
|
||||
for dir_structure in dir_structure_pages:
|
||||
user_msg = self._mk_message(objective, dir_structure)
|
||||
# user_msg = self._mk_message(objective, dir_structure)
|
||||
user_msg = self._mk_message_cn(objective, dir_structure)
|
||||
|
||||
response = create_chat_completion_content(
|
||||
client=self._client,
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "user", "content": user_msg},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.1,
|
||||
model = os.environ.get("LLM_MODEL", self.model_name)
|
||||
|
||||
json_res = chat_completion_no_stream_return_json(
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
llm_config={
|
||||
"model": model,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
)
|
||||
|
||||
json_res = json.loads(response)
|
||||
|
||||
files.extend(json_res.get("files", []))
|
||||
|
||||
reranked = rerank_files(
|
||||
|
@ -1,6 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
from minimax_util import chat_completion_no_stream_return_json
|
||||
from openai_util import create_chat_completion_content
|
||||
|
||||
# ruff: noqa: E501
|
||||
@ -28,6 +31,37 @@ Accumulated Knowledge: {accumulated_knowledge}
|
||||
Answer:
|
||||
"""
|
||||
|
||||
rerank_file_prompt_cn = """
|
||||
你是一位智能编程助手。
|
||||
用户将给你一个文件列表并提出一个和编程相关的问题,
|
||||
根据你对问题和每个文件可能包含的内容的理解,
|
||||
找到文件列表中有助于回答该问题的文件,
|
||||
并判断文件与问题的相关程度,从1到10为文件评分(10分表示相关度非常高,1分表示相关度很低),
|
||||
最后将文件按相关度从高到低排序。
|
||||
|
||||
请注意,返回的所有文件路径都应在用户给出的文件列表里,不能额外添加文件,也不能修改文件路径。
|
||||
|
||||
|
||||
以下是用户给出的文件列表:
|
||||
|
||||
{files}
|
||||
|
||||
|
||||
用户的问题是: {question}
|
||||
目前积累的相关背景知识: {accumulated_knowledge}
|
||||
|
||||
请按以下JSON格式回复:
|
||||
{{
|
||||
"result": [
|
||||
{{"item": "<最相关的文件路径>", "relevance": 7}},
|
||||
{{"item": "<第二相关的文件路径>", "relevance": 4}},
|
||||
{{"item": "<第三相关的文件路径>", "relevance": 3}}
|
||||
]
|
||||
}}
|
||||
|
||||
|
||||
"""
|
||||
|
||||
RERANK_MODEL = "gpt-3.5-turbo-1106"
|
||||
|
||||
|
||||
@ -50,25 +84,28 @@ def rerank_files(
|
||||
assert isinstance(file, str), "items must be a list of str when item_type is 'file'"
|
||||
files_str += f"- {file}\n"
|
||||
|
||||
user_msg = rerank_file_prompt.format(
|
||||
user_msg = rerank_file_prompt_cn.format(
|
||||
files=files_str,
|
||||
question=question,
|
||||
accumulated_knowledge=knowledge,
|
||||
)
|
||||
|
||||
response = create_chat_completion_content(
|
||||
model=RERANK_MODEL,
|
||||
model = os.environ.get("LLM_MODEL", RERANK_MODEL)
|
||||
result = chat_completion_no_stream_return_json(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_msg,
|
||||
},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.1,
|
||||
llm_config={
|
||||
"model": model,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
)
|
||||
if not result:
|
||||
return []
|
||||
|
||||
result = json.loads(response)
|
||||
reranked = [(i["item"], i["relevance"]) for i in result["result"]]
|
||||
|
||||
return reranked
|
||||
|
@ -1,3 +1,3 @@
|
||||
description: Generate unit tests.
|
||||
steps:
|
||||
- run: $command_python $command_path/main.py "$input"
|
||||
- run: $devchat_python $command_path/main.py "$input"
|
@ -1,7 +1,9 @@
|
||||
from typing import List
|
||||
|
||||
from assistants.directory_structure.relevant_file_finder import RelevantFileFinder
|
||||
from prompts import FIND_REFERENCE_PROMPT
|
||||
|
||||
# from prompts import FIND_REFERENCE_PROMPT
|
||||
from prompts_cn import FIND_REFERENCE_PROMPT
|
||||
from tools.file_util import verify_file_list
|
||||
|
||||
|
||||
|
@ -4,9 +4,12 @@ from enum import Enum
|
||||
|
||||
|
||||
class TUILanguage(Enum):
|
||||
EN = ("en", "English")
|
||||
ZH = ("zh", "Chinese")
|
||||
Other = ("en", "English") # default to show English
|
||||
EN = ("en", "英文")
|
||||
ZH = ("zh", "中文")
|
||||
Other = ("zh", "中文")
|
||||
# EN = ("en", "English")
|
||||
# ZH = ("zh", "Chinese")
|
||||
# Other = ("en", "English") # default to show English
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, language: str) -> "TUILanguage":
|
||||
|
@ -2,15 +2,17 @@ import os
|
||||
import sys
|
||||
|
||||
import click
|
||||
from find_reference_tests import find_reference_tests
|
||||
from i18n import TUILanguage, get_translation
|
||||
from model import FuncToTest, TokenBudgetExceededException, UserCancelledException
|
||||
from propose_test import propose_test
|
||||
from tools.file_util import retrieve_file_content
|
||||
from write_tests import write_and_print_tests
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
||||
|
||||
from find_reference_tests import find_reference_tests # noqa: E402
|
||||
from i18n import TUILanguage, get_translation # noqa: E402
|
||||
from model import FuncToTest, TokenBudgetExceededException, UserCancelledException # noqa: E402
|
||||
from propose_test import propose_test # noqa: E402
|
||||
from tools.file_util import retrieve_file_content # noqa: E402
|
||||
from write_tests import write_and_print_tests # noqa: E402
|
||||
|
||||
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
|
||||
from ide_services import ide_language # noqa: E402
|
||||
|
||||
@ -131,7 +133,8 @@ def main(input: str):
|
||||
user_prompt = f"Help me write unit tests for the `{func_name}` function"
|
||||
|
||||
repo_root = os.getcwd()
|
||||
ide_lang = ide_language()
|
||||
# ide_lang = ide_language()
|
||||
ide_lang = "zh"
|
||||
tui_lang = TUILanguage.from_str(ide_lang)
|
||||
_i = get_translation(tui_lang)
|
||||
|
||||
|
9
unit_tests/minimax_util.py
Normal file
9
unit_tests/minimax_util.py
Normal file
@ -0,0 +1,9 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
||||
|
||||
from llm_api import (
|
||||
chat_completion_no_stream_return_json,
|
||||
chat_completion_stream,
|
||||
)
|
68
unit_tests/prompts_cn.py
Normal file
68
unit_tests/prompts_cn.py
Normal file
@ -0,0 +1,68 @@
|
||||
# ruff: noqa: E501
|
||||
# Don not limit the length of each line of the prompts.
|
||||
|
||||
|
||||
PROPOSE_TEST_PROMPT = """
|
||||
你是一位智能测试用例生成助手。
|
||||
给定一个用户提示和一个目标函数,请根据提示为目标函数生成测试用例。
|
||||
|
||||
用户提示如下:
|
||||
|
||||
{user_prompt}
|
||||
|
||||
目标函数是 `{function_name}`, 该函数所在的文件是 {file_path}。
|
||||
|
||||
以下是与该函数相关的源代码:
|
||||
|
||||
{relevant_content}
|
||||
|
||||
请为每个测试用例提供一句话的描述,描述该测试用例所测试的行为。
|
||||
你不需要用代码编写测试用例,只需用普通的自然语言描述即可。
|
||||
最多生成 6 个测试用例。
|
||||
|
||||
请按照以下JSON格式回复:
|
||||
{{
|
||||
"test_cases": [
|
||||
{{"description": "<测试用例1的自然语言描述>"}},
|
||||
{{"description": "<测试用例2的自然语言描述>"}},
|
||||
]
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
FIND_REFERENCE_PROMPT = """
|
||||
请找到一个合适的参考测试文件,该文件可用于为文件 {file_path} 中的 `{function_name}` 函数编写测试提供指导。
|
||||
参考文件应该提供一个清晰的示例,演示测试类似性质的函数的最佳实践。
|
||||
"""
|
||||
|
||||
|
||||
WRITE_TESTS_PROMPT = """
|
||||
你是一位智能单元测试生成助手。
|
||||
给定一个目标函数、一些参考代码和一系列具体的测试用例描述,请完成单元测试的代码编写。
|
||||
每个测试函数都应该是独立且可执行的。
|
||||
|
||||
请确保与参考代码使用相同的测试框架、模拟对象库、断言库等,并采取相似的模拟策略、判断策略等。
|
||||
|
||||
目标函数是 `{function_name}`,该函数位于文件 {file_path}。
|
||||
|
||||
以下是与该函数相关的源代码:
|
||||
|
||||
{relevant_content}
|
||||
|
||||
{reference_content}
|
||||
|
||||
以下是测试用例列表:
|
||||
|
||||
{test_cases_str}
|
||||
|
||||
|
||||
请按照以下格式回复:
|
||||
|
||||
测试用例 1. <测试用例1的原始描述>
|
||||
|
||||
<测试用例1对应的测试函数代码>
|
||||
|
||||
测试用例 2. <测试用例2的原始描述>
|
||||
|
||||
<测试用例2对应的测试函数代码>
|
||||
"""
|
@ -1,14 +1,20 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import tiktoken
|
||||
from minimax_util import chat_completion_no_stream_return_json
|
||||
from model import FuncToTest, TokenBudgetExceededException
|
||||
from openai_util import create_chat_completion_content
|
||||
from prompts import PROPOSE_TEST_PROMPT
|
||||
|
||||
# from prompts import PROPOSE_TEST_PROMPT
|
||||
from prompts_cn import PROPOSE_TEST_PROMPT
|
||||
from tools.tiktoken_util import get_encoding
|
||||
|
||||
MODEL = "gpt-3.5-turbo-1106"
|
||||
# MODEL = "gpt-4-1106-preview"
|
||||
ENCODING = "cl100k_base"
|
||||
TOKEN_BUDGET = int(16000 * 0.9)
|
||||
|
||||
|
||||
@ -20,7 +26,7 @@ def _mk_user_msg(
|
||||
"""
|
||||
Create a user message to be sent to the model within the token budget.
|
||||
"""
|
||||
encoding: tiktoken.Encoding = tiktoken.encoding_for_model(MODEL)
|
||||
encoding = get_encoding(ENCODING)
|
||||
|
||||
func_content = f"function code\n```\n{func_to_test.func_content}\n```\n"
|
||||
class_content = ""
|
||||
@ -81,14 +87,16 @@ def propose_test(
|
||||
chat_language=chat_language,
|
||||
)
|
||||
|
||||
content = create_chat_completion_content(
|
||||
model=MODEL,
|
||||
model = os.environ.get("LLM_MODEL", MODEL)
|
||||
content = chat_completion_no_stream_return_json(
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.1,
|
||||
llm_config={
|
||||
"model": model,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
)
|
||||
|
||||
cases = json.loads(content).get("test_cases", [])
|
||||
cases = content.get("test_cases", [])
|
||||
|
||||
descriptions = []
|
||||
for case in cases:
|
||||
|
20
unit_tests/tools/tiktoken_util.py
Normal file
20
unit_tests/tools/tiktoken_util.py
Normal file
@ -0,0 +1,20 @@
|
||||
import tiktoken
|
||||
|
||||
|
||||
def get_encoding(encoding_name: str):
|
||||
"""
|
||||
Get a tiktoken encoding by name.
|
||||
"""
|
||||
try:
|
||||
return tiktoken.get_encoding(encoding_name)
|
||||
except Exception:
|
||||
from tiktoken import registry
|
||||
from tiktoken.core import Encoding
|
||||
from tiktoken.registry import _find_constructors
|
||||
|
||||
def _get_encoding(name: str):
|
||||
_find_constructors()
|
||||
constructor = registry.ENCODING_CONSTRUCTORS[name]
|
||||
return Encoding(**constructor(), use_pure_python=True)
|
||||
|
||||
return _get_encoding(encoding_name)
|
@ -1,13 +1,18 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
import tiktoken
|
||||
from minimax_util import chat_completion_stream
|
||||
from model import FuncToTest, TokenBudgetExceededException
|
||||
from openai_util import create_chat_completion_chunks
|
||||
from prompts import WRITE_TESTS_PROMPT
|
||||
|
||||
# from prompts import WRITE_TESTS_PROMPT
|
||||
from prompts_cn import WRITE_TESTS_PROMPT
|
||||
from tools.file_util import retrieve_file_content
|
||||
from tools.tiktoken_util import get_encoding
|
||||
|
||||
MODEL = "gpt-4-1106-preview"
|
||||
ENCODING = "cl100k_base"
|
||||
TOKEN_BUDGET = int(128000 * 0.9)
|
||||
|
||||
|
||||
@ -18,7 +23,7 @@ def _mk_write_tests_msg(
|
||||
chat_language: str,
|
||||
reference_files: Optional[List[str]] = None,
|
||||
) -> Optional[str]:
|
||||
encoding: tiktoken.Encoding = tiktoken.encoding_for_model(MODEL)
|
||||
encoding = get_encoding(ENCODING)
|
||||
|
||||
test_cases_str = ""
|
||||
for i, test_case in enumerate(test_cases, 1):
|
||||
@ -92,13 +97,9 @@ def write_and_print_tests(
|
||||
chat_language=chat_language,
|
||||
)
|
||||
|
||||
chunks = create_chat_completion_chunks(
|
||||
model=MODEL,
|
||||
model = os.environ.get("LLM_MODEL", MODEL)
|
||||
chat_completion_stream(
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
temperature=0.1,
|
||||
llm_config={"model": model, "temperature": 0.1},
|
||||
stream_out=True,
|
||||
)
|
||||
|
||||
for chunk in chunks:
|
||||
if chunk.choices[0].finish_reason == "stop":
|
||||
break
|
||||
print(chunk.choices[0].delta.content, flush=True, end="")
|
||||
|
Loading…
x
Reference in New Issue
Block a user