91 lines
1.7 KiB
Go
91 lines
1.7 KiB
Go
package wstest
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
|
|
"nhooyr.io/websocket"
|
|
"nhooyr.io/websocket/internal/test/assert"
|
|
"nhooyr.io/websocket/internal/test/xrand"
|
|
"nhooyr.io/websocket/internal/xsync"
|
|
)
|
|
|
|
// EchoLoop echos every msg received from c until an error
|
|
// occurs or the context expires.
|
|
// The read limit is set to 1 << 30.
|
|
func EchoLoop(ctx context.Context, c *websocket.Conn) error {
|
|
defer c.Close(websocket.StatusInternalError, "")
|
|
|
|
c.SetReadLimit(1 << 30)
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
|
defer cancel()
|
|
|
|
b := make([]byte, 32<<10)
|
|
for {
|
|
typ, r, err := c.Reader(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
w, err := c.Writer(ctx, typ)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = io.CopyBuffer(w, r, b)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = w.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// Echo writes a message and ensures the same is sent back on c.
|
|
func Echo(ctx context.Context, c *websocket.Conn, max int) error {
|
|
expType := websocket.MessageBinary
|
|
if xrand.Bool() {
|
|
expType = websocket.MessageText
|
|
}
|
|
|
|
msg := randMessage(expType, xrand.Int(max))
|
|
|
|
writeErr := xsync.Go(func() error {
|
|
return c.Write(ctx, expType, msg)
|
|
})
|
|
|
|
actType, act, err := c.Read(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = <-writeErr
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if expType != actType {
|
|
return fmt.Errorf("unexpected message typ (%v): %v", expType, actType)
|
|
}
|
|
|
|
if !bytes.Equal(msg, act) {
|
|
return fmt.Errorf("unexpected msg read: %v", assert.Diff(msg, act))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func randMessage(typ websocket.MessageType, n int) []byte {
|
|
if typ == websocket.MessageBinary {
|
|
return xrand.Bytes(n)
|
|
}
|
|
return []byte(xrand.String(n))
|
|
}
|