2021-12-01 15:43:13 +00:00

895 lines
20 KiB
Go

package websocket
import (
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"unicode/utf8"
"github.com/lesismal/llib/std/crypto/tls"
"github.com/lesismal/nbio"
"github.com/lesismal/nbio/logging"
"github.com/lesismal/nbio/mempool"
"github.com/lesismal/nbio/nbhttp"
)
// Hijacker .
type Hijacker interface {
Hijack() (net.Conn, error)
}
// Upgrader .
type Upgrader struct {
ReadLimit int64
// MessageLengthLimit is the maximum length of websocket message. 0 for unlimited.
MessageLengthLimit int64
HandshakeTimeout time.Duration
enableCompression bool
enableWriteCompression bool
compressionLevel int
Subprotocols []string
CheckOrigin func(r *http.Request) bool
pingMessageHandler func(c *Conn, appData string)
pongMessageHandler func(c *Conn, appData string)
closeMessageHandler func(c *Conn, code int, text string)
openHandler func(*Conn)
messageHandler func(c *Conn, messageType MessageType, data []byte)
dataFrameHandler func(c *Conn, messageType MessageType, fin bool, data []byte)
onClose func(c *Conn, err error)
}
type connState struct {
common *Upgrader
conn *Conn
expectingFragments bool
compress bool
opcode MessageType
buffer []byte
message []byte
Engine *nbhttp.Engine
}
// CompressionEnabled .
func (u *connState) CompressionEnabled() bool {
return u.compress
}
// NewUpgrader .
func NewUpgrader() *Upgrader {
u := &Upgrader{}
u.pingMessageHandler = func(c *Conn, data string) {
if len(data) > 125 {
c.Close()
return
}
err := c.WriteMessage(PongMessage, []byte(data))
if err != nil {
logging.Debug("failed to send pong %v", err)
c.Close()
return
}
}
u.pongMessageHandler = func(*Conn, string) {}
u.closeMessageHandler = func(c *Conn, code int, text string) {
if len(text)+2 > maxControlFramePayloadSize {
return //ErrInvalidControlFrame
}
buf := mempool.Malloc(len(text) + 2)
binary.BigEndian.PutUint16(buf[:2], uint16(code))
copy(buf[2:], text)
c.WriteMessage(CloseMessage, buf)
mempool.Free(buf)
}
return u
}
// SetCloseHandler .
func (u *Upgrader) SetCloseHandler(h func(*Conn, int, string)) {
if h != nil {
u.closeMessageHandler = h
}
}
// SetPingHandler .
func (u *Upgrader) SetPingHandler(h func(*Conn, string)) {
if h != nil {
u.pingMessageHandler = h
}
}
// SetPongHandler .
func (u *Upgrader) SetPongHandler(h func(*Conn, string)) {
if h != nil {
u.pongMessageHandler = h
}
}
// OnOpen .
func (u *Upgrader) OnOpen(h func(*Conn)) {
u.openHandler = h
}
// OnMessage .
func (u *Upgrader) OnMessage(h func(*Conn, MessageType, []byte)) {
if h != nil {
u.messageHandler = func(c *Conn, messageType MessageType, data []byte) {
if c.Engine.ReleaseWebsocketPayload {
defer c.Engine.BodyAllocator.Free(data)
}
h(c, messageType, data)
}
}
}
// OnDataFrame .
func (u *Upgrader) OnDataFrame(h func(*Conn, MessageType, bool, []byte)) {
if h != nil {
u.dataFrameHandler = func(c *Conn, messageType MessageType, fin bool, data []byte) {
if c.Engine.ReleaseWebsocketPayload {
defer c.Engine.BodyAllocator.Free(data)
}
h(c, messageType, fin, data)
}
}
}
// OnClose .
func (u *Upgrader) OnClose(h func(*Conn, error)) {
u.onClose = h
}
// EnableCompression .
func (u *Upgrader) EnableCompression(enable bool) {
u.enableCompression = enable
}
// EnableWriteCompression .
func (u *Upgrader) EnableWriteCompression(enable bool) {
u.enableWriteCompression = enable
}
// SetCompressionLevel .
func (u *Upgrader) SetCompressionLevel(level int) error {
u.compressionLevel = level
return nil
}
// Upgrade .
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (net.Conn, error) {
if !headerContains(r.Header, "Connection", "upgrade") {
return nil, u.returnError(w, r, http.StatusBadRequest, ErrUpgradeTokenNotFound)
}
if !headerContains(r.Header, "Upgrade", "websocket") {
return nil, u.returnError(w, r, http.StatusBadRequest, ErrUpgradeTokenNotFound)
}
if r.Method != "GET" {
return nil, u.returnError(w, r, http.StatusMethodNotAllowed, ErrUpgradeMethodIsGet)
}
if !headerContains(r.Header, "Sec-Websocket-Version", "13") {
return nil, u.returnError(w, r, http.StatusBadRequest, ErrUpgradeInvalidWebsocketVersion)
}
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return nil, u.returnError(w, r, http.StatusInternalServerError, ErrUpgradeUnsupportedExtensions)
}
checkOrigin := u.CheckOrigin
if checkOrigin == nil {
checkOrigin = checkSameOrigin
}
if !checkOrigin(r) {
return nil, u.returnError(w, r, http.StatusForbidden, ErrUpgradeOriginNotAllowed)
}
challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
return nil, u.returnError(w, r, http.StatusBadRequest, ErrUpgradeMissingWebsocketKey)
}
subprotocol := u.selectSubprotocol(r, responseHeader)
// Negotiate PMCE
var compress bool
if u.enableCompression {
for _, ext := range parseExtensions(r.Header) {
if ext[""] != "permessage-deflate" {
continue
}
compress = true
break
}
}
h, ok := w.(http.Hijacker)
if !ok {
return nil, u.returnError(w, r, http.StatusInternalServerError, ErrUpgradeNotHijacker)
}
conn, _, err := h.Hijack()
if err != nil {
return nil, u.returnError(w, r, http.StatusInternalServerError, err)
}
nbc, ok := conn.(*nbio.Conn)
if !ok {
tlsConn, tlsOk := conn.(*tls.Conn)
if !tlsOk {
return nil, u.returnError(w, r, http.StatusInternalServerError, err)
}
nbc, tlsOk = tlsConn.Conn().(*nbio.Conn)
if !tlsOk {
return nil, u.returnError(w, r, http.StatusInternalServerError, err)
}
}
parser, ok := nbc.Session().(*nbhttp.Parser)
if !ok {
return nil, u.returnError(w, r, http.StatusInternalServerError, err)
}
state := &connState{common: u}
parser.ConnState = state
buf := mempool.Malloc(1024)[0:0]
buf = append(buf, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
buf = append(buf, acceptKeyBytes(challengeKey)...)
buf = append(buf, "\r\n"...)
if subprotocol != "" {
buf = append(buf, "Sec-WebSocket-Protocol: "...)
buf = append(buf, subprotocol...)
buf = append(buf, "\r\n"...)
}
if compress {
buf = append(buf, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
}
for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" {
continue
}
for _, v := range vs {
buf = append(buf, k...)
buf = append(buf, ": "...)
for i := 0; i < len(v); i++ {
b := v[i]
if b <= 31 {
// prevent response splitting.
b = ' '
}
buf = append(buf, b)
}
buf = append(buf, "\r\n"...)
}
}
buf = append(buf, "\r\n"...)
if u.HandshakeTimeout > 0 {
conn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
state.conn = newConn(u, conn, subprotocol, compress)
state.Engine = parser.Engine
state.conn.Engine = parser.Engine
if u.openHandler != nil {
u.openHandler(state.conn)
}
if _, err = conn.Write(buf); err != nil {
conn.Close()
return nil, err
}
state.conn.OnClose(u.onClose)
return state.conn, nil
}
func (u *connState) validFrame(opcode MessageType, fin, res1, res2, res3, expectingFragments bool) error {
if res1 && !u.common.enableCompression {
return ErrReserveBitSet
}
if res2 || res3 {
return ErrReserveBitSet
}
if opcode > BinaryMessage && opcode < CloseMessage {
return fmt.Errorf("%w: opcode=%d", ErrReservedOpcodeSet, opcode)
}
if !fin && (opcode != FragmentMessage && opcode != TextMessage && opcode != BinaryMessage) {
return fmt.Errorf("%w: opcode=%d", ErrControlMessageFragmented, opcode)
}
if expectingFragments && (opcode == TextMessage || opcode == BinaryMessage) {
return ErrFragmentsShouldNotHaveBinaryOrTextOpcode
}
return nil
}
// return false if length is ok.
func (u *connState) isMessageTooLarge(len int) bool {
if u.common.MessageLengthLimit == 0 {
// 0 means unlimitted size
return false
}
return len > int(u.common.MessageLengthLimit)
}
// Read .
func (u *connState) Read(p *nbhttp.Parser, data []byte) error {
bufLen := len(u.buffer)
if u.common.ReadLimit > 0 && (int64(bufLen+len(data)) > u.common.ReadLimit || int64(bufLen+len(u.message)) > u.common.ReadLimit) {
return nbhttp.ErrTooLong
}
var oldBuffer []byte
if bufLen == 0 {
u.buffer = data
} else {
u.buffer = append(u.buffer, data...)
oldBuffer = u.buffer
}
var err error
for i := 0; true; i++ {
opcode, body, ok, fin, res1, res2, res3 := u.nextFrame()
if !ok {
break
}
if err = u.validFrame(opcode, fin, res1, res2, res3, u.expectingFragments); err != nil {
break
}
if opcode == FragmentMessage || opcode == TextMessage || opcode == BinaryMessage {
if u.opcode == 0 {
u.opcode = opcode
u.compress = res1
}
bl := len(body)
if u.common.dataFrameHandler != nil {
var frame []byte
if bl > 0 {
if u.isMessageTooLarge(bl) {
err = ErrMessageTooLarge
break
}
frame = u.Engine.BodyAllocator.Malloc(bl)
copy(frame, body)
}
if u.opcode == TextMessage && len(frame) > 0 && !u.Engine.CheckUtf8(frame) {
u.conn.Close()
} else {
u.common.handleDataFrame(p, u.conn, u.opcode, fin, frame)
}
}
if bl > 0 && u.common.messageHandler != nil {
if u.message == nil {
u.message = u.Engine.BodyAllocator.Malloc(len(body))
if u.isMessageTooLarge(len(body)) {
err = ErrMessageTooLarge
break
}
copy(u.message, body)
} else {
if u.isMessageTooLarge(len(u.message) + len(body)) {
err = ErrMessageTooLarge
break
}
u.message = append(u.message, body...)
}
}
if fin {
if u.common.messageHandler != nil {
if u.compress {
var b []byte
rc := decompressReader(io.MultiReader(bytes.NewBuffer(u.message), strings.NewReader(flateReaderTail)))
b, err = u.readAll(rc, len(u.message)*2)
u.Engine.BodyAllocator.Free(u.message)
u.message = b
rc.Close()
if err != nil {
break
}
}
u.handleMessage(p, u.opcode, u.message)
}
u.compress = false
u.expectingFragments = false
u.message = nil
u.opcode = 0
} else {
u.expectingFragments = true
}
} else {
var frame []byte
if len(body) > 0 {
if u.isMessageTooLarge(len(body)) {
err = ErrMessageTooLarge
break
}
frame = u.Engine.BodyAllocator.Malloc(len(body))
copy(frame, body)
}
u.handleProtocolMessage(p, opcode, frame)
}
if len(u.buffer) == 0 {
break
}
}
if bufLen == 0 {
if len(u.buffer) > 0 {
tmp := u.buffer
u.buffer = mempool.Malloc(len(tmp))
copy(u.buffer, tmp)
}
} else {
if len(u.buffer) < len(oldBuffer) {
tmp := u.buffer
u.buffer = mempool.Malloc(len(tmp))
copy(u.buffer, tmp)
mempool.Free(oldBuffer)
}
}
return err
}
// Close .
func (u *connState) Close(p *nbhttp.Parser, err error) {
if u.conn != nil {
u.conn.onClose(u.conn, err)
}
if len(u.buffer) > 0 {
mempool.Free(u.buffer)
}
if len(u.message) > 0 {
mempool.Free(u.message)
}
}
func (u *Upgrader) handleDataFrame(p *nbhttp.Parser, c *Conn, opcode MessageType, fin bool, data []byte) {
h := u.dataFrameHandler
p.Execute(func() {
h(c, opcode, fin, data)
})
}
func (u *connState) handleMessage(p *nbhttp.Parser, opcode MessageType, body []byte) {
if u.opcode == TextMessage && !u.Engine.CheckUtf8(u.message) {
u.conn.Close()
return
}
p.Execute(func() {
u.common.handleWsMessage(u.conn, opcode, body)
})
}
func (u *connState) handleProtocolMessage(p *nbhttp.Parser, opcode MessageType, body []byte) {
p.Execute(func() {
u.common.handleWsMessage(u.conn, opcode, body)
if len(body) > 0 && u.Engine.ReleaseWebsocketPayload {
u.Engine.BodyAllocator.Free(body)
}
})
}
func (u *Upgrader) handleWsMessage(c *Conn, opcode MessageType, data []byte) {
switch opcode {
case TextMessage, BinaryMessage:
u.messageHandler(c, opcode, data)
case CloseMessage:
if len(data) >= 2 {
code := int(binary.BigEndian.Uint16(data[:2]))
if !validCloseCode(code) || !c.Engine.CheckUtf8(data[2:]) {
protoErrorCode := make([]byte, 2)
binary.BigEndian.PutUint16(protoErrorCode, 1002)
c.WriteMessage(CloseMessage, protoErrorCode)
} else {
u.closeMessageHandler(c, code, string(data[2:]))
}
} else {
c.WriteMessage(CloseMessage, nil)
}
// close immediately, no need to wait for data flushed on a blocked conn
c.Close()
case PingMessage:
u.pingMessageHandler(c, string(data))
case PongMessage:
u.pongMessageHandler(c, string(data))
case FragmentMessage:
logging.Debug("invalid fragment message")
c.Close()
default:
c.Close()
}
}
func (u *connState) nextFrame() (opcode MessageType, body []byte, ok, fin, res1, res2, res3 bool) {
l := int64(len(u.buffer))
headLen := int64(2)
if l >= 2 {
opcode = MessageType(u.buffer[0] & 0xF)
res1 = int8(u.buffer[0]&0x40) != 0
res2 = int8(u.buffer[0]&0x20) != 0
res3 = int8(u.buffer[0]&0x10) != 0
fin = ((u.buffer[0] & 0x80) != 0)
payloadLen := u.buffer[1] & 0x7F
bodyLen := int64(-1)
switch payloadLen {
case 126:
if l >= 4 {
bodyLen = int64(binary.BigEndian.Uint16(u.buffer[2:4]))
headLen = 4
}
case 127:
if len(u.buffer) >= 10 {
bodyLen = int64(binary.BigEndian.Uint64(u.buffer[2:10]))
headLen = 10
}
default:
bodyLen = int64(payloadLen)
}
if bodyLen >= 0 {
masked := (u.buffer[1] & 0x80) != 0
if masked {
headLen += 4
}
total := headLen + bodyLen
if l >= total {
body = u.buffer[headLen:total]
if masked {
maskKey := u.buffer[headLen-4 : headLen]
for i := 0; i < len(body); i++ {
body[i] ^= maskKey[i%4]
}
}
ok = true
u.buffer = u.buffer[total:l]
}
}
}
return opcode, body, ok, fin, res1, res2, res3
}
func (u *Upgrader) returnError(w http.ResponseWriter, _ *http.Request, status int, err error) error {
w.Header().Set("Sec-Websocket-Version", "13")
http.Error(w, http.StatusText(status), status)
return err
}
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil {
clientProtocols := subprotocols(r)
for _, serverProtocol := range u.Subprotocols {
for _, clientProtocol := range clientProtocols {
if clientProtocol == serverProtocol {
return clientProtocol
}
}
}
} else if responseHeader != nil {
return responseHeader.Get("Sec-Websocket-Protocol")
}
return ""
}
func subprotocols(r *http.Request) []string {
h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
if h == "" {
return nil
}
protocols := strings.Split(h, ",")
for i := range protocols {
protocols[i] = strings.TrimSpace(protocols[i])
}
return protocols
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func acceptKeyString(challengeKey string) string {
h := sha1.New() //nolint:gosec // per websocket protocol spec
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
func acceptKeyBytes(challengeKey string) []byte {
h := sha1.New() //nolint:gosec // per websocket protocol spec
h.Write([]byte(challengeKey))
h.Write(keyGUID)
sum := h.Sum(nil)
buf := make([]byte, base64.StdEncoding.EncodedLen(len(sum)))
base64.StdEncoding.Encode(buf, sum)
return buf
}
func challengeKey() (string, error) {
p := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, p); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(p), nil
}
func checkSameOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
return equalASCIIFold(u.Host, r.Host)
}
func headerContains(header http.Header, name string, value string) bool {
var t string
values := header[name]
for _, s := range values {
for {
t, s = nextToken(skipSpace(s))
if t == "" {
continue
}
s = skipSpace(s)
if s != "" && s[0] != ',' {
continue
}
if equalASCIIFold(t, value) {
return true
}
if s == "" {
continue
}
s = s[1:]
}
}
return false
}
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}
func parseExtensions(header http.Header) []map[string]string {
var result []map[string]string
headers:
for _, s := range header["Sec-Websocket-Extensions"] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
ext := map[string]string{"": t}
for {
s = skipSpace(s)
if !strings.HasPrefix(s, ";") {
break
}
var k string
k, s = nextToken(skipSpace(s[1:]))
if k == "" {
continue headers
}
s = skipSpace(s)
var v string
if strings.HasPrefix(s, "=") {
v, s = nextTokenOrQuoted(skipSpace(s[1:]))
s = skipSpace(s)
}
if s != "" && s[0] != ',' && s[0] != ';' {
continue headers
}
ext[k] = v
}
if s != "" && s[0] != ',' {
continue headers
}
result = append(result, ext)
if s == "" {
continue headers
}
s = s[1:]
}
}
return result
}
var isTokenOctet = [256]bool{
'!': true,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'*': true,
'+': true,
'-': true,
'.': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'W': true,
'V': true,
'X': true,
'Y': true,
'Z': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'|': true,
'~': true,
}
func skipSpace(s string) (rest string) {
i := 0
for ; i < len(s); i++ {
if b := s[i]; b != ' ' && b != '\t' {
break
}
}
return s[i:]
}
func nextToken(s string) (token, rest string) {
i := 0
for ; i < len(s); i++ {
if !isTokenOctet[s[i]] {
break
}
}
return s[:i], s[i:]
}
func nextTokenOrQuoted(s string) (value string, rest string) {
if !strings.HasPrefix(s, "\"") {
return nextToken(s)
}
s = s[1:]
for i := 0; i < len(s); i++ {
switch s[i] {
case '"':
return s[:i], s[i+1:]
case '\\':
p := make([]byte, len(s)-1)
j := copy(p, s[:i])
escape := true
for i = i + 1; i < len(s); i++ {
b := s[i]
switch {
case escape:
escape = false
p[j] = b
j++
case b == '\\':
escape = true
case b == '"':
return string(p[:j]), s[i+1:]
default:
p[j] = b
j++
}
}
return "", ""
}
}
return "", ""
}
func (u *connState) readAll(r io.Reader, size int) ([]byte, error) {
const maxAppendSize = 1024 * 1024 * 4
buf := u.Engine.BodyAllocator.Malloc(size)[0:0]
for {
n, err := r.Read(buf[len(buf):cap(buf)])
if n > 0 {
buf = buf[:len(buf)+n]
}
if err != nil {
if err == io.EOF {
err = nil
}
return buf, err
}
if len(buf) == cap(buf) {
l := len(buf)
al := l
if al > maxAppendSize {
al = maxAppendSize
}
buf = append(buf, make([]byte, al)...)[:l]
}
}
}