Refactor pipeline imports and function calls

This commit is contained in:
bobo 2024-01-24 21:26:11 +08:00
parent 47887ca6f9
commit 6f1b5ca2ef
2 changed files with 32 additions and 53 deletions

View File

@ -6,14 +6,7 @@ from typing import Dict, List
import openai
from .pipeline import (
RetryException,
pipeline,
parallel,
retry,
exception_err,
exception_handle
)
from .pipeline import RetryException, exception_err, exception_handle, parallel, pipeline, retry
def _try_remove_markdown_block_flag(content):
@ -81,6 +74,7 @@ def retry_timeout(chunks):
def chunk_list(chunks):
return [chunk for chunk in chunks]
def chunks_content(chunks):
content = None
for chunk in chunks:
@ -92,10 +86,11 @@ def chunks_content(chunks):
content += delta["content"]
return content
def chunks_call(chunks):
function_name = None
parameters = ""
for chunk in chunks:
chunk = chunk.dict()
delta = chunk["choices"][0]["delta"]
@ -107,6 +102,7 @@ def chunks_call(chunks):
parameters += tool_call["arguments"]
return {"function_name": function_name, "parameters": parameters}
def content_to_json(content):
try:
# json will format as ```json ... ``` in 1106 model
@ -120,48 +116,24 @@ def content_to_json(content):
def to_dict_content_and_call(content, function_call):
return {
"content": content,
**function_call
}
return {"content": content, **function_call}
chat_completion_content = retry(
pipeline(
chat_completion_stream_commit,
retry_timeout,
chunks_content
),
times=3
pipeline(chat_completion_stream_commit, retry_timeout, chunks_content), times=3
)
chat_completion_stream_content = retry(
pipeline(
chat_completion_stream_commit,
retry_timeout,
stream_out_chunk,
chunks_content
),
times=3
pipeline(chat_completion_stream_commit, retry_timeout, stream_out_chunk, chunks_content),
times=3,
)
chat_completion_call = retry(
pipeline(
chat_completion_stream_commit,
retry_timeout,
chunks_call
),
times=3
pipeline(chat_completion_stream_commit, retry_timeout, chunks_call), times=3
)
chat_completion_no_stream_return_json = retry(
pipeline(
chat_completion_stream_commit,
retry_timeout,
chunks_content,
content_to_json
),
times=3
pipeline(chat_completion_stream_commit, retry_timeout, chunks_content, content_to_json), times=3
)
chat_completion_stream = exception_handle(
@ -169,13 +141,10 @@ chat_completion_stream = exception_handle(
pipeline(
chat_completion_stream_commit,
retry_timeout,
parallel(
chunks_content,
chunks_call
),
to_dict_content_and_call
parallel(chunks_content, chunks_call),
to_dict_content_and_call,
),
times=3
times=3,
),
lambda err: {"content": None, "function_name": None, "parameters": "", "error": err}
lambda err: {"content": None, "function_name": None, "parameters": "", "error": err},
)

View File

@ -5,6 +5,7 @@ class RetryException(Exception):
def __init__(self, err):
self.error = err
def retry(func, times):
def wrapper(*args, **kwargs):
for index in range(times):
@ -17,8 +18,10 @@ def retry(func, times):
except Exception as err:
raise err
raise err.error
return wrapper
def exception_err(func):
def wrapper(*args, **kwargs):
try:
@ -26,16 +29,20 @@ def exception_err(func):
return True, result
except Exception as err:
return False, err
return wrapper
def exception_handle(func, handler):
def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
return result
except Exception as err:
return handler(err)
return wrapper
def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
return result
except Exception as err:
return handler(err)
return wrapper
def pipeline(*funcs):
def wrapper(*args, **kwargs):
@ -48,12 +55,15 @@ def pipeline(*funcs):
else:
args = func(*args, **kwargs)
return args
return wrapper
def parallel(*funcs):
def wrapper(args):
results = {"__type__": "parallel", "value": []}
for func in funcs:
results["value"].append(func(args))
return results
return wrapper