2021-12-01 15:43:13 +00:00

799 lines
16 KiB
Go

// Copyright 2020 lesismal. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package nbhttp
import (
"fmt"
"net"
"net/http"
"net/textproto"
"strconv"
"strings"
"sync"
"github.com/lesismal/nbio/mempool"
)
const (
transferEncodingHeader = "Transfer-Encoding"
trailerHeader = "Trailer"
contentLengthHeader = "Content-Length"
// MaxUint .
MaxUint = ^uint(0)
// MaxInt .
MaxInt = int64(int(MaxUint >> 1))
)
// Parser .
type Parser struct {
mux sync.Mutex
cache []byte
proto string
statusCode int
status string
headerKey string
headerValue string
header http.Header
trailer http.Header
contentLength int
chunkSize int
chunked bool
headerExists bool
state int8
isClient bool
readLimit int
errClose error
onClose func(p *Parser, err error)
Processor Processor
ConnState ReadCloser
Engine *Engine
Conn net.Conn
Execute func(f func())
}
func (p *Parser) nextState(state int8) {
switch p.state {
case stateClose:
default:
p.state = state
}
}
// OnClose .
func (p *Parser) OnClose(h func(p *Parser, err error)) {
p.onClose = h
}
// Close .
func (p *Parser) Close(err error) {
p.mux.Lock()
defer p.mux.Unlock()
if p.state == stateClose {
return
}
p.state = stateClose
p.errClose = err
if p.ConnState != nil {
p.ConnState.Close(p, p.errClose)
}
if p.Processor != nil {
p.Processor.Close(p, p.errClose)
}
if len(p.cache) > 0 {
mempool.Free(p.cache)
}
if p.onClose != nil {
p.onClose(p, err)
}
}
func parseAndValidateChunkSize(originalStr string) (int, error) {
chunkSize, err := strconv.ParseInt(originalStr, 16, 63)
if err != nil {
return -1, fmt.Errorf("chunk size parse error %v: %w", originalStr, err)
}
if chunkSize < 0 {
return -1, fmt.Errorf("chunk size zero")
}
if chunkSize > MaxInt {
return -1, fmt.Errorf("chunk size greater than max int %d", chunkSize)
}
return int(chunkSize), nil
}
// Read .
func (p *Parser) Read(data []byte) error {
p.mux.Lock()
defer p.mux.Unlock()
if p.state == stateClose {
return ErrClosed
}
if len(data) == 0 {
return nil
}
var c byte
var start = 0
var offset = len(p.cache)
if offset > 0 {
if offset+len(data) > p.readLimit {
return ErrTooLong
}
p.cache = append(p.cache, data...)
data = p.cache
}
UPGRADER:
if p.ConnState != nil {
udata := data
if start > 0 {
udata = data[start:]
}
err := p.ConnState.Read(p, udata)
if p.cache != nil {
mempool.Free(p.cache)
p.cache = nil
}
return err
}
for i := offset; i < len(data); i++ {
if p.ConnState != nil {
goto UPGRADER
}
c = data[i]
switch p.state {
case stateClose:
return ErrClosed
case stateMethodBefore:
if isValidMethodChar(c) {
start = i
p.nextState(stateMethod)
continue
}
return ErrInvalidMethod
case stateMethod:
if c == ' ' {
var method = strings.ToUpper(string(data[start:i]))
if !isValidMethod(method) {
return ErrInvalidMethod
}
p.Processor.OnMethod(method)
start = i + 1
p.nextState(statePathBefore)
continue
}
if !isAlpha(c) {
return ErrInvalidMethod
}
case statePathBefore:
switch c {
case '/', '*':
start = i
p.nextState(statePath)
continue
}
switch c {
case ' ':
default:
return ErrInvalidRequestURI
}
case statePath:
if c == ' ' {
var uri = string(data[start:i])
if err := p.Processor.OnURL(uri); err != nil {
return err
}
start = i + 1
p.nextState(stateProtoBefore)
}
case stateProtoBefore:
if c != ' ' {
start = i
p.nextState(stateProto)
}
case stateProto:
switch c {
case ' ':
if p.proto == "" {
p.proto = string(data[start:i])
}
case '\r':
if p.proto == "" {
p.proto = string(data[start:i])
}
if err := p.Processor.OnProto(p.proto); err != nil {
p.proto = ""
return err
}
p.proto = ""
p.nextState(stateProtoLF)
}
case stateClientProtoBefore:
if c == 'H' {
start = i
p.nextState(stateClientProto)
continue
}
return ErrInvalidMethod
case stateClientProto:
switch c {
case ' ':
if p.proto == "" {
p.proto = string(data[start:i])
}
if err := p.Processor.OnProto(p.proto); err != nil {
p.proto = ""
return err
}
p.proto = ""
p.nextState(stateStatusCodeBefore)
}
case stateStatusCodeBefore:
switch c {
case ' ':
default:
if isNum(c) {
start = i
p.nextState(stateStatusCode)
}
continue
}
return ErrInvalidHTTPStatusCode
case stateStatusCode:
if c == ' ' {
cs := string(data[start:i])
code, err := strconv.Atoi(cs)
if err != nil {
return err
}
p.statusCode = code
p.nextState(stateStatusBefore)
continue
}
if !isNum(c) {
return ErrInvalidHTTPStatusCode
}
case stateStatusBefore:
switch c {
case ' ':
default:
if isAlpha(c) {
start = i
p.nextState(stateStatus)
}
continue
}
return ErrInvalidHTTPStatus
case stateStatus:
switch c {
case ' ':
if p.status == "" {
p.status = string(data[start:i])
}
case '\r':
if p.status == "" {
p.status = string(data[start:i])
}
p.Processor.OnStatus(p.statusCode, p.status)
p.statusCode = 0
p.status = ""
p.nextState(stateStatusLF)
}
case stateStatusLF:
if c == '\n' {
p.nextState(stateHeaderKeyBefore)
continue
}
return ErrLFExpected
case stateProtoLF:
if c == '\n' {
start = i + 1
p.nextState(stateHeaderKeyBefore)
continue
}
return ErrLFExpected
case stateHeaderValueLF:
if c == '\n' {
start = i + 1
p.nextState(stateHeaderKeyBefore)
continue
}
return ErrLFExpected
case stateHeaderKeyBefore:
switch c {
case ' ':
if !p.headerExists {
return ErrInvalidCharInHeader
}
case '\r':
err := p.parseTransferEncoding()
if err != nil {
return err
}
err = p.parseContentLength()
if err != nil {
return err
}
p.Processor.OnContentLength(p.contentLength)
err = p.parseTrailer()
if err != nil {
return err
}
start = i + 1
p.nextState(stateHeaderOverLF)
case '\n':
return ErrInvalidCharInHeader
default:
if isAlpha(c) {
start = i
p.nextState(stateHeaderKey)
p.headerExists = true
continue
}
return ErrInvalidCharInHeader
}
case stateHeaderKey:
switch c {
case ' ':
if p.headerKey == "" {
p.headerKey = http.CanonicalHeaderKey(string(data[start:i]))
}
case ':':
if p.headerKey == "" {
p.headerKey = http.CanonicalHeaderKey(string(data[start:i]))
}
start = i + 1
p.nextState(stateHeaderValueBefore)
case '\r', '\n':
return ErrInvalidCharInHeader
default:
if !isToken(c) {
return ErrInvalidCharInHeader
}
}
case stateHeaderValueBefore:
switch c {
case ' ':
case '\r':
if p.headerValue == "" {
p.headerValue = string(data[start:i])
}
switch p.headerKey {
case transferEncodingHeader, trailerHeader, contentLengthHeader:
if p.header == nil {
p.header = http.Header{}
}
p.header.Add(p.headerKey, p.headerValue)
default:
}
p.Processor.OnHeader(p.headerKey, p.headerValue)
p.headerKey = ""
p.headerValue = ""
start = i + 1
p.nextState(stateHeaderValueLF)
case '\n':
return ErrInvalidCharInHeader
default:
// if !isToken(c) {
// return ErrInvalidCharInHeader
// }
start = i
p.nextState(stateHeaderValue)
}
case stateHeaderValue:
switch c {
case '\r':
if p.headerValue == "" {
p.headerValue = string(data[start:i])
}
switch p.headerKey {
case transferEncodingHeader, trailerHeader, contentLengthHeader:
if p.header == nil {
p.header = http.Header{}
}
p.header.Add(p.headerKey, p.headerValue)
default:
}
p.Processor.OnHeader(p.headerKey, p.headerValue)
p.headerKey = ""
p.headerValue = ""
start = i + 1
p.nextState(stateHeaderValueLF)
case '\n':
return ErrInvalidCharInHeader
default:
}
case stateHeaderOverLF:
if c == '\n' {
p.headerExists = false
if p.chunked {
start = i + 1
p.nextState(stateBodyChunkSizeBefore)
} else {
start = i + 1
if p.contentLength > 0 {
p.nextState(stateBodyContentLength)
} else {
p.handleMessage()
}
}
continue
}
return ErrLFExpected
case stateBodyContentLength:
cl := p.contentLength
left := len(data) - start
if left >= cl {
p.Processor.OnBody(data[start : start+cl])
p.handleMessage()
start += cl
i = start - 1
} else {
goto Exit
}
case stateBodyChunkSizeBefore:
if isHex(c) {
p.chunkSize = -1
start = i
p.nextState(stateBodyChunkSize)
continue
}
return ErrInvalidChunkSize
case stateBodyChunkSize:
switch c {
case ' ':
if p.chunkSize < 0 {
chunkSize, err := parseAndValidateChunkSize(string(data[start:i]))
if err != nil {
return err
}
p.chunkSize = chunkSize
}
case '\r':
if p.chunkSize < 0 {
chunkSize, err := parseAndValidateChunkSize(string(data[start:i]))
if err != nil {
return err
}
p.chunkSize = chunkSize
}
start = i + 1
p.nextState(stateBodyChunkSizeLF)
default:
if !isHex(c) && p.chunkSize < 0 {
chunkSize, err := parseAndValidateChunkSize(string(data[start:i]))
if err != nil {
return err
}
p.chunkSize = chunkSize
}
}
case stateBodyChunkSizeLF:
if c == '\n' {
start = i + 1
if p.chunkSize > 0 {
p.nextState(stateBodyChunkData)
} else {
// chunk size is 0
if len(p.trailer) > 0 {
// read trailer headers
p.nextState(stateBodyTrailerHeaderKeyBefore)
} else {
// read tail cr lf
p.nextState(stateTailCR)
}
}
continue
}
return ErrLFExpected
case stateBodyChunkData:
cl := p.chunkSize
left := len(data) - start
if left >= cl {
p.Processor.OnBody(data[start : start+cl])
start += cl
i = start - 1
p.nextState(stateBodyChunkDataCR)
} else {
goto Exit
}
case stateBodyChunkDataCR:
if c == '\r' {
p.nextState(stateBodyChunkDataLF)
continue
}
return ErrCRExpected
case stateBodyChunkDataLF:
if c == '\n' {
p.nextState(stateBodyChunkSizeBefore)
continue
}
return ErrLFExpected
case stateBodyTrailerHeaderValueLF:
if c == '\n' {
start = i
p.nextState(stateBodyTrailerHeaderKeyBefore)
continue
}
return ErrLFExpected
case stateBodyTrailerHeaderKeyBefore:
if isAlpha(c) {
start = i
p.nextState(stateBodyTrailerHeaderKey)
continue
}
// all trailer header readed
if c == '\r' {
if len(p.trailer) > 0 {
return ErrTrailerExpected
}
start = i + 1
p.nextState(stateTailLF)
continue
}
case stateBodyTrailerHeaderKey:
switch c {
case ' ':
if p.headerKey == "" {
p.headerKey = http.CanonicalHeaderKey(string(data[start:i]))
}
continue
case ':':
if p.headerKey == "" {
p.headerKey = http.CanonicalHeaderKey(string(data[start:i]))
}
start = i + 1
p.nextState(stateBodyTrailerHeaderValueBefore)
continue
}
if !isToken(c) {
return ErrInvalidCharInHeader
}
case stateBodyTrailerHeaderValueBefore:
switch c {
case ' ':
case '\r':
if p.headerValue == "" {
p.headerValue = string(data[start:i])
}
p.Processor.OnTrailerHeader(p.headerKey, p.headerValue)
p.headerKey = ""
p.headerValue = ""
start = i + 1
p.nextState(stateBodyTrailerHeaderValueLF)
default:
// if !isToken(c) {
// return ErrInvalidCharInHeader
// }
start = i
p.nextState(stateBodyTrailerHeaderValue)
}
case stateBodyTrailerHeaderValue:
switch c {
case ' ':
if p.headerValue == "" {
p.headerValue = string(data[start:i])
}
case '\r':
if p.headerValue == "" {
p.headerValue = string(data[start:i])
}
if len(p.trailer) == 0 {
return fmt.Errorf("invalid trailer '%v'", p.headerKey)
}
delete(p.trailer, p.headerKey)
p.Processor.OnTrailerHeader(p.headerKey, p.headerValue)
start = i + 1
p.headerKey = ""
p.headerValue = ""
p.nextState(stateBodyTrailerHeaderValueLF)
default:
// if !isToken(c) {
// return ErrInvalidCharInHeader
// }
}
case stateTailCR:
if c == '\r' {
p.nextState(stateTailLF)
continue
}
return ErrCRExpected
case stateTailLF:
if c == '\n' {
start = i + 1
p.handleMessage()
continue
}
return ErrLFExpected
default:
}
}
Exit:
left := len(data) - start
if left > 0 {
if p.cache == nil {
p.cache = mempool.Malloc(left)
copy(p.cache, data[start:])
} else if start > 0 {
oldCache := p.cache
p.cache = mempool.Malloc(left)
copy(p.cache, data[start:])
mempool.Free(oldCache)
}
} else if len(p.cache) > 0 {
mempool.Free(p.cache)
p.cache = nil
}
return nil
}
func (p *Parser) parseTransferEncoding() error {
raw, present := p.header[transferEncodingHeader]
if !present {
return nil
}
delete(p.header, transferEncodingHeader)
if len(raw) != 1 {
return fmt.Errorf("too many transfer encodings: %q", raw)
}
if strings.ToLower(textproto.TrimString(raw[0])) != "chunked" {
return fmt.Errorf("unsupported transfer encoding: %q", raw[0])
}
delete(p.header, contentLengthHeader)
p.chunked = true
return nil
}
func (p *Parser) parseContentLength() (err error) {
if cl := p.header.Get(contentLengthHeader); cl != "" {
if p.chunked {
return ErrUnexpectedContentLength
}
end := len(cl) - 1
for i := end; i >= 0; i-- {
if cl[i] != ' ' {
if i != end {
cl = cl[:i+1]
}
break
}
}
l, err := strconv.ParseInt(cl, 10, 63)
if err != nil {
return fmt.Errorf("%s %q", "bad Content-Length", cl)
}
if l < 0 {
return fmt.Errorf("length less than zero (%d): %w", l, ErrInvalidContentLength)
}
if l > MaxInt {
return fmt.Errorf("length greater than maxint (%d): %w", l, ErrInvalidContentLength)
}
p.contentLength = int(l)
} else {
p.contentLength = -1
}
return nil
}
func (p *Parser) parseTrailer() error {
if !p.chunked {
return nil
}
header := p.header
trailers, ok := header[trailerHeader]
if !ok {
return nil
}
header.Del(trailerHeader)
trailer := http.Header{}
for _, key := range trailers {
key = textproto.TrimString(key)
if key == "" {
continue
}
if !strings.Contains(key, ",") {
key = http.CanonicalHeaderKey(key)
switch key {
case transferEncodingHeader, trailerHeader, contentLengthHeader:
return fmt.Errorf("%s %q", "bad trailer key", key)
default:
trailer[key] = nil
}
continue
}
for _, k := range strings.Split(key, ",") {
if k = textproto.TrimString(k); k != "" {
k = http.CanonicalHeaderKey(k)
switch k {
case transferEncodingHeader, trailerHeader, contentLengthHeader:
return fmt.Errorf("%s %q", "bad trailer key", k)
default:
trailer[k] = nil
}
}
}
}
if len(trailer) > 0 {
p.trailer = trailer
}
return nil
}
func (p *Parser) handleMessage() {
p.Processor.OnComplete(p)
p.header = nil
if !p.isClient {
p.nextState(stateMethodBefore)
} else {
p.nextState(stateClientProtoBefore)
}
}
// NewParser .
func NewParser(processor Processor, isClient bool, readLimit int, executor func(f func())) *Parser {
if processor == nil {
processor = NewEmptyProcessor()
}
state := stateMethodBefore
if isClient {
state = stateClientProtoBefore
}
if readLimit <= 0 {
readLimit = DefaultHTTPReadLimit
}
if executor == nil {
executor = func(f func()) {
f()
}
}
p := &Parser{
state: state,
readLimit: readLimit,
isClient: isClient,
Execute: executor,
Processor: processor,
}
return p
}