From 882c4af9f4b96d1f2433c245580021bc17bf7157 Mon Sep 17 00:00:00 2001 From: xingchen Date: Sun, 27 Apr 2025 16:17:36 +0800 Subject: [PATCH] =?UTF-8?q?SSE=E8=AF=B7=E6=B1=82=E6=96=B9=E5=BC=8F?= =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agent/agent.go | 35 +++++++++++++++++++++++++---------- src/agent/client.go | 13 +++++++++---- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/agent/agent.go b/src/agent/agent.go index 9c555d8..1e7b609 100644 --- a/src/agent/agent.go +++ b/src/agent/agent.go @@ -7,6 +7,7 @@ import ( "agent/src/workflow" "context" "fmt" + "io" "net/http" "sync" ) @@ -61,28 +62,36 @@ func (a *Agent) Start(port uint64, crtDir *string) { func (a *Agent) serveSSE(w http.ResponseWriter, r *http.Request) { flusher, ok := w.(http.Flusher) if !ok { - http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) return } - log.Info("[agent] new SSE connection", "remote", r.RemoteAddr) + // 创建新 client + client := newClient(w, flusher, a) + a.addClient(client) - // 设置SSE头 + // 设置 SSE 必备的响应头 w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") - c := newClient(w, flusher, a) + // 启动写消息循环 + client.run() - // check api key - if a.apicheck.check(c, r) { - c.run() - a.clients.Set(c, struct{}{}) - } else { - http.Error(w, "Unauthorized", http.StatusUnauthorized) + // 读取请求体 + msg, err := io.ReadAll(r.Body) + if err != nil { + client.SendText("Invalid request body") return } + defer r.Body.Close() + + // 处理请求 + client.processSSERequest(msg) + + // 阻塞直到 client 关闭 + <-client.GetCtx().Done() } func (a *Agent) serveTestPage(w http.ResponseWriter, r *http.Request) { @@ -113,6 +122,12 @@ func (a *Agent) serveAssets(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, "./"+r.URL.Path[1:]) } +// 给Agent添加addClient方法 +func (a *Agent) addClient(c *client) { + a.clients.Set(c, struct{}{}) // 放进去 + log.Info("client added", "connected", a.clients.Size()) +} + func (a *Agent) removeClient(c *client) { a.clients.Delete(c) } diff --git a/src/agent/client.go b/src/agent/client.go index fc483b9..6f02d07 100644 --- a/src/agent/client.go +++ b/src/agent/client.go @@ -158,10 +158,15 @@ func (c *client) processSSERequest(msg []byte) { // 设置希望接收响应的请求ID及对应的通道 func (c *client) WantResponse(requestId uint64, ch chan []byte) { c.lock.Lock() - c.wantResponseId, c.wantResponse = requestId, ch - c.lock.Unlock() - // 确保在设置 wantResponse 之后关闭旧的通道 - close(c.wantResponse) + defer c.lock.Unlock() + + // 关闭之前的通道(如果有的话) + if c.wantResponse != nil { + close(c.wantResponse) + } + + c.wantResponseId = requestId + c.wantResponse = ch } // 根据请求ID发送响应消息