Remove llm_api module
This commit is contained in:
parent
b3cc1eb54f
commit
c905c7c850
@ -1,19 +0,0 @@
|
||||
from .chat import chat, chat_json
|
||||
from .memory.base import ChatMemory
|
||||
from .memory.fixsize_memory import FixSizeChatMemory
|
||||
from .openai import chat_completion_no_stream_return_json, chat_completion_stream
|
||||
from .text_confirm import llm_edit_confirm
|
||||
from .tools_call import chat_tools, llm_func, llm_param
|
||||
|
||||
__all__ = [
|
||||
"chat_completion_stream",
|
||||
"chat_completion_no_stream_return_json",
|
||||
"chat_json",
|
||||
"chat",
|
||||
"llm_edit_confirm",
|
||||
"llm_func",
|
||||
"llm_param",
|
||||
"chat_tools",
|
||||
"ChatMemory",
|
||||
"FixSizeChatMemory",
|
||||
]
|
@ -1,109 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from functools import wraps
|
||||
|
||||
import openai
|
||||
|
||||
from .memory.base import ChatMemory
|
||||
from .openai import (
|
||||
chat_completion_no_stream_return_json,
|
||||
chat_completion_stream,
|
||||
chat_completion_stream_commit,
|
||||
chunks_content,
|
||||
retry_timeout,
|
||||
stream_out_chunk,
|
||||
to_dict_content_and_call,
|
||||
)
|
||||
from .pipeline import exception_handle, pipeline, retry
|
||||
|
||||
chat_completion_stream_out = exception_handle(
|
||||
retry(
|
||||
pipeline(
|
||||
chat_completion_stream_commit,
|
||||
retry_timeout,
|
||||
stream_out_chunk,
|
||||
chunks_content,
|
||||
to_dict_content_and_call,
|
||||
),
|
||||
times=3,
|
||||
),
|
||||
lambda err: {
|
||||
"content": None,
|
||||
"function_name": None,
|
||||
"parameters": "",
|
||||
"error": err.type if isinstance(err, openai.APIError) else err,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def chat(
|
||||
prompt,
|
||||
memory: ChatMemory = None,
|
||||
stream_out: bool = False,
|
||||
model: str = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106"),
|
||||
**llm_config,
|
||||
):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal prompt, memory, model, llm_config
|
||||
prompt = prompt.format(**kwargs)
|
||||
messages = memory.contexts() if memory else []
|
||||
if not any(item["content"] == prompt for item in messages) and prompt:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
if "__user_request__" in kwargs:
|
||||
messages.append(kwargs["__user_request__"])
|
||||
del kwargs["__user_request__"]
|
||||
|
||||
llm_config["model"] = model
|
||||
if not stream_out:
|
||||
response = chat_completion_stream(messages, llm_config=llm_config)
|
||||
else:
|
||||
response = chat_completion_stream_out(messages, llm_config=llm_config)
|
||||
if not response.get("content", None):
|
||||
print(f"call {func.__name__} failed:", response["error"], file=sys.stderr)
|
||||
return None
|
||||
|
||||
if memory:
|
||||
memory.append(
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": response["content"]},
|
||||
)
|
||||
return response["content"]
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def chat_json(
|
||||
prompt,
|
||||
memory: ChatMemory = None,
|
||||
model: str = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106"),
|
||||
**llm_config,
|
||||
):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal prompt, memory, model, llm_config
|
||||
prompt = prompt.format(**kwargs)
|
||||
messages = memory.contexts() if memory else []
|
||||
if not any(item["content"] == prompt for item in messages):
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
llm_config["model"] = model
|
||||
response = chat_completion_no_stream_return_json(messages, llm_config=llm_config)
|
||||
if not response:
|
||||
print(f"call {func.__name__} failed.", file=sys.stderr)
|
||||
|
||||
if memory:
|
||||
memory.append(
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": json.dumps(response)},
|
||||
)
|
||||
return response
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
@ -1,32 +0,0 @@
|
||||
class ChatMemory:
|
||||
"""
|
||||
ChatMemory is the base class for all chat memory classes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def append(self, request, response):
|
||||
"""
|
||||
Append a request and response to the memory.
|
||||
"""
|
||||
# it must implemented in sub class
|
||||
pass
|
||||
|
||||
def append_request(self, request):
|
||||
"""
|
||||
Append a request to the memory.
|
||||
"""
|
||||
pass
|
||||
|
||||
def append_response(self, response):
|
||||
"""
|
||||
Append a request to the memory.
|
||||
"""
|
||||
pass
|
||||
|
||||
def contexts(self):
|
||||
"""
|
||||
Return the contexts of the memory.
|
||||
"""
|
||||
pass
|
@ -1,51 +0,0 @@
|
||||
from .base import ChatMemory
|
||||
|
||||
|
||||
class FixSizeChatMemory(ChatMemory):
|
||||
"""
|
||||
FixSizeChatMemory is a memory class that stores up
|
||||
to a fixed number of requests and responses.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 5, messages=[], system_prompt=None):
|
||||
"""
|
||||
init the memory
|
||||
"""
|
||||
super().__init__()
|
||||
self._max_size = max_size
|
||||
# store last max_size messages
|
||||
self._messages = messages[-max_size * 2 :]
|
||||
self._system_prompt = system_prompt
|
||||
|
||||
def append(self, request, response):
|
||||
"""
|
||||
Append a request and response to the memory.
|
||||
"""
|
||||
self._messages.append(request)
|
||||
self._messages.append(response)
|
||||
if len(self._messages) > self._max_size * 2:
|
||||
self._messages = self._messages[-self._max_size * 2 :]
|
||||
|
||||
def append_request(self, request):
|
||||
"""
|
||||
Append a request to the memory.
|
||||
"""
|
||||
self._messages.append(request)
|
||||
|
||||
def append_response(self, response):
|
||||
"""
|
||||
Append a response to the memory.
|
||||
"""
|
||||
self._messages.append(response)
|
||||
if len(self._messages) > self._max_size * 2:
|
||||
self._messages = self._messages[-self._max_size * 2 :]
|
||||
|
||||
def contexts(self):
|
||||
"""
|
||||
Return the contexts of the memory.
|
||||
"""
|
||||
messages = self._messages.copy()
|
||||
# insert system prompt at the beginning
|
||||
if self._system_prompt:
|
||||
messages = [{"role": "system", "content": self._system_prompt}] + messages
|
||||
return messages
|
@ -1,197 +0,0 @@
|
||||
# flake8: noqa: E402
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from functools import wraps
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import openai
|
||||
|
||||
from .pipeline import (
|
||||
RetryException,
|
||||
exception_err,
|
||||
exception_handle,
|
||||
exception_output_handle,
|
||||
parallel,
|
||||
pipeline,
|
||||
retry,
|
||||
)
|
||||
|
||||
|
||||
def _try_remove_markdown_block_flag(content):
|
||||
"""
|
||||
如果content是一个markdown块,则删除它的头部```xxx和尾部```
|
||||
"""
|
||||
# 定义正则表达式模式,用于匹配markdown块的头部和尾部
|
||||
pattern = r"^\s*```\s*(\w+)\s*\n(.*?)\n\s*```\s*$"
|
||||
|
||||
# 使用re模块进行匹配
|
||||
match = re.search(pattern, content, re.DOTALL | re.MULTILINE)
|
||||
|
||||
if match:
|
||||
# 如果匹配成功,则提取出markdown块的内容并返回
|
||||
_ = match.group(1) # language
|
||||
markdown_content = match.group(2)
|
||||
return markdown_content.strip()
|
||||
else:
|
||||
# 如果匹配失败,则返回原始内容
|
||||
return content
|
||||
|
||||
|
||||
def chat_completion_stream_commit(
|
||||
messages: List[Dict], # [{"role": "user", "content": "hello"}]
|
||||
llm_config: Dict, # {"model": "...", ...}
|
||||
):
|
||||
client = openai.OpenAI(
|
||||
api_key=os.environ.get("OPENAI_API_KEY", None),
|
||||
base_url=os.environ.get("OPENAI_API_BASE", None),
|
||||
)
|
||||
|
||||
llm_config["stream"] = True
|
||||
llm_config["timeout"] = 60
|
||||
return client.chat.completions.create(messages=messages, **llm_config)
|
||||
|
||||
|
||||
def chat_completion_stream_raw(**kwargs):
|
||||
client = openai.OpenAI(
|
||||
api_key=os.environ.get("OPENAI_API_KEY", None),
|
||||
base_url=os.environ.get("OPENAI_API_BASE", None),
|
||||
)
|
||||
|
||||
kwargs["stream"] = True
|
||||
kwargs["timeout"] = 60
|
||||
return client.chat.completions.create(**kwargs)
|
||||
|
||||
|
||||
def stream_out_chunk(chunks):
|
||||
for chunk in chunks:
|
||||
chunk_dict = chunk.dict()
|
||||
delta = chunk_dict["choices"][0]["delta"]
|
||||
if delta.get("content", None):
|
||||
print(delta["content"], end="", flush=True)
|
||||
yield chunk
|
||||
|
||||
|
||||
def retry_timeout(chunks):
|
||||
try:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
except (openai.APIConnectionError, openai.APITimeoutError) as err:
|
||||
raise RetryException(err) from err
|
||||
|
||||
|
||||
def chunk_list(chunks):
|
||||
return [chunk for chunk in chunks]
|
||||
|
||||
|
||||
def chunks_content(chunks):
|
||||
content = None
|
||||
for chunk in chunks:
|
||||
chunk_dict = chunk.dict()
|
||||
delta = chunk_dict["choices"][0]["delta"]
|
||||
if delta.get("content", None):
|
||||
if content is None:
|
||||
content = ""
|
||||
content += delta["content"]
|
||||
return content
|
||||
|
||||
|
||||
def chunks_call(chunks):
|
||||
tool_calls = []
|
||||
|
||||
for chunk in chunks:
|
||||
chunk = chunk.dict()
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if "tool_calls" in delta and delta["tool_calls"]:
|
||||
tool_call = delta["tool_calls"][0]["function"]
|
||||
if delta["tool_calls"][0].get("index", None) is not None:
|
||||
index = delta["tool_calls"][0]["index"]
|
||||
if index >= len(tool_calls):
|
||||
tool_calls.append({"name": None, "arguments": ""})
|
||||
if tool_call.get("name", None):
|
||||
tool_calls[-1]["name"] = tool_call["name"]
|
||||
if tool_call.get("arguments", None):
|
||||
tool_calls[-1]["arguments"] += tool_call["arguments"]
|
||||
return tool_calls
|
||||
|
||||
|
||||
def content_to_json(content):
|
||||
try:
|
||||
# json will format as ```json ... ``` in 1106 model
|
||||
response_content = _try_remove_markdown_block_flag(content)
|
||||
response_obj = json.loads(response_content)
|
||||
return response_obj
|
||||
except json.JSONDecodeError as err:
|
||||
raise RetryException(err) from err
|
||||
except Exception as err:
|
||||
raise err
|
||||
|
||||
|
||||
def to_dict_content_and_call(content, tool_calls=[]):
|
||||
return {
|
||||
"content": content,
|
||||
"function_name": tool_calls[0]["name"] if tool_calls else None,
|
||||
"parameters": tool_calls[0]["arguments"] if tool_calls else "",
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
|
||||
chat_completion_content = retry(
|
||||
pipeline(chat_completion_stream_commit, retry_timeout, chunks_content), times=3
|
||||
)
|
||||
|
||||
chat_completion_stream_content = retry(
|
||||
pipeline(chat_completion_stream_commit, retry_timeout, stream_out_chunk, chunks_content),
|
||||
times=3,
|
||||
)
|
||||
|
||||
chat_completion_call = retry(
|
||||
pipeline(chat_completion_stream_commit, retry_timeout, chunks_call), times=3
|
||||
)
|
||||
|
||||
chat_completion_no_stream_return_json = exception_handle(
|
||||
retry(
|
||||
pipeline(chat_completion_stream_commit, retry_timeout, chunks_content, content_to_json),
|
||||
times=3,
|
||||
),
|
||||
exception_output_handle(lambda err: None),
|
||||
)
|
||||
|
||||
chat_completion_stream = exception_handle(
|
||||
retry(
|
||||
pipeline(
|
||||
chat_completion_stream_commit,
|
||||
retry_timeout,
|
||||
chunks_content,
|
||||
to_dict_content_and_call,
|
||||
),
|
||||
times=3,
|
||||
),
|
||||
lambda err: {
|
||||
"content": None,
|
||||
"function_name": None,
|
||||
"parameters": "",
|
||||
"error": err.type if isinstance(err, openai.APIError) else err,
|
||||
},
|
||||
)
|
||||
|
||||
chat_call_completion_stream = exception_handle(
|
||||
retry(
|
||||
pipeline(
|
||||
chat_completion_stream_commit,
|
||||
retry_timeout,
|
||||
chunk_list,
|
||||
parallel(chunks_content, chunks_call),
|
||||
to_dict_content_and_call,
|
||||
),
|
||||
times=3,
|
||||
),
|
||||
lambda err: {
|
||||
"content": None,
|
||||
"function_name": None,
|
||||
"parameters": "",
|
||||
"tool_calls": [],
|
||||
"error": err.type if isinstance(err, openai.APIError) else err,
|
||||
},
|
||||
)
|
@ -1,94 +0,0 @@
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
import openai
|
||||
|
||||
|
||||
class RetryException(Exception):
|
||||
def __init__(self, err):
|
||||
self.error = err
|
||||
|
||||
|
||||
def retry(func, times):
|
||||
def wrapper(*args, **kwargs):
|
||||
for index in range(times):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except RetryException as err:
|
||||
if index + 1 == times:
|
||||
raise err.error
|
||||
continue
|
||||
except Exception as err:
|
||||
raise err
|
||||
raise err.error
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def exception_err(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return True, result
|
||||
except Exception as err:
|
||||
return False, err
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def exception_output_handle(func):
|
||||
def wrapper(err):
|
||||
if isinstance(err, openai.APIError):
|
||||
print(err.type, file=sys.stderr, flush=True)
|
||||
else:
|
||||
print(err, file=sys.stderr, flush=True)
|
||||
return func(err)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def exception_output_handle(func):
|
||||
def wrapper(err):
|
||||
if isinstance(err, openai.APIError):
|
||||
print(err.type, file=sys.stderr, flush=True)
|
||||
else:
|
||||
print(err, file=sys.stderr, flush=True)
|
||||
return func(err)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def exception_handle(func, handler):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as err:
|
||||
return handler(err)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def pipeline(*funcs):
|
||||
def wrapper(*args, **kwargs):
|
||||
for index, func in enumerate(funcs):
|
||||
if index > 0:
|
||||
if isinstance(args, Dict) and args.get("__type__", None) == "parallel":
|
||||
args = func(*args["value"])
|
||||
else:
|
||||
args = func(args)
|
||||
else:
|
||||
args = func(*args, **kwargs)
|
||||
return args
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def parallel(*funcs):
|
||||
def wrapper(args):
|
||||
results = {"__type__": "parallel", "value": []}
|
||||
for func in funcs:
|
||||
results["value"].append(func(args))
|
||||
return results
|
||||
|
||||
return wrapper
|
@ -1,57 +0,0 @@
|
||||
# flake8: noqa: E402
|
||||
import os
|
||||
import sys
|
||||
from functools import wraps
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
|
||||
from chatmark import Checkbox, Form, TextEditor # noqa: #402
|
||||
|
||||
|
||||
class MissEditConfirmFieldException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def edit_confirm(response):
|
||||
need_regenerate = Checkbox(["Need Regenerate"])
|
||||
edit_text = TextEditor(response)
|
||||
feedback_text = TextEditor("")
|
||||
confirmation_form = Form(
|
||||
[
|
||||
"Edit AI Response:",
|
||||
edit_text,
|
||||
"Need Regenerate?",
|
||||
need_regenerate,
|
||||
"Feedback if Regenerate:",
|
||||
feedback_text,
|
||||
]
|
||||
)
|
||||
confirmation_form.render()
|
||||
if len(need_regenerate.selections) > 0:
|
||||
return True, feedback_text.new_text
|
||||
return False, edit_text.new_text
|
||||
|
||||
|
||||
def llm_edit_confirm(edit_confirm_fun=edit_confirm):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal edit_confirm_fun
|
||||
if not edit_confirm_fun:
|
||||
raise MissEditConfirmFieldException()
|
||||
|
||||
while True:
|
||||
response = func(*args, **kwargs)
|
||||
if not response:
|
||||
return response
|
||||
|
||||
do_regenerate, new_response = edit_confirm_fun(response)
|
||||
if do_regenerate:
|
||||
kwargs["__user_request__"] = {"role": "user", "content": new_response}
|
||||
else:
|
||||
return new_response if new_response else response
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
@ -1,220 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from functools import wraps
|
||||
|
||||
from .memory.base import ChatMemory
|
||||
from .openai import chat_call_completion_stream
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from chatmark import Checkbox, Form, Radio, TextEditor # noqa: #402
|
||||
from ide_services import IDEService # noqa: #402
|
||||
|
||||
|
||||
class MissToolsFieldException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def openai_tool_schema(name, description, parameters, required):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": {"type": "object", "properties": parameters, "required": required},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def openai_function_schema(name, description, properties, required):
|
||||
return {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": {"type": "object", "properties": properties, "required": required},
|
||||
}
|
||||
|
||||
|
||||
def llm_func(name, description, schema_fun=openai_tool_schema):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if not hasattr(func, "llm_metadata"):
|
||||
func.llm_metadata = {"properties": {}, "required": []}
|
||||
|
||||
wrapper.function_name = name
|
||||
wrapper.json_schema = lambda: schema_fun(
|
||||
name,
|
||||
description,
|
||||
func.llm_metadata.get("properties", {}),
|
||||
func.llm_metadata.get("required", []),
|
||||
)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def llm_param(name, description, dtype, **kwargs):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if hasattr(func, "llm_metadata"):
|
||||
wrapper.llm_metadata = func.llm_metadata
|
||||
else:
|
||||
wrapper.llm_metadata = {"properties": {}, "required": []}
|
||||
|
||||
wrapper.llm_metadata["properties"][name] = {
|
||||
"type": dtype,
|
||||
"description": description,
|
||||
**kwargs, # Add any additional keyword arguments
|
||||
}
|
||||
wrapper.llm_metadata["required"].append(name)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def call_confirm(response):
|
||||
"""
|
||||
Prompt the user to confirm if a function call should be allowed.
|
||||
|
||||
This function is responsible for asking the user to confirm whether the AI's
|
||||
intention to call a function is permissible. It prints out the response content
|
||||
and the details of the function calls that the AI intends to make. The user is
|
||||
then presented with a choice to either allow or deny the function call.
|
||||
|
||||
Parameters:
|
||||
response (dict): A dictionary containing the 'content' and 'all_calls' keys.
|
||||
'content' is a string representing the AI's response, and
|
||||
'all_calls' is a list of dictionaries, each representing a
|
||||
function call with 'function_name' and 'parameters' keys.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing a boolean and a string. The boolean indicates whether
|
||||
the function call is allowed (True) or not (False). The string contains
|
||||
additional input from the user if the function call is not allowed.
|
||||
"""
|
||||
|
||||
def display_response_and_calls(response):
|
||||
if response["content"]:
|
||||
print(f"AI Response: {response['content']}", end="\n\n", flush=True)
|
||||
print("Function Call Requests:", end="\n\n", flush=True)
|
||||
for call_request in response["all_calls"]:
|
||||
print(
|
||||
f"Function: {call_request['function_name']}, "
|
||||
f"Parameters: {call_request['parameters']}",
|
||||
end="\n\n",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def prompt_user_confirmation():
|
||||
function_call_radio = Radio(["Allow function call", "Block function call"])
|
||||
user_feedback_input = TextEditor("")
|
||||
confirmation_form = Form(
|
||||
[
|
||||
"Permission to proceed with function call?",
|
||||
function_call_radio,
|
||||
"Provide feedback if blocked:",
|
||||
user_feedback_input,
|
||||
]
|
||||
)
|
||||
confirmation_form.render()
|
||||
user_allowed_call = function_call_radio.selection == 0
|
||||
user_feedback = user_feedback_input.new_text
|
||||
return user_allowed_call, user_feedback
|
||||
|
||||
display_response_and_calls(response)
|
||||
return prompt_user_confirmation()
|
||||
|
||||
|
||||
def chat_tools(
|
||||
prompt,
|
||||
memory: ChatMemory = None,
|
||||
model: str = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106"),
|
||||
tools=None,
|
||||
call_confirm_fun=call_confirm,
|
||||
**llm_config,
|
||||
):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal prompt, memory, model, tools, call_confirm_fun, llm_config
|
||||
prompt = prompt.format(**kwargs)
|
||||
if not tools:
|
||||
raise MissToolsFieldException()
|
||||
|
||||
messages = memory.contexts() if memory else []
|
||||
if not any(item["content"] == prompt for item in messages):
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
tool_schemas = [fun.json_schema() for fun in tools] if tools else []
|
||||
|
||||
llm_config["model"] = model
|
||||
llm_config["tools"] = tool_schemas
|
||||
|
||||
user_request = {"role": "user", "content": prompt}
|
||||
while True:
|
||||
response = chat_call_completion_stream(messages, llm_config=llm_config)
|
||||
if not response.get("content", None) and not response.get("function_name", None):
|
||||
print(f"call {func.__name__} failed:", response["error"], file=sys.stderr)
|
||||
return response
|
||||
|
||||
response_content = (
|
||||
f"{response.get('content', '') or ''}\n\n"
|
||||
f"call function {response.get('function_name', '')} with arguments:"
|
||||
f"{response.get('parameters', '')}"
|
||||
)
|
||||
if memory:
|
||||
memory.append(user_request, {"role": "assistant", "content": response_content})
|
||||
messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
if not response.get("function_name", None):
|
||||
return response
|
||||
if not response.get("all_calls", None):
|
||||
response["all_calls"] = [
|
||||
{
|
||||
"function_name": response["function_name"],
|
||||
"parameters": response["parameters"],
|
||||
}
|
||||
]
|
||||
|
||||
do_call = True
|
||||
if call_confirm_fun:
|
||||
do_call, fix_prompt = call_confirm_fun(response)
|
||||
|
||||
if do_call:
|
||||
# call function
|
||||
functions = {tool.function_name: tool for tool in tools}
|
||||
for call in response["all_calls"]:
|
||||
IDEService().ide_logging(
|
||||
"info",
|
||||
f"try to call function tool: {call['function_name']} "
|
||||
f"with {call['parameters']}",
|
||||
)
|
||||
tool = functions[call["function_name"]]
|
||||
result = tool(**json.loads(call["parameters"]))
|
||||
messages.append(
|
||||
{
|
||||
"role": "function",
|
||||
"content": f"function has called, this is the result: {result}",
|
||||
"name": call["function_name"],
|
||||
}
|
||||
)
|
||||
user_request = {
|
||||
"role": "function",
|
||||
"content": f"function has called, this is the result: {result}",
|
||||
"name": call["function_name"],
|
||||
}
|
||||
else:
|
||||
# update prompt
|
||||
messages.append({"role": "user", "content": fix_prompt})
|
||||
user_request = {"role": "user", "content": fix_prompt}
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
Loading…
x
Reference in New Issue
Block a user