handle response while stream is False

This commit is contained in:
bobo.yang 2023-11-03 18:10:57 +08:00
parent b9f03afb9d
commit 782ec455d6
11 changed files with 10 additions and 6 deletions

BIN
.DS_Store vendored

Binary file not shown.

Binary file not shown.

View File

@ -3,4 +3,4 @@ import sys
if __name__ == "__main__":
from devchat._cli.main import main as _main
sys.exit(_main())
sys.exit(_main())

Binary file not shown.

Binary file not shown.

View File

@ -98,7 +98,8 @@ class Assistant:
for chunk in self._chat.stream_response(self._prompt):
if isinstance(chunk, openai.types.chat.chat_completion_chunk.ChatCompletionChunk):
chunk = chunk.dict()
if "function_call" in chunk["choices"][0]["delta"] and not chunk["choices"][0]["delta"]["function_call"]:
if "function_call" in chunk["choices"][0]["delta"] and \
not chunk["choices"][0]["delta"]["function_call"]:
del chunk["choices"][0]["delta"]["function_call"]
if not chunk["choices"][0]["delta"]["content"]:
chunk["choices"][0]["delta"]["content"] = ""
@ -117,7 +118,7 @@ class Assistant:
first_chunk = False
yield self._prompt.formatted_header()
yield delta
if not self._prompt.responses or len(self._prompt.responses) == 0:
if not self._prompt.responses:
raise RuntimeError("No responses returned from the chat API")
self._store.store_prompt(self._prompt)
yield self._prompt.formatted_footer(0) + '\n'
@ -126,7 +127,7 @@ class Assistant:
else:
response_str = self._chat.complete_response(self._prompt)
self._prompt.set_response(response_str)
if not self._prompt.responses or len(self._prompt.responses) == 0:
if not self._prompt.responses:
raise RuntimeError("No responses returned from the chat API")
self._store.store_prompt(self._prompt)
for index in range(len(self._prompt.responses)):

Binary file not shown.

Binary file not shown.

View File

@ -1,3 +1,4 @@
import json
import os
from typing import Optional, Union, List, Dict, Iterator
from pydantic import BaseModel, Field
@ -76,6 +77,8 @@ class OpenAIChat(Chat):
messages=prompt.messages,
**config_params
)
if isinstance(response, openai.types.chat.chat_completion.ChatCompletion):
return json.dumps(response.dict())
return str(response)
def stream_response(self, prompt: OpenAIPrompt) -> Iterator:

View File

@ -92,7 +92,7 @@ class Store:
logger.warning("Topic %s not found in graph but added", topic['root'])
if prompt.parent == topic['root'] or \
prompt.parent in nx.ancestors(self._graph, topic['root']):
topic['latest_time'] = prompt.timestamp
topic['latest_time'] = max(topic.get('latest_time', 0), prompt.timestamp)
self._topics_table.update(topic, doc_ids=[topic.doc_id])
break
else:

View File

@ -12,7 +12,7 @@ import tiktoken
try:
encoding = tiktoken.get_encoding("cl100k_base")
except Exception:
import tiktoken.registry as registry
from tiktoken import registry
from tiktoken.registry import _find_constructors
from tiktoken.core import Encoding