183 lines
4.8 KiB
Go
183 lines
4.8 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
|
|
"nhooyr.io/websocket"
|
|
)
|
|
|
|
// chatServer enables broadcasting to a set of subscribers.
|
|
type chatServer struct {
|
|
// subscriberMessageBuffer controls the max number
|
|
// of messages that can be queued for a subscriber
|
|
// before it is kicked.
|
|
//
|
|
// Defaults to 16.
|
|
subscriberMessageBuffer int
|
|
|
|
// publishLimiter controls the rate limit applied to the publish endpoint.
|
|
//
|
|
// Defaults to one publish every 100ms with a burst of 8.
|
|
publishLimiter *rate.Limiter
|
|
|
|
// logf controls where logs are sent.
|
|
// Defaults to log.Printf.
|
|
logf func(f string, v ...interface{})
|
|
|
|
// serveMux routes the various endpoints to the appropriate handler.
|
|
serveMux http.ServeMux
|
|
|
|
subscribersMu sync.Mutex
|
|
subscribers map[*subscriber]struct{}
|
|
}
|
|
|
|
// newChatServer constructs a chatServer with the defaults.
|
|
func newChatServer() *chatServer {
|
|
cs := &chatServer{
|
|
subscriberMessageBuffer: 16,
|
|
logf: log.Printf,
|
|
subscribers: make(map[*subscriber]struct{}),
|
|
publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8),
|
|
}
|
|
cs.serveMux.Handle("/", http.FileServer(http.Dir(".")))
|
|
cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler)
|
|
cs.serveMux.HandleFunc("/publish", cs.publishHandler)
|
|
|
|
return cs
|
|
}
|
|
|
|
// subscriber represents a subscriber.
|
|
// Messages are sent on the msgs channel and if the client
|
|
// cannot keep up with the messages, closeSlow is called.
|
|
type subscriber struct {
|
|
msgs chan []byte
|
|
closeSlow func()
|
|
}
|
|
|
|
func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
cs.serveMux.ServeHTTP(w, r)
|
|
}
|
|
|
|
// subscribeHandler accepts the WebSocket connection and then subscribes
|
|
// it to all future messages.
|
|
func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
|
|
c, err := websocket.Accept(w, r, nil)
|
|
if err != nil {
|
|
cs.logf("%v", err)
|
|
return
|
|
}
|
|
defer c.Close(websocket.StatusInternalError, "")
|
|
|
|
err = cs.subscribe(r.Context(), c)
|
|
if errors.Is(err, context.Canceled) {
|
|
return
|
|
}
|
|
if websocket.CloseStatus(err) == websocket.StatusNormalClosure ||
|
|
websocket.CloseStatus(err) == websocket.StatusGoingAway {
|
|
return
|
|
}
|
|
if err != nil {
|
|
cs.logf("%v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
// publishHandler reads the request body with a limit of 8192 bytes and then publishes
|
|
// the received message.
|
|
func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
body := http.MaxBytesReader(w, r.Body, 8192)
|
|
msg, err := ioutil.ReadAll(body)
|
|
if err != nil {
|
|
http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge)
|
|
return
|
|
}
|
|
|
|
cs.publish(msg)
|
|
|
|
w.WriteHeader(http.StatusAccepted)
|
|
}
|
|
|
|
// subscribe subscribes the given WebSocket to all broadcast messages.
|
|
// It creates a subscriber with a buffered msgs chan to give some room to slower
|
|
// connections and then registers the subscriber. It then listens for all messages
|
|
// and writes them to the WebSocket. If the context is cancelled or
|
|
// an error occurs, it returns and deletes the subscription.
|
|
//
|
|
// It uses CloseRead to keep reading from the connection to process control
|
|
// messages and cancel the context if the connection drops.
|
|
func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
|
|
ctx = c.CloseRead(ctx)
|
|
|
|
s := &subscriber{
|
|
msgs: make(chan []byte, cs.subscriberMessageBuffer),
|
|
closeSlow: func() {
|
|
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
|
|
},
|
|
}
|
|
cs.addSubscriber(s)
|
|
defer cs.deleteSubscriber(s)
|
|
|
|
for {
|
|
select {
|
|
case msg := <-s.msgs:
|
|
err := writeTimeout(ctx, time.Second*5, c, msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
}
|
|
|
|
// publish publishes the msg to all subscribers.
|
|
// It never blocks and so messages to slow subscribers
|
|
// are dropped.
|
|
func (cs *chatServer) publish(msg []byte) {
|
|
cs.subscribersMu.Lock()
|
|
defer cs.subscribersMu.Unlock()
|
|
|
|
cs.publishLimiter.Wait(context.Background())
|
|
|
|
for s := range cs.subscribers {
|
|
select {
|
|
case s.msgs <- msg:
|
|
default:
|
|
go s.closeSlow()
|
|
}
|
|
}
|
|
}
|
|
|
|
// addSubscriber registers a subscriber.
|
|
func (cs *chatServer) addSubscriber(s *subscriber) {
|
|
cs.subscribersMu.Lock()
|
|
cs.subscribers[s] = struct{}{}
|
|
cs.subscribersMu.Unlock()
|
|
}
|
|
|
|
// deleteSubscriber deletes the given subscriber.
|
|
func (cs *chatServer) deleteSubscriber(s *subscriber) {
|
|
cs.subscribersMu.Lock()
|
|
delete(cs.subscribers, s)
|
|
cs.subscribersMu.Unlock()
|
|
}
|
|
|
|
func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg []byte) error {
|
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
|
defer cancel()
|
|
|
|
return c.Write(ctx, websocket.MessageText, msg)
|
|
}
|