271 lines
5.9 KiB
Go
Raw Normal View History

2021-12-04 16:42:11 +00:00
package websocket
import (
"context"
"errors"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/lesismal/llib/std/crypto/tls"
"github.com/lesismal/nbio"
"github.com/lesismal/nbio/nbhttp"
)
const (
hostHeaderField = "Host"
upgradeHeaderField = "Upgrade"
connectionHeaderField = "Connection"
secWebsocketKeyHeaderField = "Sec-Websocket-Key"
secWebsocketVersionHeaderField = "Sec-Websocket-Version"
secWebsocketExtHeaderField = "Sec-Websocket-Extensions"
secWebsocketProtoHeaderField = "Sec-Websocket-Protocol"
)
// Dialer .
type Dialer struct {
Engine *nbhttp.Engine
Upgrader *Upgrader
Jar http.CookieJar
DialTimeout time.Duration
TLSClientConfig *tls.Config
Proxy func(*http.Request) (*url.URL, error)
CheckRedirect func(req *http.Request, via []*http.Request) error
Subprotocols []string
EnableCompression bool
Cancel context.CancelFunc
}
// Dial .
func (d *Dialer) Dial(urlStr string, requestHeader http.Header, v ...interface{}) (*Conn, *http.Response, error) {
ctx := context.Background()
if d.DialTimeout > 0 {
ctx, d.Cancel = context.WithTimeout(ctx, d.DialTimeout)
}
return d.DialContext(ctx, urlStr, requestHeader, v...)
}
// DialContext .
func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header, v ...interface{}) (*Conn, *http.Response, error) {
if d.Cancel != nil {
defer d.Cancel()
}
upgrader := d.Upgrader
if upgrader == nil {
return nil, nil, errors.New("invalid Upgrader: nil")
}
challengeKey, err := challengeKey()
if err != nil {
return nil, nil, err
}
u, err := url.Parse(urlStr)
if err != nil {
return nil, nil, err
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
default:
return nil, nil, ErrMalformedURL
}
if u.User != nil {
return nil, nil, ErrMalformedURL
}
req := &http.Request{
Method: "GET",
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Host: u.Host,
}
if d.Jar != nil {
for _, cookie := range d.Jar.Cookies(u) {
req.AddCookie(cookie)
}
}
req.Header[upgradeHeaderField] = []string{"websocket"}
req.Header[connectionHeaderField] = []string{"Upgrade"}
req.Header[secWebsocketKeyHeaderField] = []string{challengeKey}
req.Header[secWebsocketVersionHeaderField] = []string{"13"}
if len(d.Subprotocols) > 0 {
req.Header[secWebsocketProtoHeaderField] = []string{strings.Join(d.Subprotocols, ", ")}
}
for k, vs := range requestHeader {
switch {
case k == hostHeaderField:
if len(vs) > 0 {
req.Host = vs[0]
}
case k == upgradeHeaderField ||
k == connectionHeaderField ||
k == secWebsocketKeyHeaderField ||
k == secWebsocketVersionHeaderField ||
k == secWebsocketExtHeaderField ||
(k == secWebsocketProtoHeaderField && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
case k == secWebsocketProtoHeaderField:
req.Header[secWebsocketProtoHeaderField] = vs
default:
req.Header[k] = vs
}
}
if d.EnableCompression {
req.Header[secWebsocketExtHeaderField] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
}
var asyncHandler func(*Conn, *http.Response, error)
if len(v) > 0 {
if h, ok := v[0].(func(*Conn, *http.Response, error)); ok {
asyncHandler = h
}
}
var wsConn *Conn
var res *http.Response
var errCh chan error
if asyncHandler == nil {
errCh = make(chan error)
}
cliConn := &nbhttp.ClientConn{
Engine: d.Engine,
Jar: d.Jar,
Timeout: d.DialTimeout,
TLSClientConfig: d.TLSClientConfig,
Proxy: d.Proxy,
CheckRedirect: d.CheckRedirect,
}
cliConn.Do(req, func(resp *http.Response, conn net.Conn, err error) {
res = resp
notifyResult := func(e error) {
if asyncHandler == nil {
select {
case errCh <- e:
case <-ctx.Done():
if conn != nil {
conn.Close()
}
}
} else {
asyncHandler(wsConn, res, e)
}
}
if err != nil {
notifyResult(err)
return
}
nbc, ok := conn.(*nbio.Conn)
if !ok {
tlsConn, tlsOk := conn.(*tls.Conn)
if !tlsOk {
err = ErrBadHandshake
notifyResult(err)
return
}
nbc, tlsOk = tlsConn.Conn().(*nbio.Conn)
if !tlsOk {
err = errors.New(http.StatusText(http.StatusInternalServerError))
notifyResult(err)
return
}
}
parser, ok := nbc.Session().(*nbhttp.Parser)
if !ok {
err = errors.New(http.StatusText(http.StatusInternalServerError))
notifyResult(err)
return
}
state := &connState{common: upgrader}
parser.ConnState = state
if d.Jar != nil {
if rc := resp.Cookies(); len(rc) > 0 {
d.Jar.SetCookies(req.URL, rc)
}
}
remoteCompressionEnabled := false
if resp.StatusCode != 101 ||
!headerContains(resp.Header, "Upgrade", "websocket") ||
!headerContains(resp.Header, "Connection", "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != acceptKeyString(challengeKey) {
err = ErrBadHandshake
notifyResult(err)
return
}
for _, ext := range parseExtensions(resp.Header) {
if ext[""] != "permessage-deflate" {
continue
}
_, snct := ext["server_no_context_takeover"]
_, cnct := ext["client_no_context_takeover"]
if !snct || !cnct {
err = ErrInvalidCompression
notifyResult(err)
return
}
remoteCompressionEnabled = true
break
}
wsConn = newConn(upgrader, conn, resp.Header.Get(secWebsocketProtoHeaderField), remoteCompressionEnabled)
wsConn.isClient = true
wsConn.Engine = d.Engine
wsConn.OnClose(upgrader.onClose)
state.conn = wsConn
state.Engine = parser.Engine
if upgrader.openHandler != nil {
upgrader.openHandler(wsConn)
}
notifyResult(err)
})
if asyncHandler == nil {
select {
case err = <-errCh:
case <-ctx.Done():
err = nbhttp.ErrClientTimeout
}
if err != nil {
cliConn.CloseWithError(err)
}
return wsConn, res, err
}
return nil, nil, nil
}