313 lines
7.0 KiB
Go
Raw Normal View History

2021-12-04 16:42:11 +00:00
// Copyright 2020 lesismal. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"encoding/binary"
"errors"
"math/rand"
"net"
"sync"
"github.com/lesismal/nbio/mempool"
"github.com/lesismal/nbio/nbhttp"
)
const (
maxControlFramePayloadSize = 125
)
// MessageType .
type MessageType int8
// The message types are defined in RFC 6455, section 11.8.t .
const (
// FragmentMessage .
FragmentMessage MessageType = 0 // Must be preceded by Text or Binary message
// TextMessage .
TextMessage MessageType = 1
// BinaryMessage .
BinaryMessage MessageType = 2
// CloseMessage .
CloseMessage MessageType = 8
// PingMessage .
PingMessage MessageType = 9
// PongMessage .
PongMessage MessageType = 10
)
const (
maskBit = 1 << 7
)
// Conn .
type Conn struct {
net.Conn
mux sync.Mutex
isClient bool
onCloseCalled bool
remoteCompressionEnabled bool
enableWriteCompression bool
compressionLevel int
subprotocol string
session interface{}
onClose func(c *Conn, err error)
Engine *nbhttp.Engine
}
func validCloseCode(code int) bool {
switch code {
case 1000:
return true //| Normal Closure | hybi@ietf.org | RFC 6455 |
case 1001:
return true // | Going Away | hybi@ietf.org | RFC 6455 |
case 1002:
return true // | Protocol error | hybi@ietf.org | RFC 6455 |
case 1003:
return true // | Unsupported Data| hybi@ietf.org | RFC 6455 |
case 1004:
return false // | ---Reserved---- | hybi@ietf.org | RFC 6455 |
case 1005:
return false // | No Status Rcvd | hybi@ietf.org | RFC 6455 |
case 1006:
return false // | Abnormal Closure| hybi@ietf.org | RFC 6455 |
case 1007:
return true // | Invalid frame | hybi@ietf.org | RFC 6455 |
// | | payload data | | |
case 1008:
return true // | Policy Violation| hybi@ietf.org | RFC 6455 |
case 1009:
return true // | Message Too Big | hybi@ietf.org | RFC 6455 |
case 1010:
return true // | Mandatory Ext. | hybi@ietf.org | RFC 6455 |
case 1011:
return true // | Internal Server | hybi@ietf.org | RFC 6455 |
// | | Error | | |
case 1015:
return true // | TLS handshake | hybi@ietf.org | RFC 6455
default:
}
// IANA registration policy and should be granted in the range 3000-3999.
// The range of status codes from 4000-4999 is designated for Private
if code >= 3000 && code < 5000 {
return true
}
return false
}
// OnClose .
func (c *Conn) OnClose(h func(*Conn, error)) {
if h != nil {
c.onClose = func(c *Conn, err error) {
c.mux.Lock()
defer c.mux.Unlock()
if !c.onCloseCalled {
c.onCloseCalled = true
h(c, err)
}
}
// now all the upgrade, frames/messages and close are called in order
// nbc, ok := c.Conn.(*nbio.Conn)
// if ok {
// nbc.Lock()
// defer nbc.Unlock()
// closed, err := nbc.IsClosed()
// if closed {
// c.onClose(c, err)
// }
// }
}
}
// WriteMessage .
func (c *Conn) WriteMessage(messageType MessageType, data []byte) error {
c.mux.Lock()
defer c.mux.Unlock()
switch messageType {
case TextMessage:
case BinaryMessage:
case PingMessage, PongMessage, CloseMessage:
if len(data) > maxControlFramePayloadSize {
return ErrInvalidControlFrame
}
case FragmentMessage:
default:
}
compress := c.enableWriteCompression && (messageType == TextMessage || messageType == BinaryMessage)
if compress {
compress = true
w := &writeBuffer{
Buffer: bytes.NewBuffer(mempool.Malloc(len(data))),
}
defer w.Close()
w.Reset()
cw := compressWriter(w, c.compressionLevel)
_, err := cw.Write(data)
if err != nil {
compress = false
} else {
cw.Close()
data = w.Bytes()
}
}
if len(data) > 0 {
sendOpcode := true
for len(data) > 0 {
n := len(data)
if n > c.Engine.MaxWebsocketFramePayloadSize {
n = c.Engine.MaxWebsocketFramePayloadSize
}
err := c.writeFrame(messageType, sendOpcode, n == len(data), data[:n], compress)
if err != nil {
return err
}
sendOpcode = false
data = data[n:]
}
return nil
}
return c.writeFrame(messageType, true, true, []byte{}, compress)
}
// Session returns user session.
func (c *Conn) Session() interface{} {
return c.session
}
// SetSession sets user session.
func (c *Conn) SetSession(session interface{}) {
c.session = session
}
type writeBuffer struct {
*bytes.Buffer
}
// Close .
func (w *writeBuffer) Close() error {
mempool.Free(w.Bytes())
return nil
}
// WriteFrame .
func (c *Conn) WriteFrame(messageType MessageType, sendOpcode, fin bool, data []byte) error {
return c.writeFrame(messageType, sendOpcode, fin, data, false)
}
func (c *Conn) writeFrame(messageType MessageType, sendOpcode, fin bool, data []byte, compress bool) error {
var (
buf []byte
byte1 byte
maskLen int
headLen int
bodyLen = len(data)
)
if c.isClient {
byte1 |= maskBit
maskLen = 4
}
if bodyLen < 126 {
headLen = 2 + maskLen
buf = mempool.Malloc(len(data) + headLen)
buf[0] = 0
buf[1] = (byte1 | byte(bodyLen))
} else if bodyLen <= 65535 {
headLen = 4 + maskLen
buf = mempool.Malloc(len(data) + headLen)
buf[0] = 0
buf[1] = (byte1 | 126)
binary.BigEndian.PutUint16(buf[2:4], uint16(bodyLen))
} else {
headLen = 10 + maskLen
buf = mempool.Malloc(len(data) + headLen)
buf[0] = 0
buf[1] = (byte1 | 127)
binary.BigEndian.PutUint64(buf[2:10], uint64(bodyLen))
}
if c.isClient {
u32 := rand.Uint32()
maskKey := []byte{byte(u32), byte(u32 >> 8), byte(u32 >> 16), byte(u32 >> 24)}
copy(buf[headLen-4:headLen], maskKey)
for i := 0; i < len(data); i++ {
buf[headLen+i] = (data[i] ^ maskKey[i%4])
}
} else {
copy(buf[headLen:], data)
}
// opcode
if sendOpcode {
buf[0] = byte(messageType)
} else {
buf[0] = 0
}
if compress {
buf[0] |= 0x40
}
// fin
if fin {
buf[0] |= byte(0x80)
}
_, err := c.Conn.Write(buf)
return err
}
// Write overwrites nbio.Conn.Write.
func (c *Conn) Write(data []byte) (int, error) {
return -1, ErrInvalidWriteCalling
}
// EnableWriteCompression .
func (c *Conn) EnableWriteCompression(enable bool) {
if enable {
if c.remoteCompressionEnabled {
c.enableWriteCompression = enable
}
} else {
c.enableWriteCompression = enable
}
}
// SetCompressionLevel .
func (c *Conn) SetCompressionLevel(level int) error {
if !isValidCompressionLevel(level) {
return errors.New("websocket: invalid compression level")
}
c.compressionLevel = level
return nil
}
func newConn(u *Upgrader, c net.Conn, subprotocol string, remoteCompressionEnabled bool) *Conn {
conn := &Conn{
Conn: c,
subprotocol: subprotocol,
remoteCompressionEnabled: remoteCompressionEnabled,
compressionLevel: defaultCompressionLevel,
onClose: func(*Conn, error) {},
}
conn.EnableWriteCompression(u.enableWriteCompression)
conn.SetCompressionLevel(u.compressionLevel)
return conn
}