2021-12-04 16:42:11 +00:00

266 lines
5.3 KiB
Go

// +build !js
package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"runtime"
"strconv"
"sync"
"sync/atomic"
)
// Conn represents a WebSocket connection.
// All methods may be called concurrently except for Reader and Read.
//
// You must always read from the connection. Otherwise control
// frames will not be handled. See Reader and CloseRead.
//
// Be sure to call Close on the connection when you
// are finished with it to release associated resources.
//
// On any error from any method, the connection is closed
// with an appropriate reason.
type Conn struct {
subprotocol string
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio.Reader
bw *bufio.Writer
readTimeout chan context.Context
writeTimeout chan context.Context
// Read state.
readMu *mu
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
readCloseFrameErr error
// Write state.
msgWriterState *msgWriterState
writeFrameMu *mu
writeBuf []byte
writeHeaderBuf [8]byte
writeHeader header
closed chan struct{}
closeMu sync.Mutex
closeErr error
wroteClose bool
pingCounter int32
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
}
type connConfig struct {
subprotocol string
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio.Reader
bw *bufio.Writer
}
func newConn(cfg connConfig) *Conn {
c := &Conn{
subprotocol: cfg.subprotocol,
rwc: cfg.rwc,
client: cfg.client,
copts: cfg.copts,
flateThreshold: cfg.flateThreshold,
br: cfg.br,
bw: cfg.bw,
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
}
c.readMu = newMu(c)
c.writeFrameMu = newMu(c)
c.msgReader = newMsgReader(c)
c.msgWriterState = newMsgWriterState(c)
if c.client {
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
}
if c.flate() && c.flateThreshold == 0 {
c.flateThreshold = 128
if !c.msgWriterState.flateContextTakeover() {
c.flateThreshold = 512
}
}
runtime.SetFinalizer(c, func(c *Conn) {
c.close(errors.New("connection garbage collected"))
})
go c.timeoutLoop()
return c
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
return c.subprotocol
}
func (c *Conn) close(err error) {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if c.isClosed() {
return
}
c.setCloseErrLocked(err)
close(c.closed)
runtime.SetFinalizer(c, nil)
// Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns
// closeErr.
c.rwc.Close()
go func() {
c.msgWriterState.close()
c.msgReader.close()
}()
}
func (c *Conn) timeoutLoop() {
readCtx := context.Background()
writeCtx := context.Background()
for {
select {
case <-c.closed:
return
case writeCtx = <-c.writeTimeout:
case readCtx = <-c.readTimeout:
case <-readCtx.Done():
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
go c.writeError(StatusPolicyViolation, errors.New("timed out"))
case <-writeCtx.Done():
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
return
}
}
}
func (c *Conn) flate() bool {
return c.copts != nil
}
// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
// Ping must be called concurrently with Reader as it does
// not read from the connection but instead waits for a Reader call
// to read the pong.
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
p := atomic.AddInt32(&c.pingCounter, 1)
err := c.ping(ctx, strconv.Itoa(int(p)))
if err != nil {
return fmt.Errorf("failed to ping: %w", err)
}
return nil
}
func (c *Conn) ping(ctx context.Context, p string) error {
pong := make(chan struct{}, 1)
c.activePingsMu.Lock()
c.activePings[p] = pong
c.activePingsMu.Unlock()
defer func() {
c.activePingsMu.Lock()
delete(c.activePings, p)
c.activePingsMu.Unlock()
}()
err := c.writeControl(ctx, opPing, []byte(p))
if err != nil {
return err
}
select {
case <-c.closed:
return c.closeErr
case <-ctx.Done():
err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
c.close(err)
return err
case <-pong:
return nil
}
}
type mu struct {
c *Conn
ch chan struct{}
}
func newMu(c *Conn) *mu {
return &mu{
c: c,
ch: make(chan struct{}, 1),
}
}
func (m *mu) forceLock() {
m.ch <- struct{}{}
}
func (m *mu) lock(ctx context.Context) error {
select {
case <-m.c.closed:
return m.c.closeErr
case <-ctx.Done():
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
m.c.close(err)
return err
case m.ch <- struct{}{}:
// To make sure the connection is certainly alive.
// As it's possible the send on m.ch was selected
// over the receive on closed.
select {
case <-m.c.closed:
// Make sure to release.
m.unlock()
return m.c.closeErr
default:
}
return nil
}
}
func (m *mu) unlock() {
select {
case <-m.ch:
default:
}
}