313 lines
6.0 KiB
Go
313 lines
6.0 KiB
Go
package nbhttp
|
|
|
|
import (
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/lesismal/llib/std/crypto/tls"
|
|
"github.com/lesismal/nbio"
|
|
"github.com/lesismal/nbio/logging"
|
|
"github.com/lesismal/nbio/mempool"
|
|
)
|
|
|
|
type resHandler struct {
|
|
c net.Conn
|
|
t time.Time
|
|
h func(res *http.Response, conn net.Conn, err error)
|
|
}
|
|
|
|
// ClientConn .
|
|
type ClientConn struct {
|
|
mux sync.Mutex
|
|
conn net.Conn
|
|
handlers []resHandler
|
|
|
|
closed bool
|
|
|
|
onClose func()
|
|
|
|
Engine *Engine
|
|
|
|
Jar http.CookieJar
|
|
|
|
Timeout time.Duration
|
|
|
|
IdleConnTimeout time.Duration
|
|
|
|
TLSClientConfig *tls.Config
|
|
|
|
Proxy func(*http.Request) (*url.URL, error)
|
|
|
|
CheckRedirect func(req *http.Request, via []*http.Request) error
|
|
}
|
|
|
|
// Reset .
|
|
func (c *ClientConn) Reset() {
|
|
c.mux.Lock()
|
|
if c.closed {
|
|
c.conn = nil
|
|
c.handlers = nil
|
|
c.closed = false
|
|
}
|
|
c.mux.Unlock()
|
|
}
|
|
|
|
// OnClose .
|
|
func (c *ClientConn) OnClose(h func()) {
|
|
if h == nil {
|
|
return
|
|
}
|
|
|
|
pre := c.onClose
|
|
c.onClose = func() {
|
|
if pre != nil {
|
|
pre()
|
|
}
|
|
h()
|
|
}
|
|
}
|
|
|
|
// Close .
|
|
func (c *ClientConn) Close() {
|
|
c.CloseWithError(io.EOF)
|
|
}
|
|
|
|
// CloseWithError .
|
|
func (c *ClientConn) CloseWithError(err error) {
|
|
c.mux.Lock()
|
|
defer c.mux.Unlock()
|
|
if !c.closed {
|
|
c.closed = true
|
|
c.closeWithErrorWithoutLock(err)
|
|
}
|
|
}
|
|
|
|
func (c *ClientConn) closeWithErrorWithoutLock(err error) {
|
|
if err == nil {
|
|
err = io.EOF
|
|
}
|
|
for _, h := range c.handlers {
|
|
h.h(nil, c.conn, err)
|
|
}
|
|
c.handlers = nil
|
|
if c.conn != nil {
|
|
c.conn.Close()
|
|
c.conn = nil
|
|
}
|
|
if c.onClose != nil {
|
|
c.onClose()
|
|
}
|
|
}
|
|
|
|
func (c *ClientConn) onResponse(res *http.Response, err error) {
|
|
c.mux.Lock()
|
|
defer c.mux.Unlock()
|
|
|
|
if !c.closed && len(c.handlers) > 0 {
|
|
head := c.handlers[0]
|
|
head.h(res, c.conn, err)
|
|
|
|
c.handlers = c.handlers[1:]
|
|
if len(c.handlers) > 0 {
|
|
next := c.handlers[0]
|
|
timeout := c.Timeout
|
|
deadline := next.t.Add(timeout)
|
|
if timeout > 0 {
|
|
if time.Now().After(deadline) {
|
|
c.closeWithErrorWithoutLock(ErrClientTimeout)
|
|
}
|
|
} else {
|
|
c.conn.SetReadDeadline(deadline)
|
|
}
|
|
} else {
|
|
if c.IdleConnTimeout > 0 {
|
|
c.conn.SetReadDeadline(time.Now().Add(c.IdleConnTimeout))
|
|
} else {
|
|
c.conn.SetReadDeadline(time.Time{})
|
|
}
|
|
}
|
|
if len(c.handlers) == 0 {
|
|
c.handlers = nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// Do .
|
|
func (c *ClientConn) Do(req *http.Request, handler func(res *http.Response, conn net.Conn, err error)) {
|
|
c.mux.Lock()
|
|
defer func() {
|
|
c.mux.Unlock()
|
|
if err := recover(); err != nil {
|
|
const size = 64 << 10
|
|
buf := make([]byte, size)
|
|
buf = buf[:runtime.Stack(buf, false)]
|
|
logging.Error("ClientConn Do failed: %v\n%v\n", err, *(*string)(unsafe.Pointer(&buf)))
|
|
}
|
|
}()
|
|
|
|
if c.closed {
|
|
handler(nil, nil, ErrClientClosed)
|
|
return
|
|
}
|
|
|
|
var engine = c.Engine
|
|
var confTimeout = c.Timeout
|
|
|
|
c.handlers = append(c.handlers, resHandler{c: c.conn, t: time.Now(), h: handler})
|
|
|
|
var deadline time.Time
|
|
if confTimeout > 0 {
|
|
deadline = time.Now().Add(confTimeout)
|
|
}
|
|
|
|
sendRequest := func() {
|
|
err := req.Write(c.conn)
|
|
if err != nil {
|
|
c.closeWithErrorWithoutLock(err)
|
|
return
|
|
}
|
|
}
|
|
|
|
if c.conn != nil {
|
|
if confTimeout > 0 && len(c.handlers) == 1 {
|
|
c.conn.SetReadDeadline(deadline)
|
|
}
|
|
sendRequest()
|
|
} else {
|
|
var timeout time.Duration
|
|
if confTimeout > 0 {
|
|
timeout = time.Until(deadline)
|
|
if timeout <= 0 {
|
|
c.closeWithErrorWithoutLock(ErrClientTimeout)
|
|
return
|
|
}
|
|
}
|
|
|
|
strs := strings.Split(req.URL.Host, ":")
|
|
host := strs[0]
|
|
port := req.URL.Scheme
|
|
if len(strs) >= 2 {
|
|
port = strs[1]
|
|
}
|
|
addr := host + ":" + port
|
|
|
|
var netDial netDialerFunc
|
|
if confTimeout <= 0 {
|
|
netDial = func(network, addr string) (net.Conn, error) {
|
|
return net.Dial(network, addr)
|
|
}
|
|
} else {
|
|
netDial = func(network, addr string) (net.Conn, error) {
|
|
conn, err := net.DialTimeout(network, addr, timeout)
|
|
if err == nil {
|
|
conn.SetReadDeadline(deadline)
|
|
}
|
|
return conn, err
|
|
}
|
|
}
|
|
|
|
if c.Proxy != nil {
|
|
proxyURL, err := c.Proxy(req)
|
|
if err != nil {
|
|
c.closeWithErrorWithoutLock(err)
|
|
return
|
|
}
|
|
if proxyURL != nil {
|
|
dialer, err := proxyFromURL(proxyURL, netDial)
|
|
if err != nil {
|
|
c.closeWithErrorWithoutLock(err)
|
|
return
|
|
}
|
|
netDial = dialer.Dial
|
|
}
|
|
}
|
|
|
|
netConn, err := netDial(defaultNetwork, addr)
|
|
if err != nil {
|
|
c.closeWithErrorWithoutLock(err)
|
|
return
|
|
}
|
|
|
|
switch req.URL.Scheme {
|
|
case "http":
|
|
var nbc *nbio.Conn
|
|
nbc, err = nbio.NBConn(netConn)
|
|
if err != nil {
|
|
c.closeWithErrorWithoutLock(err)
|
|
return
|
|
}
|
|
|
|
c.conn = nbc
|
|
processor := NewClientProcessor(c, c.onResponse)
|
|
parser := NewParser(processor, true, engine.ReadLimit, nbc.Execute)
|
|
parser.Conn = nbc
|
|
parser.Engine = engine
|
|
parser.OnClose(func(p *Parser, err error) {
|
|
c.CloseWithError(err)
|
|
})
|
|
nbc.SetSession(parser)
|
|
|
|
nbc.OnData(engine.DataHandler)
|
|
engine.AddConn(nbc)
|
|
case "https":
|
|
tlsConfig := c.TLSClientConfig
|
|
if tlsConfig == nil {
|
|
tlsConfig = &tls.Config{}
|
|
} else {
|
|
tlsConfig = tlsConfig.Clone()
|
|
}
|
|
tlsConfig.ServerName = req.URL.Host
|
|
tlsConn := tls.NewConn(netConn, tlsConfig, true, false, mempool.DefaultMemPool)
|
|
err = tlsConn.Handshake()
|
|
if err != nil {
|
|
c.closeWithErrorWithoutLock(err)
|
|
return
|
|
}
|
|
if !tlsConfig.InsecureSkipVerify {
|
|
if err := tlsConn.VerifyHostname(tlsConfig.ServerName); err != nil {
|
|
c.closeWithErrorWithoutLock(err)
|
|
return
|
|
}
|
|
}
|
|
|
|
nbc, err := nbio.NBConn(tlsConn.Conn())
|
|
if err != nil {
|
|
c.closeWithErrorWithoutLock(err)
|
|
return
|
|
}
|
|
|
|
isNonblock := true
|
|
tlsConn.ResetConn(nbc, isNonblock)
|
|
|
|
c.conn = tlsConn
|
|
processor := NewClientProcessor(c, c.onResponse)
|
|
parser := NewParser(processor, true, engine.ReadLimit, nbc.Execute)
|
|
parser.Conn = tlsConn
|
|
parser.Engine = engine
|
|
parser.OnClose(func(p *Parser, err error) {
|
|
c.CloseWithError(err)
|
|
})
|
|
nbc.SetSession(parser)
|
|
|
|
nbc.OnData(engine.TLSDataHandler)
|
|
_, err = engine.AddConn(nbc)
|
|
if err != nil {
|
|
c.closeWithErrorWithoutLock(err)
|
|
return
|
|
}
|
|
default:
|
|
c.closeWithErrorWithoutLock(ErrClientUnsupportedSchema)
|
|
return
|
|
}
|
|
|
|
sendRequest()
|
|
}
|
|
}
|