From b5ca1316190196b731fbd38d8d2d89bac4343104 Mon Sep 17 00:00:00 2001 From: ken Date: Fri, 4 Apr 2025 20:52:11 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E4=BA=9Blog?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/workflow/chat.go | 10 ++++++---- src/workflow/exec.go | 3 +-- src/workflow/taskPool.go | 2 ++ 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/workflow/chat.go b/src/workflow/chat.go index 3b91fac..27ccfef 100644 --- a/src/workflow/chat.go +++ b/src/workflow/chat.go @@ -24,7 +24,7 @@ func (task *Task) chatWithStream(msg string, withThink bool) error { option := llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error { select { case <-task.ctx.Done(): - return nil + return errors.New("task cancelled") default: } @@ -46,11 +46,14 @@ func (task *Task) chatWithStream(msg string, withThink bool) error { buf.Write(chunk) // 避免client网络太慢 if len(ch) == 0 { - task.client.WriteJson(&message.ResponseExecSuccess{ + err := task.client.WriteJson(&message.ResponseExecSuccess{ RequestId : task.request_id, Msg : buf.String(), Stream_SeqId : &seqId, }, true) + if err != nil { + return err + } buf.Reset() seqId++ } @@ -62,13 +65,12 @@ func (task *Task) chatWithStream(msg string, withThink bool) error { return err } - task.client.WriteJson(&message.ResponseExecSuccess{ + return task.client.WriteJson(&message.ResponseExecSuccess{ RequestId : task.request_id, Msg : buf.String(), Stream_SeqId : &seqId, Stream_Finish: true, }, true) - return nil } func (task *Task) chat(msg string, stream bool) error { diff --git a/src/workflow/exec.go b/src/workflow/exec.go index ddde7a6..b984680 100644 --- a/src/workflow/exec.go +++ b/src/workflow/exec.go @@ -37,11 +37,10 @@ func (task *Task) docstring() error { // TODO extract doc part - task.client.WriteJson(&message.ResponseExecSuccess{ + return task.client.WriteJson(&message.ResponseExecSuccess{ RequestId: task.request_id, Msg : answer, }, false) - return nil } func (task *Task) fix() error { diff --git a/src/workflow/taskPool.go b/src/workflow/taskPool.go index a0caa35..72e6f3e 100644 --- a/src/workflow/taskPool.go +++ b/src/workflow/taskPool.go @@ -3,6 +3,7 @@ package workflow import ( "agent/src/common" "agent/src/llm" + "agent/src/utils/log" "context" "errors" "sync" @@ -81,6 +82,7 @@ func (t *TaskPool) loop() { case <-tick.C: // 清除pending太久的task t.queue.RemoveTimeout(PendingTimeOut) + log.Info("[TaskPool] stat", "queueLen", t.queue.Len()) default: }