2021-12-04 16:42:11 +00:00

278 lines
5.8 KiB
Go

package main
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"golang.org/x/time/rate"
"nhooyr.io/websocket"
)
func Test_chatServer(t *testing.T) {
t.Parallel()
// This is a simple echo test with a single client.
// The client sends a message and ensures it receives
// it on its WebSocket.
t.Run("simple", func(t *testing.T) {
t.Parallel()
url, closeFn := setupTest(t)
defer closeFn()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
cl, err := newClient(ctx, url)
assertSuccess(t, err)
defer cl.Close()
expMsg := randString(512)
err = cl.publish(ctx, expMsg)
assertSuccess(t, err)
msg, err := cl.nextMessage()
assertSuccess(t, err)
if expMsg != msg {
t.Fatalf("expected %v but got %v", expMsg, msg)
}
})
// This test is a complex concurrency test.
// 10 clients are started that send 128 different
// messages of max 128 bytes concurrently.
//
// The test verifies that every message is seen by ever client
// and no errors occur anywhere.
t.Run("concurrency", func(t *testing.T) {
t.Parallel()
const nmessages = 128
const maxMessageSize = 128
const nclients = 16
url, closeFn := setupTest(t)
defer closeFn()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var clients []*client
var clientMsgs []map[string]struct{}
for i := 0; i < nclients; i++ {
cl, err := newClient(ctx, url)
assertSuccess(t, err)
defer cl.Close()
clients = append(clients, cl)
clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize))
}
allMessages := make(map[string]struct{})
for _, msgs := range clientMsgs {
for m := range msgs {
allMessages[m] = struct{}{}
}
}
var wg sync.WaitGroup
for i, cl := range clients {
i := i
cl := cl
wg.Add(1)
go func() {
defer wg.Done()
err := cl.publishMsgs(ctx, clientMsgs[i])
if err != nil {
t.Errorf("client %d failed to publish all messages: %v", i, err)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
err := testAllMessagesReceived(cl, nclients*nmessages, allMessages)
if err != nil {
t.Errorf("client %d failed to receive all messages: %v", i, err)
}
}()
}
wg.Wait()
})
}
// setupTest sets up chatServer that can be used
// via the returned url.
//
// Defer closeFn to ensure everything is cleaned up at
// the end of the test.
//
// chatServer logs will be logged via t.Logf.
func setupTest(t *testing.T) (url string, closeFn func()) {
cs := newChatServer()
cs.logf = t.Logf
// To ensure tests run quickly under even -race.
cs.subscriberMessageBuffer = 4096
cs.publishLimiter.SetLimit(rate.Inf)
s := httptest.NewServer(cs)
return s.URL, func() {
s.Close()
}
}
// testAllMessagesReceived ensures that after n reads, all msgs in msgs
// have been read.
func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error {
msgs = cloneMessages(msgs)
for i := 0; i < n; i++ {
msg, err := cl.nextMessage()
if err != nil {
return err
}
delete(msgs, msg)
}
if len(msgs) != 0 {
return fmt.Errorf("did not receive all expected messages: %q", msgs)
}
return nil
}
func cloneMessages(msgs map[string]struct{}) map[string]struct{} {
msgs2 := make(map[string]struct{}, len(msgs))
for m := range msgs {
msgs2[m] = struct{}{}
}
return msgs2
}
func randMessages(n, maxMessageLength int) map[string]struct{} {
msgs := make(map[string]struct{})
for i := 0; i < n; i++ {
m := randString(randInt(maxMessageLength))
if _, ok := msgs[m]; ok {
i--
continue
}
msgs[m] = struct{}{}
}
return msgs
}
func assertSuccess(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
type client struct {
url string
c *websocket.Conn
}
func newClient(ctx context.Context, url string) (*client, error) {
c, _, err := websocket.Dial(ctx, url+"/subscribe", nil)
if err != nil {
return nil, err
}
cl := &client{
url: url,
c: c,
}
return cl, nil
}
func (cl *client) publish(ctx context.Context, msg string) (err error) {
defer func() {
if err != nil {
cl.c.Close(websocket.StatusInternalError, "publish failed")
}
}()
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
return fmt.Errorf("publish request failed: %v", resp.StatusCode)
}
return nil
}
func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error {
for m := range msgs {
err := cl.publish(ctx, m)
if err != nil {
return err
}
}
return nil
}
func (cl *client) nextMessage() (string, error) {
typ, b, err := cl.c.Read(context.Background())
if err != nil {
return "", err
}
if typ != websocket.MessageText {
cl.c.Close(websocket.StatusUnsupportedData, "expected text message")
return "", fmt.Errorf("expected text message but got %v", typ)
}
return string(b), nil
}
func (cl *client) Close() error {
return cl.c.Close(websocket.StatusNormalClosure, "")
}
// randString generates a random string with length n.
func randString(n int) string {
b := make([]byte, n)
_, err := rand.Reader.Read(b)
if err != nil {
panic(fmt.Sprintf("failed to generate rand bytes: %v", err))
}
s := strings.ToValidUTF8(string(b), "_")
s = strings.ReplaceAll(s, "\x00", "_")
if len(s) > n {
return s[:n]
}
if len(s) < n {
// Pad with =
extra := n - len(s)
return s + strings.Repeat("=", extra)
}
return s
}
// randInt returns a randomly generated integer between [0, max).
func randInt(max int) int {
x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
if err != nil {
panic(fmt.Sprintf("failed to get random int: %v", err))
}
return int(x.Int64())
}