110 lines
3.2 KiB
Python
Raw Normal View History

2024-01-31 16:15:30 +08:00
import json
import os
import sys
from functools import wraps
import openai
2024-01-31 16:15:30 +08:00
from .memory.base import ChatMemory
from .openai import (
chat_completion_no_stream_return_json,
2024-01-31 16:15:30 +08:00
chat_completion_stream,
chat_completion_stream_commit,
chunks_content,
2024-01-31 16:15:30 +08:00
retry_timeout,
stream_out_chunk,
to_dict_content_and_call,
)
2024-01-31 16:15:30 +08:00
from .pipeline import exception_handle, pipeline, retry
chat_completion_stream_out = exception_handle(
retry(
pipeline(
chat_completion_stream_commit,
retry_timeout,
stream_out_chunk,
chunks_content,
2024-01-31 16:15:30 +08:00
to_dict_content_and_call,
),
2024-01-31 16:15:30 +08:00
times=3,
),
lambda err: {
"content": None,
"function_name": None,
"parameters": "",
2024-01-31 16:15:30 +08:00
"error": err.type if isinstance(err, openai.APIError) else err,
},
)
2024-01-31 16:15:30 +08:00
def chat(
prompt,
memory: ChatMemory = None,
stream_out: bool = False,
model: str = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106"),
**llm_config,
):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal prompt, memory, model, llm_config
prompt = prompt.format(**kwargs)
messages = memory.contexts() if memory else []
if not any(item["content"] == prompt for item in messages):
messages.append({"role": "user", "content": prompt})
if "__user_request__" in kwargs:
messages.append(kwargs["__user_request__"])
del kwargs["__user_request__"]
llm_config["model"] = model
if not stream_out:
response = chat_completion_stream(messages, llm_config=llm_config)
else:
response = chat_completion_stream_out(messages, llm_config=llm_config)
if not response.get("content", None):
print(f"call {func.__name__} failed:", response["error"], file=sys.stderr)
return None
2024-01-31 16:15:30 +08:00
if memory:
memory.append(
{"role": "user", "content": prompt},
2024-01-31 16:15:30 +08:00
{"role": "assistant", "content": response["content"]},
)
return response["content"]
2024-01-31 16:15:30 +08:00
return wrapper
2024-01-31 16:15:30 +08:00
return decorator
2024-01-31 16:15:30 +08:00
def chat_json(
prompt,
memory: ChatMemory = None,
model: str = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106"),
**llm_config,
):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal prompt, memory, model, llm_config
prompt = prompt.format(**kwargs)
messages = memory.contexts() if memory else []
if not any(item["content"] == prompt for item in messages):
messages.append({"role": "user", "content": prompt})
llm_config["model"] = model
response = chat_completion_no_stream_return_json(messages, llm_config=llm_config)
if not response:
print(f"call {func.__name__} failed.", file=sys.stderr)
2024-01-31 16:15:30 +08:00
if memory:
memory.append(
{"role": "user", "content": prompt},
2024-01-31 16:15:30 +08:00
{"role": "assistant", "content": json.dumps(response)},
)
return response
2024-01-31 16:15:30 +08:00
return wrapper
2024-01-31 16:15:30 +08:00
return decorator