2021-12-01 15:43:13 +00:00

306 lines
5.9 KiB
Go

package nbhttp
import (
"io"
"net"
"net/http"
"net/url"
"runtime"
"strings"
"sync"
"time"
"unsafe"
"github.com/lesismal/llib/std/crypto/tls"
"github.com/lesismal/nbio"
"github.com/lesismal/nbio/logging"
"github.com/lesismal/nbio/mempool"
)
type resHandler struct {
c net.Conn
t time.Time
h func(res *http.Response, conn net.Conn, err error)
}
// ClientConn .
type ClientConn struct {
mux sync.Mutex
conn net.Conn
handlers []resHandler
closed bool
onClose func()
Engine *Engine
Jar http.CookieJar
Timeout time.Duration
IdleConnTimeout time.Duration
TLSClientConfig *tls.Config
Proxy func(*http.Request) (*url.URL, error)
CheckRedirect func(req *http.Request, via []*http.Request) error
}
// Reset .
func (c *ClientConn) Reset() {
c.conn = nil
c.handlers = nil
c.closed = false
}
// OnClose .
func (c *ClientConn) OnClose(h func()) {
if h == nil {
return
}
pre := c.onClose
c.onClose = func() {
if pre != nil {
pre()
}
h()
}
}
// Close .
func (c *ClientConn) Close() {
c.CloseWithError(io.EOF)
}
// CloseWithError .
func (c *ClientConn) CloseWithError(err error) {
c.mux.Lock()
closed := c.closed
c.closed = true
c.mux.Unlock()
if !closed {
c.closeWithErrorWithoutLock(err)
}
}
func (c *ClientConn) closeWithErrorWithoutLock(err error) {
for _, h := range c.handlers {
h.h(nil, c.conn, err)
}
c.handlers = nil
if c.conn != nil {
c.conn.Close()
}
if c.onClose != nil {
c.onClose()
}
}
func (c *ClientConn) onResponse(res *http.Response, err error) {
c.mux.Lock()
defer c.mux.Unlock()
if !c.closed && len(c.handlers) > 0 {
head := c.handlers[0]
head.h(res, c.conn, err)
c.handlers = c.handlers[1:]
if len(c.handlers) > 0 {
next := c.handlers[0]
timeout := c.Timeout
deadline := next.t.Add(timeout)
if timeout > 0 {
if time.Now().After(deadline) {
c.closeWithErrorWithoutLock(ErrClientTimeout)
}
} else {
c.conn.SetReadDeadline(deadline)
}
} else {
if c.IdleConnTimeout > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.IdleConnTimeout))
} else {
c.conn.SetReadDeadline(time.Time{})
}
}
if len(c.handlers) == 0 {
c.handlers = nil
}
}
}
// Do .
func (c *ClientConn) Do(req *http.Request, handler func(res *http.Response, conn net.Conn, err error)) {
c.mux.Lock()
defer func() {
c.mux.Unlock()
if err := recover(); err != nil {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
logging.Error("ClientConn Do failed: %v\n%v\n", err, *(*string)(unsafe.Pointer(&buf)))
}
}()
if c.closed {
handler(nil, nil, ErrClientClosed)
return
}
var engine = c.Engine
var confTimeout = c.Timeout
c.handlers = append(c.handlers, resHandler{c: c.conn, t: time.Now(), h: handler})
var deadline time.Time
if confTimeout > 0 {
deadline = time.Now().Add(confTimeout)
}
sendRequest := func() {
err := req.Write(c.conn)
if err != nil {
c.closeWithErrorWithoutLock(err)
return
}
}
if c.conn != nil {
if confTimeout > 0 && len(c.handlers) == 1 {
c.conn.SetReadDeadline(deadline)
}
sendRequest()
} else {
var timeout time.Duration
if confTimeout > 0 {
timeout = time.Until(deadline)
if timeout <= 0 {
c.closeWithErrorWithoutLock(ErrClientTimeout)
return
}
}
strs := strings.Split(req.URL.Host, ":")
host := strs[0]
port := req.URL.Scheme
if len(strs) >= 2 {
port = strs[1]
}
addr := host + ":" + port
var netDial netDialerFunc
if confTimeout <= 0 {
netDial = func(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}
} else {
netDial = func(network, addr string) (net.Conn, error) {
conn, err := net.DialTimeout(network, addr, timeout)
if err == nil {
conn.SetReadDeadline(deadline)
}
return conn, err
}
}
if c.Proxy != nil {
proxyURL, err := c.Proxy(req)
if err != nil {
c.closeWithErrorWithoutLock(err)
return
}
if proxyURL != nil {
dialer, err := proxyFromURL(proxyURL, netDial)
if err != nil {
c.closeWithErrorWithoutLock(err)
return
}
netDial = dialer.Dial
}
}
netConn, err := netDial(defaultNetwork, addr)
if err != nil {
c.closeWithErrorWithoutLock(err)
return
}
switch req.URL.Scheme {
case "http":
var nbc *nbio.Conn
nbc, err = nbio.NBConn(netConn)
if err != nil {
c.closeWithErrorWithoutLock(err)
return
}
c.conn = nbc
processor := NewClientProcessor(c, c.onResponse)
parser := NewParser(processor, true, engine.ReadLimit, nbc.Execute)
parser.Conn = nbc
parser.Engine = engine
parser.OnClose(func(p *Parser, err error) {
c.CloseWithError(err)
})
nbc.SetSession(parser)
nbc.OnData(engine.DataHandler)
engine.AddConn(nbc)
case "https":
tlsConfig := c.TLSClientConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{}
} else {
tlsConfig = tlsConfig.Clone()
}
tlsConfig.ServerName = req.URL.Host
tlsConn := tls.NewConn(netConn, tlsConfig, true, false, mempool.DefaultMemPool)
err = tlsConn.Handshake()
if err != nil {
c.closeWithErrorWithoutLock(err)
return
}
if !tlsConfig.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(tlsConfig.ServerName); err != nil {
c.closeWithErrorWithoutLock(err)
return
}
}
nbc, err := nbio.NBConn(tlsConn.Conn())
if err != nil {
c.closeWithErrorWithoutLock(err)
return
}
isNonblock := true
tlsConn.ResetConn(nbc, isNonblock)
c.conn = tlsConn
processor := NewClientProcessor(c, c.onResponse)
parser := NewParser(processor, true, engine.ReadLimit, nbc.Execute)
parser.Conn = tlsConn
parser.Engine = engine
parser.OnClose(func(p *Parser, err error) {
c.CloseWithError(err)
})
nbc.SetSession(parser)
nbc.OnData(engine.TLSDataHandler)
_, err = engine.AddConn(nbc)
if err != nil {
c.closeWithErrorWithoutLock(err)
return
}
default:
c.closeWithErrorWithoutLock(ErrClientUnsupportedSchema)
return
}
sendRequest()
}
}