313 lines
7.0 KiB
Go
313 lines
7.0 KiB
Go
|
// Copyright 2020 lesismal. All rights reserved.
|
||
|
// Use of this source code is governed by an MIT-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package websocket
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"encoding/binary"
|
||
|
"errors"
|
||
|
"math/rand"
|
||
|
"net"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/lesismal/nbio/mempool"
|
||
|
"github.com/lesismal/nbio/nbhttp"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
maxControlFramePayloadSize = 125
|
||
|
)
|
||
|
|
||
|
// MessageType .
|
||
|
type MessageType int8
|
||
|
|
||
|
// The message types are defined in RFC 6455, section 11.8.t .
|
||
|
const (
|
||
|
// FragmentMessage .
|
||
|
FragmentMessage MessageType = 0 // Must be preceded by Text or Binary message
|
||
|
// TextMessage .
|
||
|
TextMessage MessageType = 1
|
||
|
// BinaryMessage .
|
||
|
BinaryMessage MessageType = 2
|
||
|
// CloseMessage .
|
||
|
CloseMessage MessageType = 8
|
||
|
// PingMessage .
|
||
|
PingMessage MessageType = 9
|
||
|
// PongMessage .
|
||
|
PongMessage MessageType = 10
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
maskBit = 1 << 7
|
||
|
)
|
||
|
|
||
|
// Conn .
|
||
|
type Conn struct {
|
||
|
net.Conn
|
||
|
|
||
|
mux sync.Mutex
|
||
|
|
||
|
isClient bool
|
||
|
|
||
|
onCloseCalled bool
|
||
|
remoteCompressionEnabled bool
|
||
|
enableWriteCompression bool
|
||
|
compressionLevel int
|
||
|
|
||
|
subprotocol string
|
||
|
|
||
|
session interface{}
|
||
|
|
||
|
onClose func(c *Conn, err error)
|
||
|
Engine *nbhttp.Engine
|
||
|
}
|
||
|
|
||
|
func validCloseCode(code int) bool {
|
||
|
switch code {
|
||
|
case 1000:
|
||
|
return true //| Normal Closure | hybi@ietf.org | RFC 6455 |
|
||
|
case 1001:
|
||
|
return true // | Going Away | hybi@ietf.org | RFC 6455 |
|
||
|
case 1002:
|
||
|
return true // | Protocol error | hybi@ietf.org | RFC 6455 |
|
||
|
case 1003:
|
||
|
return true // | Unsupported Data| hybi@ietf.org | RFC 6455 |
|
||
|
case 1004:
|
||
|
return false // | ---Reserved---- | hybi@ietf.org | RFC 6455 |
|
||
|
case 1005:
|
||
|
return false // | No Status Rcvd | hybi@ietf.org | RFC 6455 |
|
||
|
case 1006:
|
||
|
return false // | Abnormal Closure| hybi@ietf.org | RFC 6455 |
|
||
|
case 1007:
|
||
|
return true // | Invalid frame | hybi@ietf.org | RFC 6455 |
|
||
|
// | | payload data | | |
|
||
|
case 1008:
|
||
|
return true // | Policy Violation| hybi@ietf.org | RFC 6455 |
|
||
|
case 1009:
|
||
|
return true // | Message Too Big | hybi@ietf.org | RFC 6455 |
|
||
|
case 1010:
|
||
|
return true // | Mandatory Ext. | hybi@ietf.org | RFC 6455 |
|
||
|
case 1011:
|
||
|
return true // | Internal Server | hybi@ietf.org | RFC 6455 |
|
||
|
// | | Error | | |
|
||
|
case 1015:
|
||
|
return true // | TLS handshake | hybi@ietf.org | RFC 6455
|
||
|
default:
|
||
|
}
|
||
|
// IANA registration policy and should be granted in the range 3000-3999.
|
||
|
// The range of status codes from 4000-4999 is designated for Private
|
||
|
if code >= 3000 && code < 5000 {
|
||
|
return true
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// OnClose .
|
||
|
func (c *Conn) OnClose(h func(*Conn, error)) {
|
||
|
if h != nil {
|
||
|
c.onClose = func(c *Conn, err error) {
|
||
|
c.mux.Lock()
|
||
|
defer c.mux.Unlock()
|
||
|
if !c.onCloseCalled {
|
||
|
c.onCloseCalled = true
|
||
|
h(c, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// now all the upgrade, frames/messages and close are called in order
|
||
|
// nbc, ok := c.Conn.(*nbio.Conn)
|
||
|
// if ok {
|
||
|
// nbc.Lock()
|
||
|
// defer nbc.Unlock()
|
||
|
// closed, err := nbc.IsClosed()
|
||
|
// if closed {
|
||
|
// c.onClose(c, err)
|
||
|
// }
|
||
|
// }
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// WriteMessage .
|
||
|
func (c *Conn) WriteMessage(messageType MessageType, data []byte) error {
|
||
|
c.mux.Lock()
|
||
|
defer c.mux.Unlock()
|
||
|
|
||
|
switch messageType {
|
||
|
case TextMessage:
|
||
|
case BinaryMessage:
|
||
|
case PingMessage, PongMessage, CloseMessage:
|
||
|
if len(data) > maxControlFramePayloadSize {
|
||
|
return ErrInvalidControlFrame
|
||
|
}
|
||
|
case FragmentMessage:
|
||
|
default:
|
||
|
}
|
||
|
|
||
|
compress := c.enableWriteCompression && (messageType == TextMessage || messageType == BinaryMessage)
|
||
|
if compress {
|
||
|
compress = true
|
||
|
w := &writeBuffer{
|
||
|
Buffer: bytes.NewBuffer(mempool.Malloc(len(data))),
|
||
|
}
|
||
|
defer w.Close()
|
||
|
w.Reset()
|
||
|
cw := compressWriter(w, c.compressionLevel)
|
||
|
_, err := cw.Write(data)
|
||
|
if err != nil {
|
||
|
compress = false
|
||
|
} else {
|
||
|
cw.Close()
|
||
|
data = w.Bytes()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(data) > 0 {
|
||
|
sendOpcode := true
|
||
|
for len(data) > 0 {
|
||
|
n := len(data)
|
||
|
if n > c.Engine.MaxWebsocketFramePayloadSize {
|
||
|
n = c.Engine.MaxWebsocketFramePayloadSize
|
||
|
}
|
||
|
err := c.writeFrame(messageType, sendOpcode, n == len(data), data[:n], compress)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
sendOpcode = false
|
||
|
data = data[n:]
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
return c.writeFrame(messageType, true, true, []byte{}, compress)
|
||
|
}
|
||
|
|
||
|
// Session returns user session.
|
||
|
func (c *Conn) Session() interface{} {
|
||
|
return c.session
|
||
|
}
|
||
|
|
||
|
// SetSession sets user session.
|
||
|
func (c *Conn) SetSession(session interface{}) {
|
||
|
c.session = session
|
||
|
}
|
||
|
|
||
|
type writeBuffer struct {
|
||
|
*bytes.Buffer
|
||
|
}
|
||
|
|
||
|
// Close .
|
||
|
func (w *writeBuffer) Close() error {
|
||
|
mempool.Free(w.Bytes())
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// WriteFrame .
|
||
|
func (c *Conn) WriteFrame(messageType MessageType, sendOpcode, fin bool, data []byte) error {
|
||
|
return c.writeFrame(messageType, sendOpcode, fin, data, false)
|
||
|
}
|
||
|
|
||
|
func (c *Conn) writeFrame(messageType MessageType, sendOpcode, fin bool, data []byte, compress bool) error {
|
||
|
var (
|
||
|
buf []byte
|
||
|
byte1 byte
|
||
|
maskLen int
|
||
|
headLen int
|
||
|
bodyLen = len(data)
|
||
|
)
|
||
|
|
||
|
if c.isClient {
|
||
|
byte1 |= maskBit
|
||
|
maskLen = 4
|
||
|
}
|
||
|
|
||
|
if bodyLen < 126 {
|
||
|
headLen = 2 + maskLen
|
||
|
buf = mempool.Malloc(len(data) + headLen)
|
||
|
buf[0] = 0
|
||
|
buf[1] = (byte1 | byte(bodyLen))
|
||
|
} else if bodyLen <= 65535 {
|
||
|
headLen = 4 + maskLen
|
||
|
buf = mempool.Malloc(len(data) + headLen)
|
||
|
buf[0] = 0
|
||
|
buf[1] = (byte1 | 126)
|
||
|
binary.BigEndian.PutUint16(buf[2:4], uint16(bodyLen))
|
||
|
} else {
|
||
|
headLen = 10 + maskLen
|
||
|
buf = mempool.Malloc(len(data) + headLen)
|
||
|
buf[0] = 0
|
||
|
buf[1] = (byte1 | 127)
|
||
|
binary.BigEndian.PutUint64(buf[2:10], uint64(bodyLen))
|
||
|
}
|
||
|
|
||
|
if c.isClient {
|
||
|
u32 := rand.Uint32()
|
||
|
maskKey := []byte{byte(u32), byte(u32 >> 8), byte(u32 >> 16), byte(u32 >> 24)}
|
||
|
copy(buf[headLen-4:headLen], maskKey)
|
||
|
for i := 0; i < len(data); i++ {
|
||
|
buf[headLen+i] = (data[i] ^ maskKey[i%4])
|
||
|
}
|
||
|
} else {
|
||
|
copy(buf[headLen:], data)
|
||
|
}
|
||
|
|
||
|
// opcode
|
||
|
if sendOpcode {
|
||
|
buf[0] = byte(messageType)
|
||
|
} else {
|
||
|
buf[0] = 0
|
||
|
}
|
||
|
|
||
|
if compress {
|
||
|
buf[0] |= 0x40
|
||
|
}
|
||
|
|
||
|
// fin
|
||
|
if fin {
|
||
|
buf[0] |= byte(0x80)
|
||
|
}
|
||
|
|
||
|
_, err := c.Conn.Write(buf)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Write overwrites nbio.Conn.Write.
|
||
|
func (c *Conn) Write(data []byte) (int, error) {
|
||
|
return -1, ErrInvalidWriteCalling
|
||
|
}
|
||
|
|
||
|
// EnableWriteCompression .
|
||
|
func (c *Conn) EnableWriteCompression(enable bool) {
|
||
|
if enable {
|
||
|
if c.remoteCompressionEnabled {
|
||
|
c.enableWriteCompression = enable
|
||
|
}
|
||
|
} else {
|
||
|
c.enableWriteCompression = enable
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SetCompressionLevel .
|
||
|
func (c *Conn) SetCompressionLevel(level int) error {
|
||
|
if !isValidCompressionLevel(level) {
|
||
|
return errors.New("websocket: invalid compression level")
|
||
|
}
|
||
|
c.compressionLevel = level
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func newConn(u *Upgrader, c net.Conn, subprotocol string, remoteCompressionEnabled bool) *Conn {
|
||
|
conn := &Conn{
|
||
|
Conn: c,
|
||
|
subprotocol: subprotocol,
|
||
|
remoteCompressionEnabled: remoteCompressionEnabled,
|
||
|
compressionLevel: defaultCompressionLevel,
|
||
|
onClose: func(*Conn, error) {},
|
||
|
}
|
||
|
conn.EnableWriteCompression(u.enableWriteCompression)
|
||
|
conn.SetCompressionLevel(u.compressionLevel)
|
||
|
|
||
|
return conn
|
||
|
}
|