package main import ( "context" "errors" "io/ioutil" "log" "net/http" "sync" "time" "golang.org/x/time/rate" "nhooyr.io/websocket" ) // chatServer enables broadcasting to a set of subscribers. type chatServer struct { // subscriberMessageBuffer controls the max number // of messages that can be queued for a subscriber // before it is kicked. // // Defaults to 16. subscriberMessageBuffer int // publishLimiter controls the rate limit applied to the publish endpoint. // // Defaults to one publish every 100ms with a burst of 8. publishLimiter *rate.Limiter // logf controls where logs are sent. // Defaults to log.Printf. logf func(f string, v ...interface{}) // serveMux routes the various endpoints to the appropriate handler. serveMux http.ServeMux subscribersMu sync.Mutex subscribers map[*subscriber]struct{} } // newChatServer constructs a chatServer with the defaults. func newChatServer() *chatServer { cs := &chatServer{ subscriberMessageBuffer: 16, logf: log.Printf, subscribers: make(map[*subscriber]struct{}), publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8), } cs.serveMux.Handle("/", http.FileServer(http.Dir("."))) cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler) cs.serveMux.HandleFunc("/publish", cs.publishHandler) return cs } // subscriber represents a subscriber. // Messages are sent on the msgs channel and if the client // cannot keep up with the messages, closeSlow is called. type subscriber struct { msgs chan []byte closeSlow func() } func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { cs.serveMux.ServeHTTP(w, r) } // subscribeHandler accepts the WebSocket connection and then subscribes // it to all future messages. func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { cs.logf("%v", err) return } defer c.Close(websocket.StatusInternalError, "") err = cs.subscribe(r.Context(), c) if errors.Is(err, context.Canceled) { return } if websocket.CloseStatus(err) == websocket.StatusNormalClosure || websocket.CloseStatus(err) == websocket.StatusGoingAway { return } if err != nil { cs.logf("%v", err) return } } // publishHandler reads the request body with a limit of 8192 bytes and then publishes // the received message. func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return } body := http.MaxBytesReader(w, r.Body, 8192) msg, err := ioutil.ReadAll(body) if err != nil { http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) return } cs.publish(msg) w.WriteHeader(http.StatusAccepted) } // subscribe subscribes the given WebSocket to all broadcast messages. // It creates a subscriber with a buffered msgs chan to give some room to slower // connections and then registers the subscriber. It then listens for all messages // and writes them to the WebSocket. If the context is cancelled or // an error occurs, it returns and deletes the subscription. // // It uses CloseRead to keep reading from the connection to process control // messages and cancel the context if the connection drops. func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { ctx = c.CloseRead(ctx) s := &subscriber{ msgs: make(chan []byte, cs.subscriberMessageBuffer), closeSlow: func() { c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") }, } cs.addSubscriber(s) defer cs.deleteSubscriber(s) for { select { case msg := <-s.msgs: err := writeTimeout(ctx, time.Second*5, c, msg) if err != nil { return err } case <-ctx.Done(): return ctx.Err() } } } // publish publishes the msg to all subscribers. // It never blocks and so messages to slow subscribers // are dropped. func (cs *chatServer) publish(msg []byte) { cs.subscribersMu.Lock() defer cs.subscribersMu.Unlock() cs.publishLimiter.Wait(context.Background()) for s := range cs.subscribers { select { case s.msgs <- msg: default: go s.closeSlow() } } } // addSubscriber registers a subscriber. func (cs *chatServer) addSubscriber(s *subscriber) { cs.subscribersMu.Lock() cs.subscribers[s] = struct{}{} cs.subscribersMu.Unlock() } // deleteSubscriber deletes the given subscriber. func (cs *chatServer) deleteSubscriber(s *subscriber) { cs.subscribersMu.Lock() delete(cs.subscribers, s) cs.subscribersMu.Unlock() } func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg []byte) error { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() return c.Write(ctx, websocket.MessageText, msg) }