// Copyright 2020 lesismal. All rights reserved. // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. package nbhttp import ( "bufio" "errors" "io" "net" "net/http" "os" "strconv" "time" "unsafe" "github.com/lesismal/nbio/logging" "github.com/lesismal/nbio/mempool" ) // Response represents the server side of an HTTP response. type Response struct { parser *Parser request *http.Request // request for this response status string statusCode int // status code passed to WriteHeader header http.Header trailer map[string]string trailerSize int buffer []byte bodyBuffer []byte intFormatBuf [10]byte chunked bool chunkChecked bool headEncoded bool hasBody bool enableSendfile bool hijacked bool } // Hijack . func (res *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { if res.parser.Processor == nil { return nil, nil, errors.New("nil Proccessor") } res.hijacked = true return res.parser.Processor.Conn(), nil, nil } // Header . func (res *Response) Header() http.Header { return res.header } // WriteHeader . func (res *Response) WriteHeader(statusCode int) { if !res.hijacked && res.statusCode == 0 && res.statusCode != statusCode { status := http.StatusText(statusCode) if status != "" { res.status = status res.statusCode = statusCode } if cl := res.header.Get(contentLengthHeader); cl != "" { v, err := strconv.ParseInt(cl, 10, 64) if err == nil && v >= 0 { } else { logging.Error("http: invalid Content-Length of %q", cl) res.header.Del(contentLengthHeader) } } res.checkChunked() } } const maxPacketSize = 65536 // WriteString . func (res *Response) WriteString(s string) (int, error) { x := (*[2]uintptr)(unsafe.Pointer(&s)) h := [3]uintptr{x[0], x[1], x[1]} data := *(*[]byte)(unsafe.Pointer(&h)) return res.Write(data) } // Write . func (res *Response) Write(data []byte) (int, error) { l := len(data) conn := res.parser.Processor.Conn() if l == 0 || conn == nil { return 0, nil } res.WriteHeader(http.StatusOK) res.hasBody = true if res.chunked { res.eoncodeHead() buf := res.buffer hl := len(buf) res.buffer = nil lenStr := res.formatInt(l, 16) size := hl + len(lenStr) + l + 4 if size < maxPacketSize { buf = append(buf, lenStr...) buf = append(buf, "\r\n"...) buf = append(buf, data...) buf = append(buf, "\r\n"...) res.buffer = buf return l, nil } _, err := conn.Write(buf) if err != nil { return 0, err } buf = mempool.Malloc(0) buf = append(buf, lenStr...) buf = append(buf, "\r\n"...) buf = append(buf, data...) buf = append(buf, "\r\n"...) if len(buf) < maxPacketSize { res.buffer = buf return l, nil } return conn.Write(buf) } if len(res.header[contentLengthHeader]) > 0 { res.eoncodeHead() buf := res.buffer res.buffer = nil if buf == nil { buf = mempool.Malloc(l)[0:0] } buf = append(buf, data...) return conn.Write(buf) } if res.bodyBuffer == nil { res.bodyBuffer = mempool.Malloc(l)[0:0] } res.bodyBuffer = append(res.bodyBuffer, data...) // res.header[contentLengthHeader] = []string{res.formatInt(l, 10)} return l, nil } // ReadFrom . func (res *Response) ReadFrom(r io.Reader) (n int64, err error) { c := res.parser.Processor.Conn() if c == nil { return 0, nil } res.hasBody = true res.eoncodeHead() _, err = c.Write(res.buffer) res.buffer = nil if err != nil { return 0, err } if res.enableSendfile { lr, ok := r.(*io.LimitedReader) if ok { n, r = lr.N, lr.R if n <= 0 { return 0, nil } } f, ok := r.(*os.File) if ok { nc, ok := c.(interface { Sendfile(f *os.File, remain int64) (int64, error) }) if ok { ns, err := nc.Sendfile(f, lr.N) return ns, err } } } return io.Copy(c, r) } // checkChunked . func (res *Response) checkChunked() { if res.chunkChecked { return } res.chunkChecked = true // res.WriteHeader(http.StatusOK) if res.request.ProtoAtLeast(1, 1) { for _, v := range res.header[transferEncodingHeader] { if v == "chunked" { res.chunked = true } } if !res.chunked { if len(res.header[trailerHeader]) > 0 { res.chunked = true hs := res.header[transferEncodingHeader] res.header[transferEncodingHeader] = append(hs, "chunked") } } } if res.chunked { delete(res.header, contentLengthHeader) } } // flush . func (res *Response) eoncodeHead() { if res.headEncoded { return } res.WriteHeader(http.StatusOK) res.headEncoded = true status := res.status statusCode := res.statusCode data := mempool.Malloc(1024)[0:0] data = append(data, res.request.Proto...) data = append(data, ' ', '0'+byte(statusCode/100), '0'+byte(statusCode%100)/10, '0'+byte(statusCode%10), ' ') data = append(data, status...) data = append(data, '\r', '\n') if res.hasBody && len(res.header["Content-Type"]) == 0 { const contentType = "Content-Type: text/plain; charset=utf-8\r\n" data = append(data, contentType...) } if !res.chunked { const contentLenthKey = "Content-Length: " if !res.hasBody { data = append(data, contentLenthKey...) data = append(data, '0', '\r', '\n') } else { data = append(data, contentLenthKey...) l := len(res.bodyBuffer) if l > 0 { s := strconv.FormatInt(int64(l), 10) data = append(data, s...) data = append(data, '\r', '\n') } else { data = append(data, '0', '\r', '\n') } } } if res.request.Close && len(res.header["Connection"]) == 0 { const connection = "Connection: close\r\n" data = append(data, connection...) } if len(res.header["Date"]) == 0 { const days = "SunMonTueWedThuFriSat" const months = "JanFebMarAprMayJunJulAugSepOctNovDec" t := time.Now().UTC() yy, mm, dd := t.Date() hh, mn, ss := t.Clock() day := days[3*t.Weekday():] mon := months[3*(mm-1):] data = append(data, 'D', 'a', 't', 'e', ':', ' ', day[0], day[1], day[2], ',', ' ', byte('0'+dd/10), byte('0'+dd%10), ' ', mon[0], mon[1], mon[2], ' ', byte('0'+yy/1000), byte('0'+(yy/100)%10), byte('0'+(yy/10)%10), byte('0'+yy%10), ' ', byte('0'+hh/10), byte('0'+hh%10), ':', byte('0'+mn/10), byte('0'+mn%10), ':', byte('0'+ss/10), byte('0'+ss%10), ' ', 'G', 'M', 'T', '\r', '\n') } res.trailer = map[string]string{} trailers := res.header[trailerHeader] for _, k := range trailers { res.trailer[k] = "" } for k, vv := range res.header { if _, ok := res.trailer[k]; !ok { for _, v := range vv { data = append(data, k...) data = append(data, ':', ' ') data = append(data, v...) data = append(data, '\r', '\n') } } else if len(vv) > 0 { v := res.header.Get(k) res.trailer[k] = v res.trailerSize += (len(k) + len(v) + 4) } } data = append(data, '\r', '\n') res.buffer = data } func (res *Response) flushTrailer(conn io.Writer) error { var err error if !res.chunked { if res.buffer != nil { if res.bodyBuffer != nil { res.buffer = append(res.buffer, res.bodyBuffer...) mempool.Free(res.bodyBuffer) res.bodyBuffer = nil } _, err = conn.Write(res.buffer) res.buffer = nil if err != nil { return err } } if res.bodyBuffer != nil { _, err = conn.Write(res.bodyBuffer) res.bodyBuffer = nil } return err } data := res.buffer res.buffer = nil if data == nil { data = mempool.Malloc(0) } if len(res.trailer) == 0 { data = append(data, "0\r\n\r\n"...) } else { data = append(data, "0\r\n"...) for k, v := range res.trailer { data = append(data, k...) data = append(data, ": "...) data = append(data, v...) data = append(data, "\r\n"...) } data = append(data, "\r\n"...) } _, err = conn.Write(data) return err } var numMap = []byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'} func (res *Response) formatInt(n int, base int) string { if n < 0 || n > 0x7FFFFFFF { return "" } buf := res.intFormatBuf[:] i := len(buf) for { i-- buf[i] = numMap[n%base] n /= base if n <= 0 { break } } return string(buf[i:]) } // NewResponse . func NewResponse(parser *Parser, request *http.Request, enableSendfile bool) *Response { res := responsePool.Get().(*Response) res.parser = parser res.request = request res.header = http.Header{ /*"Server": []string{"nbio"}*/ } res.enableSendfile = enableSendfile return res }