2021-11-22 16:05:02 +00:00

81 lines
1.7 KiB
Go

package main
import (
"context"
"fmt"
"io"
"net/http"
"time"
"golang.org/x/time/rate"
"nhooyr.io/websocket"
)
// echoServer is the WebSocket echo server implementation.
// It ensures the client speaks the echo subprotocol and
// only allows one message every 100ms with a 10 message burst.
type echoServer struct {
// logf controls where logs are sent.
logf func(f string, v ...interface{})
}
func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
Subprotocols: []string{"echo"},
})
if err != nil {
s.logf("%v", err)
return
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
if c.Subprotocol() != "echo" {
c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol")
return
}
l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10)
for {
err = echo(r.Context(), c, l)
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
return
}
if err != nil {
s.logf("failed to echo with %v: %v", r.RemoteAddr, err)
return
}
}
}
// echo reads from the WebSocket connection and then writes
// the received message back to it.
// The entire function has 10s to complete.
func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
err := l.Wait(ctx)
if err != nil {
return err
}
typ, r, err := c.Reader(ctx)
if err != nil {
return err
}
w, err := c.Writer(ctx, typ)
if err != nil {
return err
}
_, err = io.Copy(w, r)
if err != nil {
return fmt.Errorf("failed to io.Copy: %w", err)
}
err = w.Close()
return err
}