2021-12-01 15:43:13 +00:00

500 lines
12 KiB
Go

package nbhttp
import (
"bufio"
"encoding/base64"
"errors"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/lesismal/llib/std/crypto/tls"
)
func newHostConns(cli *Client) *hostConns {
hcs := &hostConns{
cli: cli,
conns: map[*ClientConn]struct{}{},
maxConnNum: 1024, // 1024 by default
}
if cli.MaxConnsPerHost > 0 {
hcs.maxConnNum = cli.MaxConnsPerHost
}
hcs.chConnss = make(chan *ClientConn, hcs.maxConnNum)
return hcs
}
type hostConns struct {
mux sync.Mutex
cli *Client
conns map[*ClientConn]struct{}
connNum int32
maxConnNum int32
chConnss chan *ClientConn
}
func (hcs *hostConns) closeWithError(err error) {
hcs.mux.Lock()
for hc := range hcs.conns {
hc.CloseWithError(err)
}
hcs.mux.Unlock()
}
func (hcs *hostConns) getConn() (*hostConns, *ClientConn, error) {
c := hcs.cli
if !c.closed {
timer := time.NewTimer(c.Timeout)
defer timer.Stop()
// 1. fast get an existed free connection
select {
case hc, ok := <-hcs.chConnss:
if !ok {
return nil, nil, ErrClientClosed
}
return hcs, hc, nil
case <-timer.C:
return nil, nil, ErrClientTimeout
default:
}
// 2. try to create a new connection if the num of existed connections is smaller than maxConnNum
if atomic.AddInt32(&hcs.connNum, 1) <= hcs.maxConnNum {
hc := &ClientConn{
Engine: c.Engine,
Jar: c.Jar,
Timeout: c.Timeout,
IdleConnTimeout: c.IdleConnTimeout,
TLSClientConfig: c.TLSClientConfig,
Proxy: c.Proxy,
CheckRedirect: c.CheckRedirect,
}
hcs.mux.Lock()
hcs.conns[hc] = struct{}{}
hcs.mux.Unlock()
return hcs, hc, nil
}
atomic.AddInt32(&hcs.connNum, -1)
// 3. wait for an old connection
select {
case hc, ok := <-hcs.chConnss:
if !ok {
return nil, nil, ErrClientClosed
}
return hcs, hc, nil
case <-timer.C:
return nil, nil, ErrClientTimeout
}
}
return nil, nil, ErrClientClosed
}
func (hcs *hostConns) releaseConn(hc *ClientConn) {
hcs.chConnss <- hc
}
// Client .
type Client struct {
mux sync.Mutex
closed bool
connsMux sync.RWMutex
connsOfHosts map[string]*hostConns
Engine *Engine
Jar http.CookieJar
Timeout time.Duration
MaxConnsPerHost int32
IdleConnTimeout time.Duration
TLSClientConfig *tls.Config
Proxy func(*http.Request) (*url.URL, error)
CheckRedirect func(req *http.Request, via []*http.Request) error
}
// Close .
func (c *Client) Close() {
c.CloseWithError(io.EOF)
}
// CloseWithError .
func (c *Client) CloseWithError(err error) {
c.mux.Lock()
if !c.closed {
c.closed = true
for _, hcs := range c.connsOfHosts {
hcs.closeWithError(err)
}
}
c.mux.Unlock()
}
func (c *Client) getConn(host string) (*hostConns, *ClientConn, error) {
c.connsMux.Lock()
if c.closed {
c.connsMux.Unlock()
return nil, nil, ErrClientClosed
}
if c.connsOfHosts == nil {
c.connsOfHosts = map[string]*hostConns{}
}
hcs, ok := c.connsOfHosts[host]
if !ok {
hcs = newHostConns(c)
c.connsOfHosts[host] = hcs
}
c.connsMux.Unlock()
return hcs.getConn()
}
// Do .
func (c *Client) Do(req *http.Request, handler func(res *http.Response, conn net.Conn, err error)) {
c.Engine.ExecuteClient(func() {
host := req.URL.Host
hcs, hc, err := c.getConn(host)
if err != nil {
handler(nil, nil, err)
return
}
if hc.closed {
hc.Reset()
}
hc.Do(req, func(res *http.Response, conn net.Conn, err error) {
hcs.releaseConn(hc)
handler(res, conn, err)
})
})
}
type netDialerFunc func(network, addr string) (net.Conn, error)
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr)
}
var proxySchemes map[string]func(*url.URL, proxyDialer) (proxyDialer, error)
func proxyRegisterDialerType(scheme string, f func(*url.URL, proxyDialer) (proxyDialer, error)) {
if proxySchemes == nil {
proxySchemes = make(map[string]func(*url.URL, proxyDialer) (proxyDialer, error))
}
proxySchemes[scheme] = f
}
func proxyFromURL(u *url.URL, forward proxyDialer) (proxyDialer, error) {
var auth *proxyAuth
if u.User != nil {
auth = new(proxyAuth)
auth.User = u.User.Username()
if p, ok := u.User.Password(); ok {
auth.Password = p
}
}
switch u.Scheme {
case "socks5":
return proxySOCKS5("tcp", u.Host, auth, forward)
}
if proxySchemes != nil {
if f, ok := proxySchemes[u.Scheme]; ok {
return f(u, forward)
}
}
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
}
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host
hostNoPort = u.Host
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
hostNoPort = hostNoPort[:i]
} else {
switch u.Scheme {
case "wss":
hostPort += ":443"
case "https":
hostPort += ":443"
default:
hostPort += ":80"
}
}
return hostPort, hostNoPort
}
type proxyDialer interface {
Dial(network, addr string) (c net.Conn, err error)
}
type httpProxyDialer struct {
proxyURL *url.URL
forwardDial func(network, addr string) (net.Conn, error)
}
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
hostPort, _ := hostPortNoPort(hpd.proxyURL)
conn, err := hpd.forwardDial(network, hostPort)
if err != nil {
return nil, err
}
connectHeader := make(http.Header)
if user := hpd.proxyURL.User; user != nil {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
}
}
connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: addr},
Host: addr,
Header: connectHeader,
}
if errWrite := connectReq.Write(conn); errWrite != nil {
conn.Close()
return nil, errWrite
}
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
conn.Close()
return nil, err
}
resp.Body.Close()
if resp.StatusCode != 200 {
conn.Close()
f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1])
}
return conn, nil
}
type proxyAuth struct {
User, Password string
}
func proxySOCKS5(network, addr string, auth *proxyAuth, forward proxyDialer) (proxyDialer, error) {
s := &proxySocks5{
network: network,
addr: addr,
forward: forward,
}
if auth != nil {
s.user = auth.User
s.password = auth.Password
}
return s, nil
}
type proxySocks5 struct {
user, password string
network, addr string
forward proxyDialer
}
const proxySocks5Version = 5
const (
proxySocks5AuthNone = 0
proxySocks5AuthPassword = 2
)
const proxySocks5Connect = 1
const (
proxySocks5IP4 = 1
proxySocks5Domain = 3
proxySocks5IP6 = 4
)
var proxySocks5Errors = []string{
"",
"general failure",
"connection forbidden",
"network unreachable",
"host unreachable",
"connection refused",
"TTL expired",
"command not supported",
"address type not supported",
}
func (s *proxySocks5) Dial(network, addr string) (net.Conn, error) {
switch network {
case "tcp", "tcp6", "tcp4":
default:
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
}
conn, err := s.forward.Dial(s.network, s.addr)
if err != nil {
return nil, err
}
if err := s.connect(conn, addr); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
func (s *proxySocks5) connect(conn net.Conn, target string) error {
host, portStr, err := net.SplitHostPort(target)
if err != nil {
return err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return errors.New("proxy: failed to parse port number: " + portStr)
}
if port < 1 || port > 0xffff {
return errors.New("proxy: port number out of range: " + portStr)
}
// the size here is just an estimate
buf := make([]byte, 0, 6+len(host))
buf = append(buf, proxySocks5Version)
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
buf = append(buf, 2 /* num auth methods */, proxySocks5AuthNone, proxySocks5AuthPassword)
} else {
buf = append(buf, 1 /* num auth methods */, proxySocks5AuthNone)
}
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[0] != 5 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
}
if buf[1] == 0xff {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
}
// See RFC 1929
if buf[1] == proxySocks5AuthPassword {
buf = buf[:0]
buf = append(buf, 1 /* password protocol version */)
buf = append(buf, uint8(len(s.user)))
buf = append(buf, s.user...)
buf = append(buf, uint8(len(s.password)))
buf = append(buf, s.password...)
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[1] != 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
}
}
buf = buf[:0]
buf = append(buf, proxySocks5Version, proxySocks5Connect, 0 /* reserved */)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, proxySocks5IP4)
ip = ip4
} else {
buf = append(buf, proxySocks5IP6)
}
buf = append(buf, ip...)
} else {
if len(host) > 255 {
return errors.New("proxy: destination host name too long: " + host)
}
buf = append(buf, proxySocks5Domain)
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
buf = append(buf, byte(port>>8), byte(port))
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
failure := "unknown error"
if int(buf[1]) < len(proxySocks5Errors) {
failure = proxySocks5Errors[buf[1]]
}
if len(failure) > 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
}
var bytesToDiscard int
switch buf[3] {
case proxySocks5IP4:
bytesToDiscard = net.IPv4len
case proxySocks5IP6:
bytesToDiscard = net.IPv6len
case proxySocks5Domain:
_, err := io.ReadFull(conn, buf[:1])
if err != nil {
return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
bytesToDiscard = int(buf[0])
default:
return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
}
if cap(buf) < bytesToDiscard {
buf = make([]byte, bytesToDiscard)
} else {
buf = buf[:bytesToDiscard]
}
if _, err := io.ReadFull(conn, buf); err != nil {
return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
return nil
}
func init() {
proxyRegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxyDialer) (proxyDialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
})
}