Compare commits
22 Commits
scripts
...
minimax_ma
Author | SHA1 | Date | |
---|---|---|---|
![]() |
accf557fe6 | ||
![]() |
4d80089598 | ||
![]() |
b82ebd8af1 | ||
![]() |
09daeef589 | ||
![]() |
e9a28f6d9a | ||
![]() |
ae12062d34 | ||
![]() |
7be108b874 | ||
![]() |
d49b2f4732 | ||
![]() |
5ced3dcfa8 | ||
![]() |
9138e1569e | ||
![]() |
17f4aba1b4 | ||
![]() |
b7ca8127d1 | ||
![]() |
d3008d5237 | ||
![]() |
7c1252aee2 | ||
![]() |
da51cbfc5f | ||
![]() |
8dc6c5ffe9 | ||
![]() |
7652f73537 | ||
![]() |
aff4142d24 | ||
![]() |
12e5ce6a57 | ||
![]() |
2daf4ed765 | ||
![]() |
94249cd6d1 | ||
![]() |
bc5f4e23a9 |
234
libs/llm_api/minimax_chat.py
Normal file
234
libs/llm_api/minimax_chat.py
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class StreamIterWrapper:
|
||||||
|
def __init__(self, response, is_private=True):
|
||||||
|
self.response = response
|
||||||
|
self.create_time = int(time.time())
|
||||||
|
self.line_iterator = response.iter_lines()
|
||||||
|
self.is_private = is_private
|
||||||
|
self.stop = False
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
try:
|
||||||
|
if self.stop:
|
||||||
|
raise StopIteration
|
||||||
|
response_line = next(self.line_iterator)
|
||||||
|
if response_line == b"":
|
||||||
|
return self.__next__()
|
||||||
|
if response_line == b"\n":
|
||||||
|
return self.__next__()
|
||||||
|
|
||||||
|
response_line = response_line.replace(b"data: ", b"")
|
||||||
|
response_result = json.loads(response_line.decode("utf-8"))
|
||||||
|
if self.is_private:
|
||||||
|
if "finish" in response_result and response_result["finish"] == True:
|
||||||
|
self.stop = True
|
||||||
|
if "err" in response_result and response_result["err"]:
|
||||||
|
raise ValueError(f"minimax api response error: {response_result['err']}")
|
||||||
|
if not self.is_private:
|
||||||
|
if response_result["choices"][0].get("finish_reason", None):
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
if self.is_private:
|
||||||
|
data = json.loads(response_result["data"])
|
||||||
|
|
||||||
|
stream_response = {
|
||||||
|
"id": f"minimax_{self.create_time}",
|
||||||
|
"created": self.create_time,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"model": response_result.get("model", "abab5.5-chat"),
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response_result["choices"][0]["messages"][0]["text"] if not self.is_private else data.get("text", ""),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 100},
|
||||||
|
}
|
||||||
|
|
||||||
|
return stream_response
|
||||||
|
except StopIteration as exc: # If there is no more event
|
||||||
|
raise StopIteration from exc
|
||||||
|
except Exception as err:
|
||||||
|
print("Exception:", err.__class__.__name__, err, file=sys.stderr, end="\n\n")
|
||||||
|
raise StopIteration from err
|
||||||
|
|
||||||
|
|
||||||
|
def chat_completion(messages, llm_config):
|
||||||
|
url = _make_api_url()
|
||||||
|
headers = _make_header()
|
||||||
|
if _is_private_llm():
|
||||||
|
payload = _make_private_payload(messages, llm_config)
|
||||||
|
else:
|
||||||
|
payload = _make_public_payload(messages, llm_config)
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, json=payload)
|
||||||
|
response_json = json.loads(response.text)
|
||||||
|
if not response_json.get("texts", []):
|
||||||
|
raise ValueError(f"minimax api response error: {response_json}")
|
||||||
|
return {"content": response_json["texts"][0]}
|
||||||
|
|
||||||
|
|
||||||
|
def stream_chat_completion(messages, llm_config):
|
||||||
|
url = _make_api_url()
|
||||||
|
headers = _make_header()
|
||||||
|
if _is_private_llm():
|
||||||
|
payload = _make_private_payload(messages, llm_config, True)
|
||||||
|
else:
|
||||||
|
payload = _make_public_payload(messages, llm_config, True)
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, json=payload)
|
||||||
|
streamIters = StreamIterWrapper(response, _is_private_llm())
|
||||||
|
return streamIters
|
||||||
|
|
||||||
|
|
||||||
|
def _is_private_llm():
|
||||||
|
api_base_url = os.environ.get("OPENAI_API_BASE", "")
|
||||||
|
return not api_base_url.startswith("https://api.minimax.chat")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_api_url():
|
||||||
|
api_base_url = os.environ.get("OPENAI_API_BASE", None)
|
||||||
|
if not api_base_url:
|
||||||
|
raise ValueError("minimax api url is not set")
|
||||||
|
|
||||||
|
if api_base_url.startswith("https://api.minimax.chat"):
|
||||||
|
if api_base_url.endswith("/"):
|
||||||
|
api_base_url = api_base_url[:-1]
|
||||||
|
if not api_base_url.endswith("/v1"):
|
||||||
|
api_base_url = api_base_url + "/v1"
|
||||||
|
api_base_url += "/text/chatcompletion_pro"
|
||||||
|
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY", None)
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("minimax api key is not set")
|
||||||
|
|
||||||
|
group_id = api_key.split("##")[0]
|
||||||
|
api_base_url += f"?GroupId={group_id}"
|
||||||
|
return api_base_url
|
||||||
|
else:
|
||||||
|
if api_base_url.endswith("/"):
|
||||||
|
api_base_url = api_base_url[:-1]
|
||||||
|
if not api_base_url.endswith("/interact"):
|
||||||
|
api_base_url = api_base_url + "/interact"
|
||||||
|
return api_base_url
|
||||||
|
|
||||||
|
|
||||||
|
def _make_api_key():
|
||||||
|
if _is_private_llm():
|
||||||
|
return ""
|
||||||
|
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY", None)
|
||||||
|
return api_key.split("##")[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_header():
|
||||||
|
api_key = _make_api_key()
|
||||||
|
return {
|
||||||
|
**({"Authorization": f"Bearer {api_key}"} if not _is_private_llm() else {}),
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _to_private_messages(messages):
|
||||||
|
new_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "user":
|
||||||
|
new_messages.append({"role": "user", "name": "user", "text": message["content"]})
|
||||||
|
else:
|
||||||
|
new_messages.append({"role": "ai", "name": "ai", "text": message["content"]})
|
||||||
|
new_messages.append({"role": "ai", "name": "ai", "text": ""})
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _make_private_payload(messages, llm_config, stream=False):
|
||||||
|
return {
|
||||||
|
"data": _to_private_messages(messages),
|
||||||
|
"model_control": {
|
||||||
|
"system_data": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"ai_setting": "ai",
|
||||||
|
"text": "你是minimax编码助理,擅长编写代码,编写注释,编写测试用例,并且很注重编码的规范性。",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
# "alpha_frequency": 128,
|
||||||
|
# "alpha_frequency_src": 1,
|
||||||
|
# "alpha_presence": 0,
|
||||||
|
# "alpha_presence_src": 0,
|
||||||
|
# "block_ngram": 0,
|
||||||
|
# "clean_init_no_penalty_list": True,
|
||||||
|
# "context_block_ngram": 0,
|
||||||
|
# "factual_topp": False,
|
||||||
|
# "lamda_decay": 1,
|
||||||
|
# "length_penalty": 1,
|
||||||
|
# "no_penalty_list": ",",
|
||||||
|
# "omega_bound": 0,
|
||||||
|
# "repeat_filter": False,
|
||||||
|
# "repeat_sampling": 1,
|
||||||
|
# "skip_text_mask": True,
|
||||||
|
"tokens_to_generate": llm_config.get("max_tokens", 2048),
|
||||||
|
# "sampler_type": "nucleus",
|
||||||
|
"beam_width": 1,
|
||||||
|
# "delimiter": "\n",
|
||||||
|
# "min_length": 0,
|
||||||
|
# "skip_info_mask": True,
|
||||||
|
"stop_sequence": [],
|
||||||
|
# "top_p": 0.95,
|
||||||
|
"temperature": llm_config.get("temperature", 0.95),
|
||||||
|
},
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _to_public_messages(messages):
|
||||||
|
new_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "user":
|
||||||
|
new_messages.append(
|
||||||
|
{"sender_type": "USER", "sender_name": "USER", "text": message["content"]}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_messages.append(
|
||||||
|
{"sender_type": "BOT", "sender_name": "ai", "text": message["content"]}
|
||||||
|
)
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _make_public_payload(messages, llm_config, stream=False):
|
||||||
|
response = {
|
||||||
|
"model": "abab5.5-chat",
|
||||||
|
"tokens_to_generate": llm_config.get("max_tokens", 2048),
|
||||||
|
"temperature": llm_config.get("temperature", 0.1),
|
||||||
|
# "top_p": 0.9,
|
||||||
|
"reply_constraints": {"sender_type": "BOT", "sender_name": "ai"},
|
||||||
|
"sample_messages": [],
|
||||||
|
"plugins": [],
|
||||||
|
"messages": _to_public_messages(messages),
|
||||||
|
"bot_setting": [
|
||||||
|
{
|
||||||
|
"bot_name": "ai",
|
||||||
|
"content": (
|
||||||
|
"MM智能助理是一款由MiniMax自研的,"
|
||||||
|
"没有调用其他产品的接口的大型语言模型。"
|
||||||
|
"MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
return response
|
@ -7,8 +7,10 @@ import sys
|
|||||||
import openai
|
import openai
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
sys.path.append(os.path.dirname(__file__))
|
||||||
|
|
||||||
from ide_services.services import log_warn
|
from ide_services.services import log_warn
|
||||||
|
from minimax_chat import stream_chat_completion
|
||||||
|
|
||||||
|
|
||||||
def _try_remove_markdown_block_flag(content):
|
def _try_remove_markdown_block_flag(content):
|
||||||
@ -46,18 +48,23 @@ def chat_completion_stream(messages, llm_config, error_out: bool = True, stream_
|
|||||||
"""
|
"""
|
||||||
for try_times in range(3):
|
for try_times in range(3):
|
||||||
try:
|
try:
|
||||||
client = openai.OpenAI(
|
if llm_config.get("model", "").startswith("abab"):
|
||||||
api_key=os.environ.get("OPENAI_API_KEY", None),
|
response = stream_chat_completion(messages, llm_config)
|
||||||
base_url=os.environ.get("OPENAI_API_BASE", None),
|
else:
|
||||||
)
|
client = openai.OpenAI(
|
||||||
|
api_key=os.environ.get("OPENAI_API_KEY", None),
|
||||||
|
base_url=os.environ.get("OPENAI_API_BASE", None),
|
||||||
|
)
|
||||||
|
|
||||||
llm_config["stream"] = True
|
llm_config["stream"] = True
|
||||||
llm_config["timeout"] = 8
|
llm_config["timeout"] = 8
|
||||||
response = client.chat.completions.create(messages=messages, **llm_config)
|
response = client.chat.completions.create(messages=messages, **llm_config)
|
||||||
|
|
||||||
response_result = {"content": None, "function_name": None, "parameters": ""}
|
response_result = {"content": None, "function_name": None, "parameters": ""}
|
||||||
for chunk in response: # pylint: disable=E1133
|
for chunk in response: # pylint: disable=E1133
|
||||||
chunk = chunk.dict()
|
if not isinstance(chunk, dict):
|
||||||
|
chunk = chunk.dict()
|
||||||
|
|
||||||
delta = chunk["choices"][0]["delta"]
|
delta = chunk["choices"][0]["delta"]
|
||||||
if "tool_calls" in delta and delta["tool_calls"]:
|
if "tool_calls" in delta and delta["tool_calls"]:
|
||||||
tool_call = delta["tool_calls"][0]["function"]
|
tool_call = delta["tool_calls"][0]["function"]
|
||||||
|
5
rewrite/command.yml
Normal file
5
rewrite/command.yml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
description: rewrite selected code.
|
||||||
|
hint: question
|
||||||
|
input: required
|
||||||
|
steps:
|
||||||
|
- run: $devchat_python $command_path/rewrite.py "$input"
|
3
rewrite/doc_comment/command.yml
Normal file
3
rewrite/doc_comment/command.yml
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
description: add doc comment for selected code
|
||||||
|
steps:
|
||||||
|
- run: $devchat_python $command_path/../rewrite.py "add doc comment"
|
4
rewrite/optimize_names/command.yml
Normal file
4
rewrite/optimize_names/command.yml
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
description: Optimizing variable and function names in code.
|
||||||
|
steps:
|
||||||
|
- run: $devchat_python $command_path/../rewrite.py "Refine internal variable and function names within the code to achieve concise and meaningful identifiers that comply with English naming conventions."
|
||||||
|
|
3
rewrite/optimize_string/command.yml
Normal file
3
rewrite/optimize_string/command.yml
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
description: improve the readability and conformity of a given string to English language conventions.
|
||||||
|
steps:
|
||||||
|
- run: $devchat_python $command_path/../rewrite.py "enhance the given string's readability and ensure it adheres to English linguistic standards"
|
122
rewrite/rewrite.py
Normal file
122
rewrite/rewrite.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
|
||||||
|
home = os.path.expanduser("~")
|
||||||
|
org_libs_path = os.path.join(home, ".chat", "workflows", "org", "libs")
|
||||||
|
sys_libs_path = os.path.join(home, ".chat", "workflows", "sys", "libs")
|
||||||
|
sys.path.append(org_libs_path)
|
||||||
|
sys.path.append(sys_libs_path)
|
||||||
|
|
||||||
|
from llm_api import chat_completion_stream # noqa: E402
|
||||||
|
from ide_services.services import visible_lines, selected_lines, diff_apply # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def create_prompt():
|
||||||
|
question = sys.argv[1]
|
||||||
|
|
||||||
|
visible_data = visible_lines()
|
||||||
|
selected_data = selected_lines()
|
||||||
|
|
||||||
|
file_path = visible_data["filePath"]
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
print("Current file is not valid filename:", file_path, file=sys.stderr, flush=True)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
if selected_data["selectedText"] == "":
|
||||||
|
print("Please select some text.", file=sys.stderr, flush=True)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
你的任务是:
|
||||||
|
{question}
|
||||||
|
根据任务要求,仅修改选中的代码部分。请确保修改后的代码段与选中的代码保持相同的缩进,\
|
||||||
|
以便与现有的代码结构无缝集成并保持正确的语法。只重构选中的代码。保留所有其他信息。\
|
||||||
|
以下是您参考的相关上下文信息:
|
||||||
|
1. 选中代码信息: {selected_data}
|
||||||
|
2. 可视窗口代码信息: {visible_data}
|
||||||
|
"""
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def extract_markdown_block(text):
|
||||||
|
"""
|
||||||
|
Extracts the first Markdown code block from the given text without the language specifier.
|
||||||
|
|
||||||
|
:param text: A string containing Markdown text
|
||||||
|
:return: The content of the first Markdown code block, or None if not found
|
||||||
|
"""
|
||||||
|
# 正则表达式匹配Markdown代码块,忽略可选的语言类型标记
|
||||||
|
pattern = r"```(?:\w+)?\s*\n(.*?)\n```"
|
||||||
|
match = re.search(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
# 返回第一个匹配的代码块内容,去除首尾的反引号和语言类型标记
|
||||||
|
# 去除块结束标记前的一个换行符,但保留其他内容
|
||||||
|
block_content = match.group(1)
|
||||||
|
return block_content
|
||||||
|
else:
|
||||||
|
# 如果没有找到匹配项,返回None
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def replace_selected(new_code):
|
||||||
|
selected_data = selected_lines()
|
||||||
|
select_file = selected_data["filePath"]
|
||||||
|
select_range = selected_data["selectedRange"] # [start_line, start_col, end_line, end_col]
|
||||||
|
|
||||||
|
# Read the file
|
||||||
|
with open(select_file, "r") as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
lines.append("\n")
|
||||||
|
|
||||||
|
# Modify the selected lines
|
||||||
|
start_line, start_col, end_line, end_col = select_range
|
||||||
|
|
||||||
|
# If the selection spans multiple lines, handle the last line and delete the lines in between
|
||||||
|
if start_line != end_line:
|
||||||
|
lines[start_line] = lines[start_line][:start_col] + new_code
|
||||||
|
# Append the text after the selection on the last line
|
||||||
|
lines[start_line] += lines[end_line][end_col:]
|
||||||
|
# Delete the lines between start_line and end_line
|
||||||
|
del lines[start_line + 1 : end_line + 1]
|
||||||
|
else:
|
||||||
|
# If the selection is within a single line, remove the selected text
|
||||||
|
lines[start_line] = lines[start_line][:start_col] + new_code + lines[end_line][end_col:]
|
||||||
|
|
||||||
|
# Combine everything back together
|
||||||
|
modified_text = "".join(lines)
|
||||||
|
|
||||||
|
# Write the changes back to the file
|
||||||
|
with open(select_file, "w") as file:
|
||||||
|
file.write(modified_text)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# messages = json.loads(
|
||||||
|
# os.environ.get("CONTEXT_CONTENTS", json.dumps([{"role": "user", "content": ""}]))
|
||||||
|
# )
|
||||||
|
messages = [{"role": "user", "content": create_prompt()}]
|
||||||
|
|
||||||
|
response = chat_completion_stream(messages, {"model": os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")}, stream_out=True)
|
||||||
|
if not response:
|
||||||
|
sys.exit(-1)
|
||||||
|
print("\n")
|
||||||
|
new_code = extract_markdown_block(response["content"])
|
||||||
|
# Check if new_code is empty and handle the case appropriately
|
||||||
|
if not new_code:
|
||||||
|
print("Parsing result failed. Exiting with error code -1", file=sys.stderr)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
# replace_selected(new_code)
|
||||||
|
selected_data = selected_lines()
|
||||||
|
select_file = selected_data["filePath"]
|
||||||
|
diff_apply("", new_code)
|
||||||
|
# print(response["content"], flush=True)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
# print("hello")
|
@ -14,7 +14,6 @@ class DirectoryStructureBase(ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._root_path = root_path
|
self._root_path = root_path
|
||||||
|
|
||||||
self._client = OpenAI()
|
|
||||||
self._chat_language = chat_language
|
self._chat_language = chat_language
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1,20 +1,20 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
|
|
||||||
import tiktoken
|
|
||||||
from assistants.directory_structure.base import DirectoryStructureBase
|
from assistants.directory_structure.base import DirectoryStructureBase
|
||||||
from assistants.rerank_files import rerank_files
|
from assistants.rerank_files import rerank_files
|
||||||
|
from minimax_util import chat_completion_no_stream_return_json
|
||||||
from openai_util import create_chat_completion_content
|
from openai_util import create_chat_completion_content
|
||||||
from tools.directory_viewer import (
|
from tools.directory_viewer import ListViewer
|
||||||
ListViewer,
|
from tools.tiktoken_util import get_encoding
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RelevantFileFinder(DirectoryStructureBase):
|
class RelevantFileFinder(DirectoryStructureBase):
|
||||||
model_name = "gpt-3.5-turbo-1106"
|
model_name = "gpt-3.5-turbo-1106"
|
||||||
dir_token_budget = 16000 * 0.95
|
dir_token_budget = 16000 * 0.95
|
||||||
encoding: tiktoken.Encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-1106")
|
encoding = get_encoding("cl100k_base")
|
||||||
|
|
||||||
def _paginate_dir_structure(
|
def _paginate_dir_structure(
|
||||||
self, criteria: Callable[[Path], bool], style: str = "list"
|
self, criteria: Callable[[Path], bool], style: str = "list"
|
||||||
@ -79,23 +79,48 @@ class RelevantFileFinder(DirectoryStructureBase):
|
|||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
def _mk_message_cn(self, objective: str, dir_structure: str) -> str:
|
||||||
|
message = f"""
|
||||||
|
你是一位智能编程助手,你的任务是理解代码库的目录结构,推测各目录的作用,
|
||||||
|
并根据用户输入的问题或者目标,找到与之最相关的10个文件。
|
||||||
|
|
||||||
|
请注意,你的目的并不是分析各个文件里的代码,而是通过分析项目的结构来判断
|
||||||
|
哪个文件最有可能包含用户需要的信息。
|
||||||
|
|
||||||
|
以下是代码库的目录结构:
|
||||||
|
|
||||||
|
{dir_structure}
|
||||||
|
|
||||||
|
|
||||||
|
请根据你的理解,找出10个和以下问题最相关的文件:
|
||||||
|
|
||||||
|
“{objective}”
|
||||||
|
|
||||||
|
|
||||||
|
请按以下JSON格式回复:
|
||||||
|
{{
|
||||||
|
"files": ["<文件1的路径>" , "<文件2的路径>", "<文件3的路径>", ... ]
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
def _find_relevant_files(self, objective: str, dir_structure_pages: List[str]) -> List[str]:
|
def _find_relevant_files(self, objective: str, dir_structure_pages: List[str]) -> List[str]:
|
||||||
files: List[str] = []
|
files: List[str] = []
|
||||||
for dir_structure in dir_structure_pages:
|
for dir_structure in dir_structure_pages:
|
||||||
user_msg = self._mk_message(objective, dir_structure)
|
# user_msg = self._mk_message(objective, dir_structure)
|
||||||
|
user_msg = self._mk_message_cn(objective, dir_structure)
|
||||||
|
|
||||||
response = create_chat_completion_content(
|
model = os.environ.get("LLM_MODEL", self.model_name)
|
||||||
client=self._client,
|
|
||||||
model=self.model_name,
|
json_res = chat_completion_no_stream_return_json(
|
||||||
messages=[
|
messages=[{"role": "user", "content": user_msg}],
|
||||||
{"role": "user", "content": user_msg},
|
llm_config={
|
||||||
],
|
"model": model,
|
||||||
response_format={"type": "json_object"},
|
"temperature": 0.1,
|
||||||
temperature=0.1,
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
json_res = json.loads(response)
|
|
||||||
|
|
||||||
files.extend(json_res.get("files", []))
|
files.extend(json_res.get("files", []))
|
||||||
|
|
||||||
reranked = rerank_files(
|
reranked = rerank_files(
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from minimax_util import chat_completion_no_stream_return_json
|
||||||
from openai_util import create_chat_completion_content
|
from openai_util import create_chat_completion_content
|
||||||
|
|
||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
@ -28,6 +31,37 @@ Accumulated Knowledge: {accumulated_knowledge}
|
|||||||
Answer:
|
Answer:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
rerank_file_prompt_cn = """
|
||||||
|
你是一位智能编程助手。
|
||||||
|
用户将给你一个文件列表并提出一个和编程相关的问题,
|
||||||
|
根据你对问题和每个文件可能包含的内容的理解,
|
||||||
|
找到文件列表中有助于回答该问题的文件,
|
||||||
|
并判断文件与问题的相关程度,从1到10为文件评分(10分表示相关度非常高,1分表示相关度很低),
|
||||||
|
最后将文件按相关度从高到低排序。
|
||||||
|
|
||||||
|
请注意,返回的所有文件路径都应在用户给出的文件列表里,不能额外添加文件,也不能修改文件路径。
|
||||||
|
|
||||||
|
|
||||||
|
以下是用户给出的文件列表:
|
||||||
|
|
||||||
|
{files}
|
||||||
|
|
||||||
|
|
||||||
|
用户的问题是: {question}
|
||||||
|
目前积累的相关背景知识: {accumulated_knowledge}
|
||||||
|
|
||||||
|
请按以下JSON格式回复:
|
||||||
|
{{
|
||||||
|
"result": [
|
||||||
|
{{"item": "<最相关的文件路径>", "relevance": 7}},
|
||||||
|
{{"item": "<第二相关的文件路径>", "relevance": 4}},
|
||||||
|
{{"item": "<第三相关的文件路径>", "relevance": 3}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
RERANK_MODEL = "gpt-3.5-turbo-1106"
|
RERANK_MODEL = "gpt-3.5-turbo-1106"
|
||||||
|
|
||||||
|
|
||||||
@ -50,25 +84,28 @@ def rerank_files(
|
|||||||
assert isinstance(file, str), "items must be a list of str when item_type is 'file'"
|
assert isinstance(file, str), "items must be a list of str when item_type is 'file'"
|
||||||
files_str += f"- {file}\n"
|
files_str += f"- {file}\n"
|
||||||
|
|
||||||
user_msg = rerank_file_prompt.format(
|
user_msg = rerank_file_prompt_cn.format(
|
||||||
files=files_str,
|
files=files_str,
|
||||||
question=question,
|
question=question,
|
||||||
accumulated_knowledge=knowledge,
|
accumulated_knowledge=knowledge,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = create_chat_completion_content(
|
model = os.environ.get("LLM_MODEL", RERANK_MODEL)
|
||||||
model=RERANK_MODEL,
|
result = chat_completion_no_stream_return_json(
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": user_msg,
|
"content": user_msg,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
response_format={"type": "json_object"},
|
llm_config={
|
||||||
temperature=0.1,
|
"model": model,
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
if not result:
|
||||||
|
return []
|
||||||
|
|
||||||
result = json.loads(response)
|
|
||||||
reranked = [(i["item"], i["relevance"]) for i in result["result"]]
|
reranked = [(i["item"], i["relevance"]) for i in result["result"]]
|
||||||
|
|
||||||
return reranked
|
return reranked
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
description: Generate unit tests.
|
description: Generate unit tests.
|
||||||
steps:
|
steps:
|
||||||
- run: $command_python $command_path/main.py "$input"
|
- run: $devchat_python $command_path/main.py "$input"
|
@ -1,7 +1,9 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from assistants.directory_structure.relevant_file_finder import RelevantFileFinder
|
from assistants.directory_structure.relevant_file_finder import RelevantFileFinder
|
||||||
from prompts import FIND_REFERENCE_PROMPT
|
|
||||||
|
# from prompts import FIND_REFERENCE_PROMPT
|
||||||
|
from prompts_cn import FIND_REFERENCE_PROMPT
|
||||||
from tools.file_util import verify_file_list
|
from tools.file_util import verify_file_list
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,9 +4,12 @@ from enum import Enum
|
|||||||
|
|
||||||
|
|
||||||
class TUILanguage(Enum):
|
class TUILanguage(Enum):
|
||||||
EN = ("en", "English")
|
EN = ("en", "英文")
|
||||||
ZH = ("zh", "Chinese")
|
ZH = ("zh", "中文")
|
||||||
Other = ("en", "English") # default to show English
|
Other = ("zh", "中文")
|
||||||
|
# EN = ("en", "English")
|
||||||
|
# ZH = ("zh", "Chinese")
|
||||||
|
# Other = ("en", "English") # default to show English
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_str(cls, language: str) -> "TUILanguage":
|
def from_str(cls, language: str) -> "TUILanguage":
|
||||||
|
@ -2,15 +2,17 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from find_reference_tests import find_reference_tests
|
|
||||||
from i18n import TUILanguage, get_translation
|
|
||||||
from model import FuncToTest, TokenBudgetExceededException, UserCancelledException
|
|
||||||
from propose_test import propose_test
|
|
||||||
from tools.file_util import retrieve_file_content
|
|
||||||
from write_tests import write_and_print_tests
|
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(__file__))
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
||||||
|
|
||||||
|
from find_reference_tests import find_reference_tests # noqa: E402
|
||||||
|
from i18n import TUILanguage, get_translation # noqa: E402
|
||||||
|
from model import FuncToTest, TokenBudgetExceededException, UserCancelledException # noqa: E402
|
||||||
|
from propose_test import propose_test # noqa: E402
|
||||||
|
from tools.file_util import retrieve_file_content # noqa: E402
|
||||||
|
from write_tests import write_and_print_tests # noqa: E402
|
||||||
|
|
||||||
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
|
from chatmark import Checkbox, Form, Step, TextEditor # noqa: E402
|
||||||
from ide_services import ide_language # noqa: E402
|
from ide_services import ide_language # noqa: E402
|
||||||
|
|
||||||
@ -131,7 +133,8 @@ def main(input: str):
|
|||||||
user_prompt = f"Help me write unit tests for the `{func_name}` function"
|
user_prompt = f"Help me write unit tests for the `{func_name}` function"
|
||||||
|
|
||||||
repo_root = os.getcwd()
|
repo_root = os.getcwd()
|
||||||
ide_lang = ide_language()
|
# ide_lang = ide_language()
|
||||||
|
ide_lang = "zh"
|
||||||
tui_lang = TUILanguage.from_str(ide_lang)
|
tui_lang = TUILanguage.from_str(ide_lang)
|
||||||
_i = get_translation(tui_lang)
|
_i = get_translation(tui_lang)
|
||||||
|
|
||||||
|
9
unit_tests/minimax_util.py
Normal file
9
unit_tests/minimax_util.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
|
||||||
|
|
||||||
|
from llm_api import (
|
||||||
|
chat_completion_no_stream_return_json,
|
||||||
|
chat_completion_stream,
|
||||||
|
)
|
68
unit_tests/prompts_cn.py
Normal file
68
unit_tests/prompts_cn.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# ruff: noqa: E501
|
||||||
|
# Don not limit the length of each line of the prompts.
|
||||||
|
|
||||||
|
|
||||||
|
PROPOSE_TEST_PROMPT = """
|
||||||
|
你是一位智能测试用例生成助手。
|
||||||
|
给定一个用户提示和一个目标函数,请根据提示为目标函数生成测试用例。
|
||||||
|
|
||||||
|
用户提示如下:
|
||||||
|
|
||||||
|
{user_prompt}
|
||||||
|
|
||||||
|
目标函数是 `{function_name}`, 该函数所在的文件是 {file_path}。
|
||||||
|
|
||||||
|
以下是与该函数相关的源代码:
|
||||||
|
|
||||||
|
{relevant_content}
|
||||||
|
|
||||||
|
请为每个测试用例提供一句话的描述,描述该测试用例所测试的行为。
|
||||||
|
你不需要用代码编写测试用例,只需用普通的自然语言描述即可。
|
||||||
|
最多生成 6 个测试用例。
|
||||||
|
|
||||||
|
请按照以下JSON格式回复:
|
||||||
|
{{
|
||||||
|
"test_cases": [
|
||||||
|
{{"description": "<测试用例1的自然语言描述>"}},
|
||||||
|
{{"description": "<测试用例2的自然语言描述>"}},
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
FIND_REFERENCE_PROMPT = """
|
||||||
|
请找到一个合适的参考测试文件,该文件可用于为文件 {file_path} 中的 `{function_name}` 函数编写测试提供指导。
|
||||||
|
参考文件应该提供一个清晰的示例,演示测试类似性质的函数的最佳实践。
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
WRITE_TESTS_PROMPT = """
|
||||||
|
你是一位智能单元测试生成助手。
|
||||||
|
给定一个目标函数、一些参考代码和一系列具体的测试用例描述,请完成单元测试的代码编写。
|
||||||
|
每个测试函数都应该是独立且可执行的。
|
||||||
|
|
||||||
|
请确保与参考代码使用相同的测试框架、模拟对象库、断言库等,并采取相似的模拟策略、判断策略等。
|
||||||
|
|
||||||
|
目标函数是 `{function_name}`,该函数位于文件 {file_path}。
|
||||||
|
|
||||||
|
以下是与该函数相关的源代码:
|
||||||
|
|
||||||
|
{relevant_content}
|
||||||
|
|
||||||
|
{reference_content}
|
||||||
|
|
||||||
|
以下是测试用例列表:
|
||||||
|
|
||||||
|
{test_cases_str}
|
||||||
|
|
||||||
|
|
||||||
|
请按照以下格式回复:
|
||||||
|
|
||||||
|
测试用例 1. <测试用例1的原始描述>
|
||||||
|
|
||||||
|
<测试用例1对应的测试函数代码>
|
||||||
|
|
||||||
|
测试用例 2. <测试用例2的原始描述>
|
||||||
|
|
||||||
|
<测试用例2对应的测试函数代码>
|
||||||
|
"""
|
@ -1,14 +1,20 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import tiktoken
|
from minimax_util import chat_completion_no_stream_return_json
|
||||||
from model import FuncToTest, TokenBudgetExceededException
|
from model import FuncToTest, TokenBudgetExceededException
|
||||||
from openai_util import create_chat_completion_content
|
from openai_util import create_chat_completion_content
|
||||||
from prompts import PROPOSE_TEST_PROMPT
|
|
||||||
|
# from prompts import PROPOSE_TEST_PROMPT
|
||||||
|
from prompts_cn import PROPOSE_TEST_PROMPT
|
||||||
|
from tools.tiktoken_util import get_encoding
|
||||||
|
|
||||||
MODEL = "gpt-3.5-turbo-1106"
|
MODEL = "gpt-3.5-turbo-1106"
|
||||||
# MODEL = "gpt-4-1106-preview"
|
# MODEL = "gpt-4-1106-preview"
|
||||||
|
ENCODING = "cl100k_base"
|
||||||
TOKEN_BUDGET = int(16000 * 0.9)
|
TOKEN_BUDGET = int(16000 * 0.9)
|
||||||
|
|
||||||
|
|
||||||
@ -20,7 +26,7 @@ def _mk_user_msg(
|
|||||||
"""
|
"""
|
||||||
Create a user message to be sent to the model within the token budget.
|
Create a user message to be sent to the model within the token budget.
|
||||||
"""
|
"""
|
||||||
encoding: tiktoken.Encoding = tiktoken.encoding_for_model(MODEL)
|
encoding = get_encoding(ENCODING)
|
||||||
|
|
||||||
func_content = f"function code\n```\n{func_to_test.func_content}\n```\n"
|
func_content = f"function code\n```\n{func_to_test.func_content}\n```\n"
|
||||||
class_content = ""
|
class_content = ""
|
||||||
@ -81,14 +87,16 @@ def propose_test(
|
|||||||
chat_language=chat_language,
|
chat_language=chat_language,
|
||||||
)
|
)
|
||||||
|
|
||||||
content = create_chat_completion_content(
|
model = os.environ.get("LLM_MODEL", MODEL)
|
||||||
model=MODEL,
|
content = chat_completion_no_stream_return_json(
|
||||||
messages=[{"role": "user", "content": user_msg}],
|
messages=[{"role": "user", "content": user_msg}],
|
||||||
response_format={"type": "json_object"},
|
llm_config={
|
||||||
temperature=0.1,
|
"model": model,
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
cases = json.loads(content).get("test_cases", [])
|
cases = content.get("test_cases", [])
|
||||||
|
|
||||||
descriptions = []
|
descriptions = []
|
||||||
for case in cases:
|
for case in cases:
|
||||||
|
20
unit_tests/tools/tiktoken_util.py
Normal file
20
unit_tests/tools/tiktoken_util.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
|
def get_encoding(encoding_name: str):
|
||||||
|
"""
|
||||||
|
Get a tiktoken encoding by name.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return tiktoken.get_encoding(encoding_name)
|
||||||
|
except Exception:
|
||||||
|
from tiktoken import registry
|
||||||
|
from tiktoken.core import Encoding
|
||||||
|
from tiktoken.registry import _find_constructors
|
||||||
|
|
||||||
|
def _get_encoding(name: str):
|
||||||
|
_find_constructors()
|
||||||
|
constructor = registry.ENCODING_CONSTRUCTORS[name]
|
||||||
|
return Encoding(**constructor(), use_pure_python=True)
|
||||||
|
|
||||||
|
return _get_encoding(encoding_name)
|
@ -1,13 +1,18 @@
|
|||||||
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import tiktoken
|
from minimax_util import chat_completion_stream
|
||||||
from model import FuncToTest, TokenBudgetExceededException
|
from model import FuncToTest, TokenBudgetExceededException
|
||||||
from openai_util import create_chat_completion_chunks
|
from openai_util import create_chat_completion_chunks
|
||||||
from prompts import WRITE_TESTS_PROMPT
|
|
||||||
|
# from prompts import WRITE_TESTS_PROMPT
|
||||||
|
from prompts_cn import WRITE_TESTS_PROMPT
|
||||||
from tools.file_util import retrieve_file_content
|
from tools.file_util import retrieve_file_content
|
||||||
|
from tools.tiktoken_util import get_encoding
|
||||||
|
|
||||||
MODEL = "gpt-4-1106-preview"
|
MODEL = "gpt-4-1106-preview"
|
||||||
|
ENCODING = "cl100k_base"
|
||||||
TOKEN_BUDGET = int(128000 * 0.9)
|
TOKEN_BUDGET = int(128000 * 0.9)
|
||||||
|
|
||||||
|
|
||||||
@ -18,7 +23,7 @@ def _mk_write_tests_msg(
|
|||||||
chat_language: str,
|
chat_language: str,
|
||||||
reference_files: Optional[List[str]] = None,
|
reference_files: Optional[List[str]] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
encoding: tiktoken.Encoding = tiktoken.encoding_for_model(MODEL)
|
encoding = get_encoding(ENCODING)
|
||||||
|
|
||||||
test_cases_str = ""
|
test_cases_str = ""
|
||||||
for i, test_case in enumerate(test_cases, 1):
|
for i, test_case in enumerate(test_cases, 1):
|
||||||
@ -92,13 +97,9 @@ def write_and_print_tests(
|
|||||||
chat_language=chat_language,
|
chat_language=chat_language,
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = create_chat_completion_chunks(
|
model = os.environ.get("LLM_MODEL", MODEL)
|
||||||
model=MODEL,
|
chat_completion_stream(
|
||||||
messages=[{"role": "user", "content": user_msg}],
|
messages=[{"role": "user", "content": user_msg}],
|
||||||
temperature=0.1,
|
llm_config={"model": model, "temperature": 0.1},
|
||||||
|
stream_out=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in chunks:
|
|
||||||
if chunk.choices[0].finish_reason == "stop":
|
|
||||||
break
|
|
||||||
print(chunk.choices[0].delta.content, flush=True, end="")
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user