From 67cfd763c3c80725831df84fe86b7f2dffb9ba8e Mon Sep 17 00:00:00 2001 From: xingchen Date: Sun, 27 Apr 2025 15:16:22 +0800 Subject: [PATCH] =?UTF-8?q?websocket=20to=20=20sse=E5=88=9D=E7=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agent/agent.go | 71 ++++++++-------- src/agent/client.go | 197 +++++++++----------------------------------- 2 files changed, 73 insertions(+), 195 deletions(-) diff --git a/src/agent/agent.go b/src/agent/agent.go index 00ec4a7..9c555d8 100644 --- a/src/agent/agent.go +++ b/src/agent/agent.go @@ -14,10 +14,11 @@ import ( const AgentWorkers = 4 type Agent struct { - ctx context.Context - apicheck *keyChecker - clients *common.ThreadSafeMap[*client, struct{}] + ctx context.Context + apicheck *keyChecker + clients *common.ThreadSafeMap[*client, struct{}] } + var once sync.Once func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) { @@ -29,23 +30,20 @@ func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) { return nil, err } return &Agent{ - ctx : ctx, + ctx: ctx, clients: new(common.ThreadSafeMap[*client, struct{}]).Init(nil, false), }, nil } func (a *Agent) Start(port uint64, crtDir *string) { workflow.InitTaskPool(a.ctx) - - go func(){ + + go func() { sm := http.NewServeMux() - // sm.HandleFunc("/simpletest2025", a.serveTestPage) - // sm.HandleFunc("/doc", a.serveDoc) - // sm.HandleFunc("/assets/", a.serveAssets) - sm.HandleFunc("/ws", a.serveWs) + sm.HandleFunc("/sse", a.serveSSE) addr := fmt.Sprintf(":%d", port) - log.Info("[agent] start websocket server", "addr", addr) + log.Info("[agent] start SSE server", "addr", addr) var err error if crtDir == nil { @@ -60,6 +58,33 @@ func (a *Agent) Start(port uint64, crtDir *string) { }() } +func (a *Agent) serveSSE(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + log.Info("[agent] new SSE connection", "remote", r.RemoteAddr) + + // 设置SSE头 + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + c := newClient(w, flusher, a) + + // check api key + if a.apicheck.check(c, r) { + c.run() + a.clients.Set(c, struct{}{}) + } else { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } +} + func (a *Agent) serveTestPage(w http.ResponseWriter, r *http.Request) { log.Info("serveHome", "url", r.URL) if r.Method != http.MethodGet { @@ -88,30 +113,6 @@ func (a *Agent) serveAssets(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, "./"+r.URL.Path[1:]) } -func (a *Agent) serveWs(w http.ResponseWriter, r *http.Request) { - // TODO IP连接数限制 - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Error("[agent] serveWs", "err", err) - return - } - log.Info("[agent] new connection", "remote", conn.RemoteAddr()) - c := newClient(conn, a) - - // check api key - if a.apicheck.check(c, r) { - c.run() - a.clients.Set(c, struct{}{}) - } else { - go func() { - c.SendText("HTTP/1.1 401 Unauthorized\r\n\r\n") - conn.Close() - }() - } -} - func (a *Agent) removeClient(c *client) { a.clients.Delete(c) } - - diff --git a/src/agent/client.go b/src/agent/client.go index 866f3d2..ce338eb 100644 --- a/src/agent/client.go +++ b/src/agent/client.go @@ -4,141 +4,70 @@ import ( "context" "encoding/json" "errors" - "fmt" + "net/http" "sync" "time" - "agent/src/message" "agent/src/utils/log" - "agent/src/workflow" - "github.com/gorilla/websocket" "golang.org/x/time/rate" ) -var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true // 允许所有来源(仅限开发环境) - }, - ReadBufferSize: 1024, - WriteBufferSize: 1024, -} - - const ( - // Time allowed to write a message to the peer. writeWait = 15 * time.Second - - // Time allowed to read the next pong message from the peer. - pongWait = 3 * 60 * time.Second - - // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = (pongWait * 9) / 10 - - // Maximum message size allowed from peer. - maxMessageSize = 10*1024*1024 ) type client struct { - ctx context.Context - conn *websocket.Conn - outMsg chan []byte + ctx context.Context + writer http.ResponseWriter + flusher http.Flusher + outMsg chan []byte authenticated bool - apikey string - agent *Agent - fnCancel context.CancelFunc - rate *rate.Limiter - - wantResponseId uint64 - wantResponse chan []byte - lock sync.Mutex - once sync.Once + apikey string + agent *Agent + fnCancel context.CancelFunc + rate *rate.Limiter + + lock sync.Mutex + once sync.Once } -func newClient(conn *websocket.Conn, a *Agent) *client { +func newClient(w http.ResponseWriter, f http.Flusher, a *Agent) *client { ctx, fnCancel := context.WithCancel(a.ctx) return &client{ - ctx : ctx, - conn : conn, - outMsg: make(chan []byte, 32), - agent : a, + ctx: ctx, + writer: w, + flusher: f, + outMsg: make(chan []byte, 32), + agent: a, fnCancel: fnCancel, - rate: rate.NewLimiter(5, 1), // 每秒5次限制 + rate: rate.NewLimiter(5, 1), // 每秒5次限制 } } -func (c *client) setApiKey(apikey string) { + +func (c *client) setApiKey(apikey string) { c.apikey = apikey c.authenticated = true } func (c *client) run() { - go c.readPump() go c.writePump() } -func (c *client) readPump() { - defer c.Close() - - c.conn.SetReadLimit(maxMessageSize) - c.conn.SetReadDeadline(time.Now().Add(pongWait)) - c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) - done := c.ctx.Done() - for { - select { - case <-done: - return - default: - } - - _, msg, err := c.conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Error(fmt.Sprintf("error: %v", err)) - } - break - } - - if c.rate.Allow() { - workflow.Handle(c, msg) - } else { - c.WriteJson(&message.ReponseError{ - Error: "maximum 5 requests per second", - }, true) - } - - } -} - - func (c *client) writePump() { - ticker := time.NewTicker(pingPeriod) - defer ticker.Stop() defer c.Close() done := c.ctx.Done() for { - select { + select { case message, ok := <-c.outMsg: - c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { - c.conn.WriteMessage(websocket.CloseMessage, []byte{}) - return - } - - w, err := c.conn.NextWriter(websocket.TextMessage) - if err != nil { - return - } - w.Write(message) - - if err := w.Close(); err != nil { - return - } - case <-ticker.C: - c.conn.SetWriteDeadline(time.Now().Add(writeWait)) - if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } + c.writer.Write([]byte("data: ")) + c.writer.Write(message) + c.writer.Write([]byte("\n\n")) + c.flusher.Flush() case <-done: return } @@ -146,74 +75,40 @@ func (c *client) writePump() { } func (c *client) Close() { - once.Do(func(){ + c.once.Do(func() { c.fnCancel() c.agent.removeClient(c) - c.conn.Close() - log.Info("client close", "remote", c.conn.RemoteAddr().String(), "connected", c.agent.clients.Size()) + log.Info("client close", "connected", c.agent.clients.Size()) }) } func (c *client) WriteJson(data interface{}, block bool) error { - b, err := json.MarshalIndent(data, "", "") + b, err := json.Marshal(data) if err != nil { return err } - if block { - select{ - case <-c.ctx.Done(): - return errors.New("connection closed") - case c.outMsg <- b: - return nil - } - } else { - select { - case <-c.ctx.Done(): - return errors.New("connection closed") - case c.outMsg <- b: - return nil - default: - go func(){ - select{ - case <-c.ctx.Done(): - case c.outMsg <- b: - } - }() - } - } - return nil + return c.WriteText(string(b), block) } func (c *client) WriteText(msg string, block bool) error { if block { select { - case <-c.ctx.Done(): - return errors.New("connection closed") - case c.outMsg <- []byte(msg): - return nil + case <-c.ctx.Done(): + return errors.New("connection closed") + case c.outMsg <- []byte(msg): + return nil } } else { select { - case c.outMsg <- []byte(msg): - default: + case c.outMsg <- []byte(msg): + default: } } return nil } -// send string, block util flush func (c *client) SendText(msg string) error { - c.conn.SetWriteDeadline(time.Now().Add(writeWait)) - w, err := c.conn.NextWriter(websocket.TextMessage) - if err != nil { - return err - } - _, err = w.Write([]byte(msg)) - if err != nil { - return err - } - err = w.Close() - return err + return c.WriteText(msg, true) } func (c *client) GetCtx() context.Context { @@ -223,21 +118,3 @@ func (c *client) GetCtx() context.Context { func (c *client) GetChannel() chan<- []byte { return c.outMsg } - - -func (c *client) WantResponse(requestId uint64, ch chan []byte) { - c.lock.Lock() - c.wantResponseId, c.wantResponse = requestId, ch - c.lock.Unlock() - close(c.wantResponse) -} - -func (c *client) NewWantResponse(requestId uint64, msg []byte) error { - c.lock.Lock() - defer c.lock.Unlock() - if requestId != c.wantResponseId || c.wantResponse == nil { - return errors.New("unexpected msg") - } - c.wantResponse <- msg - return nil -} \ No newline at end of file