Refactor pipeline imports and function calls

This commit is contained in:
bobo 2024-01-24 21:26:11 +08:00
parent 47887ca6f9
commit 6f1b5ca2ef
2 changed files with 32 additions and 53 deletions

View File

@ -6,14 +6,7 @@ from typing import Dict, List
import openai import openai
from .pipeline import ( from .pipeline import RetryException, exception_err, exception_handle, parallel, pipeline, retry
RetryException,
pipeline,
parallel,
retry,
exception_err,
exception_handle
)
def _try_remove_markdown_block_flag(content): def _try_remove_markdown_block_flag(content):
@ -81,6 +74,7 @@ def retry_timeout(chunks):
def chunk_list(chunks): def chunk_list(chunks):
return [chunk for chunk in chunks] return [chunk for chunk in chunks]
def chunks_content(chunks): def chunks_content(chunks):
content = None content = None
for chunk in chunks: for chunk in chunks:
@ -92,10 +86,11 @@ def chunks_content(chunks):
content += delta["content"] content += delta["content"]
return content return content
def chunks_call(chunks): def chunks_call(chunks):
function_name = None function_name = None
parameters = "" parameters = ""
for chunk in chunks: for chunk in chunks:
chunk = chunk.dict() chunk = chunk.dict()
delta = chunk["choices"][0]["delta"] delta = chunk["choices"][0]["delta"]
@ -107,6 +102,7 @@ def chunks_call(chunks):
parameters += tool_call["arguments"] parameters += tool_call["arguments"]
return {"function_name": function_name, "parameters": parameters} return {"function_name": function_name, "parameters": parameters}
def content_to_json(content): def content_to_json(content):
try: try:
# json will format as ```json ... ``` in 1106 model # json will format as ```json ... ``` in 1106 model
@ -120,48 +116,24 @@ def content_to_json(content):
def to_dict_content_and_call(content, function_call): def to_dict_content_and_call(content, function_call):
return { return {"content": content, **function_call}
"content": content,
**function_call
}
chat_completion_content = retry( chat_completion_content = retry(
pipeline( pipeline(chat_completion_stream_commit, retry_timeout, chunks_content), times=3
chat_completion_stream_commit,
retry_timeout,
chunks_content
),
times=3
) )
chat_completion_stream_content = retry( chat_completion_stream_content = retry(
pipeline( pipeline(chat_completion_stream_commit, retry_timeout, stream_out_chunk, chunks_content),
chat_completion_stream_commit, times=3,
retry_timeout,
stream_out_chunk,
chunks_content
),
times=3
) )
chat_completion_call = retry( chat_completion_call = retry(
pipeline( pipeline(chat_completion_stream_commit, retry_timeout, chunks_call), times=3
chat_completion_stream_commit,
retry_timeout,
chunks_call
),
times=3
) )
chat_completion_no_stream_return_json = retry( chat_completion_no_stream_return_json = retry(
pipeline( pipeline(chat_completion_stream_commit, retry_timeout, chunks_content, content_to_json), times=3
chat_completion_stream_commit,
retry_timeout,
chunks_content,
content_to_json
),
times=3
) )
chat_completion_stream = exception_handle( chat_completion_stream = exception_handle(
@ -169,13 +141,10 @@ chat_completion_stream = exception_handle(
pipeline( pipeline(
chat_completion_stream_commit, chat_completion_stream_commit,
retry_timeout, retry_timeout,
parallel( parallel(chunks_content, chunks_call),
chunks_content, to_dict_content_and_call,
chunks_call
),
to_dict_content_and_call
), ),
times=3 times=3,
), ),
lambda err: {"content": None, "function_name": None, "parameters": "", "error": err} lambda err: {"content": None, "function_name": None, "parameters": "", "error": err},
) )

View File

@ -5,6 +5,7 @@ class RetryException(Exception):
def __init__(self, err): def __init__(self, err):
self.error = err self.error = err
def retry(func, times): def retry(func, times):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
for index in range(times): for index in range(times):
@ -17,8 +18,10 @@ def retry(func, times):
except Exception as err: except Exception as err:
raise err raise err
raise err.error raise err.error
return wrapper return wrapper
def exception_err(func): def exception_err(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
@ -26,16 +29,20 @@ def exception_err(func):
return True, result return True, result
except Exception as err: except Exception as err:
return False, err return False, err
return wrapper return wrapper
def exception_handle(func, handler): def exception_handle(func, handler):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
result = func(*args, **kwargs) result = func(*args, **kwargs)
return result return result
except Exception as err: except Exception as err:
return handler(err) return handler(err)
return wrapper
return wrapper
def pipeline(*funcs): def pipeline(*funcs):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -48,12 +55,15 @@ def pipeline(*funcs):
else: else:
args = func(*args, **kwargs) args = func(*args, **kwargs)
return args return args
return wrapper return wrapper
def parallel(*funcs): def parallel(*funcs):
def wrapper(args): def wrapper(args):
results = {"__type__": "parallel", "value": []} results = {"__type__": "parallel", "value": []}
for func in funcs: for func in funcs:
results["value"].append(func(args)) results["value"].append(func(args))
return results return results
return wrapper return wrapper