799 lines
16 KiB
Go
799 lines
16 KiB
Go
// Copyright 2020 lesismal. All rights reserved.
|
|
// Use of this source code is governed by an MIT-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package nbhttp
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/textproto"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/lesismal/nbio/mempool"
|
|
)
|
|
|
|
const (
|
|
transferEncodingHeader = "Transfer-Encoding"
|
|
trailerHeader = "Trailer"
|
|
contentLengthHeader = "Content-Length"
|
|
|
|
// MaxUint .
|
|
MaxUint = ^uint(0)
|
|
// MaxInt .
|
|
MaxInt = int64(int(MaxUint >> 1))
|
|
)
|
|
|
|
// Parser .
|
|
type Parser struct {
|
|
mux sync.Mutex
|
|
|
|
cache []byte
|
|
|
|
proto string
|
|
|
|
statusCode int
|
|
status string
|
|
|
|
headerKey string
|
|
headerValue string
|
|
|
|
header http.Header
|
|
trailer http.Header
|
|
|
|
contentLength int
|
|
chunkSize int
|
|
chunked bool
|
|
headerExists bool
|
|
|
|
state int8
|
|
isClient bool
|
|
|
|
readLimit int
|
|
|
|
errClose error
|
|
|
|
onClose func(p *Parser, err error)
|
|
|
|
Processor Processor
|
|
|
|
ConnState ReadCloser
|
|
|
|
Engine *Engine
|
|
|
|
Conn net.Conn
|
|
|
|
Execute func(f func())
|
|
}
|
|
|
|
func (p *Parser) nextState(state int8) {
|
|
switch p.state {
|
|
case stateClose:
|
|
default:
|
|
p.state = state
|
|
}
|
|
}
|
|
|
|
// OnClose .
|
|
func (p *Parser) OnClose(h func(p *Parser, err error)) {
|
|
p.onClose = h
|
|
}
|
|
|
|
// Close .
|
|
func (p *Parser) Close(err error) {
|
|
p.mux.Lock()
|
|
defer p.mux.Unlock()
|
|
|
|
if p.state == stateClose {
|
|
return
|
|
}
|
|
|
|
p.state = stateClose
|
|
|
|
p.errClose = err
|
|
|
|
if p.ConnState != nil {
|
|
p.ConnState.Close(p, p.errClose)
|
|
}
|
|
if p.Processor != nil {
|
|
p.Processor.Close(p, p.errClose)
|
|
}
|
|
if len(p.cache) > 0 {
|
|
mempool.Free(p.cache)
|
|
}
|
|
if p.onClose != nil {
|
|
p.onClose(p, err)
|
|
}
|
|
}
|
|
|
|
func parseAndValidateChunkSize(originalStr string) (int, error) {
|
|
chunkSize, err := strconv.ParseInt(originalStr, 16, 63)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("chunk size parse error %v: %w", originalStr, err)
|
|
}
|
|
if chunkSize < 0 {
|
|
return -1, fmt.Errorf("chunk size zero")
|
|
}
|
|
if chunkSize > MaxInt {
|
|
return -1, fmt.Errorf("chunk size greater than max int %d", chunkSize)
|
|
}
|
|
return int(chunkSize), nil
|
|
}
|
|
|
|
// Read .
|
|
func (p *Parser) Read(data []byte) error {
|
|
p.mux.Lock()
|
|
defer p.mux.Unlock()
|
|
|
|
if p.state == stateClose {
|
|
return ErrClosed
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return nil
|
|
}
|
|
|
|
var c byte
|
|
var start = 0
|
|
var offset = len(p.cache)
|
|
if offset > 0 {
|
|
if offset+len(data) > p.readLimit {
|
|
return ErrTooLong
|
|
}
|
|
p.cache = append(p.cache, data...)
|
|
data = p.cache
|
|
}
|
|
|
|
UPGRADER:
|
|
if p.ConnState != nil {
|
|
udata := data
|
|
if start > 0 {
|
|
udata = data[start:]
|
|
}
|
|
err := p.ConnState.Read(p, udata)
|
|
if p.cache != nil {
|
|
mempool.Free(p.cache)
|
|
p.cache = nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
for i := offset; i < len(data); i++ {
|
|
if p.ConnState != nil {
|
|
goto UPGRADER
|
|
}
|
|
c = data[i]
|
|
switch p.state {
|
|
case stateClose:
|
|
return ErrClosed
|
|
case stateMethodBefore:
|
|
if isValidMethodChar(c) {
|
|
start = i
|
|
p.nextState(stateMethod)
|
|
continue
|
|
}
|
|
return ErrInvalidMethod
|
|
case stateMethod:
|
|
if c == ' ' {
|
|
var method = strings.ToUpper(string(data[start:i]))
|
|
if !isValidMethod(method) {
|
|
return ErrInvalidMethod
|
|
}
|
|
p.Processor.OnMethod(method)
|
|
start = i + 1
|
|
p.nextState(statePathBefore)
|
|
continue
|
|
}
|
|
if !isAlpha(c) {
|
|
return ErrInvalidMethod
|
|
}
|
|
case statePathBefore:
|
|
switch c {
|
|
case '/', '*':
|
|
start = i
|
|
p.nextState(statePath)
|
|
continue
|
|
}
|
|
switch c {
|
|
case ' ':
|
|
default:
|
|
return ErrInvalidRequestURI
|
|
}
|
|
case statePath:
|
|
if c == ' ' {
|
|
var uri = string(data[start:i])
|
|
if err := p.Processor.OnURL(uri); err != nil {
|
|
return err
|
|
}
|
|
start = i + 1
|
|
p.nextState(stateProtoBefore)
|
|
}
|
|
case stateProtoBefore:
|
|
if c != ' ' {
|
|
start = i
|
|
p.nextState(stateProto)
|
|
}
|
|
case stateProto:
|
|
switch c {
|
|
case ' ':
|
|
if p.proto == "" {
|
|
p.proto = string(data[start:i])
|
|
}
|
|
case '\r':
|
|
if p.proto == "" {
|
|
p.proto = string(data[start:i])
|
|
}
|
|
if err := p.Processor.OnProto(p.proto); err != nil {
|
|
p.proto = ""
|
|
return err
|
|
}
|
|
p.proto = ""
|
|
p.nextState(stateProtoLF)
|
|
}
|
|
case stateClientProtoBefore:
|
|
if c == 'H' {
|
|
start = i
|
|
p.nextState(stateClientProto)
|
|
continue
|
|
}
|
|
return ErrInvalidMethod
|
|
case stateClientProto:
|
|
switch c {
|
|
case ' ':
|
|
if p.proto == "" {
|
|
p.proto = string(data[start:i])
|
|
}
|
|
if err := p.Processor.OnProto(p.proto); err != nil {
|
|
p.proto = ""
|
|
return err
|
|
}
|
|
p.proto = ""
|
|
p.nextState(stateStatusCodeBefore)
|
|
}
|
|
case stateStatusCodeBefore:
|
|
switch c {
|
|
case ' ':
|
|
default:
|
|
if isNum(c) {
|
|
start = i
|
|
p.nextState(stateStatusCode)
|
|
}
|
|
continue
|
|
}
|
|
return ErrInvalidHTTPStatusCode
|
|
case stateStatusCode:
|
|
if c == ' ' {
|
|
cs := string(data[start:i])
|
|
code, err := strconv.Atoi(cs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.statusCode = code
|
|
p.nextState(stateStatusBefore)
|
|
continue
|
|
}
|
|
if !isNum(c) {
|
|
return ErrInvalidHTTPStatusCode
|
|
}
|
|
case stateStatusBefore:
|
|
switch c {
|
|
case ' ':
|
|
default:
|
|
if isAlpha(c) {
|
|
start = i
|
|
p.nextState(stateStatus)
|
|
}
|
|
continue
|
|
}
|
|
return ErrInvalidHTTPStatus
|
|
case stateStatus:
|
|
switch c {
|
|
case ' ':
|
|
if p.status == "" {
|
|
p.status = string(data[start:i])
|
|
}
|
|
case '\r':
|
|
if p.status == "" {
|
|
p.status = string(data[start:i])
|
|
}
|
|
p.Processor.OnStatus(p.statusCode, p.status)
|
|
p.statusCode = 0
|
|
p.status = ""
|
|
p.nextState(stateStatusLF)
|
|
}
|
|
case stateStatusLF:
|
|
if c == '\n' {
|
|
p.nextState(stateHeaderKeyBefore)
|
|
continue
|
|
}
|
|
return ErrLFExpected
|
|
case stateProtoLF:
|
|
if c == '\n' {
|
|
start = i + 1
|
|
p.nextState(stateHeaderKeyBefore)
|
|
continue
|
|
}
|
|
return ErrLFExpected
|
|
case stateHeaderValueLF:
|
|
if c == '\n' {
|
|
start = i + 1
|
|
p.nextState(stateHeaderKeyBefore)
|
|
continue
|
|
}
|
|
return ErrLFExpected
|
|
case stateHeaderKeyBefore:
|
|
switch c {
|
|
case ' ':
|
|
if !p.headerExists {
|
|
return ErrInvalidCharInHeader
|
|
}
|
|
case '\r':
|
|
err := p.parseTransferEncoding()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = p.parseContentLength()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.Processor.OnContentLength(p.contentLength)
|
|
err = p.parseTrailer()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
start = i + 1
|
|
p.nextState(stateHeaderOverLF)
|
|
case '\n':
|
|
return ErrInvalidCharInHeader
|
|
default:
|
|
if isAlpha(c) {
|
|
start = i
|
|
p.nextState(stateHeaderKey)
|
|
p.headerExists = true
|
|
continue
|
|
}
|
|
return ErrInvalidCharInHeader
|
|
}
|
|
case stateHeaderKey:
|
|
switch c {
|
|
case ' ':
|
|
if p.headerKey == "" {
|
|
p.headerKey = http.CanonicalHeaderKey(string(data[start:i]))
|
|
}
|
|
case ':':
|
|
if p.headerKey == "" {
|
|
p.headerKey = http.CanonicalHeaderKey(string(data[start:i]))
|
|
}
|
|
start = i + 1
|
|
p.nextState(stateHeaderValueBefore)
|
|
case '\r', '\n':
|
|
return ErrInvalidCharInHeader
|
|
default:
|
|
if !isToken(c) {
|
|
return ErrInvalidCharInHeader
|
|
}
|
|
}
|
|
case stateHeaderValueBefore:
|
|
switch c {
|
|
case ' ':
|
|
case '\r':
|
|
if p.headerValue == "" {
|
|
p.headerValue = string(data[start:i])
|
|
}
|
|
switch p.headerKey {
|
|
case transferEncodingHeader, trailerHeader, contentLengthHeader:
|
|
if p.header == nil {
|
|
p.header = http.Header{}
|
|
}
|
|
p.header.Add(p.headerKey, p.headerValue)
|
|
default:
|
|
}
|
|
|
|
p.Processor.OnHeader(p.headerKey, p.headerValue)
|
|
p.headerKey = ""
|
|
p.headerValue = ""
|
|
|
|
start = i + 1
|
|
p.nextState(stateHeaderValueLF)
|
|
case '\n':
|
|
return ErrInvalidCharInHeader
|
|
default:
|
|
// if !isToken(c) {
|
|
// return ErrInvalidCharInHeader
|
|
// }
|
|
start = i
|
|
p.nextState(stateHeaderValue)
|
|
}
|
|
case stateHeaderValue:
|
|
switch c {
|
|
case '\r':
|
|
if p.headerValue == "" {
|
|
p.headerValue = string(data[start:i])
|
|
}
|
|
switch p.headerKey {
|
|
case transferEncodingHeader, trailerHeader, contentLengthHeader:
|
|
if p.header == nil {
|
|
p.header = http.Header{}
|
|
}
|
|
p.header.Add(p.headerKey, p.headerValue)
|
|
default:
|
|
}
|
|
|
|
p.Processor.OnHeader(p.headerKey, p.headerValue)
|
|
p.headerKey = ""
|
|
p.headerValue = ""
|
|
|
|
start = i + 1
|
|
p.nextState(stateHeaderValueLF)
|
|
case '\n':
|
|
return ErrInvalidCharInHeader
|
|
default:
|
|
}
|
|
case stateHeaderOverLF:
|
|
if c == '\n' {
|
|
p.headerExists = false
|
|
if p.chunked {
|
|
start = i + 1
|
|
p.nextState(stateBodyChunkSizeBefore)
|
|
} else {
|
|
start = i + 1
|
|
if p.contentLength > 0 {
|
|
p.nextState(stateBodyContentLength)
|
|
} else {
|
|
p.handleMessage()
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
return ErrLFExpected
|
|
case stateBodyContentLength:
|
|
cl := p.contentLength
|
|
left := len(data) - start
|
|
if left >= cl {
|
|
p.Processor.OnBody(data[start : start+cl])
|
|
p.handleMessage()
|
|
start += cl
|
|
i = start - 1
|
|
} else {
|
|
goto Exit
|
|
}
|
|
case stateBodyChunkSizeBefore:
|
|
if isHex(c) {
|
|
p.chunkSize = -1
|
|
start = i
|
|
p.nextState(stateBodyChunkSize)
|
|
continue
|
|
}
|
|
return ErrInvalidChunkSize
|
|
case stateBodyChunkSize:
|
|
switch c {
|
|
case ' ':
|
|
if p.chunkSize < 0 {
|
|
chunkSize, err := parseAndValidateChunkSize(string(data[start:i]))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.chunkSize = chunkSize
|
|
}
|
|
case '\r':
|
|
if p.chunkSize < 0 {
|
|
chunkSize, err := parseAndValidateChunkSize(string(data[start:i]))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.chunkSize = chunkSize
|
|
}
|
|
start = i + 1
|
|
p.nextState(stateBodyChunkSizeLF)
|
|
default:
|
|
if !isHex(c) && p.chunkSize < 0 {
|
|
chunkSize, err := parseAndValidateChunkSize(string(data[start:i]))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.chunkSize = chunkSize
|
|
}
|
|
}
|
|
case stateBodyChunkSizeLF:
|
|
if c == '\n' {
|
|
start = i + 1
|
|
if p.chunkSize > 0 {
|
|
p.nextState(stateBodyChunkData)
|
|
} else {
|
|
// chunk size is 0
|
|
|
|
if len(p.trailer) > 0 {
|
|
// read trailer headers
|
|
p.nextState(stateBodyTrailerHeaderKeyBefore)
|
|
} else {
|
|
// read tail cr lf
|
|
p.nextState(stateTailCR)
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
return ErrLFExpected
|
|
case stateBodyChunkData:
|
|
cl := p.chunkSize
|
|
left := len(data) - start
|
|
if left >= cl {
|
|
p.Processor.OnBody(data[start : start+cl])
|
|
start += cl
|
|
i = start - 1
|
|
p.nextState(stateBodyChunkDataCR)
|
|
} else {
|
|
goto Exit
|
|
}
|
|
case stateBodyChunkDataCR:
|
|
if c == '\r' {
|
|
p.nextState(stateBodyChunkDataLF)
|
|
continue
|
|
}
|
|
return ErrCRExpected
|
|
case stateBodyChunkDataLF:
|
|
if c == '\n' {
|
|
p.nextState(stateBodyChunkSizeBefore)
|
|
continue
|
|
}
|
|
return ErrLFExpected
|
|
case stateBodyTrailerHeaderValueLF:
|
|
if c == '\n' {
|
|
start = i
|
|
p.nextState(stateBodyTrailerHeaderKeyBefore)
|
|
continue
|
|
}
|
|
return ErrLFExpected
|
|
case stateBodyTrailerHeaderKeyBefore:
|
|
if isAlpha(c) {
|
|
start = i
|
|
p.nextState(stateBodyTrailerHeaderKey)
|
|
continue
|
|
}
|
|
|
|
// all trailer header readed
|
|
if c == '\r' {
|
|
if len(p.trailer) > 0 {
|
|
return ErrTrailerExpected
|
|
}
|
|
start = i + 1
|
|
p.nextState(stateTailLF)
|
|
continue
|
|
}
|
|
case stateBodyTrailerHeaderKey:
|
|
switch c {
|
|
case ' ':
|
|
if p.headerKey == "" {
|
|
p.headerKey = http.CanonicalHeaderKey(string(data[start:i]))
|
|
}
|
|
continue
|
|
case ':':
|
|
if p.headerKey == "" {
|
|
p.headerKey = http.CanonicalHeaderKey(string(data[start:i]))
|
|
}
|
|
start = i + 1
|
|
p.nextState(stateBodyTrailerHeaderValueBefore)
|
|
continue
|
|
}
|
|
if !isToken(c) {
|
|
return ErrInvalidCharInHeader
|
|
}
|
|
case stateBodyTrailerHeaderValueBefore:
|
|
switch c {
|
|
case ' ':
|
|
case '\r':
|
|
if p.headerValue == "" {
|
|
p.headerValue = string(data[start:i])
|
|
}
|
|
p.Processor.OnTrailerHeader(p.headerKey, p.headerValue)
|
|
p.headerKey = ""
|
|
p.headerValue = ""
|
|
|
|
start = i + 1
|
|
p.nextState(stateBodyTrailerHeaderValueLF)
|
|
default:
|
|
// if !isToken(c) {
|
|
// return ErrInvalidCharInHeader
|
|
// }
|
|
start = i
|
|
p.nextState(stateBodyTrailerHeaderValue)
|
|
}
|
|
case stateBodyTrailerHeaderValue:
|
|
switch c {
|
|
case ' ':
|
|
if p.headerValue == "" {
|
|
p.headerValue = string(data[start:i])
|
|
}
|
|
case '\r':
|
|
if p.headerValue == "" {
|
|
p.headerValue = string(data[start:i])
|
|
}
|
|
if len(p.trailer) == 0 {
|
|
return fmt.Errorf("invalid trailer '%v'", p.headerKey)
|
|
}
|
|
delete(p.trailer, p.headerKey)
|
|
|
|
p.Processor.OnTrailerHeader(p.headerKey, p.headerValue)
|
|
start = i + 1
|
|
p.headerKey = ""
|
|
p.headerValue = ""
|
|
p.nextState(stateBodyTrailerHeaderValueLF)
|
|
default:
|
|
// if !isToken(c) {
|
|
// return ErrInvalidCharInHeader
|
|
// }
|
|
}
|
|
case stateTailCR:
|
|
if c == '\r' {
|
|
p.nextState(stateTailLF)
|
|
continue
|
|
}
|
|
return ErrCRExpected
|
|
case stateTailLF:
|
|
if c == '\n' {
|
|
start = i + 1
|
|
p.handleMessage()
|
|
continue
|
|
}
|
|
return ErrLFExpected
|
|
default:
|
|
}
|
|
}
|
|
|
|
Exit:
|
|
left := len(data) - start
|
|
if left > 0 {
|
|
if p.cache == nil {
|
|
p.cache = mempool.Malloc(left)
|
|
copy(p.cache, data[start:])
|
|
} else if start > 0 {
|
|
oldCache := p.cache
|
|
p.cache = mempool.Malloc(left)
|
|
copy(p.cache, data[start:])
|
|
mempool.Free(oldCache)
|
|
}
|
|
} else if len(p.cache) > 0 {
|
|
mempool.Free(p.cache)
|
|
p.cache = nil
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Parser) parseTransferEncoding() error {
|
|
raw, present := p.header[transferEncodingHeader]
|
|
if !present {
|
|
return nil
|
|
}
|
|
delete(p.header, transferEncodingHeader)
|
|
|
|
if len(raw) != 1 {
|
|
return fmt.Errorf("too many transfer encodings: %q", raw)
|
|
}
|
|
if strings.ToLower(textproto.TrimString(raw[0])) != "chunked" {
|
|
return fmt.Errorf("unsupported transfer encoding: %q", raw[0])
|
|
}
|
|
delete(p.header, contentLengthHeader)
|
|
p.chunked = true
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Parser) parseContentLength() (err error) {
|
|
if cl := p.header.Get(contentLengthHeader); cl != "" {
|
|
if p.chunked {
|
|
return ErrUnexpectedContentLength
|
|
}
|
|
end := len(cl) - 1
|
|
for i := end; i >= 0; i-- {
|
|
if cl[i] != ' ' {
|
|
if i != end {
|
|
cl = cl[:i+1]
|
|
}
|
|
break
|
|
}
|
|
}
|
|
l, err := strconv.ParseInt(cl, 10, 63)
|
|
if err != nil {
|
|
return fmt.Errorf("%s %q", "bad Content-Length", cl)
|
|
}
|
|
if l < 0 {
|
|
return fmt.Errorf("length less than zero (%d): %w", l, ErrInvalidContentLength)
|
|
}
|
|
if l > MaxInt {
|
|
return fmt.Errorf("length greater than maxint (%d): %w", l, ErrInvalidContentLength)
|
|
}
|
|
p.contentLength = int(l)
|
|
} else {
|
|
p.contentLength = -1
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *Parser) parseTrailer() error {
|
|
if !p.chunked {
|
|
return nil
|
|
}
|
|
header := p.header
|
|
|
|
trailers, ok := header[trailerHeader]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
header.Del(trailerHeader)
|
|
|
|
trailer := http.Header{}
|
|
for _, key := range trailers {
|
|
key = textproto.TrimString(key)
|
|
if key == "" {
|
|
continue
|
|
}
|
|
if !strings.Contains(key, ",") {
|
|
key = http.CanonicalHeaderKey(key)
|
|
switch key {
|
|
case transferEncodingHeader, trailerHeader, contentLengthHeader:
|
|
return fmt.Errorf("%s %q", "bad trailer key", key)
|
|
default:
|
|
trailer[key] = nil
|
|
}
|
|
continue
|
|
}
|
|
for _, k := range strings.Split(key, ",") {
|
|
if k = textproto.TrimString(k); k != "" {
|
|
k = http.CanonicalHeaderKey(k)
|
|
switch k {
|
|
case transferEncodingHeader, trailerHeader, contentLengthHeader:
|
|
return fmt.Errorf("%s %q", "bad trailer key", k)
|
|
default:
|
|
trailer[k] = nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if len(trailer) > 0 {
|
|
p.trailer = trailer
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *Parser) handleMessage() {
|
|
p.Processor.OnComplete(p)
|
|
p.header = nil
|
|
|
|
if !p.isClient {
|
|
p.nextState(stateMethodBefore)
|
|
} else {
|
|
p.nextState(stateClientProtoBefore)
|
|
}
|
|
}
|
|
|
|
// NewParser .
|
|
func NewParser(processor Processor, isClient bool, readLimit int, executor func(f func())) *Parser {
|
|
if processor == nil {
|
|
processor = NewEmptyProcessor()
|
|
}
|
|
state := stateMethodBefore
|
|
if isClient {
|
|
state = stateClientProtoBefore
|
|
}
|
|
if readLimit <= 0 {
|
|
readLimit = DefaultHTTPReadLimit
|
|
}
|
|
if executor == nil {
|
|
executor = func(f func()) {
|
|
f()
|
|
}
|
|
}
|
|
p := &Parser{
|
|
state: state,
|
|
readLimit: readLimit,
|
|
isClient: isClient,
|
|
Execute: executor,
|
|
Processor: processor,
|
|
}
|
|
return p
|
|
}
|