2021-11-22 16:05:02 +00:00

91 lines
1.7 KiB
Go

package wstest
import (
"bytes"
"context"
"fmt"
"io"
"time"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/xrand"
"nhooyr.io/websocket/internal/xsync"
)
// EchoLoop echos every msg received from c until an error
// occurs or the context expires.
// The read limit is set to 1 << 30.
func EchoLoop(ctx context.Context, c *websocket.Conn) error {
defer c.Close(websocket.StatusInternalError, "")
c.SetReadLimit(1 << 30)
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
b := make([]byte, 32<<10)
for {
typ, r, err := c.Reader(ctx)
if err != nil {
return err
}
w, err := c.Writer(ctx, typ)
if err != nil {
return err
}
_, err = io.CopyBuffer(w, r, b)
if err != nil {
return err
}
err = w.Close()
if err != nil {
return err
}
}
}
// Echo writes a message and ensures the same is sent back on c.
func Echo(ctx context.Context, c *websocket.Conn, max int) error {
expType := websocket.MessageBinary
if xrand.Bool() {
expType = websocket.MessageText
}
msg := randMessage(expType, xrand.Int(max))
writeErr := xsync.Go(func() error {
return c.Write(ctx, expType, msg)
})
actType, act, err := c.Read(ctx)
if err != nil {
return err
}
err = <-writeErr
if err != nil {
return err
}
if expType != actType {
return fmt.Errorf("unexpected message typ (%v): %v", expType, actType)
}
if !bytes.Equal(msg, act) {
return fmt.Errorf("unexpected msg read: %v", assert.Diff(msg, act))
}
return nil
}
func randMessage(typ websocket.MessageType, n int) []byte {
if typ == websocket.MessageBinary {
return xrand.Bytes(n)
}
return []byte(xrand.String(n))
}