2022-02-06 07:06:32 +00:00

237 lines
6.8 KiB
Go

// Copyright (C) 2021 Michael J. Fromberger. All Rights Reserved.
package jhttp_test
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/handler"
"github.com/creachadair/jrpc2/jhttp"
"github.com/fortytw2/leaktest"
"github.com/google/go-cmp/cmp"
)
func TestGetter(t *testing.T) {
defer leaktest.Check(t)()
mux := handler.Map{
"concat": handler.NewPos(func(ctx context.Context, a, b string) string {
return a + b
}, "first", "second"),
}
g := jhttp.NewGetter(mux, &jhttp.GetterOptions{
Client: &jrpc2.ClientOptions{EncodeContext: checkContext},
})
defer checkClose(t, g)
hsrv := httptest.NewServer(g)
defer hsrv.Close()
url := func(pathQuery string) string {
return hsrv.URL + "/" + pathQuery
}
t.Run("OK", func(t *testing.T) {
got := mustGet(t, url("concat?second=world&first=hello"), http.StatusOK)
const want = `"helloworld"`
if got != want {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
t.Run("NotFound", func(t *testing.T) {
got := mustGet(t, url("nonesuch"), http.StatusNotFound)
const want = `"code":-32601` // MethodNotFound
if !strings.Contains(got, want) {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
t.Run("BadRequest", func(t *testing.T) {
// N.B. invalid query string
got := mustGet(t, url("concat?x%2"), http.StatusBadRequest)
const want = `"code":-32700` // ParseError
if !strings.Contains(got, want) {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
t.Run("InternalError", func(t *testing.T) {
got := mustGet(t, url("concat?third=c"), http.StatusInternalServerError)
const want = `"code":-32602` // InvalidParams
if !strings.Contains(got, want) {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
}
func TestParseQuery(t *testing.T) {
tests := []struct {
url string
body string
method string
want interface{}
errText string
}{
// Error: Missing method name.
{"http://localhost:2112/", "", "", nil, "empty URL path"},
// No parameters.
{"http://localhost/foo", "", "foo", nil, ""},
// Unbalanced double-quoted string.
{"https://fuzz.ball/foo?bad=%22xyz", "", "foo", nil, "missing string quote"},
{"https://fuzz.ball/bar?bad=xyz%22", "", "bar", nil, "missing string quote"},
// Unbalanced single-quoted string.
{`http://stripe/sister?bad='invalid`, "", "sister", nil, "missing bytes quote"},
{`http://stripe/sister?bad=invalid'`, "", "sister", nil, "missing bytes quote"},
// Invalid byte string.
{`http://green.as/balls?bad='NOT%20VALID'`, "", "balls", nil, "decoding bytes"},
// Invalid double-quoted string.
{`http://black.as/sin?bad=%22a%5Cx25%22`, "", "sin", nil, "invalid character"},
// Valid: Single-quoted byte string (base64).
{`http://fast.as.hell/and?twice='YXMgcHJldHR5IGFzIHlvdQ=='`,
"", "and", map[string]interface{}{
"twice": []byte("as pretty as you"),
}, ""},
// Valid: Unquoted strings and null.
{`http://head.like/a-hole?black=as&your=null&soul`,
"", "a-hole", map[string]interface{}{
"black": "as",
"your": nil,
"soul": "",
}, ""},
// Valid: Quoted strings, numbers, Booleans.
{`http://foo.com:1999/go/west/?alpha=%22xyz%22&bravo=3&charlie=true&delta=false&echo=3.2`,
"", "go/west", map[string]interface{}{
"alpha": "xyz",
"bravo": int64(3),
"charlie": true,
"delta": false,
"echo": 3.2,
}, ""},
// Valid: Form-encoded query in the request body.
{`http://buz.org:2013/bodyblow`,
"alpha=%22pdq%22&bravo=-19.4&charlie=false", "bodyblow", map[string]interface{}{
"alpha": "pdq",
"bravo": float64(-19.4),
"charlie": false,
}, ""},
}
for _, test := range tests {
t.Run("ParseQuery", func(t *testing.T) {
req, err := http.NewRequest("PUT", test.url, strings.NewReader(test.body))
if err != nil {
t.Fatalf("New request for %q failed: %v", test.url, err)
}
if test.body != "" {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
method, params, err := jhttp.ParseQuery(req)
if err != nil {
if test.errText == "" {
t.Fatalf("ParseQuery failed: %v", err)
} else if !strings.Contains(err.Error(), test.errText) {
t.Fatalf("ParseQuery: got error %v, want %q", err, test.errText)
}
} else if test.errText != "" {
t.Fatalf("ParseQuery: got method %q, params %+v, wanted error %q",
method, params, test.errText)
} else {
if method != test.method {
t.Errorf("ParseQuery method: got %q, want %q", method, test.method)
}
if diff := cmp.Diff(test.want, params); diff != "" {
t.Errorf("Wrong params: (-want, +got)\n%s", diff)
}
}
})
}
}
func TestGetter_parseRequest(t *testing.T) {
defer leaktest.Check(t)()
mux := handler.Map{
"format": handler.NewPos(func(ctx context.Context, a string, b int) string {
return fmt.Sprintf("%s-%d", a, b)
}, "tag", "value"),
}
g := jhttp.NewGetter(mux, &jhttp.GetterOptions{
ParseRequest: func(req *http.Request) (string, interface{}, error) {
if err := req.ParseForm(); err != nil {
return "", nil, err
}
tag := req.Form.Get("tag")
val, err := strconv.ParseInt(req.Form.Get("value"), 10, 64)
if err != nil && req.Form.Get("value") != "" {
return "", nil, fmt.Errorf("invalid number: %w", err)
}
return strings.TrimPrefix(req.URL.Path, "/x/"), map[string]interface{}{
"tag": tag,
"value": val,
}, nil
},
})
defer checkClose(t, g)
hsrv := httptest.NewServer(g)
defer hsrv.Close()
url := func(pathQuery string) string {
return hsrv.URL + "/" + pathQuery
}
t.Run("OK", func(t *testing.T) {
got := mustGet(t, url("x/format?tag=foo&value=25"), http.StatusOK)
const want = `"foo-25"`
if got != want {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
t.Run("NotFound", func(t *testing.T) {
// N.B. Missing path prefix.
got := mustGet(t, url("format"), http.StatusNotFound)
const want = `"code":-32601` // MethodNotFound
if !strings.Contains(got, want) {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
t.Run("InternalError", func(t *testing.T) {
// N.B. Parameter type does not match on the server side.
got := mustGet(t, url("x/format?tag=foo&value=bar"), http.StatusBadRequest)
const want = `"code":-32700` // ParseError
if !strings.Contains(got, want) {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
}
func mustGet(t *testing.T, url string, code int) string {
t.Helper()
rsp, err := http.Get(url)
if err != nil {
t.Fatalf("GET request failed: %v", err)
} else if got := rsp.StatusCode; got != code {
t.Errorf("GET response code: got %v, want %v", got, code)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Errorf("Reading GET body: %v", err)
}
return string(body)
}