236 lines
5.3 KiB
Go
236 lines
5.3 KiB
Go
|
// +build !js
|
||
|
|
||
|
package websocket
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/rand"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"nhooyr.io/websocket/internal/test/assert"
|
||
|
)
|
||
|
|
||
|
func TestBadDials(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
t.Run("badReq", func(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
testCases := []struct {
|
||
|
name string
|
||
|
url string
|
||
|
opts *DialOptions
|
||
|
rand readerFunc
|
||
|
}{
|
||
|
{
|
||
|
name: "badURL",
|
||
|
url: "://noscheme",
|
||
|
},
|
||
|
{
|
||
|
name: "badURLScheme",
|
||
|
url: "ftp://nhooyr.io",
|
||
|
},
|
||
|
{
|
||
|
name: "badTLS",
|
||
|
url: "wss://totallyfake.nhooyr.io",
|
||
|
},
|
||
|
{
|
||
|
name: "badReader",
|
||
|
rand: func(p []byte) (int, error) {
|
||
|
return 0, io.EOF
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tc := range testCases {
|
||
|
tc := tc
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||
|
defer cancel()
|
||
|
|
||
|
if tc.rand == nil {
|
||
|
tc.rand = rand.Reader.Read
|
||
|
}
|
||
|
|
||
|
_, _, err := dial(ctx, tc.url, tc.opts, tc.rand)
|
||
|
assert.Error(t, err)
|
||
|
})
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("badResponse", func(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||
|
defer cancel()
|
||
|
|
||
|
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
|
||
|
HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) {
|
||
|
return &http.Response{
|
||
|
Body: ioutil.NopCloser(strings.NewReader("hi")),
|
||
|
}, nil
|
||
|
}),
|
||
|
})
|
||
|
assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0")
|
||
|
})
|
||
|
|
||
|
t.Run("badBody", func(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||
|
defer cancel()
|
||
|
|
||
|
rt := func(r *http.Request) (*http.Response, error) {
|
||
|
h := http.Header{}
|
||
|
h.Set("Connection", "Upgrade")
|
||
|
h.Set("Upgrade", "websocket")
|
||
|
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
|
||
|
|
||
|
return &http.Response{
|
||
|
StatusCode: http.StatusSwitchingProtocols,
|
||
|
Header: h,
|
||
|
Body: ioutil.NopCloser(strings.NewReader("hi")),
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
|
||
|
HTTPClient: mockHTTPClient(rt),
|
||
|
})
|
||
|
assert.Contains(t, err, "response body is not a io.ReadWriteCloser")
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func Test_verifyServerHandshake(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
testCases := []struct {
|
||
|
name string
|
||
|
response func(w http.ResponseWriter)
|
||
|
success bool
|
||
|
}{
|
||
|
{
|
||
|
name: "badStatus",
|
||
|
response: func(w http.ResponseWriter) {
|
||
|
w.WriteHeader(http.StatusOK)
|
||
|
},
|
||
|
success: false,
|
||
|
},
|
||
|
{
|
||
|
name: "badConnection",
|
||
|
response: func(w http.ResponseWriter) {
|
||
|
w.Header().Set("Connection", "???")
|
||
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
||
|
},
|
||
|
success: false,
|
||
|
},
|
||
|
{
|
||
|
name: "badUpgrade",
|
||
|
response: func(w http.ResponseWriter) {
|
||
|
w.Header().Set("Connection", "Upgrade")
|
||
|
w.Header().Set("Upgrade", "???")
|
||
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
||
|
},
|
||
|
success: false,
|
||
|
},
|
||
|
{
|
||
|
name: "badSecWebSocketAccept",
|
||
|
response: func(w http.ResponseWriter) {
|
||
|
w.Header().Set("Connection", "Upgrade")
|
||
|
w.Header().Set("Upgrade", "websocket")
|
||
|
w.Header().Set("Sec-WebSocket-Accept", "xd")
|
||
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
||
|
},
|
||
|
success: false,
|
||
|
},
|
||
|
{
|
||
|
name: "badSecWebSocketProtocol",
|
||
|
response: func(w http.ResponseWriter) {
|
||
|
w.Header().Set("Connection", "Upgrade")
|
||
|
w.Header().Set("Upgrade", "websocket")
|
||
|
w.Header().Set("Sec-WebSocket-Protocol", "xd")
|
||
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
||
|
},
|
||
|
success: false,
|
||
|
},
|
||
|
{
|
||
|
name: "unsupportedExtension",
|
||
|
response: func(w http.ResponseWriter) {
|
||
|
w.Header().Set("Connection", "Upgrade")
|
||
|
w.Header().Set("Upgrade", "websocket")
|
||
|
w.Header().Set("Sec-WebSocket-Extensions", "meow")
|
||
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
||
|
},
|
||
|
success: false,
|
||
|
},
|
||
|
{
|
||
|
name: "unsupportedDeflateParam",
|
||
|
response: func(w http.ResponseWriter) {
|
||
|
w.Header().Set("Connection", "Upgrade")
|
||
|
w.Header().Set("Upgrade", "websocket")
|
||
|
w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow")
|
||
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
||
|
},
|
||
|
success: false,
|
||
|
},
|
||
|
{
|
||
|
name: "success",
|
||
|
response: func(w http.ResponseWriter) {
|
||
|
w.Header().Set("Connection", "Upgrade")
|
||
|
w.Header().Set("Upgrade", "websocket")
|
||
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
||
|
},
|
||
|
success: true,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tc := range testCases {
|
||
|
tc := tc
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
w := httptest.NewRecorder()
|
||
|
tc.response(w)
|
||
|
resp := w.Result()
|
||
|
|
||
|
r := httptest.NewRequest("GET", "/", nil)
|
||
|
key, err := secWebSocketKey(rand.Reader)
|
||
|
assert.Success(t, err)
|
||
|
r.Header.Set("Sec-WebSocket-Key", key)
|
||
|
|
||
|
if resp.Header.Get("Sec-WebSocket-Accept") == "" {
|
||
|
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
|
||
|
}
|
||
|
|
||
|
opts := &DialOptions{
|
||
|
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
|
||
|
}
|
||
|
_, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
|
||
|
if tc.success {
|
||
|
assert.Success(t, err)
|
||
|
} else {
|
||
|
assert.Error(t, err)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func mockHTTPClient(fn roundTripperFunc) *http.Client {
|
||
|
return &http.Client{
|
||
|
Transport: fn,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||
|
|
||
|
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||
|
return f(r)
|
||
|
}
|