websocket to sse初版

This commit is contained in:
xingchen 2025-04-27 15:16:22 +08:00
parent f0fb9f19d0
commit 67cfd763c3
2 changed files with 73 additions and 195 deletions

View File

@ -14,10 +14,11 @@ import (
const AgentWorkers = 4 const AgentWorkers = 4
type Agent struct { type Agent struct {
ctx context.Context ctx context.Context
apicheck *keyChecker apicheck *keyChecker
clients *common.ThreadSafeMap[*client, struct{}] clients *common.ThreadSafeMap[*client, struct{}]
} }
var once sync.Once var once sync.Once
func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) { func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) {
@ -29,23 +30,20 @@ func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) {
return nil, err return nil, err
} }
return &Agent{ return &Agent{
ctx : ctx, ctx: ctx,
clients: new(common.ThreadSafeMap[*client, struct{}]).Init(nil, false), clients: new(common.ThreadSafeMap[*client, struct{}]).Init(nil, false),
}, nil }, nil
} }
func (a *Agent) Start(port uint64, crtDir *string) { func (a *Agent) Start(port uint64, crtDir *string) {
workflow.InitTaskPool(a.ctx) workflow.InitTaskPool(a.ctx)
go func(){ go func() {
sm := http.NewServeMux() sm := http.NewServeMux()
// sm.HandleFunc("/simpletest2025", a.serveTestPage) sm.HandleFunc("/sse", a.serveSSE)
// sm.HandleFunc("/doc", a.serveDoc)
// sm.HandleFunc("/assets/", a.serveAssets)
sm.HandleFunc("/ws", a.serveWs)
addr := fmt.Sprintf(":%d", port) addr := fmt.Sprintf(":%d", port)
log.Info("[agent] start websocket server", "addr", addr) log.Info("[agent] start SSE server", "addr", addr)
var err error var err error
if crtDir == nil { if crtDir == nil {
@ -60,6 +58,33 @@ func (a *Agent) Start(port uint64, crtDir *string) {
}() }()
} }
func (a *Agent) serveSSE(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return
}
log.Info("[agent] new SSE connection", "remote", r.RemoteAddr)
// 设置SSE头
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
c := newClient(w, flusher, a)
// check api key
if a.apicheck.check(c, r) {
c.run()
a.clients.Set(c, struct{}{})
} else {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
}
func (a *Agent) serveTestPage(w http.ResponseWriter, r *http.Request) { func (a *Agent) serveTestPage(w http.ResponseWriter, r *http.Request) {
log.Info("serveHome", "url", r.URL) log.Info("serveHome", "url", r.URL)
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
@ -88,30 +113,6 @@ func (a *Agent) serveAssets(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "./"+r.URL.Path[1:]) http.ServeFile(w, r, "./"+r.URL.Path[1:])
} }
func (a *Agent) serveWs(w http.ResponseWriter, r *http.Request) {
// TODO IP连接数限制
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Error("[agent] serveWs", "err", err)
return
}
log.Info("[agent] new connection", "remote", conn.RemoteAddr())
c := newClient(conn, a)
// check api key
if a.apicheck.check(c, r) {
c.run()
a.clients.Set(c, struct{}{})
} else {
go func() {
c.SendText("HTTP/1.1 401 Unauthorized\r\n\r\n")
conn.Close()
}()
}
}
func (a *Agent) removeClient(c *client) { func (a *Agent) removeClient(c *client) {
a.clients.Delete(c) a.clients.Delete(c)
} }

View File

@ -4,141 +4,70 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "net/http"
"sync" "sync"
"time" "time"
"agent/src/message"
"agent/src/utils/log" "agent/src/utils/log"
"agent/src/workflow"
"github.com/gorilla/websocket"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // 允许所有来源(仅限开发环境)
},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
const ( const (
// Time allowed to write a message to the peer.
writeWait = 15 * time.Second writeWait = 15 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 3 * 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 10*1024*1024
) )
type client struct { type client struct {
ctx context.Context ctx context.Context
conn *websocket.Conn writer http.ResponseWriter
outMsg chan []byte flusher http.Flusher
outMsg chan []byte
authenticated bool authenticated bool
apikey string apikey string
agent *Agent agent *Agent
fnCancel context.CancelFunc fnCancel context.CancelFunc
rate *rate.Limiter rate *rate.Limiter
wantResponseId uint64 lock sync.Mutex
wantResponse chan []byte once sync.Once
lock sync.Mutex
once sync.Once
} }
func newClient(conn *websocket.Conn, a *Agent) *client { func newClient(w http.ResponseWriter, f http.Flusher, a *Agent) *client {
ctx, fnCancel := context.WithCancel(a.ctx) ctx, fnCancel := context.WithCancel(a.ctx)
return &client{ return &client{
ctx : ctx, ctx: ctx,
conn : conn, writer: w,
outMsg: make(chan []byte, 32), flusher: f,
agent : a, outMsg: make(chan []byte, 32),
agent: a,
fnCancel: fnCancel, fnCancel: fnCancel,
rate: rate.NewLimiter(5, 1), // 每秒5次限制 rate: rate.NewLimiter(5, 1), // 每秒5次限制
} }
} }
func (c *client) setApiKey(apikey string) {
func (c *client) setApiKey(apikey string) {
c.apikey = apikey c.apikey = apikey
c.authenticated = true c.authenticated = true
} }
func (c *client) run() { func (c *client) run() {
go c.readPump()
go c.writePump() go c.writePump()
} }
func (c *client) readPump() {
defer c.Close()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
done := c.ctx.Done()
for {
select {
case <-done:
return
default:
}
_, msg, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Error(fmt.Sprintf("error: %v", err))
}
break
}
if c.rate.Allow() {
workflow.Handle(c, msg)
} else {
c.WriteJson(&message.ReponseError{
Error: "maximum 5 requests per second",
}, true)
}
}
}
func (c *client) writePump() { func (c *client) writePump() {
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
defer c.Close() defer c.Close()
done := c.ctx.Done() done := c.ctx.Done()
for { for {
select { select {
case message, ok := <-c.outMsg: case message, ok := <-c.outMsg:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok { if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return return
} }
c.writer.Write([]byte("data: "))
c.writer.Write(message)
c.writer.Write([]byte("\n\n"))
c.flusher.Flush()
case <-done: case <-done:
return return
} }
@ -146,74 +75,40 @@ func (c *client) writePump() {
} }
func (c *client) Close() { func (c *client) Close() {
once.Do(func(){ c.once.Do(func() {
c.fnCancel() c.fnCancel()
c.agent.removeClient(c) c.agent.removeClient(c)
c.conn.Close() log.Info("client close", "connected", c.agent.clients.Size())
log.Info("client close", "remote", c.conn.RemoteAddr().String(), "connected", c.agent.clients.Size())
}) })
} }
func (c *client) WriteJson(data interface{}, block bool) error { func (c *client) WriteJson(data interface{}, block bool) error {
b, err := json.MarshalIndent(data, "", "") b, err := json.Marshal(data)
if err != nil { if err != nil {
return err return err
} }
if block { return c.WriteText(string(b), block)
select{
case <-c.ctx.Done():
return errors.New("connection closed")
case c.outMsg <- b:
return nil
}
} else {
select {
case <-c.ctx.Done():
return errors.New("connection closed")
case c.outMsg <- b:
return nil
default:
go func(){
select{
case <-c.ctx.Done():
case c.outMsg <- b:
}
}()
}
}
return nil
} }
func (c *client) WriteText(msg string, block bool) error { func (c *client) WriteText(msg string, block bool) error {
if block { if block {
select { select {
case <-c.ctx.Done(): case <-c.ctx.Done():
return errors.New("connection closed") return errors.New("connection closed")
case c.outMsg <- []byte(msg): case c.outMsg <- []byte(msg):
return nil return nil
} }
} else { } else {
select { select {
case c.outMsg <- []byte(msg): case c.outMsg <- []byte(msg):
default: default:
} }
} }
return nil return nil
} }
// send string, block util flush
func (c *client) SendText(msg string) error { func (c *client) SendText(msg string) error {
c.conn.SetWriteDeadline(time.Now().Add(writeWait)) return c.WriteText(msg, true)
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return err
}
_, err = w.Write([]byte(msg))
if err != nil {
return err
}
err = w.Close()
return err
} }
func (c *client) GetCtx() context.Context { func (c *client) GetCtx() context.Context {
@ -223,21 +118,3 @@ func (c *client) GetCtx() context.Context {
func (c *client) GetChannel() chan<- []byte { func (c *client) GetChannel() chan<- []byte {
return c.outMsg return c.outMsg
} }
func (c *client) WantResponse(requestId uint64, ch chan []byte) {
c.lock.Lock()
c.wantResponseId, c.wantResponse = requestId, ch
c.lock.Unlock()
close(c.wantResponse)
}
func (c *client) NewWantResponse(requestId uint64, msg []byte) error {
c.lock.Lock()
defer c.lock.Unlock()
if requestId != c.wantResponseId || c.wantResponse == nil {
return errors.New("unexpected msg")
}
c.wantResponse <- msg
return nil
}