websocket to sse初版
This commit is contained in:
parent
f0fb9f19d0
commit
67cfd763c3
@ -18,6 +18,7 @@ type Agent struct {
|
|||||||
apicheck *keyChecker
|
apicheck *keyChecker
|
||||||
clients *common.ThreadSafeMap[*client, struct{}]
|
clients *common.ThreadSafeMap[*client, struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
var once sync.Once
|
var once sync.Once
|
||||||
|
|
||||||
func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) {
|
func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) {
|
||||||
@ -29,7 +30,7 @@ func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Agent{
|
return &Agent{
|
||||||
ctx : ctx,
|
ctx: ctx,
|
||||||
clients: new(common.ThreadSafeMap[*client, struct{}]).Init(nil, false),
|
clients: new(common.ThreadSafeMap[*client, struct{}]).Init(nil, false),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@ -37,15 +38,12 @@ func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) {
|
|||||||
func (a *Agent) Start(port uint64, crtDir *string) {
|
func (a *Agent) Start(port uint64, crtDir *string) {
|
||||||
workflow.InitTaskPool(a.ctx)
|
workflow.InitTaskPool(a.ctx)
|
||||||
|
|
||||||
go func(){
|
go func() {
|
||||||
sm := http.NewServeMux()
|
sm := http.NewServeMux()
|
||||||
// sm.HandleFunc("/simpletest2025", a.serveTestPage)
|
sm.HandleFunc("/sse", a.serveSSE)
|
||||||
// sm.HandleFunc("/doc", a.serveDoc)
|
|
||||||
// sm.HandleFunc("/assets/", a.serveAssets)
|
|
||||||
sm.HandleFunc("/ws", a.serveWs)
|
|
||||||
|
|
||||||
addr := fmt.Sprintf(":%d", port)
|
addr := fmt.Sprintf(":%d", port)
|
||||||
log.Info("[agent] start websocket server", "addr", addr)
|
log.Info("[agent] start SSE server", "addr", addr)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if crtDir == nil {
|
if crtDir == nil {
|
||||||
@ -60,6 +58,33 @@ func (a *Agent) Start(port uint64, crtDir *string) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Agent) serveSSE(w http.ResponseWriter, r *http.Request) {
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("[agent] new SSE connection", "remote", r.RemoteAddr)
|
||||||
|
|
||||||
|
// 设置SSE头
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
|
|
||||||
|
c := newClient(w, flusher, a)
|
||||||
|
|
||||||
|
// check api key
|
||||||
|
if a.apicheck.check(c, r) {
|
||||||
|
c.run()
|
||||||
|
a.clients.Set(c, struct{}{})
|
||||||
|
} else {
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Agent) serveTestPage(w http.ResponseWriter, r *http.Request) {
|
func (a *Agent) serveTestPage(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Info("serveHome", "url", r.URL)
|
log.Info("serveHome", "url", r.URL)
|
||||||
if r.Method != http.MethodGet {
|
if r.Method != http.MethodGet {
|
||||||
@ -88,30 +113,6 @@ func (a *Agent) serveAssets(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.ServeFile(w, r, "./"+r.URL.Path[1:])
|
http.ServeFile(w, r, "./"+r.URL.Path[1:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) serveWs(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// TODO IP连接数限制
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("[agent] serveWs", "err", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Info("[agent] new connection", "remote", conn.RemoteAddr())
|
|
||||||
c := newClient(conn, a)
|
|
||||||
|
|
||||||
// check api key
|
|
||||||
if a.apicheck.check(c, r) {
|
|
||||||
c.run()
|
|
||||||
a.clients.Set(c, struct{}{})
|
|
||||||
} else {
|
|
||||||
go func() {
|
|
||||||
c.SendText("HTTP/1.1 401 Unauthorized\r\n\r\n")
|
|
||||||
conn.Close()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Agent) removeClient(c *client) {
|
func (a *Agent) removeClient(c *client) {
|
||||||
a.clients.Delete(c)
|
a.clients.Delete(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,44 +4,23 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"agent/src/message"
|
|
||||||
"agent/src/utils/log"
|
"agent/src/utils/log"
|
||||||
"agent/src/workflow"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{
|
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
|
||||||
return true // 允许所有来源(仅限开发环境)
|
|
||||||
},
|
|
||||||
ReadBufferSize: 1024,
|
|
||||||
WriteBufferSize: 1024,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Time allowed to write a message to the peer.
|
|
||||||
writeWait = 15 * time.Second
|
writeWait = 15 * time.Second
|
||||||
|
|
||||||
// Time allowed to read the next pong message from the peer.
|
|
||||||
pongWait = 3 * 60 * time.Second
|
|
||||||
|
|
||||||
// Send pings to peer with this period. Must be less than pongWait.
|
|
||||||
pingPeriod = (pongWait * 9) / 10
|
|
||||||
|
|
||||||
// Maximum message size allowed from peer.
|
|
||||||
maxMessageSize = 10*1024*1024
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type client struct {
|
type client struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
conn *websocket.Conn
|
writer http.ResponseWriter
|
||||||
|
flusher http.Flusher
|
||||||
outMsg chan []byte
|
outMsg chan []byte
|
||||||
authenticated bool
|
authenticated bool
|
||||||
apikey string
|
apikey string
|
||||||
@ -49,96 +28,46 @@ type client struct {
|
|||||||
fnCancel context.CancelFunc
|
fnCancel context.CancelFunc
|
||||||
rate *rate.Limiter
|
rate *rate.Limiter
|
||||||
|
|
||||||
wantResponseId uint64
|
|
||||||
wantResponse chan []byte
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
once sync.Once
|
once sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClient(conn *websocket.Conn, a *Agent) *client {
|
func newClient(w http.ResponseWriter, f http.Flusher, a *Agent) *client {
|
||||||
ctx, fnCancel := context.WithCancel(a.ctx)
|
ctx, fnCancel := context.WithCancel(a.ctx)
|
||||||
return &client{
|
return &client{
|
||||||
ctx : ctx,
|
ctx: ctx,
|
||||||
conn : conn,
|
writer: w,
|
||||||
|
flusher: f,
|
||||||
outMsg: make(chan []byte, 32),
|
outMsg: make(chan []byte, 32),
|
||||||
agent : a,
|
agent: a,
|
||||||
fnCancel: fnCancel,
|
fnCancel: fnCancel,
|
||||||
rate: rate.NewLimiter(5, 1), // 每秒5次限制
|
rate: rate.NewLimiter(5, 1), // 每秒5次限制
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) setApiKey(apikey string) {
|
func (c *client) setApiKey(apikey string) {
|
||||||
c.apikey = apikey
|
c.apikey = apikey
|
||||||
c.authenticated = true
|
c.authenticated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) run() {
|
func (c *client) run() {
|
||||||
go c.readPump()
|
|
||||||
go c.writePump()
|
go c.writePump()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) readPump() {
|
|
||||||
defer c.Close()
|
|
||||||
|
|
||||||
c.conn.SetReadLimit(maxMessageSize)
|
|
||||||
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
||||||
c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
|
|
||||||
done := c.ctx.Done()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
_, msg, err := c.conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
||||||
log.Error(fmt.Sprintf("error: %v", err))
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.rate.Allow() {
|
|
||||||
workflow.Handle(c, msg)
|
|
||||||
} else {
|
|
||||||
c.WriteJson(&message.ReponseError{
|
|
||||||
Error: "maximum 5 requests per second",
|
|
||||||
}, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
func (c *client) writePump() {
|
func (c *client) writePump() {
|
||||||
ticker := time.NewTicker(pingPeriod)
|
|
||||||
defer ticker.Stop()
|
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
done := c.ctx.Done()
|
done := c.ctx.Done()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case message, ok := <-c.outMsg:
|
case message, ok := <-c.outMsg:
|
||||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
||||||
if !ok {
|
if !ok {
|
||||||
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w, err := c.conn.NextWriter(websocket.TextMessage)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Write(message)
|
|
||||||
|
|
||||||
if err := w.Close(); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-ticker.C:
|
|
||||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
||||||
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.writer.Write([]byte("data: "))
|
||||||
|
c.writer.Write(message)
|
||||||
|
c.writer.Write([]byte("\n\n"))
|
||||||
|
c.flusher.Flush()
|
||||||
case <-done:
|
case <-done:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -146,42 +75,19 @@ func (c *client) writePump() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) Close() {
|
func (c *client) Close() {
|
||||||
once.Do(func(){
|
c.once.Do(func() {
|
||||||
c.fnCancel()
|
c.fnCancel()
|
||||||
c.agent.removeClient(c)
|
c.agent.removeClient(c)
|
||||||
c.conn.Close()
|
log.Info("client close", "connected", c.agent.clients.Size())
|
||||||
log.Info("client close", "remote", c.conn.RemoteAddr().String(), "connected", c.agent.clients.Size())
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) WriteJson(data interface{}, block bool) error {
|
func (c *client) WriteJson(data interface{}, block bool) error {
|
||||||
b, err := json.MarshalIndent(data, "", "")
|
b, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if block {
|
return c.WriteText(string(b), block)
|
||||||
select{
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return errors.New("connection closed")
|
|
||||||
case c.outMsg <- b:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
select {
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return errors.New("connection closed")
|
|
||||||
case c.outMsg <- b:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
go func(){
|
|
||||||
select{
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
case c.outMsg <- b:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) WriteText(msg string, block bool) error {
|
func (c *client) WriteText(msg string, block bool) error {
|
||||||
@ -201,19 +107,8 @@ func (c *client) WriteText(msg string, block bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// send string, block util flush
|
|
||||||
func (c *client) SendText(msg string) error {
|
func (c *client) SendText(msg string) error {
|
||||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
return c.WriteText(msg, true)
|
||||||
w, err := c.conn.NextWriter(websocket.TextMessage)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = w.Write([]byte(msg))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = w.Close()
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) GetCtx() context.Context {
|
func (c *client) GetCtx() context.Context {
|
||||||
@ -223,21 +118,3 @@ func (c *client) GetCtx() context.Context {
|
|||||||
func (c *client) GetChannel() chan<- []byte {
|
func (c *client) GetChannel() chan<- []byte {
|
||||||
return c.outMsg
|
return c.outMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (c *client) WantResponse(requestId uint64, ch chan []byte) {
|
|
||||||
c.lock.Lock()
|
|
||||||
c.wantResponseId, c.wantResponse = requestId, ch
|
|
||||||
c.lock.Unlock()
|
|
||||||
close(c.wantResponse)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) NewWantResponse(requestId uint64, msg []byte) error {
|
|
||||||
c.lock.Lock()
|
|
||||||
defer c.lock.Unlock()
|
|
||||||
if requestId != c.wantResponseId || c.wantResponse == nil {
|
|
||||||
return errors.New("unexpected msg")
|
|
||||||
}
|
|
||||||
c.wantResponse <- msg
|
|
||||||
return nil
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user