fix: Improve LLM stream handling in devchatComplete function

- Refactor stream processing for better chunk handling
- Add max_tokens parameter to API payload
- Implement simulated stream receive for testing purposes
This commit is contained in:
bobo.yang 2024-08-19 18:29:57 +08:00
parent c0fe7f15ab
commit fa57ebe527

View File

@ -184,6 +184,7 @@ export async function * ollamaDeepseekComplete(prompt: string) : AsyncGenerator<
}
export async function * devchatComplete(prompt: string) : AsyncGenerator<CodeCompletionChunk> {
const devchatEndpoint = DevChatConfig.getInstance().get("providers.devchat.api_base");
const llmApiBase = DevChatConfig.getInstance().get("complete_api_base");
@ -207,10 +208,12 @@ export async function * devchatComplete(prompt: string) : AsyncGenerator<CodeCom
prompt: prompt,
stream: true,
stop: ["<|endoftext|>", "<|EOT|>", "<file_sep>", "\n\n"],
temperature: 0.2
temperature: 0.2,
max_tokens: 200
};
let idResponse = undefined;
// 内部实现的 sleep 函数
const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms));
try {
const response = await fetch(completionApiBase, {
@ -227,8 +230,27 @@ export async function * devchatComplete(prompt: string) : AsyncGenerator<CodeCom
const durationLLM = endTimeLLM[0] + endTimeLLM[1] / 1e9;
logger.channel()?.debug(`LLM api post took ${durationLLM} seconds`);
let hasFirstLine = false;
let hasFirstChunk = false;
let hasFirstLine = false;
let buffer = '';
const dataRegex = /^data: /m; // 匹配行首的 "data: "
// 模拟接收数据的函数
async function* simulateStreamReceive(stream: any): AsyncGenerator<Uint8Array> {
for await (const chunk of stream) {
const chunkSize = chunk.length;
const numParts = Math.ceil(Math.random() * 3) + 1; // 随机将chunk分成1-4部分
const partSize = Math.ceil(chunkSize / numParts);
for (let i = 0; i < chunkSize; i += partSize) {
const part = chunk.slice(i, Math.min(i + partSize, chunkSize));
logger.channel()?.debug(`Simulated receiving part ${i / partSize + 1}/${numParts} of chunk, size: ${part.length} bytes`);
yield part;
await sleep(Math.random() * 100); // 模拟网络延迟0-100ms
}
}
}
for await (const chunk of stream) {
if (!hasFirstChunk) {
hasFirstChunk = true;
@ -236,33 +258,25 @@ export async function * devchatComplete(prompt: string) : AsyncGenerator<CodeCom
const durationFirstChunk = endTimeFirstChunk[0] + endTimeFirstChunk[1] / 1e9;
logger.channel()?.debug(`LLM first chunk took ${durationFirstChunk} seconds`);
}
const chunkDataText = decoder.decode(chunk).trim();
// split chunkText by "data: ", for example:
// data: 123 data: 456 will split to ["", "data: 123 ", "data: 456"]
const chunkTexts = chunkDataText.split("data: ");
for (const chunkTextSplit of chunkTexts) {
if (chunkTextSplit.trim().length === 0) {
continue;
}
const chunkText = "data: " + chunkTextSplit.trim();
const chunkDataText = decoder.decode(chunk);
buffer += chunkDataText;
// logger.channel()?.info("receve chunk:", chunkText);
// data: {"id": "cmpl-1713846153", "created": 1713846160.292709, "object": "completion.chunk", "model": "ollama/starcoder2:7b", "choices": [{"index": 0, "finish_reason": "stop", "text": "\n});"}]}
// data: {"id": "cmpl-1713846153", "created": 1713846160.366049, "object": "completion.chunk", "model": "ollama/starcoder2:7b", "choices": [{"index": 0, "finish_reason": "stop", "text": ""}], "usage": {"prompt_tokens": 413, "completion_tokens": 16}}
if (!chunkText.startsWith("data:")) {
// log unexpected data
logger.channel()?.warn("Unexpected data: " + chunkText);
return;
}
while (true) {
const match = dataRegex.exec(buffer);
if (!match) break;
const dataStart = match.index;
const nextDataStart = buffer.slice(dataStart + 5).search(dataRegex);
const jsonEnd = nextDataStart !== -1 ? dataStart + 5 + nextDataStart : buffer.length;
const jsonData = buffer.substring(dataStart + 5, jsonEnd).trim();
const jsonData = chunkText.substring(5).trim();
if (jsonData === "[DONE]") {
return;
}
try {
const data = JSON.parse(chunkText.substring(5).trim());
const data = JSON.parse(jsonData);
if (!hasFirstLine && data.choices[0].text.indexOf("\n") !== -1) {
hasFirstLine = true;
const endTimeLine = process.hrtime(startTimeLLM);
@ -273,9 +287,14 @@ export async function * devchatComplete(prompt: string) : AsyncGenerator<CodeCom
text: data.choices[0].text,
id: data.id
};
} catch (e: any) {
logger.channel()?.info("receve:", chunkText);
logger.channel()?.warn("JSON Parsing Error:", e.message);
buffer = buffer.slice(jsonEnd);
} catch (e) {
// 如果解析失败,可能是因为数据不完整,我们继续到下一个循环
if (nextDataStart === -1) {
// 如果没有下一个 'data:'保留剩余的buffer
break;
}
buffer = buffer.slice(dataStart + 5 + nextDataStart);
}
}
}