From 9300ef9b14a0b182a45e70eb3c4a7b5139b2f16a Mon Sep 17 00:00:00 2001 From: Captain Date: Wed, 1 Dec 2021 15:43:13 +0000 Subject: [PATCH] DERO-HE STARGATE Testnet Release33 --- blockchain/blockchain.go | 1 - blockchain/hardcoded_sc/nameservice.bas | 1 + blockchain/miner_block.go | 30 +- blockchain/storetopo.go | 16 +- cmd/dero-miner/miner.go | 189 +- cmd/dero-wallet-cli/prompt.go | 2 +- cmd/derod/main.go | 21 +- cmd/derod/rpc/websocket_getwork_server.go | 341 + cmd/derod/rpc/websocket_server.go | 5 +- config/config.go | 5 +- config/version.go | 2 +- rpc/daemon_rpc.go | 20 +- vendor/github.com/lesismal/llib/LICENSE | 27 + vendor/github.com/lesismal/llib/README.md | 12 + .../github.com/lesismal/llib/bytes/buffer.go | 225 + .../lesismal/llib/bytes/buffer_test.go | 108 + vendor/github.com/lesismal/llib/bytes/pool.go | 84 + .../lesismal/llib/bytes/pool_test.go | 22 + .../lesismal/llib/concurrent/batch.go | 60 + .../lesismal/llib/concurrent/batch_test.go | 34 + .../lesismal/llib/concurrent/map.go | 100 + .../lesismal/llib/concurrent/map_test.go | 109 + .../lesismal/llib/concurrent/mutex.go | 56 + .../lesismal/llib/concurrent/mutex_test.go | 26 + .../lesismal/llib/concurrent/rwmutex_test.go | 38 + .../lesismal/llib/concurrent/rwmutext.go | 88 + vendor/github.com/lesismal/llib/go.mod | 9 + vendor/github.com/lesismal/llib/go.sum | 17 + .../lesismal/llib/std/crypto/tls/alert.go | 99 + .../lesismal/llib/std/crypto/tls/auth.go | 289 + .../lesismal/llib/std/crypto/tls/auth_test.go | 168 + .../llib/std/crypto/tls/cipher_suites.go | 516 + .../lesismal/llib/std/crypto/tls/common.go | 1563 +++ .../llib/std/crypto/tls/common_string.go | 116 + .../lesismal/llib/std/crypto/tls/conn.go | 1775 +++ .../lesismal/llib/std/crypto/tls/conn_test.go | 287 + .../llib/std/crypto/tls/example_test.go | 232 + .../llib/std/crypto/tls/generate_cert.go | 172 + .../llib/std/crypto/tls/handshake_client.go | 1002 ++ .../std/crypto/tls/handshake_client_test.go | 2513 ++++ .../std/crypto/tls/handshake_client_tls13.go | 685 + .../llib/std/crypto/tls/handshake_messages.go | 1809 +++ .../std/crypto/tls/handshake_messages_test.go | 465 + .../llib/std/crypto/tls/handshake_server.go | 1106 ++ .../std/crypto/tls/handshake_server_test.go | 1941 +++ .../std/crypto/tls/handshake_server_tls13.go | 971 ++ .../llib/std/crypto/tls/handshake_test.go | 535 + .../std/crypto/tls/handshake_unix_test.go | 18 + .../llib/std/crypto/tls/key_agreement.go | 334 + .../llib/std/crypto/tls/key_schedule.go | 199 + .../llib/std/crypto/tls/key_schedule_test.go | 175 + .../lesismal/llib/std/crypto/tls/link_test.go | 108 + .../lesismal/llib/std/crypto/tls/prf.go | 283 + .../lesismal/llib/std/crypto/tls/prf_test.go | 140 + .../lesismal/llib/std/crypto/tls/ticket.go | 185 + .../lesismal/llib/std/crypto/tls/tls.go | 430 + .../lesismal/llib/std/crypto/tls/tls_test.go | 1477 +++ .../lesismal/llib/std/internal/cpu/cpu.go | 226 + .../lesismal/llib/std/internal/cpu/cpu.s | 6 + .../lesismal/llib/std/internal/cpu/cpu_386.go | 7 + .../llib/std/internal/cpu/cpu_amd64.go | 7 + .../lesismal/llib/std/internal/cpu/cpu_arm.go | 34 + .../llib/std/internal/cpu/cpu_arm64.go | 28 + .../llib/std/internal/cpu/cpu_arm64.s | 18 + .../std/internal/cpu/cpu_arm64_android.go | 11 + .../llib/std/internal/cpu/cpu_arm64_darwin.go | 34 + .../std/internal/cpu/cpu_arm64_freebsd.go | 45 + .../llib/std/internal/cpu/cpu_arm64_hwcap.go | 63 + .../llib/std/internal/cpu/cpu_arm64_linux.go | 13 + .../llib/std/internal/cpu/cpu_arm64_other.go | 17 + .../llib/std/internal/cpu/cpu_mips.go | 10 + .../llib/std/internal/cpu/cpu_mips64x.go | 32 + .../llib/std/internal/cpu/cpu_mipsle.go | 10 + .../llib/std/internal/cpu/cpu_no_name.go | 19 + .../llib/std/internal/cpu/cpu_ppc64x.go | 23 + .../llib/std/internal/cpu/cpu_ppc64x_aix.go | 21 + .../llib/std/internal/cpu/cpu_ppc64x_linux.go | 29 + .../llib/std/internal/cpu/cpu_riscv64.go | 10 + .../llib/std/internal/cpu/cpu_s390x.go | 205 + .../llib/std/internal/cpu/cpu_s390x.s | 63 + .../llib/std/internal/cpu/cpu_s390x_test.go | 63 + .../llib/std/internal/cpu/cpu_test.go | 83 + .../llib/std/internal/cpu/cpu_wasm.go | 10 + .../lesismal/llib/std/internal/cpu/cpu_x86.go | 163 + .../lesismal/llib/std/internal/cpu/cpu_x86.s | 26 + .../llib/std/internal/cpu/cpu_x86_test.go | 54 + .../llib/std/internal/cpu/export_test.go | 9 + .../llib/std/internal/nettrace/nettrace.go | 45 + .../llib/std/internal/testenv/testenv.go | 308 + .../llib/std/internal/testenv/testenv_cgo.go | 11 + .../std/internal/testenv/testenv_notwin.go | 20 + .../std/internal/testenv/testenv_windows.go | 47 + .../lesismal/llib/std/net/http/alpn_test.go | 132 + .../lesismal/llib/std/net/http/cgi/child.go | 220 + .../llib/std/net/http/cgi/child_test.go | 208 + .../lesismal/llib/std/net/http/cgi/host.go | 408 + .../llib/std/net/http/cgi/host_test.go | 578 + .../llib/std/net/http/cgi/integration_test.go | 295 + .../llib/std/net/http/cgi/plan9_test.go | 17 + .../llib/std/net/http/cgi/posix_test.go | 20 + .../llib/std/net/http/cgi/testdata/test.cgi | 95 + .../lesismal/llib/std/net/http/client.go | 1009 ++ .../lesismal/llib/std/net/http/client_test.go | 2084 ++++ .../llib/std/net/http/clientserver_test.go | 1584 +++ .../lesismal/llib/std/net/http/clone.go | 74 + .../lesismal/llib/std/net/http/cookie.go | 433 + .../lesismal/llib/std/net/http/cookie_test.go | 619 + .../http/cookiejar/dummy_publicsuffix_test.go | 21 + .../std/net/http/cookiejar/example_test.go | 65 + .../llib/std/net/http/cookiejar/jar.go | 503 + .../llib/std/net/http/cookiejar/jar_test.go | 1322 ++ .../llib/std/net/http/cookiejar/punycode.go | 159 + .../std/net/http/cookiejar/punycode_test.go | 161 + .../lesismal/llib/std/net/http/doc.go | 107 + .../std/net/http/example_filesystem_test.go | 71 + .../llib/std/net/http/example_handle_test.go | 29 + .../llib/std/net/http/example_test.go | 192 + .../lesismal/llib/std/net/http/export_test.go | 313 + .../lesismal/llib/std/net/http/fcgi/child.go | 405 + .../lesismal/llib/std/net/http/fcgi/fcgi.go | 270 + .../llib/std/net/http/fcgi/fcgi_test.go | 401 + .../llib/std/net/http/filetransport.go | 123 + .../llib/std/net/http/filetransport_test.go | 66 + .../lesismal/llib/std/net/http/fs.go | 970 ++ .../lesismal/llib/std/net/http/fs_test.go | 1414 +++ .../lesismal/llib/std/net/http/h2_bundle.go | 10371 ++++++++++++++++ .../lesismal/llib/std/net/http/header.go | 263 + .../lesismal/llib/std/net/http/header_test.go | 253 + .../lesismal/llib/std/net/http/http.go | 168 + .../lesismal/llib/std/net/http/http_test.go | 158 + .../std/net/http/httptest/example_test.go | 99 + .../llib/std/net/http/httptest/httptest.go | 90 + .../std/net/http/httptest/httptest_test.go | 179 + .../llib/std/net/http/httptest/recorder.go | 234 + .../std/net/http/httptest/recorder_test.go | 347 + .../llib/std/net/http/httptest/server.go | 383 + .../llib/std/net/http/httptest/server_test.go | 240 + .../std/net/http/httptrace/example_test.go | 29 + .../llib/std/net/http/httptrace/trace.go | 256 + .../llib/std/net/http/httptrace/trace_test.go | 89 + .../llib/std/net/http/httputil/dump.go | 340 + .../llib/std/net/http/httputil/dump_test.go | 519 + .../std/net/http/httputil/example_test.go | 123 + .../llib/std/net/http/httputil/httputil.go | 41 + .../llib/std/net/http/httputil/persist.go | 431 + .../std/net/http/httputil/reverseproxy.go | 617 + .../net/http/httputil/reverseproxy_test.go | 1420 +++ .../llib/std/net/http/internal/chunked.go | 255 + .../std/net/http/internal/chunked_test.go | 213 + .../llib/std/net/http/internal/testcert.go | 45 + .../lesismal/llib/std/net/http/jar.go | 27 + .../lesismal/llib/std/net/http/main_test.go | 171 + .../lesismal/llib/std/net/http/method.go | 20 + .../lesismal/llib/std/net/http/omithttp2.go | 71 + .../lesismal/llib/std/net/http/pprof/pprof.go | 449 + .../llib/std/net/http/pprof/pprof_test.go | 258 + .../lesismal/llib/std/net/http/proxy_test.go | 50 + .../lesismal/llib/std/net/http/range_test.go | 79 + .../llib/std/net/http/readrequest_test.go | 474 + .../lesismal/llib/std/net/http/request.go | 1456 +++ .../llib/std/net/http/request_test.go | 1235 ++ .../llib/std/net/http/requestwrite_test.go | 977 ++ .../lesismal/llib/std/net/http/response.go | 372 + .../llib/std/net/http/response_test.go | 1012 ++ .../llib/std/net/http/responsewrite_test.go | 291 + .../lesismal/llib/std/net/http/roundtrip.go | 18 + .../llib/std/net/http/roundtrip_js.go | 307 + .../lesismal/llib/std/net/http/serve_test.go | 6509 ++++++++++ .../lesismal/llib/std/net/http/server.go | 3549 ++++++ .../lesismal/llib/std/net/http/server_test.go | 45 + .../lesismal/llib/std/net/http/sniff.go | 309 + .../lesismal/llib/std/net/http/sniff_test.go | 225 + .../llib/std/net/http/socks_bundle.go | 473 + .../lesismal/llib/std/net/http/status.go | 152 + .../lesismal/llib/std/net/http/testdata/file | 1 + .../llib/std/net/http/testdata/index.html | 1 + .../llib/std/net/http/testdata/style.css | 1 + .../lesismal/llib/std/net/http/transfer.go | 1109 ++ .../llib/std/net/http/transfer_test.go | 363 + .../lesismal/llib/std/net/http/transport.go | 2896 +++++ .../std/net/http/transport_internal_test.go | 262 + .../llib/std/net/http/transport_test.go | 6485 ++++++++++ .../lesismal/llib/std/net/http/triv.go | 138 + .../github.com/lesismal/nbio/.gitattributes | 1 + .../nbio/.github/workflows/autobahn.yml | 33 + .../nbio/.github/workflows/build_bsd.yml | 39 + .../nbio/.github/workflows/build_linux.yml | 39 + .../nbio/.github/workflows/build_windows.yml | 45 + .../.github/workflows/codeql-analysis.yml | 71 + .../nbio/.github/workflows/golangci-lint.yml | 32 + vendor/github.com/lesismal/nbio/.gitignore | 1 + vendor/github.com/lesismal/nbio/.golangci.yml | 660 + vendor/github.com/lesismal/nbio/LICENSE | 21 + vendor/github.com/lesismal/nbio/Makefile | 28 + vendor/github.com/lesismal/nbio/README.md | 335 + .../lesismal/nbio/autobahn/.gitignore | 1 + .../lesismal/nbio/autobahn/README.md | 9 + .../nbio/autobahn/config/client_tests.json | 24 + .../nbio/autobahn/docker/autobahn/Dockerfile | 14 + .../nbio/autobahn/docker/server/Dockerfile | 15 + .../github.com/lesismal/nbio/autobahn/main.go | 265 + .../lesismal/nbio/autobahn/main_go18.go | 11 + .../lesismal/nbio/autobahn/script/test.sh | 127 + .../nbio/autobahn/server/autobahn_test.go | 9 + .../lesismal/nbio/autobahn/server/server.go | 166 + vendor/github.com/lesismal/nbio/conn.go | 129 + vendor/github.com/lesismal/nbio/conn_std.go | 270 + vendor/github.com/lesismal/nbio/conn_unix.go | 479 + vendor/github.com/lesismal/nbio/error.go | 15 + .../lesismal/nbio/extension/tls/tls.go | 87 + vendor/github.com/lesismal/nbio/go.mod | 5 + vendor/github.com/lesismal/nbio/go.sum | 16 + vendor/github.com/lesismal/nbio/gopher.go | 459 + vendor/github.com/lesismal/nbio/gopher_std.go | 114 + .../github.com/lesismal/nbio/gopher_unix.go | 122 + .../github.com/lesismal/nbio/logging/log.go | 132 + .../lesismal/nbio/logging/log_test.go | 62 + .../lesismal/nbio/mempool/mempool.go | 184 + .../lesismal/nbio/mempool/mempool_test.go | 33 + .../github.com/lesismal/nbio/nbhttp/body.go | 99 + .../github.com/lesismal/nbio/nbhttp/client.go | 499 + .../lesismal/nbio/nbhttp/client_conn.go | 305 + .../github.com/lesismal/nbio/nbhttp/engine.go | 713 ++ .../github.com/lesismal/nbio/nbhttp/error.go | 94 + .../github.com/lesismal/nbio/nbhttp/parser.go | 798 ++ .../lesismal/nbio/nbhttp/parser_test.go | 272 + .../lesismal/nbio/nbhttp/processor.go | 446 + .../lesismal/nbio/nbhttp/response.go | 389 + .../github.com/lesismal/nbio/nbhttp/server.go | 51 + .../github.com/lesismal/nbio/nbhttp/state.go | 56 + .../github.com/lesismal/nbio/nbhttp/table.go | 174 + .../lesismal/nbio/nbhttp/tests/poller_test.go | 206 + .../lesismal/nbio/nbhttp/upgrader.go | 23 + .../nbio/nbhttp/websocket/compression.go | 122 + .../lesismal/nbio/nbhttp/websocket/conn.go | 312 + .../lesismal/nbio/nbhttp/websocket/dialer.go | 270 + .../lesismal/nbio/nbhttp/websocket/error.go | 65 + .../nbio/nbhttp/websocket/upgrader.go | 894 ++ .../nbio/nbhttp/websocket/upgrader_test.go | 37 + vendor/github.com/lesismal/nbio/nbio_test.go | 366 + vendor/github.com/lesismal/nbio/net_unix.go | 54 + .../github.com/lesismal/nbio/poller_epoll.go | 307 + .../github.com/lesismal/nbio/poller_kqueue.go | 276 + vendor/github.com/lesismal/nbio/poller_std.go | 144 + .../github.com/lesismal/nbio/sendfile_bsd.go | 62 + .../lesismal/nbio/sendfile_linux.go | 98 + .../github.com/lesismal/nbio/sendfile_std.go | 62 + .../lesismal/nbio/taskpool/caller.go | 24 + .../nbio/taskpool/fixednoorderpool.go | 44 + .../lesismal/nbio/taskpool/fixedpool.go | 124 + .../lesismal/nbio/taskpool/mixedpool.go | 84 + .../lesismal/nbio/taskpool/taskpool.go | 110 + .../lesismal/nbio/taskpool/taskpool_test.go | 139 + vendor/github.com/lesismal/nbio/timer_heap.go | 62 + vendor/golang.org/x/crypto/.gitignore | 2 +- vendor/golang.org/x/crypto/README.md | 2 + vendor/golang.org/x/crypto/acme/acme.go | 778 +- vendor/golang.org/x/crypto/acme/acme_test.go | 756 +- .../x/crypto/acme/autocert/autocert.go | 794 +- .../x/crypto/acme/autocert/autocert_test.go | 826 +- .../x/crypto/acme/autocert/cache.go | 14 +- .../x/crypto/acme/autocert/cache_test.go | 9 + .../x/crypto/acme/autocert/example_test.go | 7 +- .../acme/autocert/internal/acmetest/ca.go | 552 + .../x/crypto/acme/autocert/listener.go | 7 +- .../x/crypto/acme/autocert/renewal.go | 45 +- .../x/crypto/acme/autocert/renewal_test.go | 181 +- vendor/golang.org/x/crypto/acme/http.go | 325 + vendor/golang.org/x/crypto/acme/http_test.go | 255 + .../crypto/acme/internal/acmeprobe/prober.go | 480 + vendor/golang.org/x/crypto/acme/jws.go | 140 +- vendor/golang.org/x/crypto/acme/jws_test.go | 229 +- vendor/golang.org/x/crypto/acme/rfc8555.go | 438 + .../golang.org/x/crypto/acme/rfc8555_test.go | 916 ++ vendor/golang.org/x/crypto/acme/types.go | 395 +- vendor/golang.org/x/crypto/acme/types_test.go | 156 + .../golang.org/x/crypto/acme/version_go112.go | 28 + vendor/golang.org/x/crypto/argon2/argon2.go | 124 +- .../golang.org/x/crypto/argon2/argon2_test.go | 120 + .../x/crypto/argon2/blamka_amd64.go | 12 +- .../golang.org/x/crypto/argon2/blamka_amd64.s | 12 +- .../golang.org/x/crypto/argon2/blamka_ref.go | 3 +- .../golang.org/x/crypto/bcrypt/bcrypt_test.go | 6 +- vendor/golang.org/x/crypto/blake2b/blake2b.go | 78 +- .../x/crypto/blake2b/blake2bAVX2_amd64.go | 29 +- .../x/crypto/blake2b/blake2bAVX2_amd64.s | 109 +- .../x/crypto/blake2b/blake2b_amd64.go | 12 +- .../x/crypto/blake2b/blake2b_amd64.s | 67 +- .../x/crypto/blake2b/blake2b_generic.go | 69 +- .../x/crypto/blake2b/blake2b_ref.go | 3 +- .../x/crypto/blake2b/blake2b_test.go | 49 + vendor/golang.org/x/crypto/blake2b/blake2x.go | 2 +- .../golang.org/x/crypto/blake2b/register.go | 1 + vendor/golang.org/x/crypto/blake2s/blake2s.go | 59 + .../x/crypto/blake2s/blake2s_386.go | 22 +- .../golang.org/x/crypto/blake2s/blake2s_386.s | 118 +- .../x/crypto/blake2s/blake2s_amd64.go | 26 +- .../x/crypto/blake2s/blake2s_amd64.s | 94 +- .../x/crypto/blake2s/blake2s_generic.go | 68 +- .../x/crypto/blake2s/blake2s_ref.go | 3 +- .../x/crypto/blake2s/blake2s_test.go | 48 + vendor/golang.org/x/crypto/blake2s/blake2x.go | 2 +- .../golang.org/x/crypto/blake2s/register.go | 1 + vendor/golang.org/x/crypto/blowfish/cipher.go | 8 + vendor/golang.org/x/crypto/bn256/bn256.go | 53 +- vendor/golang.org/x/crypto/bn256/curve.go | 9 + vendor/golang.org/x/crypto/bn256/gfp12.go | 4 +- vendor/golang.org/x/crypto/bn256/twist.go | 9 + vendor/golang.org/x/crypto/cast5/cast5.go | 11 +- .../x/crypto/chacha20/chacha_arm64.go | 4 +- .../x/crypto/chacha20/chacha_arm64.s | 4 +- .../x/crypto/chacha20/chacha_generic.go | 140 +- .../x/crypto/chacha20/chacha_noasm.go | 3 +- .../x/crypto/chacha20/chacha_ppc64le.go | 3 +- .../x/crypto/chacha20/chacha_ppc64le.s | 3 +- .../x/crypto/chacha20/chacha_s390x.go | 3 +- .../x/crypto/chacha20/chacha_s390x.s | 3 +- .../x/crypto/chacha20/chacha_test.go | 103 +- vendor/golang.org/x/crypto/chacha20/xor.go | 17 +- .../chacha20poly1305/chacha20poly1305.go | 25 +- .../chacha20poly1305_amd64.go | 78 +- .../chacha20poly1305/chacha20poly1305_amd64.s | 26 +- .../chacha20poly1305_generic.go | 79 +- .../chacha20poly1305_noasm.go | 3 +- .../chacha20poly1305/chacha20poly1305_test.go | 236 +- .../chacha20poly1305_vectors_test.go | 399 +- .../chacha20poly1305/xchacha20poly1305.go | 86 + vendor/golang.org/x/crypto/cryptobyte/asn1.go | 128 +- .../x/crypto/cryptobyte/asn1_test.go | 105 + .../golang.org/x/crypto/cryptobyte/builder.go | 38 +- .../x/crypto/cryptobyte/cryptobyte_test.go | 98 +- .../golang.org/x/crypto/cryptobyte/string.go | 36 +- .../x/crypto/curve25519/curve25519.go | 933 +- .../x/crypto/curve25519/curve25519_test.go | 131 +- .../x/crypto/curve25519/internal/field/README | 7 + .../internal/field/_asm/fe_amd64_asm.go | 298 + .../curve25519/internal/field/_asm/go.mod | 10 + .../curve25519/internal/field/_asm/go.sum | 34 + .../x/crypto/curve25519/internal/field/fe.go | 416 + .../internal/field/fe_alias_test.go | 126 + .../curve25519/internal/field/fe_amd64.go | 13 + .../curve25519/internal/field/fe_amd64.s | 379 + .../internal/field/fe_amd64_noasm.go | 12 + .../curve25519/internal/field/fe_arm64.go | 16 + .../curve25519/internal/field/fe_arm64.s | 43 + .../internal/field/fe_arm64_noasm.go | 12 + .../internal/field/fe_bench_test.go | 36 + .../curve25519/internal/field/fe_generic.go | 264 + .../curve25519/internal/field/fe_test.go | 558 + .../curve25519/internal/field/sync.checkpoint | 1 + .../crypto/curve25519/internal/field/sync.sh | 19 + .../x/crypto/curve25519/vectors_test.go | 105 + vendor/golang.org/x/crypto/ed25519/ed25519.go | 66 +- .../x/crypto/ed25519/ed25519_go113.go | 74 + .../x/crypto/ed25519/ed25519_test.go | 78 +- .../golang.org/x/crypto/ed25519/go113_test.go | 25 + .../internal/edwards25519/edwards25519.go | 22 + vendor/golang.org/x/crypto/go.mod | 9 +- vendor/golang.org/x/crypto/go.sum | 19 +- .../golang.org/x/crypto/hkdf/example_test.go | 51 +- vendor/golang.org/x/crypto/hkdf/hkdf.go | 62 +- vendor/golang.org/x/crypto/hkdf/hkdf_test.go | 81 +- .../x/crypto/internal/poly1305/bits_compat.go | 40 + .../x/crypto/internal/poly1305/bits_go1.13.go | 22 + .../x/crypto/internal/poly1305/mac_noasm.go | 10 + .../x/crypto/internal/poly1305/poly1305.go | 99 + .../crypto/internal/poly1305/poly1305_test.go | 276 + .../x/crypto/internal/poly1305/sum_amd64.go | 48 + .../x/crypto/internal/poly1305/sum_amd64.s | 109 + .../x/crypto/internal/poly1305/sum_generic.go | 310 + .../x/crypto/internal/poly1305/sum_ppc64le.go | 48 + .../x/crypto/internal/poly1305/sum_ppc64le.s | 182 + .../x/crypto/internal/poly1305/sum_s390x.go | 76 + .../x/crypto/internal/poly1305/sum_s390x.s | 504 + .../crypto/internal/poly1305/vectors_test.go | 3000 +++++ .../x/crypto/internal/subtle/aliasing.go | 3 +- .../crypto/internal/subtle/aliasing_purego.go | 36 + .../internal/wycheproof/ecdsa_compat_test.go | 1 + .../internal/wycheproof/ecdsa_go115_test.go | 1 + .../crypto/internal/wycheproof/eddsa_test.go | 1 + .../wycheproof/rsa_oaep_decrypt_test.go | 149 + .../internal/wycheproof/rsa_pss_test.go | 2 +- vendor/golang.org/x/crypto/md4/md4.go | 4 + vendor/golang.org/x/crypto/nacl/auth/auth.go | 2 +- vendor/golang.org/x/crypto/nacl/box/box.go | 83 +- .../golang.org/x/crypto/nacl/box/box_test.go | 103 + .../x/crypto/nacl/secretbox/secretbox.go | 11 +- vendor/golang.org/x/crypto/nacl/sign/sign.go | 90 + .../x/crypto/nacl/sign/sign_test.go | 74 + vendor/golang.org/x/crypto/ocsp/ocsp.go | 43 +- vendor/golang.org/x/crypto/ocsp/ocsp_test.go | 413 +- .../x/crypto/openpgp/armor/armor.go | 27 +- .../x/crypto/openpgp/clearsign/clearsign.go | 118 +- .../openpgp/clearsign/clearsign_test.go | 210 +- .../x/crypto/openpgp/elgamal/elgamal.go | 10 +- .../x/crypto/openpgp/elgamal/elgamal_test.go | 15 + .../x/crypto/openpgp/errors/errors.go | 6 + vendor/golang.org/x/crypto/openpgp/keys.go | 179 +- .../x/crypto/openpgp/keys_data_test.go | 200 + .../golang.org/x/crypto/openpgp/keys_test.go | 268 +- .../x/crypto/openpgp/packet/encrypted_key.go | 11 +- .../openpgp/packet/encrypted_key_test.go | 80 +- .../x/crypto/openpgp/packet/opaque_test.go | 2 +- .../x/crypto/openpgp/packet/packet.go | 119 +- .../x/crypto/openpgp/packet/packet_test.go | 38 +- .../x/crypto/openpgp/packet/private_key.go | 33 +- .../crypto/openpgp/packet/private_key_test.go | 27 +- .../x/crypto/openpgp/packet/public_key.go | 11 +- .../crypto/openpgp/packet/public_key_test.go | 26 + .../x/crypto/openpgp/packet/signature.go | 2 +- .../x/crypto/openpgp/packet/userattribute.go | 2 +- vendor/golang.org/x/crypto/openpgp/read.go | 6 + .../golang.org/x/crypto/openpgp/read_test.go | 2 +- vendor/golang.org/x/crypto/openpgp/s2k/s2k.go | 6 + vendor/golang.org/x/crypto/openpgp/write.go | 176 +- .../golang.org/x/crypto/openpgp/write_test.go | 89 + vendor/golang.org/x/crypto/otr/otr.go | 6 +- vendor/golang.org/x/crypto/pkcs12/pkcs12.go | 20 +- .../golang.org/x/crypto/pkcs12/pkcs12_test.go | 3 + .../x/crypto/poly1305/poly1305_compat.go | 91 + .../x/crypto/ripemd160/ripemd160.go | 6 +- .../x/crypto/ripemd160/ripemd160_test.go | 18 +- .../x/crypto/ripemd160/ripemd160block.go | 64 +- .../x/crypto/salsa20/salsa/salsa20_amd64.go | 3 +- .../x/crypto/salsa20/salsa/salsa20_amd64.s | 238 +- .../salsa20/salsa/salsa20_amd64_test.go | 3 +- .../x/crypto/salsa20/salsa/salsa20_noasm.go | 3 +- vendor/golang.org/x/crypto/scrypt/scrypt.go | 124 +- .../x/crypto/sha3/hashes_generic.go | 3 +- vendor/golang.org/x/crypto/sha3/keccakf.go | 3 +- .../golang.org/x/crypto/sha3/keccakf_amd64.go | 3 +- .../golang.org/x/crypto/sha3/keccakf_amd64.s | 3 +- vendor/golang.org/x/crypto/sha3/register.go | 1 + vendor/golang.org/x/crypto/sha3/sha3_s390x.go | 3 +- vendor/golang.org/x/crypto/sha3/sha3_s390x.s | 3 +- vendor/golang.org/x/crypto/sha3/sha3_test.go | 9 +- .../golang.org/x/crypto/sha3/shake_generic.go | 3 +- vendor/golang.org/x/crypto/sha3/xor.go | 3 +- .../golang.org/x/crypto/sha3/xor_generic.go | 2 +- .../golang.org/x/crypto/sha3/xor_unaligned.go | 5 +- .../golang.org/x/crypto/ssh/agent/client.go | 150 +- .../x/crypto/ssh/agent/client_test.go | 205 +- .../x/crypto/ssh/agent/example_test.go | 18 +- .../golang.org/x/crypto/ssh/agent/keyring.go | 30 +- .../golang.org/x/crypto/ssh/agent/server.go | 49 +- .../golang.org/x/crypto/ssh/benchmark_test.go | 6 +- vendor/golang.org/x/crypto/ssh/certs.go | 83 +- vendor/golang.org/x/crypto/ssh/certs_test.go | 78 +- vendor/golang.org/x/crypto/ssh/cipher.go | 276 +- vendor/golang.org/x/crypto/ssh/cipher_test.go | 94 +- vendor/golang.org/x/crypto/ssh/client.go | 23 +- vendor/golang.org/x/crypto/ssh/client_auth.go | 229 +- .../x/crypto/ssh/client_auth_test.go | 288 +- vendor/golang.org/x/crypto/ssh/client_test.go | 182 +- vendor/golang.org/x/crypto/ssh/common.go | 83 +- vendor/golang.org/x/crypto/ssh/common_test.go | 176 + .../golang.org/x/crypto/ssh/example_test.go | 3 +- vendor/golang.org/x/crypto/ssh/handshake.go | 34 +- .../golang.org/x/crypto/ssh/handshake_test.go | 3 + .../ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go | 93 + .../bcrypt_pbkdf/bcrypt_pbkdf_test.go | 97 + vendor/golang.org/x/crypto/ssh/kex.go | 244 +- vendor/golang.org/x/crypto/ssh/kex_test.go | 65 +- vendor/golang.org/x/crypto/ssh/keys.go | 590 +- vendor/golang.org/x/crypto/ssh/keys_test.go | 210 +- .../x/crypto/ssh/knownhosts/knownhosts.go | 62 +- .../crypto/ssh/knownhosts/knownhosts_test.go | 29 +- vendor/golang.org/x/crypto/ssh/messages.go | 100 + vendor/golang.org/x/crypto/ssh/mux.go | 23 +- vendor/golang.org/x/crypto/ssh/mux_test.go | 220 +- vendor/golang.org/x/crypto/ssh/server.go | 156 +- .../golang.org/x/crypto/ssh/session_test.go | 20 +- vendor/golang.org/x/crypto/ssh/ssh_gss.go | 139 + .../golang.org/x/crypto/ssh/ssh_gss_test.go | 109 + vendor/golang.org/x/crypto/ssh/streamlocal.go | 1 + vendor/golang.org/x/crypto/ssh/tcpip.go | 9 + .../x/crypto/ssh/terminal/terminal.go | 953 +- .../x/crypto/ssh/test/agent_unix_test.go | 3 +- .../x/crypto/ssh/test/banner_test.go | 3 +- .../golang.org/x/crypto/ssh/test/cert_test.go | 3 +- .../x/crypto/ssh/test/dial_unix_test.go | 7 +- .../x/crypto/ssh/test/forward_unix_test.go | 20 +- .../x/crypto/ssh/test/multi_auth_test.go | 145 + .../x/crypto/ssh/test/session_test.go | 174 +- .../x/crypto/ssh/test/sshd_test_pw.c | 173 + .../x/crypto/ssh/test/test_unix_test.go | 91 +- .../golang.org/x/crypto/ssh/testdata/keys.go | 156 +- .../golang.org/x/crypto/ssh/testdata_test.go | 8 + vendor/golang.org/x/crypto/ssh/transport.go | 70 +- .../golang.org/x/crypto/ssh/transport_test.go | 14 +- vendor/golang.org/x/crypto/tea/cipher.go | 8 + vendor/golang.org/x/crypto/twofish/twofish.go | 6 + vendor/golang.org/x/crypto/xtea/block.go | 2 +- vendor/golang.org/x/crypto/xtea/cipher.go | 12 +- vendor/golang.org/x/crypto/xts/xts.go | 45 +- vendor/golang.org/x/crypto/xts/xts_test.go | 16 + walletapi/daemon_communication.go | 122 +- 497 files changed, 118941 insertions(+), 5537 deletions(-) create mode 100644 cmd/derod/rpc/websocket_getwork_server.go create mode 100644 vendor/github.com/lesismal/llib/LICENSE create mode 100644 vendor/github.com/lesismal/llib/README.md create mode 100644 vendor/github.com/lesismal/llib/bytes/buffer.go create mode 100644 vendor/github.com/lesismal/llib/bytes/buffer_test.go create mode 100644 vendor/github.com/lesismal/llib/bytes/pool.go create mode 100644 vendor/github.com/lesismal/llib/bytes/pool_test.go create mode 100644 vendor/github.com/lesismal/llib/concurrent/batch.go create mode 100644 vendor/github.com/lesismal/llib/concurrent/batch_test.go create mode 100644 vendor/github.com/lesismal/llib/concurrent/map.go create mode 100644 vendor/github.com/lesismal/llib/concurrent/map_test.go create mode 100644 vendor/github.com/lesismal/llib/concurrent/mutex.go create mode 100644 vendor/github.com/lesismal/llib/concurrent/mutex_test.go create mode 100644 vendor/github.com/lesismal/llib/concurrent/rwmutex_test.go create mode 100644 vendor/github.com/lesismal/llib/concurrent/rwmutext.go create mode 100644 vendor/github.com/lesismal/llib/go.mod create mode 100644 vendor/github.com/lesismal/llib/go.sum create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/alert.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/auth.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/auth_test.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/cipher_suites.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/common.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/common_string.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/conn.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/conn_test.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/example_test.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/generate_cert.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/handshake_client.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/handshake_client_test.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/handshake_client_tls13.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/handshake_messages.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/handshake_messages_test.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/handshake_server.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/handshake_server_test.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/handshake_server_tls13.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/handshake_test.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/handshake_unix_test.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/key_agreement.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/key_schedule.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/key_schedule_test.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/link_test.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/prf.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/prf_test.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/ticket.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/tls.go create mode 100644 vendor/github.com/lesismal/llib/std/crypto/tls/tls_test.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu.s create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_386.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_amd64.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64.s create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_android.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_darwin.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_freebsd.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_hwcap.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_linux.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_other.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_mips.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_mips64x.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_mipsle.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_no_name.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_ppc64x.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_ppc64x_aix.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_ppc64x_linux.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_riscv64.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_s390x.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_s390x.s create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_s390x_test.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_test.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_wasm.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_x86.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_x86.s create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/cpu_x86_test.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/cpu/export_test.go create mode 100755 vendor/github.com/lesismal/llib/std/internal/nettrace/nettrace.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/testenv/testenv.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/testenv/testenv_cgo.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/testenv/testenv_notwin.go create mode 100644 vendor/github.com/lesismal/llib/std/internal/testenv/testenv_windows.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/alpn_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cgi/child.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cgi/child_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cgi/host.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cgi/host_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cgi/integration_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cgi/plan9_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cgi/posix_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cgi/testdata/test.cgi create mode 100644 vendor/github.com/lesismal/llib/std/net/http/client.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/client_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/clientserver_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/clone.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cookie.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cookie_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cookiejar/dummy_publicsuffix_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cookiejar/example_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cookiejar/jar.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cookiejar/jar_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cookiejar/punycode.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/cookiejar/punycode_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/doc.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/example_filesystem_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/example_handle_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/example_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/export_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/fcgi/child.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/fcgi/fcgi.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/fcgi/fcgi_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/filetransport.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/filetransport_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/fs.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/fs_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/h2_bundle.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/header.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/header_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/http.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/http_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httptest/example_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httptest/httptest.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httptest/httptest_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httptest/recorder.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httptest/recorder_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httptest/server.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httptest/server_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httptrace/example_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httptrace/trace.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httptrace/trace_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httputil/dump.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httputil/dump_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httputil/example_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httputil/httputil.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httputil/persist.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httputil/reverseproxy.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/httputil/reverseproxy_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/internal/chunked.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/internal/chunked_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/internal/testcert.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/jar.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/main_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/method.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/omithttp2.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/pprof/pprof.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/pprof/pprof_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/proxy_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/range_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/readrequest_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/request.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/request_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/requestwrite_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/response.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/response_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/responsewrite_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/roundtrip.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/roundtrip_js.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/serve_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/server.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/server_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/sniff.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/sniff_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/socks_bundle.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/status.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/testdata/file create mode 100644 vendor/github.com/lesismal/llib/std/net/http/testdata/index.html create mode 100644 vendor/github.com/lesismal/llib/std/net/http/testdata/style.css create mode 100644 vendor/github.com/lesismal/llib/std/net/http/transfer.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/transfer_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/transport.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/transport_internal_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/transport_test.go create mode 100644 vendor/github.com/lesismal/llib/std/net/http/triv.go create mode 100644 vendor/github.com/lesismal/nbio/.gitattributes create mode 100644 vendor/github.com/lesismal/nbio/.github/workflows/autobahn.yml create mode 100644 vendor/github.com/lesismal/nbio/.github/workflows/build_bsd.yml create mode 100644 vendor/github.com/lesismal/nbio/.github/workflows/build_linux.yml create mode 100644 vendor/github.com/lesismal/nbio/.github/workflows/build_windows.yml create mode 100644 vendor/github.com/lesismal/nbio/.github/workflows/codeql-analysis.yml create mode 100644 vendor/github.com/lesismal/nbio/.github/workflows/golangci-lint.yml create mode 100644 vendor/github.com/lesismal/nbio/.gitignore create mode 100644 vendor/github.com/lesismal/nbio/.golangci.yml create mode 100644 vendor/github.com/lesismal/nbio/LICENSE create mode 100644 vendor/github.com/lesismal/nbio/Makefile create mode 100644 vendor/github.com/lesismal/nbio/README.md create mode 100644 vendor/github.com/lesismal/nbio/autobahn/.gitignore create mode 100644 vendor/github.com/lesismal/nbio/autobahn/README.md create mode 100644 vendor/github.com/lesismal/nbio/autobahn/config/client_tests.json create mode 100644 vendor/github.com/lesismal/nbio/autobahn/docker/autobahn/Dockerfile create mode 100644 vendor/github.com/lesismal/nbio/autobahn/docker/server/Dockerfile create mode 100644 vendor/github.com/lesismal/nbio/autobahn/main.go create mode 100644 vendor/github.com/lesismal/nbio/autobahn/main_go18.go create mode 100755 vendor/github.com/lesismal/nbio/autobahn/script/test.sh create mode 100644 vendor/github.com/lesismal/nbio/autobahn/server/autobahn_test.go create mode 100644 vendor/github.com/lesismal/nbio/autobahn/server/server.go create mode 100644 vendor/github.com/lesismal/nbio/conn.go create mode 100644 vendor/github.com/lesismal/nbio/conn_std.go create mode 100644 vendor/github.com/lesismal/nbio/conn_unix.go create mode 100644 vendor/github.com/lesismal/nbio/error.go create mode 100644 vendor/github.com/lesismal/nbio/extension/tls/tls.go create mode 100644 vendor/github.com/lesismal/nbio/go.mod create mode 100644 vendor/github.com/lesismal/nbio/go.sum create mode 100644 vendor/github.com/lesismal/nbio/gopher.go create mode 100644 vendor/github.com/lesismal/nbio/gopher_std.go create mode 100644 vendor/github.com/lesismal/nbio/gopher_unix.go create mode 100644 vendor/github.com/lesismal/nbio/logging/log.go create mode 100644 vendor/github.com/lesismal/nbio/logging/log_test.go create mode 100644 vendor/github.com/lesismal/nbio/mempool/mempool.go create mode 100644 vendor/github.com/lesismal/nbio/mempool/mempool_test.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/body.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/client.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/client_conn.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/engine.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/error.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/parser.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/parser_test.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/processor.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/response.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/server.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/state.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/table.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/tests/poller_test.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/upgrader.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/websocket/compression.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/websocket/conn.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/websocket/dialer.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/websocket/error.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/websocket/upgrader.go create mode 100644 vendor/github.com/lesismal/nbio/nbhttp/websocket/upgrader_test.go create mode 100644 vendor/github.com/lesismal/nbio/nbio_test.go create mode 100644 vendor/github.com/lesismal/nbio/net_unix.go create mode 100644 vendor/github.com/lesismal/nbio/poller_epoll.go create mode 100644 vendor/github.com/lesismal/nbio/poller_kqueue.go create mode 100644 vendor/github.com/lesismal/nbio/poller_std.go create mode 100644 vendor/github.com/lesismal/nbio/sendfile_bsd.go create mode 100644 vendor/github.com/lesismal/nbio/sendfile_linux.go create mode 100644 vendor/github.com/lesismal/nbio/sendfile_std.go create mode 100644 vendor/github.com/lesismal/nbio/taskpool/caller.go create mode 100644 vendor/github.com/lesismal/nbio/taskpool/fixednoorderpool.go create mode 100644 vendor/github.com/lesismal/nbio/taskpool/fixedpool.go create mode 100644 vendor/github.com/lesismal/nbio/taskpool/mixedpool.go create mode 100644 vendor/github.com/lesismal/nbio/taskpool/taskpool.go create mode 100644 vendor/github.com/lesismal/nbio/taskpool/taskpool_test.go create mode 100644 vendor/github.com/lesismal/nbio/timer_heap.go create mode 100644 vendor/golang.org/x/crypto/acme/autocert/internal/acmetest/ca.go create mode 100644 vendor/golang.org/x/crypto/acme/http.go create mode 100644 vendor/golang.org/x/crypto/acme/http_test.go create mode 100644 vendor/golang.org/x/crypto/acme/internal/acmeprobe/prober.go create mode 100644 vendor/golang.org/x/crypto/acme/rfc8555.go create mode 100644 vendor/golang.org/x/crypto/acme/rfc8555_test.go create mode 100644 vendor/golang.org/x/crypto/acme/version_go112.go create mode 100644 vendor/golang.org/x/crypto/chacha20poly1305/xchacha20poly1305.go create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/README create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/_asm/fe_amd64_asm.go create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/_asm/go.mod create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/_asm/go.sum create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/fe.go create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/fe_alias_test.go create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/fe_amd64.go create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/fe_amd64.s create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/fe_amd64_noasm.go create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/fe_arm64.go create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/fe_arm64.s create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/fe_arm64_noasm.go create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/fe_bench_test.go create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/fe_generic.go create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/fe_test.go create mode 100644 vendor/golang.org/x/crypto/curve25519/internal/field/sync.checkpoint create mode 100755 vendor/golang.org/x/crypto/curve25519/internal/field/sync.sh create mode 100644 vendor/golang.org/x/crypto/curve25519/vectors_test.go create mode 100644 vendor/golang.org/x/crypto/ed25519/ed25519_go113.go create mode 100644 vendor/golang.org/x/crypto/ed25519/go113_test.go create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/bits_compat.go create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/bits_go1.13.go create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/mac_noasm.go create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/poly1305.go create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/poly1305_test.go create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/sum_amd64.go create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/sum_amd64.s create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/sum_generic.go create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/sum_ppc64le.go create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/sum_ppc64le.s create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/sum_s390x.go create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/sum_s390x.s create mode 100644 vendor/golang.org/x/crypto/internal/poly1305/vectors_test.go create mode 100644 vendor/golang.org/x/crypto/internal/subtle/aliasing_purego.go create mode 100644 vendor/golang.org/x/crypto/internal/wycheproof/rsa_oaep_decrypt_test.go create mode 100644 vendor/golang.org/x/crypto/nacl/sign/sign.go create mode 100644 vendor/golang.org/x/crypto/nacl/sign/sign_test.go create mode 100644 vendor/golang.org/x/crypto/openpgp/keys_data_test.go create mode 100644 vendor/golang.org/x/crypto/poly1305/poly1305_compat.go create mode 100644 vendor/golang.org/x/crypto/ssh/common_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go create mode 100644 vendor/golang.org/x/crypto/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/ssh_gss.go create mode 100644 vendor/golang.org/x/crypto/ssh/ssh_gss_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/test/multi_auth_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/test/sshd_test_pw.c diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index 1adec41..b620707 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -1283,7 +1283,6 @@ func (chain *Blockchain) IS_TX_Valid(txhash crypto.Hash) (valid_blid crypto.Hash for _, bltxhash := range bl.Tx_hashes { if bltxhash == txhash { exist_list = append(exist_list, blid) - //break , this is removed so as this case can be tested well } } } diff --git a/blockchain/hardcoded_sc/nameservice.bas b/blockchain/hardcoded_sc/nameservice.bas index 8f7d60e..cb91109 100644 --- a/blockchain/hardcoded_sc/nameservice.bas +++ b/blockchain/hardcoded_sc/nameservice.bas @@ -11,6 +11,7 @@ // Register a name, limit names of 5 or less length Function Register(name String) Uint64 10 IF EXISTS(name) THEN GOTO 50 // if name is already used, it cannot reregistered + 15 IF STRLEN(name) >= 64 THEN GOTO 50 // skip names misuse 20 IF STRLEN(name) >= 6 THEN GOTO 40 30 IF SIGNER() == address_raw("deto1qyvyeyzrcm2fzf6kyq7egkes2ufgny5xn77y6typhfx9s7w3mvyd5qqynr5hx") THEN GOTO 40 35 IF SIGNER() != address_raw("deto1qy0ehnqjpr0wxqnknyc66du2fsxyktppkr8m8e6jvplp954klfjz2qqdzcd8p") THEN GOTO 50 diff --git a/blockchain/miner_block.go b/blockchain/miner_block.go index f750bce..d3797ac 100644 --- a/blockchain/miner_block.go +++ b/blockchain/miner_block.go @@ -249,24 +249,24 @@ func (chain *Blockchain) Create_new_miner_block(miner_address rpc.Address) (cbl } if tx.IsProofRequired() && len(bl.Tips) == 2 { - if tx.BLID == bl.Tips[0] || tx.BLID == bl.Tips[1] { + if tx.BLID == bl.Tips[0] || tx.BLID == bl.Tips[1] { // delay txs by a block if they would collide logger.V(8).Info("not selecting tx due to probable collision", "txid", tx_hash_list_sorted[i].Hash) continue } - } else { - version, err := chain.ReadBlockSnapshotVersion(tx.BLID) - if err != nil { - continue - } - hash, err := chain.Load_Merkle_Hash(version) - if err != nil { - continue - } + } - if hash != tx.Payloads[0].Statement.Roothash { - //return fmt.Errorf("Tx statement roothash mismatch expected %x actual %x", tx.Payloads[0].Statement.Roothash, hash[:]) - continue - } + version, err := chain.ReadBlockSnapshotVersion(tx.BLID) + if err != nil { + continue + } + hash, err := chain.Load_Merkle_Hash(version) + if err != nil { + continue + } + + if hash != tx.Payloads[0].Statement.Roothash { + //return fmt.Errorf("Tx statement roothash mismatch expected %x actual %x", tx.Payloads[0].Statement.Roothash, hash[:]) + continue } if height-int64(tx.Height) < TX_VALIDITY_HEIGHT { @@ -353,7 +353,7 @@ func ConvertBlockToMiniblock(bl block.Block, miniblock_miner_address rpc.Address timestamp := uint64(globals.Time().UTC().UnixMilli()) mbl.Timestamp = uint16(timestamp) // this will help us better understand network conditions - + mbl.PastCount = byte(len(bl.Tips)) for i := range bl.Tips { mbl.Past[i] = binary.BigEndian.Uint32(bl.Tips[i][:]) diff --git a/blockchain/storetopo.go b/blockchain/storetopo.go index 7d8920a..c352cf9 100644 --- a/blockchain/storetopo.go +++ b/blockchain/storetopo.go @@ -277,8 +277,22 @@ func (chain *Blockchain) Find_Blocks_Height_Range(startheight, stopheight int64) } _, topos_end := chain.Store.Topo_store.binarySearchHeight(stopheight) + lowest := topos_start[0] + for _, t := range topos_start { + if t < lowest { + lowest = t + } + } + + highest := topos_end[0] + for _, t := range topos_end { + if t > highest { + highest = t + } + } + blid_map := map[crypto.Hash]bool{} - for i := topos_start[0]; i <= topos_end[0]; i++ { + for i := lowest; i <= highest; i++ { if toporecord, err := chain.Store.Topo_store.Read(i); err != nil { panic(err) } else { diff --git a/cmd/dero-miner/miner.go b/cmd/dero-miner/miner.go index 1e0396c..81ebd76 100644 --- a/cmd/dero-miner/miner.go +++ b/cmd/dero-miner/miner.go @@ -20,7 +20,9 @@ import "io" import "os" import "fmt" import "time" +import "net/url" import "crypto/rand" +import "crypto/tls" import "sync" import "runtime" import "math/big" @@ -31,7 +33,6 @@ import "os/signal" import "sync/atomic" import "strings" import "strconv" -import "context" import "github.com/go-logr/logr" @@ -48,13 +49,10 @@ import "github.com/docopt/docopt-go" import "github.com/deroproject/derohe/pow" import "github.com/gorilla/websocket" -import "github.com/deroproject/derohe/glue/rwc" - -import "github.com/creachadair/jrpc2" -import "github.com/creachadair/jrpc2/channel" var mutex sync.RWMutex var job rpc.GetBlockTemplate_Result +var job_counter int64 var maxdelay int = 10000 var threads int var iterations int = 100 @@ -67,8 +65,8 @@ var hash_rate uint64 var Difficulty uint64 var our_height int64 -var block_counter int -var mini_block_counter int +var block_counter uint64 +var mini_block_counter uint64 var logger logr.Logger var command_line string = `dero-miner @@ -96,15 +94,6 @@ If daemon running on local machine no requirement of '--daemon-rpc-address' argu ` var Exit_In_Progress = make(chan bool) -func Notify_broadcaster(req *jrpc2.Request) { - switch req.Method() { - case "Block", "MiniBlock", "Height": - go rpc_client.update_job() - default: - logger.V(1).Info("Notification received but not handled", "method", req.Method()) - } -} - func main() { var err error @@ -165,9 +154,9 @@ func main() { } if !globals.Arguments["--testnet"].(bool) { - daemon_rpc_address = "127.0.0.1:10102" + daemon_rpc_address = "127.0.0.1:10100" } else { - daemon_rpc_address = "127.0.0.1:40402" + daemon_rpc_address = "127.0.0.1:10100" } if globals.Arguments["--daemon-rpc-address"] != nil { @@ -234,17 +223,12 @@ func main() { go func() { last_our_height := int64(0) last_best_height := int64(0) - last_peer_count := uint64(0) - last_topo_height := int64(0) - last_mempool_tx_count := 0 + last_counter := uint64(0) last_counter_time := time.Now() last_mining_state := false _ = last_mining_state - _ = last_peer_count - _ = last_topo_height - _ = last_mempool_tx_count mining := true for { @@ -254,27 +238,12 @@ func main() { default: } - best_height, best_topo_height := int64(0), int64(0) - peer_count := uint64(0) - - mempool_tx_count := 0 - + best_height := int64(0) // only update prompt if needed if last_our_height != our_height || last_best_height != best_height || last_counter != counter { // choose color based on urgency - color := "\033[33m" // default is green color - /*if our_height < best_height { - color = "\033[33m" // make prompt yellow - } else if our_height > best_height { - color = "\033[31m" // make prompt red - }*/ - + color := "\033[33m" // default is green color pcolor := "\033[32m" // default is green color - /*if peer_count < 1 { - pcolor = "\033[31m" // make prompt red - } else if peer_count <= 8 { - pcolor = "\033[33m" // make prompt yellow - }*/ mining_string := "" @@ -313,16 +282,10 @@ func main() { testnet_string = "\033[31m TESTNET" } - extra := fmt.Sprintf("%f", float32(mini_block_counter)/float32(block_counter)) - - l.SetPrompt(fmt.Sprintf("\033[1m\033[32mDERO Miner: \033[0m"+color+"Height %d "+pcolor+" BLOCKS %d MiniBlocks %d \033[32mNW %s %s>%s> avg %s >\033[0m ", our_height, block_counter, mini_block_counter, hash_rate_string, mining_string, testnet_string, extra)) + l.SetPrompt(fmt.Sprintf("\033[1m\033[32mDERO Miner: \033[0m"+color+"Height %d "+pcolor+" BLOCKS %d MiniBlocks %d \033[32mNW %s %s>%s>>\033[0m ", our_height, block_counter, mini_block_counter, hash_rate_string, mining_string, testnet_string)) l.Refresh() last_our_height = our_height last_best_height = best_height - last_peer_count = peer_count - last_mempool_tx_count = mempool_tx_count - last_topo_height = best_topo_height - } time.Sleep(1 * time.Second) } @@ -348,11 +311,10 @@ func main() { threads = 255 } - go increase_delay() - go getwork() + go getwork(wallet_address) for i := 0; i < threads; i++ { - go rpc_client.mineblock(i) + go mineblock(i) } for { @@ -426,91 +388,61 @@ func random_execution(wg *sync.WaitGroup, iterations int) { runtime.UnlockOSThread() } -func increase_delay() { - for { - time.Sleep(time.Second) - maxdelay++ - } -} - -type Client struct { - WS *websocket.Conn - RPC *jrpc2.Client - Connected bool -} - -var rpc_client = &Client{} - // continuously get work -func getwork() { + +var connection *websocket.Conn +var connection_mutex sync.Mutex + +func getwork(wallet_address string) { var err error for { - rpc_client.WS, _, err = websocket.DefaultDialer.Dial("ws://"+daemon_rpc_address+"/ws", nil) + u := url.URL{Scheme: "wss", Host: daemon_rpc_address, Path: "/ws/" + wallet_address} + logger.Info("connecting to ", "url", u.String()) + + dialer := websocket.DefaultDialer + dialer.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + connection, _, err = websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { logger.Error(err, "Error connecting to server", "server adress", daemon_rpc_address) logger.Info("Will try in 10 secs", "server adress", daemon_rpc_address) - rpc_client.Connected = false time.Sleep(10 * time.Second) continue } - input_output := rwc.New(rpc_client.WS) - rpc_client.RPC = jrpc2.NewClient(channel.RawJSON(input_output, input_output), &jrpc2.ClientOptions{OnNotify: Notify_broadcaster}) - rpc_client.Connected = true + var result rpc.GetBlockTemplate_Result + wait_for_another_job: - for { - if err = rpc_client.update_job(); err != nil { - break - } - time.Sleep(100 * time.Millisecond) + if err = connection.ReadJSON(&result); err != nil { + logger.Error(err, "connection error") + continue } - time.Sleep(4 * time.Second) - } - -} - -func (cli *Client) update_job() (err error) { - defer globals.Recover(1) - var result rpc.GetBlockTemplate_Result - - if err = rpc_client.Call("DERO.GetBlockTemplate", rpc.GetBlockTemplate_Params{Wallet_Address: wallet_address}, &result); err == nil { mutex.Lock() job = result - maxdelay = 0 + job_counter++ mutex.Unlock() + if job.LastError != "" { + logger.Error(nil, "received error", "err", job.LastError) + } + + block_counter = job.Blocks + mini_block_counter = job.MiniBlocks hash_rate = job.Difficultyuint64 our_height = int64(job.Height) Difficulty = job.Difficultyuint64 - } else { - rpc_client.WS.Close() - rpc_client.Connected = false - logger.Error(err, "Error receiving block template") - + //fmt.Printf("recv: %s", result) + goto wait_for_another_job } - return err } -func (cli *Client) Call(method string, params interface{}, result interface{}) error { - return cli.RPC.CallResult(context.Background(), method, params, result) -} - -// tests connectivity when connectivity to daemon -func (rpc_client *Client) test_connectivity() (err error) { - var info rpc.GetInfo_Result - if err = rpc_client.Call("DERO.GetInfo", nil, &info); err != nil { - logger.V(1).Error(err, "DERO.GetInfo Call failed:") - return - } - return nil -} - -func (rpc_client *Client) mineblock(tid int) { +func mineblock(tid int) { var diff big.Int var work [block.MINIBLOCK_SIZE]byte @@ -518,19 +450,16 @@ func (rpc_client *Client) mineblock(tid int) { runtime.LockOSThread() threadaffinity() + var local_job_counter int64 + i := uint32(0) for { mutex.RLock() - myjob := job + local_job_counter = job_counter mutex.RUnlock() - if rpc_client.Connected == false { - time.Sleep(10 * time.Millisecond) - continue - } - n, err := hex.Decode(work[:], []byte(myjob.Blockhashing_blob)) if err != nil || n != block.MINIBLOCK_SIZE { logger.Error(err, "Blockwork could not decoded successfully", "blockwork", myjob.Blockhashing_blob, "n", n, "job", myjob) @@ -543,40 +472,27 @@ func (rpc_client *Client) mineblock(tid int) { diff.SetString(myjob.Difficulty, 10) if work[0]&0xf != 1 { // check version - logger.Error(nil, "Unknown version", "version", work[0]&0x1f) + logger.Error(nil, "Unknown version, please check for updates", "version", work[0]&0x1f) time.Sleep(time.Second) continue } - for { + for local_job_counter == job_counter { // update job when it comes, expected rate 1 per second i++ binary.BigEndian.PutUint32(nonce_buf, i) - if i&0x3ff == 0x3ff { // get updated job every 250 millisecs - break - } - powhash := pow.Pow(work[:]) atomic.AddUint64(&counter, 1) if CheckPowHashBig(powhash, &diff) == true { logger.V(1).Info("Successfully found DERO miniblock", "difficulty", myjob.Difficulty, "height", myjob.Height) - maxdelay = 200 - var result rpc.SubmitBlock_Result - if err = rpc_client.Call("DERO.SubmitBlock", rpc.SubmitBlock_Params{JobID: myjob.JobID, MiniBlockhashing_blob: fmt.Sprintf("%x", work[:])}, &result); err == nil { + func() { + defer globals.Recover(1) + connection_mutex.Lock() + defer connection_mutex.Unlock() + connection.WriteJSON(rpc.SubmitBlock_Params{JobID: myjob.JobID, MiniBlockhashing_blob: fmt.Sprintf("%x", work[:])}) + }() - if result.MiniBlock { - mini_block_counter++ - } else { - block_counter++ - } - logger.V(2).Info("submitting block", "result", result) - go rpc_client.update_job() - } else { - logger.Error(err, "error submitting block") - rpc_client.update_job() - break - } } } } @@ -584,7 +500,6 @@ func (rpc_client *Client) mineblock(tid int) { func usage(w io.Writer) { io.WriteString(w, "commands:\n") - //io.WriteString(w, completer.Tree(" ")) io.WriteString(w, "\t\033[1mhelp\033[0m\t\tthis help\n") io.WriteString(w, "\t\033[1mstatus\033[0m\t\tShow general information\n") io.WriteString(w, "\t\033[1mbye\033[0m\t\tQuit the miner\n") diff --git a/cmd/dero-wallet-cli/prompt.go b/cmd/dero-wallet-cli/prompt.go index 877993f..2efcd74 100644 --- a/cmd/dero-wallet-cli/prompt.go +++ b/cmd/dero-wallet-cli/prompt.go @@ -890,7 +890,7 @@ func valid_registration_or_display_error(l *readline.Instance, wallet *walletapi // show the transfers to the user originating from this account func show_transfers(l *readline.Instance, wallet *walletapi.Wallet_Disk, scid crypto.Hash, limit uint64) { - if wallet.GetMode() { // if wallet is in offline mode , we cannot do anything + if wallet.GetMode() && walletapi.IsDaemonOnline() { // if wallet is in offline mode , we cannot do anything if err := wallet.Sync_Wallet_Memory_With_Daemon_internal(scid); err != nil { logger.Error(err, "Error syncing wallet", "scid", scid.String()) return diff --git a/cmd/derod/main.go b/cmd/derod/main.go index dc1b883..85af369 100644 --- a/cmd/derod/main.go +++ b/cmd/derod/main.go @@ -59,7 +59,7 @@ var command_line string = `derod DERO : A secure, private blockchain with smart-contracts Usage: - derod [--help] [--version] [--testnet] [--debug] [--sync-node] [--timeisinsync] [--fastsync] [--socks-proxy=] [--data-dir=] [--p2p-bind=<0.0.0.0:18089>] [--add-exclusive-node=]... [--add-priority-node=]... [--min-peers=<11>] [--rpc-bind=<127.0.0.1:9999>] [--node-tag=] [--prune-history=<50>] [--integrator-address=
] [--clog-level=1] [--flog-level=1] + derod [--help] [--version] [--testnet] [--debug] [--sync-node] [--timeisinsync] [--fastsync] [--socks-proxy=] [--data-dir=] [--p2p-bind=<0.0.0.0:18089>] [--add-exclusive-node=]... [--add-priority-node=]... [--min-peers=<11>] [--rpc-bind=<127.0.0.1:9999>] [--getwork-bind=<0.0.0.0:18089>] [--node-tag=] [--prune-history=<50>] [--integrator-address=
] [--clog-level=1] [--flog-level=1] derod -h | --help derod --version @@ -76,6 +76,7 @@ Options: --data-dir= Store blockchain data at this location --rpc-bind=<127.0.0.1:9999> RPC listens on this ip:port --p2p-bind=<0.0.0.0:18089> p2p server listens on this ip:port, specify port 0 to disable listening server + --getwork-bind=<0.0.0.0:10100> getwork server listens on this ip:port, specify port 0 to disable listening server --add-exclusive-node= Connect to specific peer only --add-priority-node= Maintain persistant connection to specified peer --sync-node Sync node automatically with the seeds nodes. This option is for rare use. @@ -211,6 +212,8 @@ func main() { p2p.P2P_Init(params) rpcserver, _ := derodrpc.RPCServer_Start(params) + go derodrpc.Getwork_server() + // setup function pointers chain.P2P_Block_Relayer = func(cbl *block.Complete_Block, peerid uint64) { p2p.Broadcast_Block(cbl, peerid) @@ -284,7 +287,7 @@ func main() { testnet_string += " " + strconv.Itoa(chain.MiniBlocks.Count()) + " " + globals.GetOffset().Round(time.Millisecond).String() + "|" + globals.GetOffsetNTP().Round(time.Millisecond).String() + "|" + globals.GetOffsetP2P().Round(time.Millisecond).String() - l.SetPrompt(fmt.Sprintf("\033[1m\033[32mDERO HE: \033[0m"+color+"%d/%d [%d/%d] "+pcolor+"P %d TXp %d:%d \033[32mNW %s >%s>>\033[0m ", our_height, topo_height, best_height, best_topo_height, peer_count, mempool_tx_count, regpool_tx_count, hash_rate_string, testnet_string)) + l.SetPrompt(fmt.Sprintf("\033[1m\033[32mDERO HE: \033[0m"+color+"%d/%d [%d/%d] "+pcolor+"P %d TXp %d:%d \033[32mNW %s >Miners %d %s>>\033[0m ", our_height, topo_height, best_height, best_topo_height, peer_count, mempool_tx_count, regpool_tx_count, hash_rate_string, derodrpc.CountMiners(), testnet_string)) l.Refresh() last_second = time.Now().Unix() last_our_height = our_height @@ -491,6 +494,19 @@ func readline_loop(l *readline.Instance, chain *blockchain.Blockchain, logger lo logger.Error(fmt.Errorf("regpool_delete_tx needs a single transaction id as argument"), "") } + case command == "mempool_dump": // dump mempool to directory + tx_hash_list_sorted := chain.Mempool.Mempool_List_TX_SortedInfo() // hash of all tx expected to be included within this block , sorted by fees + + os.Mkdir(filepath.Join(globals.GetDataDirectory(), "mempool"), 0755) + count := 0 + for _, txi := range tx_hash_list_sorted { + if tx := chain.Mempool.Mempool_Get_TX(txi.Hash); tx != nil { + os.WriteFile(filepath.Join(globals.GetDataDirectory(), "mempool", txi.Hash.String()), tx.Serialize(), 0755) + count++ + } + } + logger.Info("flushed mempool to driectory", "count", count, "dir", filepath.Join(globals.GetDataDirectory(), "mempool")) + case command == "mempool_print": chain.Mempool.Mempool_Print() @@ -981,6 +997,7 @@ var completer = readline.NewPrefixCompleter( readline.PcItem("help"), readline.PcItem("diff"), readline.PcItem("gc"), + readline.PcItem("mempool_dump"), readline.PcItem("mempool_flush"), readline.PcItem("mempool_delete_tx"), readline.PcItem("mempool_print"), diff --git a/cmd/derod/rpc/websocket_getwork_server.go b/cmd/derod/rpc/websocket_getwork_server.go new file mode 100644 index 0000000..44ad7ef --- /dev/null +++ b/cmd/derod/rpc/websocket_getwork_server.go @@ -0,0 +1,341 @@ +package rpc + +import ( + "flag" + "fmt" + "net/http" + + "time" + + "github.com/lesismal/llib/std/crypto/tls" + "github.com/lesismal/nbio/nbhttp" + "github.com/lesismal/nbio/nbhttp/websocket" +) + +import "github.com/lesismal/nbio" +import "github.com/lesismal/nbio/logging" + +import "net" +import "bytes" +import "encoding/hex" +import "encoding/json" +import "runtime" +import "strings" +import "math/big" +import "crypto/ecdsa" +import "crypto/elliptic" + +//import "crypto/tls" +import "crypto/rand" +import "crypto/x509" +import "encoding/pem" + +import "github.com/deroproject/derohe/globals" +import "github.com/deroproject/derohe/rpc" +import "github.com/deroproject/graviton" +import "github.com/go-logr/logr" + +// this file implements the non-blocking job streamer +// only job is to stream jobs to thousands of workers, if any is successful,accept and report back + +import "sync" + +var memPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 16*1024) + }, +} + +var logger_getwork logr.Logger +var ( + svr *nbhttp.Server + print = flag.Bool("print", false, "stdout output of echoed data") +) + +type user_session struct { + blocks uint64 + miniblocks uint64 + lasterr string + address rpc.Address + valid_address bool + address_sum [32]byte +} + +var client_list_mutex sync.Mutex +var client_list = map[*websocket.Conn]*user_session{} + +func CountMiners() int { + client_list_mutex.Lock() + defer client_list_mutex.Unlock() + return len(client_list) +} + +func SendJob() { + + var params rpc.GetBlockTemplate_Result + + var buf bytes.Buffer + encoder := json.NewEncoder(&buf) + + // get a block template, and then we will fill the address here as optimization + bl, mbl, _, _, err := chain.Create_new_block_template_mining(chain.IntegratorAddress()) + if err != nil { + return + } + + prev_hash := "" + for i := range bl.Tips { + prev_hash = prev_hash + bl.Tips[i].String() + } + + params.JobID = fmt.Sprintf("%d.%d.%s", bl.Timestamp, 0, "notified") + diff := chain.Get_Difficulty_At_Tips(bl.Tips) + + params.Height = bl.Height + params.Prev_Hash = prev_hash + params.Difficultyuint64 = diff.Uint64() + params.Difficulty = diff.String() + client_list_mutex.Lock() + defer client_list_mutex.Unlock() + + for k, v := range client_list { + if !mbl.Final { //write miners address only if possible + copy(mbl.KeyHash[:], v.address_sum[:]) + } + + for i := range mbl.Nonce { // give each user different work + mbl.Nonce[i] = globals.Global_Random.Uint32() // fill with randomness + } + + if v.lasterr != "" { + params.LastError = v.lasterr + v.lasterr = "" + } + + if !v.valid_address && !chain.IsAddressHashValid(false, v.address_sum) { + params.LastError = "unregistered miner or you need to wait 15 mins" + } else { + v.valid_address = true + } + params.Blockhashing_blob = fmt.Sprintf("%x", mbl.Serialize()) + params.Blocks = v.blocks + params.MiniBlocks = v.miniblocks + + encoder.Encode(params) + k.WriteMessage(websocket.TextMessage, buf.Bytes()) + buf.Reset() + + } + +} + +func newUpgrader() *websocket.Upgrader { + u := websocket.NewUpgrader() + + u.OnMessage(func(c *websocket.Conn, messageType websocket.MessageType, data []byte) { + // echo + c.WriteMessage(messageType, data) + + if messageType != websocket.TextMessage { + return + } + + sess := c.Session().(*user_session) + + client_list_mutex.Lock() + client_list_mutex.Unlock() + + var p rpc.SubmitBlock_Params + + if err := json.Unmarshal(data, &p); err != nil { + + } + + mbl_block_data_bytes, err := hex.DecodeString(p.MiniBlockhashing_blob) + if err != nil { + //logger.Info("Submitting block could not be decoded") + sess.lasterr = fmt.Sprintf("Submitted block could not be decoded. err: %s", err) + return + } + + var tstamp, extra uint64 + fmt.Sscanf(p.JobID, "%d.%d", &tstamp, &extra) + + _, blid, sresult, err := chain.Accept_new_block(tstamp, mbl_block_data_bytes) + + if sresult { + //logger.Infof("Submitted block %s accepted", blid) + if blid.IsZero() { + sess.miniblocks++ + } else { + sess.blocks++ + } + } + + }) + u.OnClose(func(c *websocket.Conn, err error) { + client_list_mutex.Lock() + delete(client_list, c) + client_list_mutex.Unlock() + }) + + return u +} + +func onWebsocket(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/ws/") { + http.NotFound(w, r) + return + } + address := strings.TrimPrefix(r.URL.Path, "/ws/") + + addr, err := globals.ParseValidateAddress(address) + if err != nil { + fmt.Fprintf(w, "err: %s\n", err) + return + } + addr_raw := addr.PublicKey.EncodeCompressed() + + upgrader := newUpgrader() + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + //panic(err) + return + } + wsConn := conn.(*websocket.Conn) + + session := user_session{address: *addr, address_sum: graviton.Sum(addr_raw)} + wsConn.SetSession(&session) + + client_list_mutex.Lock() + client_list[wsConn] = &session + client_list_mutex.Unlock() +} + +func Getwork_server() { + + var err error + + logger_getwork = globals.Logger.WithName("GETWORK") + + logging.SetLevel(logging.LevelNone) //LevelDebug)//LevelNone) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{generate_random_tls_cert()}, + InsecureSkipVerify: true, + } + + mux := &http.ServeMux{} + mux.HandleFunc("/", onWebsocket) // handle everything + + default_address := fmt.Sprintf("0.0.0.0:%d", globals.Config.GETWORK_Default_Port) + + if _, ok := globals.Arguments["--getwork-bind"]; ok && globals.Arguments["--getwork-bind"] != nil { + addr, err := net.ResolveTCPAddr("tcp", globals.Arguments["--getwork-bind"].(string)) + if err != nil { + logger_getwork.Error(err, "--getwork-bind address is invalid") + return + } else { + if addr.Port == 0 { + logger_getwork.Info("GETWORK server is disabled, No ports will be opened for miners to get work") + return + } else { + default_address = addr.String() + } + } + } + + logger_getwork.Info("GETWORK will listen", "address", default_address) + + svr = nbhttp.NewServer(nbhttp.Config{ + Name: "GETWORK", + Network: "tcp", + AddrsTLS: []string{default_address}, + TLSConfig: tlsConfig, + Handler: mux, + MaxLoad: 10 * 1024, + MaxWriteBufferSize: 32 * 1024, + ReleaseWebsocketPayload: true, + KeepaliveTime: 240 * time.Hour, // we expects all miners to find a block every 10 days, + NPoller: runtime.NumCPU(), + }) + + svr.OnReadBufferAlloc(func(c *nbio.Conn) []byte { + return memPool.Get().([]byte) + }) + svr.OnReadBufferFree(func(c *nbio.Conn, b []byte) { + memPool.Put(b) + }) + + globals.Cron.AddFunc("@every 2s", SendJob) // if daemon restart automaticaly send job + + if err = svr.Start(); err != nil { + logger_getwork.Error(err, "nbio.Start failed.") + return + } + logger.Info("GETWORK/Websocket server started") + svr.Wait() + defer svr.Stop() + +} + +// generate default tls cert to encrypt everything +// NOTE: this does NOT protect from individual active man-in-the-middle attacks +func generate_random_tls_cert() tls.Certificate { + + /* RSA can do only 500 exchange per second, we need to be faster + * reference https://github.com/golang/go/issues/20058 + key, err := rsa.GenerateKey(rand.Reader, 512) // current using minimum size + if err != nil { + log.Fatal("Private key cannot be created.", err.Error()) + } + + // Generate a pem block with the private key + keyPem := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + */ + // EC256 does roughly 20000 exchanges per second + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + b, err := x509.MarshalECPrivateKey(key) + if err != nil { + logger.Error(err, "Unable to marshal ECDSA private key") + panic(err) + } + // Generate a pem block with the private key + keyPem := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) + + tml := x509.Certificate{ + SerialNumber: big.NewInt(int64(time.Now().UnixNano())), + + // TODO do we need to add more parameters to make our certificate more authentic + // and thwart traffic identification as a mass scale + + // you can add any attr that you need + NotBefore: time.Now().AddDate(0, -1, 0), + NotAfter: time.Now().AddDate(1, 0, 0), + // you have to generate a different serial number each execution + /* + Subject: pkix.Name{ + CommonName: "New Name", + Organization: []string{"New Org."}, + }, + BasicConstraintsValid: true, // even basic constraints are not required + */ + } + cert, err := x509.CreateCertificate(rand.Reader, &tml, &tml, &key.PublicKey, key) + if err != nil { + logger.Error(err, "Certificate cannot be created.") + panic(err) + } + + // Generate a pem block with the certificate + certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert}) + tlsCert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + logger.Error(err, "Certificate cannot be loaded.") + panic(err) + } + return tlsCert +} diff --git a/cmd/derod/rpc/websocket_server.go b/cmd/derod/rpc/websocket_server.go index 9b8ab70..f32c285 100644 --- a/cmd/derod/rpc/websocket_server.go +++ b/cmd/derod/rpc/websocket_server.go @@ -117,10 +117,7 @@ func Notify_MiniBlock_Addition() { chain.RPC_NotifyNewMiniBlock.L.Unlock() go func() { defer globals.Recover(2) - client_connections.Range(func(key, value interface{}) bool { - key.(*jrpc2.Server).Notify(context.Background(), "MiniBlock", nil) - return true - }) + SendJob() }() } } diff --git a/config/config.go b/config/config.go index ed99b68..f52f9c4 100644 --- a/config/config.go +++ b/config/config.go @@ -76,6 +76,7 @@ type CHAIN_CONFIG struct { Name string Network_ID uuid.UUID // network ID + GETWORK_Default_Port int // used for miner getwork as effeciently as poosible P2P_Default_Port int RPC_Default_Port int Wallet_RPC_Default_Port int @@ -87,6 +88,7 @@ type CHAIN_CONFIG struct { var Mainnet = CHAIN_CONFIG{Name: "mainnet", Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x9a, 0x44, 0x45, 0x0}), + GETWORK_Default_Port: 10100, P2P_Default_Port: 10101, RPC_Default_Port: 10102, Wallet_RPC_Default_Port: 10103, @@ -103,7 +105,8 @@ var Mainnet = CHAIN_CONFIG{Name: "mainnet", } var Testnet = CHAIN_CONFIG{Name: "testnet", // testnet will always have last 3 bytes 0 - Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x73, 0x00, 0x00, 0x00}), + Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x74, 0x00, 0x00, 0x00}), + GETWORK_Default_Port: 10100, P2P_Default_Port: 40401, RPC_Default_Port: 40402, Wallet_RPC_Default_Port: 40403, diff --git a/config/version.go b/config/version.go index e400d0b..87a0cde 100644 --- a/config/version.go +++ b/config/version.go @@ -20,4 +20,4 @@ import "github.com/blang/semver/v4" // right now it has to be manually changed // do we need to include git commitsha?? -var Version = semver.MustParse("3.4.93-1.DEROHE.STARGATE+25112021") +var Version = semver.MustParse("3.4.94-1.DEROHE.STARGATE+25112021") diff --git a/rpc/daemon_rpc.go b/rpc/daemon_rpc.go index 7b39f94..22fe5bf 100644 --- a/rpc/daemon_rpc.go +++ b/rpc/daemon_rpc.go @@ -114,6 +114,9 @@ type ( Height uint64 `json:"height"` Prev_Hash string `json:"prev_hash"` EpochMilli uint64 `json:"epochmilli"` + Blocks uint64 `json:"blocks"` // number of blocks found + MiniBlocks uint64 `json:"miniblocks"` // number of miniblocks found + LastError string `json:"lasterror"` // last error Status string `json:"status"` } ) @@ -197,14 +200,14 @@ type ( } // no params GetTransaction_Result struct { Txs_as_hex []string `json:"txs_as_hex"` - Txs_as_json []string `json:"txs_as_json"` + Txs_as_json []string `json:"txs_as_json,omitempty"` Txs []Tx_Related_Info `json:"txs"` Status string `json:"status"` } Tx_Related_Info struct { As_Hex string `json:"as_hex"` - As_Json string `json:"as_json"` + As_Json string `json:"as_json,omitempty"` Block_Height int64 `json:"block_height"` Reward uint64 `json:"reward"` // miner tx rewards are decided by the protocol during execution Ignored bool `json:"ignored"` // tell whether this tx is okau as per client protocol or bein ignored @@ -261,17 +264,8 @@ type ( Tx_as_hex string `json:"tx_as_hex"` } SendRawTransaction_Result struct { - Status string `json:"status"` - DoubleSpend bool `json:"double_spend"` - FeeTooLow bool `json:"fee_too_low"` - InvalidInput bool `json:"invalid_input"` - InvalidOutput bool `json:"invalid_output"` - Low_Mixin bool `json:"low_mixin"` - Non_rct bool `json:"not_rct"` - NotRelayed bool `json:"not_relayed"` - Overspend bool `json:"overspend"` - TooBig bool `json:"too_big"` - Reason string `json:"string"` + Status string `json:"status"` + Reason string `json:"string"` } ) diff --git a/vendor/github.com/lesismal/llib/LICENSE b/vendor/github.com/lesismal/llib/LICENSE new file mode 100644 index 0000000..c16031d --- /dev/null +++ b/vendor/github.com/lesismal/llib/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2021 lesismal. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/lesismal/llib/README.md b/vendor/github.com/lesismal/llib/README.md new file mode 100644 index 0000000..61f364e --- /dev/null +++ b/vendor/github.com/lesismal/llib/README.md @@ -0,0 +1,12 @@ +# llib - [lesismal](https://github.com/lesismal)'s lib + +[![GoDoc][1]][2] [![MIT licensed][3]][4] [![Go Version][5]][6] + +[1]: https://godoc.org/github.com/lesismal/llib?status.svg +[2]: https://godoc.org/github.com/lesismal/llib +[3]: https://img.shields.io/badge/license-BSD-blue.svg +[4]: LICENSE +[5]: https://img.shields.io/badge/go-%3E%3D1.16-30dff3?style=flat-square&logo=go +[6]: https://github.com/lesismal/llib + +Less Is More :smile: diff --git a/vendor/github.com/lesismal/llib/bytes/buffer.go b/vendor/github.com/lesismal/llib/bytes/buffer.go new file mode 100644 index 0000000..d1dc3a9 --- /dev/null +++ b/vendor/github.com/lesismal/llib/bytes/buffer.go @@ -0,0 +1,225 @@ +package bytes + +import ( + "errors" +) + +var ( + ErrInvalidLength = errors.New("invalid length") + ErrInvalidPosition = errors.New("invalid position") + ErrNotEnougth = errors.New("bytes not enougth") +) + +// Buffer . +type Buffer struct { + total int + buffers [][]byte + onRelease func(b []byte) +} + +// Len . +func (bb *Buffer) Len() int { + return bb.total +} + +// Push . +func (bb *Buffer) Push(b []byte) { + if len(b) == 0 { + return + } + bb.buffers = append(bb.buffers, b) + bb.total += len(b) +} + +// Pop . +func (bb *Buffer) Pop(n int) ([]byte, error) { + if n < 0 { + return nil, ErrInvalidLength + } + if bb.total < n { + return nil, ErrNotEnougth + } + + bb.total -= n + + var buf = bb.buffers[0] + if len(buf) >= n { + ret := buf[:n] + bb.buffers[0] = bb.buffers[0][n:] + if len(bb.buffers[0]) == 0 { + bb.releaseHead() + } + return ret, nil + } + + var ret = make([]byte, n)[0:0] + for n > 0 { + if len(buf) >= n { + ret = append(ret, buf[:n]...) + bb.buffers[0] = bb.buffers[0][n:] + if len(bb.buffers[0]) == 0 { + bb.releaseHead() + } + return ret, nil + } + ret = append(ret, buf...) + bb.releaseHead() + n -= len(buf) + buf = bb.buffers[0] + } + return ret, nil +} + +// Append . +func (bb *Buffer) Append(b []byte) { + if len(b) == 0 { + return + } + + n := len(bb.buffers) + + if n == 0 { + bb.buffers = append(bb.buffers, b) + return + } + bb.buffers[n-1] = append(bb.buffers[n-1], b...) + bb.total += len(b) +} + +// Head . +func (bb *Buffer) Head(n int) ([]byte, error) { + if n < 0 { + return nil, ErrInvalidLength + } + if bb.total < n { + return nil, ErrNotEnougth + } + + if len(bb.buffers[0]) >= n { + return bb.buffers[0][:n], nil + } + + ret := make([]byte, n) + + copied := 0 + for i := 0; n > 0; i++ { + buf := bb.buffers[i] + if len(buf) >= n { + copy(ret[copied:], buf[:n]) + return ret, nil + } else { + copy(ret[copied:], buf) + n -= len(buf) + copied += len(buf) + } + } + + return ret, nil +} + +// Sub . +func (bb *Buffer) Sub(from, to int) ([]byte, error) { + if from < 0 || to < 0 || to < from { + return nil, ErrInvalidPosition + } + if bb.total < to { + return nil, ErrNotEnougth + } + + if len(bb.buffers[0]) >= to { + return bb.buffers[0][from:to], nil + } + + n := to - from + ret := make([]byte, n) + copied := 0 + for i := 0; n > 0; i++ { + buf := bb.buffers[i] + if len(buf) >= from+n { + copy(ret[copied:], buf[from:from+n]) + return ret, nil + } else { + if len(buf) > from { + if from > 0 { + buf = buf[from:] + from = 0 + } + copy(ret[copied:], buf) + copied += len(buf) + n -= len(buf) + } else { + from -= len(buf) + } + } + } + + return ret, nil +} + +// Write . +func (bb *Buffer) Write(b []byte) { + bb.Push(b) +} + +// Read . +func (bb *Buffer) Read(n int) ([]byte, error) { + return bb.Pop(n) +} + +// ReadAll . +func (bb *Buffer) ReadAll() ([]byte, error) { + if len(bb.buffers) == 0 { + return nil, nil + } + + ret := append([]byte{}, bb.buffers[0]...) + if bb.onRelease != nil { + bb.onRelease(bb.buffers[0]) + for i := 1; i < len(bb.buffers); i++ { + ret = append(ret, bb.buffers[i]...) + bb.onRelease(bb.buffers[i]) + + } + } else { + for i := 1; i < len(bb.buffers); i++ { + ret = append(ret, bb.buffers[i]...) + } + } + bb.buffers = nil + bb.total = 0 + + return ret, nil +} + +// Reset . +func (bb *Buffer) Reset() { + if bb.onRelease != nil { + for i := 0; i < len(bb.buffers); i++ { + bb.onRelease(bb.buffers[i]) + + } + } + bb.buffers = nil + bb.total = 0 +} + +func (bb *Buffer) OnRelease(onRelease func(b []byte)) { + bb.onRelease = onRelease +} + +func (bb *Buffer) releaseHead() { + if bb.onRelease != nil { + bb.onRelease(bb.buffers[0]) + } + switch len(bb.buffers) { + case 1: + bb.buffers = nil + default: + bb.buffers = bb.buffers[1:] + } +} + +// NewBuffer . +func NewBuffer() *Buffer { + return &Buffer{} +} diff --git a/vendor/github.com/lesismal/llib/bytes/buffer_test.go b/vendor/github.com/lesismal/llib/bytes/buffer_test.go new file mode 100644 index 0000000..982a3c8 --- /dev/null +++ b/vendor/github.com/lesismal/llib/bytes/buffer_test.go @@ -0,0 +1,108 @@ +package bytes + +import ( + "testing" +) + +func TestBuffer(t *testing.T) { + str := "hello world" + + buffer := NewBuffer() + buffer.Write([]byte("hel")) + buffer.Write([]byte("lo world")) + b, err := buffer.ReadAll() + if err != nil { + t.Fatal(err) + } + if string(b) != str { + t.Fatal(string(b)) + } + + buffer.Write([]byte("hel")) + buffer.Write([]byte("lo ")) + buffer.Write([]byte("wor")) + buffer.Write([]byte("ld")) + for i := 0; i < len(str); i++ { + for j := i; j < len(str); j++ { + sub, err := buffer.Sub(i, j) + if err != nil { + t.Fatal(err) + } + if string(sub) != string([]byte(str)[i:j]) { + t.Fatalf("[%v:%v] %v != %v", i, j, string(sub), string([]byte(str)[i:j])) + } + } + } + + for i := 0; i < len(str); i++ { + for j := i; j < len(str); j++ { + buffer.Write([]byte("hel")) + buffer.Write([]byte("lo ")) + buffer.Write([]byte("wor")) + buffer.Write([]byte("ld")) + + b, err = buffer.Read(j) + if err != nil { + t.Fatal(err) + } + if string(b) != string([]byte(str)[:j]) { + t.Fatalf("[%v:%v] %v != %v", i, j, string(b), string([]byte(str)[:j])) + } + + buffer.Reset() + } + } + + for i := 0; i < len(str); i++ { + for j := i; j < len(str); j++ { + buffer.Write([]byte("hel")) + buffer.Write([]byte("lo ")) + buffer.Write([]byte("wor")) + buffer.Write([]byte("ld")) + + buffer.Read(i) + b, err = buffer.Read(j - i) + if err != nil { + t.Fatal(err) + } + if string(b) != string([]byte(str)[i:j]) { + t.Fatalf("[%v:%v] %v != %v", i, j, string(b), string([]byte(str)[i:j])) + } + + buffer.Reset() + } + } + + buffer.Append([]byte("hello")) + buffer.Append([]byte(" world")) + if string(buffer.buffers[0]) != "hello world" { + t.Fatal(string(buffer.buffers[0])) + } + b, err = buffer.ReadAll() + if err != nil { + t.Fatal(err) + } + if string(b) != "hello world" { + t.Fatal(string(b)) + } + + buffer.Reset() + + buffer.Push([]byte("hello ")) + buffer.Push([]byte("world")) + if string(buffer.buffers[0]) != "hello " { + t.Fatal(string(buffer.buffers[0])) + } + buffer.Pop(1) + if string(buffer.buffers[0]) != "ello " { + t.Fatal(string(buffer.buffers[0])) + } + buffer.Pop(5) + if string(buffer.buffers[0]) != "world" { + t.Fatal(string(buffer.buffers[0])) + } + buffer.ReadAll() + if len(buffer.buffers) != 0 { + t.Fatal(string(buffer.buffers[0])) + } +} diff --git a/vendor/github.com/lesismal/llib/bytes/pool.go b/vendor/github.com/lesismal/llib/bytes/pool.go new file mode 100644 index 0000000..3a6cb6d --- /dev/null +++ b/vendor/github.com/lesismal/llib/bytes/pool.go @@ -0,0 +1,84 @@ +// Copyright 2020 lesismal. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package bytes + +import ( + "sync" +) + +// maxAppendSize represents the max size to append to a slice. +const maxAppendSize = 1024 * 1024 * 4 + +// Pool is the default instance of []byte pool. +// User can customize a Pool implementation and reset this instance if needed. +var Pool interface { + Get() []byte + GetN(size int) []byte + Put(b []byte) +} = NewPool(64) + +// bufferPool is a default implementatiion of []byte Pool. +type bufferPool struct { + sync.Pool + MinSize int +} + +// NewPool creates and returns a bufferPool instance. +// All slice created by this instance has an initial cap of minSize. +func NewPool(minSize int) *bufferPool { + if minSize <= 0 { + minSize = 64 + } + bp := &bufferPool{ + MinSize: minSize, + } + bp.Pool.New = func() interface{} { + buf := make([]byte, bp.MinSize) + return &buf + } + return bp +} + +// Get gets a slice from the pool and returns it with length 0. +// User can append the slice and should Put it back to the pool after being used over. +func (bp *bufferPool) Get() []byte { + pbuf := bp.Pool.Get().(*[]byte) + return (*pbuf)[0:0] +} + +// GetN returns a slice with length size. +// To reuse slices as possible, +// if the cap of the slice got from the pool is not enough, +// It will append the slice, +// or put the slice back to the pool and create a new slice with cap of size. +// +// User can use the slice both by the size or append it, +// and should Put it back to the pool after being used over. +func (bp *bufferPool) GetN(size int) []byte { + pbuf := bp.Pool.Get().(*[]byte) + need := size - cap(*pbuf) + if need > 0 { + if need <= maxAppendSize { + *pbuf = (*pbuf)[:cap(*pbuf)] + *pbuf = append(*pbuf, make([]byte, need)...) + } else { + bp.Pool.Put(pbuf) + newBuf := make([]byte, size) + pbuf = &newBuf + } + } + + return (*pbuf)[:size] +} + +// Put puts a slice back to the pool. +// If the slice's cap is smaller than MinSize, +// it will not be put back to the pool but dropped. +func (bp *bufferPool) Put(b []byte) { + if cap(b) < bp.MinSize { + return + } + bp.Pool.Put(&b) +} diff --git a/vendor/github.com/lesismal/llib/bytes/pool_test.go b/vendor/github.com/lesismal/llib/bytes/pool_test.go new file mode 100644 index 0000000..bf4e987 --- /dev/null +++ b/vendor/github.com/lesismal/llib/bytes/pool_test.go @@ -0,0 +1,22 @@ +package bytes + +import "testing" + +func TestMemPool(t *testing.T) { + const minMemSize = 64 + pool := NewPool(minMemSize) + for i := 0; i < 1024*1024; i++ { + buf := pool.GetN(i) + if len(buf) != i { + t.Fatalf("invalid length: %v != %v", len(buf), i) + } + pool.Put(buf) + } + for i := 1024 * 1024; i < 1024*1024*1024; i += 1024 * 1024 { + buf := pool.GetN(i) + if len(buf) != i { + t.Fatalf("invalid length: %v != %v", len(buf), i) + } + pool.Put(buf) + } +} diff --git a/vendor/github.com/lesismal/llib/concurrent/batch.go b/vendor/github.com/lesismal/llib/concurrent/batch.go new file mode 100644 index 0000000..5ec0244 --- /dev/null +++ b/vendor/github.com/lesismal/llib/concurrent/batch.go @@ -0,0 +1,60 @@ +// Copyright 2020 lesismal. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package concurrent + +import ( + "sync" +) + +var ( + _defaultBatch = NewBatch() +) + +type call struct { + mux sync.RWMutex + ret interface{} + err error +} + +// Batch . +type Batch struct { + _mux sync.Mutex + _callings map[interface{}]*call +} + +// Do . +func (o *Batch) Do(key interface{}, f func() (interface{}, error)) (interface{}, error) { + o._mux.Lock() + c, ok := o._callings[key] + if ok { + o._mux.Unlock() + c.mux.RLock() + c.mux.RUnlock() + return c.ret, c.err + } + + c = &call{} + c.mux.Lock() + o._callings[key] = c + o._mux.Unlock() + c.ret, c.err = f() + c.mux.Unlock() + + o._mux.Lock() + delete(o._callings, key) + o._mux.Unlock() + + return c.ret, c.err +} + +// NewBatch . +func NewBatch() *Batch { + return &Batch{_callings: map[interface{}]*call{}} +} + +// Do . +func Do(key interface{}, f func() (interface{}, error)) (interface{}, error) { + return _defaultBatch.Do(key, f) +} diff --git a/vendor/github.com/lesismal/llib/concurrent/batch_test.go b/vendor/github.com/lesismal/llib/concurrent/batch_test.go new file mode 100644 index 0000000..a803d4f --- /dev/null +++ b/vendor/github.com/lesismal/llib/concurrent/batch_test.go @@ -0,0 +1,34 @@ +// Copyright 2020 lesismal. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package concurrent + +import ( + "log" + "testing" + "time" +) + +func TestBatch(t *testing.T) { + batchCall := func() (interface{}, error) { + time.Sleep(time.Second) + return time.Now().Format("2006/01/02 15:04:05.000"), nil + } + for i := 0; i < 10; i++ { + go func(id int) { + ret, err := Do(3, batchCall) + log.Println("Batch().Do():", id, ret, err) + }(2) + } + func(id int) { + ret, err := Do(3, batchCall) + log.Println("Batch().Do():", id, ret, err) + }(1) + + func(id int) { + ret, err := Do(3, batchCall) + log.Println("Batch().Do():", id, ret, err) + }(3) + time.Sleep(time.Second) +} diff --git a/vendor/github.com/lesismal/llib/concurrent/map.go b/vendor/github.com/lesismal/llib/concurrent/map.go new file mode 100644 index 0000000..cf4aa0b --- /dev/null +++ b/vendor/github.com/lesismal/llib/concurrent/map.go @@ -0,0 +1,100 @@ +package concurrent + +import ( + "sync" + "sync/atomic" + + "github.com/cespare/xxhash" +) + +type bucket struct { + mux sync.RWMutex + values map[string]interface{} +} + +func (b *bucket) Get(k string) (interface{}, bool) { + b.mux.RLock() + v, ok := b.values[k] + b.mux.RUnlock() + return v, ok +} + +func (b *bucket) Set(k string, v interface{}) bool { + b.mux.Lock() + _, exsist := b.values[k] + b.values[k] = v + b.mux.Unlock() + return !exsist +} + +func (b *bucket) Delete(k string) bool { + b.mux.Lock() + _, exsist := b.values[k] + delete(b.values, k) + b.mux.Unlock() + return exsist +} + +func (b *bucket) forEach(f func(k string, v interface{}) bool) bool { + success := false + b.mux.RLock() + for k, v := range b.values { + success = f(k, v) + if !success { + break + } + } + b.mux.RUnlock() + return success +} + +type Map struct { + size int64 + buckets []*bucket +} + +func (m *Map) Get(k string) (interface{}, bool) { + i := hash(k) % uint64(len(m.buckets)) + return m.buckets[i].Get(k) +} + +func (m *Map) Set(k string, v interface{}) { + i := hash(k) % uint64(len(m.buckets)) + if m.buckets[i].Set(k, v) { + atomic.AddInt64(&m.size, 1) + } +} + +func (m *Map) Delete(k string) { + i := hash(k) % uint64(len(m.buckets)) + if m.buckets[i].Delete(k) { + atomic.AddInt64(&m.size, -1) + } +} + +func (m *Map) Size() int64 { + return atomic.LoadInt64(&m.size) +} + +func (m *Map) ForEach(f func(k string, v interface{}) bool) { + for _, b := range m.buckets { + if !b.forEach(f) { + return + } + } +} + +func NewMap(bucketNum int) *Map { + if bucketNum <= 0 { + bucketNum = 64 + } + m := &Map{buckets: make([]*bucket, bucketNum)} + for i := 0; i < bucketNum; i++ { + m.buckets[i] = &bucket{values: map[string]interface{}{}} + } + return m +} + +func hash(k string) uint64 { + return xxhash.Sum64String(k) +} diff --git a/vendor/github.com/lesismal/llib/concurrent/map_test.go b/vendor/github.com/lesismal/llib/concurrent/map_test.go new file mode 100644 index 0000000..c3d67c1 --- /dev/null +++ b/vendor/github.com/lesismal/llib/concurrent/map_test.go @@ -0,0 +1,109 @@ +// Copyright 2020 lesismal. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package concurrent + +import ( + "fmt" + "log" + "testing" +) + +func TestMap(t *testing.T) { + m := NewMap(64) + size := 100000 + for i := 0; i < size; i++ { + k := fmt.Sprintf("key_%d", i) + v := fmt.Sprintf("value_%d", i) + vv, ok := m.Get(k) + if ok { + log.Fatalf("[%v] exists: '%v'", k, vv) + } + m.Set(k, v) + vv, ok = m.Get(k) + if !ok { + log.Fatalf("[%v] does not exist: '%v'", k, vv) + } + if v != vv { + log.Fatalf("invalid value: '%v' for key [%v] ", vv, k) + } + } + cnt := 0 + m.ForEach(func(k string, v interface{}) bool { + if k[3:] != (v.(string))[5:] { + log.Fatalf("invalid key-value: '%v', '%v'", k, v) + } + cnt++ + return true + }) + if cnt != size { + log.Fatalf("invalid ForEach num: %v, want: %v", cnt, size) + } + if m.Size() != int64(size) { + log.Fatalf("invalid size: %v, want: %v", m.Size(), size) + } + for i := 0; i < size; i++ { + k := fmt.Sprintf("key_%d", i) + v := fmt.Sprintf("value_%d", i) + vv, ok := m.Get(k) + if !ok { + log.Fatalf("[%v] does not exist: '%v'", k, vv) + } + if v != vv { + log.Fatalf("invalid value: '%v' for key [%v]", vv, k) + } + m.Delete(k) + if m.Size() != int64(size-i-1) { + log.Fatalf("invalid size: %v, want: %v", m.Size(), int64(size-i-1)) + } + } + for i := 0; i < size; i++ { + k := fmt.Sprintf("key_%d", i) + vv, ok := m.Get(k) + if ok { + log.Fatalf("[%v] exists: '%v'", k, vv) + } + if m.Size() != 0 { + log.Fatalf("invalid size: %v, want: %v", m.Size(), 0) + } + } +} + +func BenchmarkMapSet(b *testing.B) { + m := NewMap(64) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + k := fmt.Sprintf("key_%d", i) + v := fmt.Sprintf("value_%d", i) + m.Set(k, v) + } +} + +func BenchmarkMapGet(b *testing.B) { + m := NewMap(64) + + for i := 0; i < b.N; i++ { + k := fmt.Sprintf("key_%d", i) + v := fmt.Sprintf("value_%d", i) + m.Set(k, v) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + k := fmt.Sprintf("key_%d", i) + v := fmt.Sprintf("value_%d", i) + vv, ok := m.Get(k) + if !ok { + log.Fatalf("[%v] does not exist: '%v'", k, vv) + } + if v != vv { + log.Fatalf("invalid value: '%v' for key [%v], want: %v", vv, k, v) + } + } +} diff --git a/vendor/github.com/lesismal/llib/concurrent/mutex.go b/vendor/github.com/lesismal/llib/concurrent/mutex.go new file mode 100644 index 0000000..886b7b6 --- /dev/null +++ b/vendor/github.com/lesismal/llib/concurrent/mutex.go @@ -0,0 +1,56 @@ +// Copyright 2020 lesismal. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package concurrent + +import ( + "sync" +) + +var ( + _defaultMux = NewMutex() +) + +// Mutex . +type Mutex struct { + _mux sync.Mutex + _muxes map[interface{}]*sync.Mutex +} + +// Lock . +func (m *Mutex) Lock(key interface{}) { + m._mux.Lock() + mux, ok := m._muxes[key] + if !ok { + mux = &sync.Mutex{} + m._muxes[key] = mux + } + m._mux.Unlock() + mux.Lock() +} + +// Unlock . +func (m *Mutex) Unlock(key interface{}) { + m._mux.Lock() + mux, ok := m._muxes[key] + m._mux.Unlock() + if ok { + mux.Unlock() + } +} + +// NewMutex . +func NewMutex() *Mutex { + return &Mutex{_muxes: map[interface{}]*sync.Mutex{}} +} + +// // Lock . +// func Lock(key interface{}) { +// _defaultMux.Lock(key) +// } + +// // Unlock . +// func Unlock(key interface{}) { +// _defaultMux.Unlock(key) +// } diff --git a/vendor/github.com/lesismal/llib/concurrent/mutex_test.go b/vendor/github.com/lesismal/llib/concurrent/mutex_test.go new file mode 100644 index 0000000..10564d3 --- /dev/null +++ b/vendor/github.com/lesismal/llib/concurrent/mutex_test.go @@ -0,0 +1,26 @@ +// Copyright 2020 lesismal. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package concurrent + +import ( + "log" + "testing" + "time" +) + +func TestMutex(t *testing.T) { + mux := NewMutex() + muxPrint := func(id int) { + for i := 0; i < 3; i++ { + mux.Lock(1) + time.Sleep(time.Second / 100) + log.Println("mux print:", id, i) + mux.Unlock(1) + } + } + go muxPrint(2) + muxPrint(1) + time.Sleep(time.Second / 10) +} diff --git a/vendor/github.com/lesismal/llib/concurrent/rwmutex_test.go b/vendor/github.com/lesismal/llib/concurrent/rwmutex_test.go new file mode 100644 index 0000000..23bd5ec --- /dev/null +++ b/vendor/github.com/lesismal/llib/concurrent/rwmutex_test.go @@ -0,0 +1,38 @@ +// Copyright 2020 lesismal. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package concurrent + +import ( + "log" + "testing" + "time" +) + +func TestRWMutex(t *testing.T) { + rwmux := NewRWMutex() + rwmuxRLockPrint := func(id int) { + for i := 0; i < 3; i++ { + rwmux.RLock(2) + time.Sleep(time.Second / 100) + log.Println("rwmux print:", id, i) + rwmux.RUnlock(2) + } + } + go rwmuxRLockPrint(2) + rwmuxRLockPrint(1) + + rwmuxLockPrint := func(id int) { + for i := 0; i < 3; i++ { + rwmux.Lock(2) + time.Sleep(time.Second / 100) + log.Println("rwmux print:", id, i) + rwmux.Unlock(2) + } + } + go rwmuxLockPrint(2) + rwmuxLockPrint(1) + + time.Sleep(time.Second / 10) +} diff --git a/vendor/github.com/lesismal/llib/concurrent/rwmutext.go b/vendor/github.com/lesismal/llib/concurrent/rwmutext.go new file mode 100644 index 0000000..30d94db --- /dev/null +++ b/vendor/github.com/lesismal/llib/concurrent/rwmutext.go @@ -0,0 +1,88 @@ +// Copyright 2020 lesismal. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package concurrent + +import ( + "sync" +) + +var ( + _defaultRWMux = NewRWMutex() +) + +// RWMutex . +type RWMutex struct { + _mux sync.Mutex + _rwmuxes map[interface{}]*sync.RWMutex +} + +// Lock . +func (m *RWMutex) Lock(key interface{}) { + m._mux.Lock() + mux, ok := m._rwmuxes[key] + if !ok { + mux = &sync.RWMutex{} + m._rwmuxes[key] = mux + } + m._mux.Unlock() + mux.Lock() +} + +// Unlock . +func (m *RWMutex) Unlock(key interface{}) { + m._mux.Lock() + mux, ok := m._rwmuxes[key] + m._mux.Unlock() + if ok { + mux.Unlock() + } +} + +// RLock . +func (m *RWMutex) RLock(key interface{}) { + m._mux.Lock() + mux, ok := m._rwmuxes[key] + if !ok { + mux = &sync.RWMutex{} + m._rwmuxes[key] = mux + } + m._mux.Unlock() + mux.RLock() +} + +// RUnlock . +func (m *RWMutex) RUnlock(key interface{}) { + m._mux.Lock() + mux, ok := m._rwmuxes[key] + m._mux.Unlock() + if ok { + mux.RUnlock() + } +} + +// NewRWMutex . +func NewRWMutex() *RWMutex { + return &RWMutex{_rwmuxes: map[interface{}]*sync.RWMutex{}} +} + +// Lock . +func Lock(key interface{}) { + _defaultRWMux.Lock(key) +} + +// Unlock . +func Unlock(key interface{}) { + _defaultRWMux.Unlock(key) +} + +// RLock . +func RLock(key interface{}) { + _defaultRWMux.RLock(key) +} + +// RUnlock . +func RUnlock(key interface{}) { + _defaultRWMux.RUnlock(key) +} diff --git a/vendor/github.com/lesismal/llib/go.mod b/vendor/github.com/lesismal/llib/go.mod new file mode 100644 index 0000000..7fa6efb --- /dev/null +++ b/vendor/github.com/lesismal/llib/go.mod @@ -0,0 +1,9 @@ +module github.com/lesismal/llib + +go 1.16 + +require ( + github.com/cespare/xxhash v1.1.0 // indirect + golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5 + golang.org/x/net v0.0.0-20210510120150-4163338589ed +) diff --git a/vendor/github.com/lesismal/llib/go.sum b/vendor/github.com/lesismal/llib/go.sum new file mode 100644 index 0000000..2969d96 --- /dev/null +++ b/vendor/github.com/lesismal/llib/go.sum @@ -0,0 +1,17 @@ +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5 h1:N6Jp/LCiEoIBX56BZSR2bepK5GtbSC2DDOYT742mMfE= +golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210510120150-4163338589ed h1:p9UgmWI9wKpfYmgaV/IZKGdXc5qEK45tDwwwDyjS26I= +golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/alert.go b/vendor/github.com/lesismal/llib/std/crypto/tls/alert.go new file mode 100644 index 0000000..4790b73 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/alert.go @@ -0,0 +1,99 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import "strconv" + +type alert uint8 + +const ( + // alert level + alertLevelWarning = 1 + alertLevelError = 2 +) + +const ( + alertCloseNotify alert = 0 + alertUnexpectedMessage alert = 10 + alertBadRecordMAC alert = 20 + alertDecryptionFailed alert = 21 + alertRecordOverflow alert = 22 + alertDecompressionFailure alert = 30 + alertHandshakeFailure alert = 40 + alertBadCertificate alert = 42 + alertUnsupportedCertificate alert = 43 + alertCertificateRevoked alert = 44 + alertCertificateExpired alert = 45 + alertCertificateUnknown alert = 46 + alertIllegalParameter alert = 47 + alertUnknownCA alert = 48 + alertAccessDenied alert = 49 + alertDecodeError alert = 50 + alertDecryptError alert = 51 + alertExportRestriction alert = 60 + alertProtocolVersion alert = 70 + alertInsufficientSecurity alert = 71 + alertInternalError alert = 80 + alertInappropriateFallback alert = 86 + alertUserCanceled alert = 90 + alertNoRenegotiation alert = 100 + alertMissingExtension alert = 109 + alertUnsupportedExtension alert = 110 + alertCertificateUnobtainable alert = 111 + alertUnrecognizedName alert = 112 + alertBadCertificateStatusResponse alert = 113 + alertBadCertificateHashValue alert = 114 + alertUnknownPSKIdentity alert = 115 + alertCertificateRequired alert = 116 + alertNoApplicationProtocol alert = 120 +) + +var alertText = map[alert]string{ + alertCloseNotify: "close notify", + alertUnexpectedMessage: "unexpected message", + alertBadRecordMAC: "bad record MAC", + alertDecryptionFailed: "decryption failed", + alertRecordOverflow: "record overflow", + alertDecompressionFailure: "decompression failure", + alertHandshakeFailure: "handshake failure", + alertBadCertificate: "bad certificate", + alertUnsupportedCertificate: "unsupported certificate", + alertCertificateRevoked: "revoked certificate", + alertCertificateExpired: "expired certificate", + alertCertificateUnknown: "unknown certificate", + alertIllegalParameter: "illegal parameter", + alertUnknownCA: "unknown certificate authority", + alertAccessDenied: "access denied", + alertDecodeError: "error decoding message", + alertDecryptError: "error decrypting message", + alertExportRestriction: "export restriction", + alertProtocolVersion: "protocol version not supported", + alertInsufficientSecurity: "insufficient security level", + alertInternalError: "internal error", + alertInappropriateFallback: "inappropriate fallback", + alertUserCanceled: "user canceled", + alertNoRenegotiation: "no renegotiation", + alertMissingExtension: "missing extension", + alertUnsupportedExtension: "unsupported extension", + alertCertificateUnobtainable: "certificate unobtainable", + alertUnrecognizedName: "unrecognized name", + alertBadCertificateStatusResponse: "bad certificate status response", + alertBadCertificateHashValue: "bad certificate hash value", + alertUnknownPSKIdentity: "unknown PSK identity", + alertCertificateRequired: "certificate required", + alertNoApplicationProtocol: "no application protocol", +} + +func (e alert) String() string { + s, ok := alertText[e] + if ok { + return "tls: " + s + } + return "tls: alert(" + strconv.Itoa(int(e)) + ")" +} + +func (e alert) Error() string { + return e.String() +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/auth.go b/vendor/github.com/lesismal/llib/std/crypto/tls/auth.go new file mode 100644 index 0000000..a9df0da --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/auth.go @@ -0,0 +1,289 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rsa" + "errors" + "fmt" + "hash" + "io" +) + +// verifyHandshakeSignature verifies a signature against pre-hashed +// (if required) handshake contents. +func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error { + switch sigType { + case signatureECDSA: + pubKey, ok := pubkey.(*ecdsa.PublicKey) + if !ok { + return fmt.Errorf("expected an ECDSA public key, got %T", pubkey) + } + if !ecdsa.VerifyASN1(pubKey, signed, sig) { + return errors.New("ECDSA verification failure") + } + case signatureEd25519: + pubKey, ok := pubkey.(ed25519.PublicKey) + if !ok { + return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey) + } + if !ed25519.Verify(pubKey, signed, sig) { + return errors.New("Ed25519 verification failure") + } + case signaturePKCS1v15: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil { + return err + } + case signatureRSAPSS: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} + if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil { + return err + } + default: + return errors.New("internal error: unknown signature type") + } + return nil +} + +const ( + serverSignatureContext = "TLS 1.3, server CertificateVerify\x00" + clientSignatureContext = "TLS 1.3, client CertificateVerify\x00" +) + +var signaturePadding = []byte{ + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, +} + +// signedMessage returns the pre-hashed (if necessary) message to be signed by +// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3. +func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte { + if sigHash == directSigning { + b := &bytes.Buffer{} + b.Write(signaturePadding) + io.WriteString(b, context) + b.Write(transcript.Sum(nil)) + return b.Bytes() + } + h := sigHash.New() + h.Write(signaturePadding) + io.WriteString(h, context) + h.Write(transcript.Sum(nil)) + return h.Sum(nil) +} + +// typeAndHashFromSignatureScheme returns the corresponding signature type and +// crypto.Hash for a given TLS SignatureScheme. +func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) { + switch signatureAlgorithm { + case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512: + sigType = signaturePKCS1v15 + case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512: + sigType = signatureRSAPSS + case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512: + sigType = signatureECDSA + case Ed25519: + sigType = signatureEd25519 + default: + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + switch signatureAlgorithm { + case PKCS1WithSHA1, ECDSAWithSHA1: + hash = crypto.SHA1 + case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256: + hash = crypto.SHA256 + case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384: + hash = crypto.SHA384 + case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512: + hash = crypto.SHA512 + case Ed25519: + hash = directSigning + default: + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + return sigType, hash, nil +} + +// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for +// a given public key used with TLS 1.0 and 1.1, before the introduction of +// signature algorithm negotiation. +func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) { + switch pub.(type) { + case *rsa.PublicKey: + return signaturePKCS1v15, crypto.MD5SHA1, nil + case *ecdsa.PublicKey: + return signatureECDSA, crypto.SHA1, nil + case ed25519.PublicKey: + // RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1, + // but it requires holding on to a handshake transcript to do a + // full signature, and not even OpenSSL bothers with the + // complexity, so we can't even test it properly. + return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2") + default: + return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub) + } +} + +var rsaSignatureSchemes = []struct { + scheme SignatureScheme + minModulusBytes int + maxVersion uint16 +}{ + // RSA-PSS is used with PSSSaltLengthEqualsHash, and requires + // emLen >= hLen + sLen + 2 + {PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13}, + // PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires + // emLen >= len(prefix) + hLen + 11 + // TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS. + {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12}, + {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12}, + {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12}, + {PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12}, +} + +// signatureSchemesForCertificate returns the list of supported SignatureSchemes +// for a given certificate, based on the public key and the protocol version, +// and optionally filtered by its explicit SupportedSignatureAlgorithms. +// +// This function must be kept in sync with supportedSignatureAlgorithms. +func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { + priv, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil + } + + var sigAlgs []SignatureScheme + switch pub := priv.Public().(type) { + case *ecdsa.PublicKey: + if version != VersionTLS13 { + // In TLS 1.2 and earlier, ECDSA algorithms are not + // constrained to a single curve. + sigAlgs = []SignatureScheme{ + ECDSAWithP256AndSHA256, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512, + ECDSAWithSHA1, + } + break + } + switch pub.Curve { + case elliptic.P256(): + sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256} + case elliptic.P384(): + sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384} + case elliptic.P521(): + sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512} + default: + return nil + } + case *rsa.PublicKey: + size := pub.Size() + sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes)) + for _, candidate := range rsaSignatureSchemes { + if size >= candidate.minModulusBytes && version <= candidate.maxVersion { + sigAlgs = append(sigAlgs, candidate.scheme) + } + } + case ed25519.PublicKey: + sigAlgs = []SignatureScheme{Ed25519} + default: + return nil + } + + if cert.SupportedSignatureAlgorithms != nil { + var filteredSigAlgs []SignatureScheme + for _, sigAlg := range sigAlgs { + if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) { + filteredSigAlgs = append(filteredSigAlgs, sigAlg) + } + } + return filteredSigAlgs + } + return sigAlgs +} + +// selectSignatureScheme picks a SignatureScheme from the peer's preference list +// that works with the selected certificate. It's only called for protocol +// versions that support signature algorithms, so TLS 1.2 and 1.3. +func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) { + supportedAlgs := signatureSchemesForCertificate(vers, c) + if len(supportedAlgs) == 0 { + return 0, unsupportedCertificateError(c) + } + if len(peerAlgs) == 0 && vers == VersionTLS12 { + // For TLS 1.2, if the client didn't send signature_algorithms then we + // can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1. + peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1} + } + // Pick signature scheme in the peer's preference order, as our + // preference order is not configurable. + for _, preferredAlg := range peerAlgs { + if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) { + return preferredAlg, nil + } + } + return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms") +} + +// unsupportedCertificateError returns a helpful error for certificates with +// an unsupported private key. +func unsupportedCertificateError(cert *Certificate) error { + switch cert.PrivateKey.(type) { + case rsa.PrivateKey, ecdsa.PrivateKey: + return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T", + cert.PrivateKey, cert.PrivateKey) + case *ed25519.PrivateKey: + return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey") + } + + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer", + cert.PrivateKey) + } + + switch pub := signer.Public().(type) { + case *ecdsa.PublicKey: + switch pub.Curve { + case elliptic.P256(): + case elliptic.P384(): + case elliptic.P521(): + default: + return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name) + } + case *rsa.PublicKey: + return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms") + case ed25519.PublicKey: + default: + return fmt.Errorf("tls: unsupported certificate key (%T)", pub) + } + + if cert.SupportedSignatureAlgorithms != nil { + return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms") + } + + return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey) +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/auth_test.go b/vendor/github.com/lesismal/llib/std/crypto/tls/auth_test.go new file mode 100644 index 0000000..c42e349 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/auth_test.go @@ -0,0 +1,168 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto" + "testing" +) + +func TestSignatureSelection(t *testing.T) { + rsaCert := &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + } + pkcs1Cert := &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}, + } + ecdsaCert := &Certificate{ + Certificate: [][]byte{testP256Certificate}, + PrivateKey: testP256PrivateKey, + } + ed25519Cert := &Certificate{ + Certificate: [][]byte{testEd25519Certificate}, + PrivateKey: testEd25519PrivateKey, + } + + tests := []struct { + cert *Certificate + peerSigAlgs []SignatureScheme + tlsVersion uint16 + + expectedSigAlg SignatureScheme + expectedSigType uint8 + expectedHash crypto.Hash + }{ + {rsaCert, []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}, VersionTLS12, PKCS1WithSHA1, signaturePKCS1v15, crypto.SHA1}, + {rsaCert, []SignatureScheme{PKCS1WithSHA512, PKCS1WithSHA1}, VersionTLS12, PKCS1WithSHA512, signaturePKCS1v15, crypto.SHA512}, + {rsaCert, []SignatureScheme{PSSWithSHA256, PKCS1WithSHA256}, VersionTLS12, PSSWithSHA256, signatureRSAPSS, crypto.SHA256}, + {pkcs1Cert, []SignatureScheme{PSSWithSHA256, PKCS1WithSHA256}, VersionTLS12, PKCS1WithSHA256, signaturePKCS1v15, crypto.SHA256}, + {rsaCert, []SignatureScheme{PSSWithSHA384, PKCS1WithSHA1}, VersionTLS13, PSSWithSHA384, signatureRSAPSS, crypto.SHA384}, + {ecdsaCert, []SignatureScheme{ECDSAWithSHA1}, VersionTLS12, ECDSAWithSHA1, signatureECDSA, crypto.SHA1}, + {ecdsaCert, []SignatureScheme{ECDSAWithP256AndSHA256}, VersionTLS12, ECDSAWithP256AndSHA256, signatureECDSA, crypto.SHA256}, + {ecdsaCert, []SignatureScheme{ECDSAWithP256AndSHA256}, VersionTLS13, ECDSAWithP256AndSHA256, signatureECDSA, crypto.SHA256}, + {ed25519Cert, []SignatureScheme{Ed25519}, VersionTLS12, Ed25519, signatureEd25519, directSigning}, + {ed25519Cert, []SignatureScheme{Ed25519}, VersionTLS13, Ed25519, signatureEd25519, directSigning}, + + // TLS 1.2 without signature_algorithms extension + {rsaCert, nil, VersionTLS12, PKCS1WithSHA1, signaturePKCS1v15, crypto.SHA1}, + {ecdsaCert, nil, VersionTLS12, ECDSAWithSHA1, signatureECDSA, crypto.SHA1}, + + // TLS 1.2 does not restrict the ECDSA curve (our ecdsaCert is P-256) + {ecdsaCert, []SignatureScheme{ECDSAWithP384AndSHA384}, VersionTLS12, ECDSAWithP384AndSHA384, signatureECDSA, crypto.SHA384}, + } + + for testNo, test := range tests { + sigAlg, err := selectSignatureScheme(test.tlsVersion, test.cert, test.peerSigAlgs) + if err != nil { + t.Errorf("test[%d]: unexpected selectSignatureScheme error: %v", testNo, err) + } + if test.expectedSigAlg != sigAlg { + t.Errorf("test[%d]: expected signature scheme %v, got %v", testNo, test.expectedSigAlg, sigAlg) + } + sigType, hashFunc, err := typeAndHashFromSignatureScheme(sigAlg) + if err != nil { + t.Errorf("test[%d]: unexpected typeAndHashFromSignatureScheme error: %v", testNo, err) + } + if test.expectedSigType != sigType { + t.Errorf("test[%d]: expected signature algorithm %#x, got %#x", testNo, test.expectedSigType, sigType) + } + if test.expectedHash != hashFunc { + t.Errorf("test[%d]: expected hash function %#x, got %#x", testNo, test.expectedHash, hashFunc) + } + } + + brokenCert := &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + SupportedSignatureAlgorithms: []SignatureScheme{Ed25519}, + } + + badTests := []struct { + cert *Certificate + peerSigAlgs []SignatureScheme + tlsVersion uint16 + }{ + {rsaCert, []SignatureScheme{ECDSAWithP256AndSHA256, ECDSAWithSHA1}, VersionTLS12}, + {ecdsaCert, []SignatureScheme{PKCS1WithSHA256, PKCS1WithSHA1}, VersionTLS12}, + {rsaCert, []SignatureScheme{0}, VersionTLS12}, + {ed25519Cert, []SignatureScheme{ECDSAWithP256AndSHA256, ECDSAWithSHA1}, VersionTLS12}, + {ecdsaCert, []SignatureScheme{Ed25519}, VersionTLS12}, + {brokenCert, []SignatureScheme{Ed25519}, VersionTLS12}, + {brokenCert, []SignatureScheme{PKCS1WithSHA256}, VersionTLS12}, + // RFC 5246, Section 7.4.1.4.1, says to only consider {sha1,ecdsa} as + // default when the extension is missing, and RFC 8422 does not update + // it. Anyway, if a stack supports Ed25519 it better support sigalgs. + {ed25519Cert, nil, VersionTLS12}, + // TLS 1.3 has no default signature_algorithms. + {rsaCert, nil, VersionTLS13}, + {ecdsaCert, nil, VersionTLS13}, + {ed25519Cert, nil, VersionTLS13}, + // Wrong curve, which TLS 1.3 checks + {ecdsaCert, []SignatureScheme{ECDSAWithP384AndSHA384}, VersionTLS13}, + // TLS 1.3 does not support PKCS1v1.5 or SHA-1. + {rsaCert, []SignatureScheme{PKCS1WithSHA256}, VersionTLS13}, + {pkcs1Cert, []SignatureScheme{PSSWithSHA256, PKCS1WithSHA256}, VersionTLS13}, + {ecdsaCert, []SignatureScheme{ECDSAWithSHA1}, VersionTLS13}, + // The key can be too small for the hash. + {rsaCert, []SignatureScheme{PSSWithSHA512}, VersionTLS12}, + } + + for testNo, test := range badTests { + sigAlg, err := selectSignatureScheme(test.tlsVersion, test.cert, test.peerSigAlgs) + if err == nil { + t.Errorf("test[%d]: unexpected success, got %v", testNo, sigAlg) + } + } +} + +func TestLegacyTypeAndHash(t *testing.T) { + sigType, hashFunc, err := legacyTypeAndHashFromPublicKey(testRSAPrivateKey.Public()) + if err != nil { + t.Errorf("RSA: unexpected error: %v", err) + } + if expectedSigType := signaturePKCS1v15; expectedSigType != sigType { + t.Errorf("RSA: expected signature type %#x, got %#x", expectedSigType, sigType) + } + if expectedHashFunc := crypto.MD5SHA1; expectedHashFunc != hashFunc { + t.Errorf("RSA: expected hash %#x, got %#x", expectedHashFunc, hashFunc) + } + + sigType, hashFunc, err = legacyTypeAndHashFromPublicKey(testECDSAPrivateKey.Public()) + if err != nil { + t.Errorf("ECDSA: unexpected error: %v", err) + } + if expectedSigType := signatureECDSA; expectedSigType != sigType { + t.Errorf("ECDSA: expected signature type %#x, got %#x", expectedSigType, sigType) + } + if expectedHashFunc := crypto.SHA1; expectedHashFunc != hashFunc { + t.Errorf("ECDSA: expected hash %#x, got %#x", expectedHashFunc, hashFunc) + } + + // Ed25519 is not supported by TLS 1.0 and 1.1. + _, _, err = legacyTypeAndHashFromPublicKey(testEd25519PrivateKey.Public()) + if err == nil { + t.Errorf("Ed25519: unexpected success") + } +} + +// TestSupportedSignatureAlgorithms checks that all supportedSignatureAlgorithms +// have valid type and hash information. +func TestSupportedSignatureAlgorithms(t *testing.T) { + for _, sigAlg := range supportedSignatureAlgorithms { + sigType, hash, err := typeAndHashFromSignatureScheme(sigAlg) + if err != nil { + t.Errorf("%v: unexpected error: %v", sigAlg, err) + } + if sigType == 0 { + t.Errorf("%v: missing signature type", sigAlg) + } + if hash == 0 && sigAlg != Ed25519 { + t.Errorf("%v: missing hash", sigAlg) + } + } +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/cipher_suites.go b/vendor/github.com/lesismal/llib/std/crypto/tls/cipher_suites.go new file mode 100644 index 0000000..9a35675 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/cipher_suites.go @@ -0,0 +1,516 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/hmac" + "crypto/rc4" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "fmt" + "hash" + + "golang.org/x/crypto/chacha20poly1305" +) + +// CipherSuite is a TLS cipher suite. Note that most functions in this package +// accept and expose cipher suite IDs instead of this type. +type CipherSuite struct { + ID uint16 + Name string + + // Supported versions is the list of TLS protocol versions that can + // negotiate this cipher suite. + SupportedVersions []uint16 + + // Insecure is true if the cipher suite has known security issues + // due to its primitives, design, or implementation. + Insecure bool +} + +var ( + supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12} + supportedOnlyTLS12 = []uint16{VersionTLS12} + supportedOnlyTLS13 = []uint16{VersionTLS13} +) + +// CipherSuites returns a list of cipher suites currently implemented by this +// package, excluding those with security issues, which are returned by +// InsecureCipherSuites. +// +// The list is sorted by ID. Note that the default cipher suites selected by +// this package might depend on logic that can't be captured by a static list. +func CipherSuites() []*CipherSuite { + return []*CipherSuite{ + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + + {TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false}, + {TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false}, + {TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false}, + + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, + } +} + +// InsecureCipherSuites returns a list of cipher suites currently implemented by +// this package and which have security issues. +// +// Most applications should not use the cipher suites in this list, and should +// only use those returned by CipherSuites. +func InsecureCipherSuites() []*CipherSuite { + // RC4 suites are broken because RC4 is. + // CBC-SHA256 suites have no Lucky13 countermeasures. + return []*CipherSuite{ + {TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + } +} + +// CipherSuiteName returns the standard name for the passed cipher suite ID +// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation +// of the ID value if the cipher suite is not implemented by this package. +func CipherSuiteName(id uint16) string { + for _, c := range CipherSuites() { + if c.ID == id { + return c.Name + } + } + for _, c := range InsecureCipherSuites() { + if c.ID == id { + return c.Name + } + } + return fmt.Sprintf("0x%04X", id) +} + +// a keyAgreement implements the client and server side of a TLS key agreement +// protocol by generating and processing key exchange messages. +type keyAgreement interface { + // On the server side, the first two methods are called in order. + + // In the case that the key agreement protocol doesn't use a + // ServerKeyExchange message, generateServerKeyExchange can return nil, + // nil. + generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) + processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) + + // On the client side, the next two methods are called in order. + + // This method may not be called if the server doesn't send a + // ServerKeyExchange message. + processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error + generateClientKeyExchange(*Config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) +} + +const ( + // suiteECDHE indicates that the cipher suite involves elliptic curve + // Diffie-Hellman. This means that it should only be selected when the + // client indicates that it supports ECC with a curve and point format + // that we're happy with. + suiteECDHE = 1 << iota + // suiteECSign indicates that the cipher suite involves an ECDSA or + // EdDSA signature and therefore may only be selected when the server's + // certificate is ECDSA or EdDSA. If this is not set then the cipher suite + // is RSA based. + suiteECSign + // suiteTLS12 indicates that the cipher suite should only be advertised + // and accepted when using TLS 1.2. + suiteTLS12 + // suiteSHA384 indicates that the cipher suite uses SHA384 as the + // handshake hash. + suiteSHA384 + // suiteDefaultOff indicates that this cipher suite is not included by + // default. + suiteDefaultOff +) + +// A cipherSuite is a specific combination of key agreement, cipher and MAC function. +type cipherSuite struct { + id uint16 + // the lengths, in bytes, of the key material needed for each component. + keyLen int + macLen int + ivLen int + ka func(version uint16) keyAgreement + // flags is a bitmask of the suite* values, above. + flags int + cipher func(key, iv []byte, isRead bool) interface{} + mac func(key []byte) hash.Hash + aead func(key, fixedNonce []byte) aead +} + +var cipherSuites = []*cipherSuite{ + // Ciphersuite order is chosen so that ECDHE comes before plain RSA and + // AEADs are the top preference. + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, + + // RC4-based cipher suites are disabled by default. + {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, suiteDefaultOff, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE | suiteDefaultOff, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteDefaultOff, cipherRC4, macSHA1, nil}, +} + +// selectCipherSuite returns the first cipher suite from ids which is also in +// supportedIDs and passes the ok filter. +func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite { + for _, id := range ids { + candidate := cipherSuiteByID(id) + if candidate == nil || !ok(candidate) { + continue + } + + for _, suppID := range supportedIDs { + if id == suppID { + return candidate + } + } + } + return nil +} + +// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash +// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. +type cipherSuiteTLS13 struct { + id uint16 + keyLen int + aead func(key, fixedNonce []byte) aead + hash crypto.Hash +} + +var cipherSuitesTLS13 = []*cipherSuiteTLS13{ + {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256}, + {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256}, + {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384}, +} + +func cipherRC4(key, iv []byte, isRead bool) interface{} { + cipher, _ := rc4.NewCipher(key) + return cipher +} + +func cipher3DES(key, iv []byte, isRead bool) interface{} { + block, _ := des.NewTripleDESCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +func cipherAES(key, iv []byte, isRead bool) interface{} { + block, _ := aes.NewCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +// macSHA1 returns a SHA-1 based constant time MAC. +func macSHA1(key []byte) hash.Hash { + return hmac.New(newConstantTimeHash(sha1.New), key) +} + +// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and +// is currently only used in disabled-by-default cipher suites. +func macSHA256(key []byte) hash.Hash { + return hmac.New(sha256.New, key) +} + +type aead interface { + cipher.AEAD + + // explicitNonceLen returns the number of bytes of explicit nonce + // included in each record. This is eight for older AEADs and + // zero for modern ones. + explicitNonceLen() int +} + +const ( + aeadNonceLength = 12 + noncePrefixLength = 4 +) + +// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to +// each call. +type prefixNonceAEAD struct { + // nonce contains the fixed part of the nonce in the first four bytes. + nonce [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength } +func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() } + +func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + copy(f.nonce[4:], nonce) + return f.aead.Seal(out, f.nonce[:], plaintext, additionalData) +} + +func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + copy(f.nonce[4:], nonce) + return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) +} + +// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce +// before each call. +type xorNonceAEAD struct { + nonceMask [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number +func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } + +func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result +} + +func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result, err +} + +func aeadAESGCM(key, noncePrefix []byte) aead { + if len(noncePrefix) != noncePrefixLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &prefixNonceAEAD{aead: aead} + copy(ret.nonce[:], noncePrefix) + return ret +} + +func aeadAESGCMTLS13(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +func aeadChaCha20Poly1305(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aead, err := chacha20poly1305.New(key) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +type constantTimeHash interface { + hash.Hash + ConstantTimeSum(b []byte) []byte +} + +// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces +// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. +type cthWrapper struct { + h constantTimeHash +} + +func (c *cthWrapper) Size() int { return c.h.Size() } +func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } +func (c *cthWrapper) Reset() { c.h.Reset() } +func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } +func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } + +func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { + return func() hash.Hash { + return &cthWrapper{h().(constantTimeHash)} + } +} + +// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. +func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte { + h.Reset() + h.Write(seq) + h.Write(header) + h.Write(data) + res := h.Sum(out) + if extra != nil { + h.Write(extra) + } + return res +} + +func rsaKA(version uint16) keyAgreement { + return rsaKeyAgreement{} +} + +func ecdheECDSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: false, + version: version, + } +} + +func ecdheRSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: true, + version: version, + } +} + +// mutualCipherSuite returns a cipherSuite given a list of supported +// ciphersuites and the id requested by the peer. +func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { + for _, id := range have { + if id == want { + return cipherSuiteByID(id) + } + } + return nil +} + +func cipherSuiteByID(id uint16) *cipherSuite { + for _, cipherSuite := range cipherSuites { + if cipherSuite.id == id { + return cipherSuite + } + } + return nil +} + +func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 { + for _, id := range have { + if id == want { + return cipherSuiteTLS13ByID(id) + } + } + return nil +} + +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 { + for _, cipherSuite := range cipherSuitesTLS13 { + if cipherSuite.id == id { + return cipherSuite + } + } + return nil +} + +// A list of cipher suite IDs that are, or have been, implemented by this +// package. +// +// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml +const ( + // TLS 1.0 - 1.2 cipher suites. + TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a + TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f + TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c + TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c + TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009 + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a + TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011 + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013 + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030 + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9 + + // TLS 1.3 cipher suites. + TLS_AES_128_GCM_SHA256 uint16 = 0x1301 + TLS_AES_256_GCM_SHA384 uint16 = 0x1302 + TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303 + + // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator + // that the client is doing version fallback. See RFC 7507. + TLS_FALLBACK_SCSV uint16 = 0x5600 + + // Legacy names for the corresponding cipher suites with the correct _SHA256 + // suffix, retained for backward compatibility. + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 +) diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/common.go b/vendor/github.com/lesismal/llib/std/crypto/tls/common.go new file mode 100644 index 0000000..4ee294a --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/common.go @@ -0,0 +1,1563 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "container/list" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha512" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "runtime" + "sort" + "strings" + "sync" + "time" + + "github.com/lesismal/llib/std/internal/cpu" +) + +const ( + VersionTLS10 = 0x0301 + VersionTLS11 = 0x0302 + VersionTLS12 = 0x0303 + VersionTLS13 = 0x0304 + + // Deprecated: SSLv3 is cryptographically broken, and is no longer + // supported by this package. See golang.org/issue/32716. + VersionSSL30 = 0x0300 +) + +const ( + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 + recordHeaderLen = 5 // record header length + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + maxUselessRecords = 16 // maximum number of consecutive non-advancing records +) + +// TLS record types. +type recordType uint8 + +const ( + recordTypeChangeCipherSpec recordType = 20 + recordTypeAlert recordType = 21 + recordTypeHandshake recordType = 22 + recordTypeApplicationData recordType = 23 +) + +// TLS handshake message types. +const ( + typeHelloRequest uint8 = 0 + typeClientHello uint8 = 1 + typeServerHello uint8 = 2 + typeNewSessionTicket uint8 = 4 + typeEndOfEarlyData uint8 = 5 + typeEncryptedExtensions uint8 = 8 + typeCertificate uint8 = 11 + typeServerKeyExchange uint8 = 12 + typeCertificateRequest uint8 = 13 + typeServerHelloDone uint8 = 14 + typeCertificateVerify uint8 = 15 + typeClientKeyExchange uint8 = 16 + typeFinished uint8 = 20 + typeCertificateStatus uint8 = 22 + typeKeyUpdate uint8 = 24 + typeNextProtocol uint8 = 67 // Not IANA assigned + typeMessageHash uint8 = 254 // synthetic message +) + +// TLS compression types. +const ( + compressionNone uint8 = 0 +) + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 + extensionSessionTicket uint16 = 35 + extensionPreSharedKey uint16 = 41 + extensionEarlyData uint16 = 42 + extensionSupportedVersions uint16 = 43 + extensionCookie uint16 = 44 + extensionPSKModes uint16 = 45 + extensionCertificateAuthorities uint16 = 47 + extensionSignatureAlgorithmsCert uint16 = 50 + extensionKeyShare uint16 = 51 + extensionRenegotiationInfo uint16 = 0xff01 +) + +// TLS signaling cipher suite values +const ( + scsvRenegotiation uint16 = 0x00ff +) + +// CurveID is the type of a TLS identifier for an elliptic curve. See +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8. +// +// In TLS 1.3, this type is called NamedGroup, but at this time this library +// only supports Elliptic Curve based groups. See RFC 8446, Section 4.2.7. +type CurveID uint16 + +const ( + CurveP256 CurveID = 23 + CurveP384 CurveID = 24 + CurveP521 CurveID = 25 + X25519 CurveID = 29 +) + +// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. +type keyShare struct { + group CurveID + data []byte +} + +// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9. +const ( + pskModePlain uint8 = 0 + pskModeDHE uint8 = 1 +) + +// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved +// session. See RFC 8446, Section 4.2.11. +type pskIdentity struct { + label []byte + obfuscatedTicketAge uint32 +} + +// TLS Elliptic Curve Point Formats +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 +const ( + pointFormatUncompressed uint8 = 0 +) + +// TLS CertificateStatusType (RFC 3546) +const ( + statusTypeOCSP uint8 = 1 +) + +// Certificate types (for certificateRequestMsg) +const ( + certTypeRSASign = 1 + certTypeECDSASign = 64 // ECDSA or EdDSA keys, see RFC 8422, Section 3. +) + +// Signature algorithms (for internal signaling use). Starting at 225 to avoid overlap with +// TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do. +const ( + signaturePKCS1v15 uint8 = iota + 225 + signatureRSAPSS + signatureECDSA + signatureEd25519 +) + +// directSigning is a standard Hash value that signals that no pre-hashing +// should be performed, and that the input should be signed directly. It is the +// hash function associated with the Ed25519 signature scheme. +var directSigning crypto.Hash = 0 + +// supportedSignatureAlgorithms contains the signature and hash algorithms that +// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+ +// CertificateRequest. The two fields are merged to match with TLS 1.3. +// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. +var supportedSignatureAlgorithms = []SignatureScheme{ + PSSWithSHA256, + ECDSAWithP256AndSHA256, + Ed25519, + PSSWithSHA384, + PSSWithSHA512, + PKCS1WithSHA256, + PKCS1WithSHA384, + PKCS1WithSHA512, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512, + PKCS1WithSHA1, + ECDSAWithSHA1, +} + +// helloRetryRequestRandom is set as the Random value of a ServerHello +// to signal that the message is actually a HelloRetryRequest. +var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, + 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, + 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +} + +const ( + // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server + // random as a downgrade protection if the server would be capable of + // negotiating a higher version. See RFC 8446, Section 4.1.3. + downgradeCanaryTLS12 = "DOWNGRD\x01" + downgradeCanaryTLS11 = "DOWNGRD\x00" +) + +// testingOnlyForceDowngradeCanary is set in tests to force the server side to +// include downgrade canaries even if it's using its highers supported version. +var testingOnlyForceDowngradeCanary bool + +// ConnectionState records basic TLS details about the connection. +type ConnectionState struct { + // Version is the TLS version used by the connection (e.g. VersionTLS12). + Version uint16 + + // HandshakeComplete is true if the handshake has concluded. + HandshakeComplete bool + + // DidResume is true if this connection was successfully resumed from a + // previous session with a session ticket or similar mechanism. + DidResume bool + + // CipherSuite is the cipher suite negotiated for the connection (e.g. + // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256). + CipherSuite uint16 + + // NegotiatedProtocol is the application protocol negotiated with ALPN. + NegotiatedProtocol string + + // NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation. + // + // Deprecated: this value is always true. + NegotiatedProtocolIsMutual bool + + // ServerName is the value of the Server Name Indication extension sent by + // the client. It's available both on the server and on the client side. + ServerName string + + // PeerCertificates are the parsed certificates sent by the peer, in the + // order in which they were sent. The first element is the leaf certificate + // that the connection is verified against. + // + // On the client side, it can't be empty. On the server side, it can be + // empty if Config.ClientAuth is not RequireAnyClientCert or + // RequireAndVerifyClientCert. + PeerCertificates []*x509.Certificate + + // VerifiedChains is a list of one or more chains where the first element is + // PeerCertificates[0] and the last element is from Config.RootCAs (on the + // client side) or Config.ClientCAs (on the server side). + // + // On the client side, it's set if Config.InsecureSkipVerify is false. On + // the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven + // (and the peer provided a certificate) or RequireAndVerifyClientCert. + VerifiedChains [][]*x509.Certificate + + // SignedCertificateTimestamps is a list of SCTs provided by the peer + // through the TLS handshake for the leaf certificate, if any. + SignedCertificateTimestamps [][]byte + + // OCSPResponse is a stapled Online Certificate Status Protocol (OCSP) + // response provided by the peer for the leaf certificate, if any. + OCSPResponse []byte + + // TLSUnique contains the "tls-unique" channel binding value (see RFC 5929, + // Section 3). This value will be nil for TLS 1.3 connections and for all + // resumed connections. + // + // Deprecated: there are conditions in which this value might not be unique + // to a connection. See the Security Considerations sections of RFC 5705 and + // RFC 7627, and https://mitls.org/pages/attacks/3SHAKE#channelbindings. + TLSUnique []byte + + // ekm is a closure exposed via ExportKeyingMaterial. + ekm func(label string, context []byte, length int) ([]byte, error) +} + +// ExportKeyingMaterial returns length bytes of exported key material in a new +// slice as defined in RFC 5705. If context is nil, it is not used as part of +// the seed. If the connection was set to allow renegotiation via +// Config.Renegotiation, this function will return an error. +func (cs *ConnectionState) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) { + return cs.ekm(label, context, length) +} + +// ClientAuthType declares the policy the server will follow for +// TLS Client Authentication. +type ClientAuthType int + +const ( + // NoClientCert indicates that no client certificate should be requested + // during the handshake, and if any certificates are sent they will not + // be verified. + NoClientCert ClientAuthType = iota + // RequestClientCert indicates that a client certificate should be requested + // during the handshake, but does not require that the client send any + // certificates. + RequestClientCert + // RequireAnyClientCert indicates that a client certificate should be requested + // during the handshake, and that at least one certificate is required to be + // sent by the client, but that certificate is not required to be valid. + RequireAnyClientCert + // VerifyClientCertIfGiven indicates that a client certificate should be requested + // during the handshake, but does not require that the client sends a + // certificate. If the client does send a certificate it is required to be + // valid. + VerifyClientCertIfGiven + // RequireAndVerifyClientCert indicates that a client certificate should be requested + // during the handshake, and that at least one valid certificate is required + // to be sent by the client. + RequireAndVerifyClientCert +) + +// requiresClientCert reports whether the ClientAuthType requires a client +// certificate to be provided. +func requiresClientCert(c ClientAuthType) bool { + switch c { + case RequireAnyClientCert, RequireAndVerifyClientCert: + return true + default: + return false + } +} + +// ClientSessionState contains the state needed by clients to resume TLS +// sessions. +type ClientSessionState struct { + sessionTicket []uint8 // Encrypted ticket used for session resumption with server + vers uint16 // TLS version negotiated for the session + cipherSuite uint16 // Ciphersuite negotiated for the session + masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret + serverCertificates []*x509.Certificate // Certificate chain presented by the server + verifiedChains [][]*x509.Certificate // Certificate chains we built for verification + receivedAt time.Time // When the session ticket was received from the server + ocspResponse []byte // Stapled OCSP response presented by the server + scts [][]byte // SCTs presented by the server + + // TLS 1.3 fields. + nonce []byte // Ticket nonce sent by the server, to derive PSK + useBy time.Time // Expiration of the ticket lifetime as set by the server + ageAdd uint32 // Random obfuscation factor for sending the ticket age +} + +// ClientSessionCache is a cache of ClientSessionState objects that can be used +// by a client to resume a TLS session with a given server. ClientSessionCache +// implementations should expect to be called concurrently from different +// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not +// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which +// are supported via this interface. +type ClientSessionCache interface { + // Get searches for a ClientSessionState associated with the given key. + // On return, ok is true if one was found. + Get(sessionKey string) (session *ClientSessionState, ok bool) + + // Put adds the ClientSessionState to the cache with the given key. It might + // get called multiple times in a connection if a TLS 1.3 server provides + // more than one session ticket. If called with a nil *ClientSessionState, + // it should remove the cache entry. + Put(sessionKey string, cs *ClientSessionState) +} + +//go:generate stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go + +// SignatureScheme identifies a signature algorithm supported by TLS. See +// RFC 8446, Section 4.2.3. +type SignatureScheme uint16 + +const ( + // RSASSA-PKCS1-v1_5 algorithms. + PKCS1WithSHA256 SignatureScheme = 0x0401 + PKCS1WithSHA384 SignatureScheme = 0x0501 + PKCS1WithSHA512 SignatureScheme = 0x0601 + + // RSASSA-PSS algorithms with public key OID rsaEncryption. + PSSWithSHA256 SignatureScheme = 0x0804 + PSSWithSHA384 SignatureScheme = 0x0805 + PSSWithSHA512 SignatureScheme = 0x0806 + + // ECDSA algorithms. Only constrained to a specific curve in TLS 1.3. + ECDSAWithP256AndSHA256 SignatureScheme = 0x0403 + ECDSAWithP384AndSHA384 SignatureScheme = 0x0503 + ECDSAWithP521AndSHA512 SignatureScheme = 0x0603 + + // EdDSA algorithms. + Ed25519 SignatureScheme = 0x0807 + + // Legacy signature and hash algorithms for TLS 1.2. + PKCS1WithSHA1 SignatureScheme = 0x0201 + ECDSAWithSHA1 SignatureScheme = 0x0203 +) + +// ClientHelloInfo contains information from a ClientHello message in order to +// guide application logic in the GetCertificate and GetConfigForClient callbacks. +type ClientHelloInfo struct { + // CipherSuites lists the CipherSuites supported by the client (e.g. + // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). + CipherSuites []uint16 + + // ServerName indicates the name of the server requested by the client + // in order to support virtual hosting. ServerName is only set if the + // client is using SNI (see RFC 4366, Section 3.1). + ServerName string + + // SupportedCurves lists the elliptic curves supported by the client. + // SupportedCurves is set only if the Supported Elliptic Curves + // Extension is being used (see RFC 4492, Section 5.1.1). + SupportedCurves []CurveID + + // SupportedPoints lists the point formats supported by the client. + // SupportedPoints is set only if the Supported Point Formats Extension + // is being used (see RFC 4492, Section 5.1.2). + SupportedPoints []uint8 + + // SignatureSchemes lists the signature and hash schemes that the client + // is willing to verify. SignatureSchemes is set only if the Signature + // Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1). + SignatureSchemes []SignatureScheme + + // SupportedProtos lists the application protocols supported by the client. + // SupportedProtos is set only if the Application-Layer Protocol + // Negotiation Extension is being used (see RFC 7301, Section 3.1). + // + // Servers can select a protocol by setting Config.NextProtos in a + // GetConfigForClient return value. + SupportedProtos []string + + // SupportedVersions lists the TLS versions supported by the client. + // For TLS versions less than 1.3, this is extrapolated from the max + // version advertised by the client, so values other than the greatest + // might be rejected if used. + SupportedVersions []uint16 + + // Conn is the underlying net.Conn for the connection. Do not read + // from, or write to, this connection; that will cause the TLS + // connection to fail. + Conn net.Conn + + // config is embedded by the GetCertificate or GetConfigForClient caller, + // for use with SupportsCertificate. + config *Config +} + +// CertificateRequestInfo contains information from a server's +// CertificateRequest message, which is used to demand a certificate and proof +// of control from a client. +type CertificateRequestInfo struct { + // AcceptableCAs contains zero or more, DER-encoded, X.501 + // Distinguished Names. These are the names of root or intermediate CAs + // that the server wishes the returned certificate to be signed by. An + // empty slice indicates that the server has no preference. + AcceptableCAs [][]byte + + // SignatureSchemes lists the signature schemes that the server is + // willing to verify. + SignatureSchemes []SignatureScheme + + // Version is the TLS version that was negotiated for this connection. + Version uint16 +} + +// RenegotiationSupport enumerates the different levels of support for TLS +// renegotiation. TLS renegotiation is the act of performing subsequent +// handshakes on a connection after the first. This significantly complicates +// the state machine and has been the source of numerous, subtle security +// issues. Initiating a renegotiation is not supported, but support for +// accepting renegotiation requests may be enabled. +// +// Even when enabled, the server may not change its identity between handshakes +// (i.e. the leaf certificate must be the same). Additionally, concurrent +// handshake and application data flow is not permitted so renegotiation can +// only be used with protocols that synchronise with the renegotiation, such as +// HTTPS. +// +// Renegotiation is not defined in TLS 1.3. +type RenegotiationSupport int + +const ( + // RenegotiateNever disables renegotiation. + RenegotiateNever RenegotiationSupport = iota + + // RenegotiateOnceAsClient allows a remote server to request + // renegotiation once per connection. + RenegotiateOnceAsClient + + // RenegotiateFreelyAsClient allows a remote server to repeatedly + // request renegotiation. + RenegotiateFreelyAsClient +) + +// A Config structure is used to configure a TLS client or server. +// After one has been passed to a TLS function it must not be +// modified. A Config may be reused; the tls package will also not +// modify it. +type Config struct { + // Rand provides the source of entropy for nonces and RSA blinding. + // If Rand is nil, TLS uses the cryptographic random reader in package + // crypto/rand. + // The Reader must be safe for use by multiple goroutines. + Rand io.Reader + + // Time returns the current time as the number of seconds since the epoch. + // If Time is nil, TLS uses time.Now. + Time func() time.Time + + // Certificates contains one or more certificate chains to present to the + // other side of the connection. The first certificate compatible with the + // peer's requirements is selected automatically. + // + // Server configurations must set one of Certificates, GetCertificate or + // GetConfigForClient. Clients doing client-authentication may set either + // Certificates or GetClientCertificate. + // + // Note: if there are multiple Certificates, and they don't have the + // optional field Leaf set, certificate selection will incur a significant + // per-handshake performance cost. + Certificates []Certificate + + // NameToCertificate maps from a certificate name to an element of + // Certificates. Note that a certificate name can be of the form + // '*.example.com' and so doesn't have to be a domain name as such. + // + // Deprecated: NameToCertificate only allows associating a single + // certificate with a given name. Leave this field nil to let the library + // select the first compatible chain from Certificates. + NameToCertificate map[string]*Certificate + + // GetCertificate returns a Certificate based on the given + // ClientHelloInfo. It will only be called if the client supplies SNI + // information or if Certificates is empty. + // + // If GetCertificate is nil or returns nil, then the certificate is + // retrieved from NameToCertificate. If NameToCertificate is nil, the + // best element of Certificates will be used. + GetCertificate func(*ClientHelloInfo) (*Certificate, error) + + // GetClientCertificate, if not nil, is called when a server requests a + // certificate from a client. If set, the contents of Certificates will + // be ignored. + // + // If GetClientCertificate returns an error, the handshake will be + // aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. If + // Certificate.Certificate is empty then no certificate will be sent to + // the server. If this is unacceptable to the server then it may abort + // the handshake. + // + // GetClientCertificate may be called multiple times for the same + // connection if renegotiation occurs or if TLS 1.3 is in use. + GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) + + // GetConfigForClient, if not nil, is called after a ClientHello is + // received from a client. It may return a non-nil Config in order to + // change the Config that will be used to handle this connection. If + // the returned Config is nil, the original Config will be used. The + // Config returned by this callback may not be subsequently modified. + // + // If GetConfigForClient is nil, the Config passed to Server() will be + // used for all connections. + // + // If SessionTicketKey was explicitly set on the returned Config, or if + // SetSessionTicketKeys was called on the returned Config, those keys will + // be used. Otherwise, the original Config keys will be used (and possibly + // rotated if they are automatically managed). + GetConfigForClient func(*ClientHelloInfo) (*Config, error) + + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a TLS client or server. It + // receives the raw ASN.1 certificates provided by the peer and also + // any verified chains that normal processing found. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify, or (for a server) when ClientAuth is + // RequestClientCert or RequireAnyClientCert, then this callback will + // be considered but the verifiedChains argument will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + + // VerifyConnection, if not nil, is called after normal certificate + // verification and after VerifyPeerCertificate by either a TLS client + // or server. If it returns a non-nil error, the handshake is aborted + // and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. This callback will run for all connections + // regardless of InsecureSkipVerify or ClientAuth settings. + VerifyConnection func(ConnectionState) error + + // RootCAs defines the set of root certificate authorities + // that clients use when verifying server certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + + // NextProtos is a list of supported application level protocols, in + // order of preference. + NextProtos []string + + // ServerName is used to verify the hostname on the returned + // certificates unless InsecureSkipVerify is given. It is also included + // in the client's handshake to support virtual hosting unless it is + // an IP address. + ServerName string + + // ClientAuth determines the server's policy for + // TLS Client Authentication. The default is NoClientCert. + ClientAuth ClientAuthType + + // ClientCAs defines the set of root certificate authorities + // that servers use if required to verify a client certificate + // by the policy in ClientAuth. + ClientCAs *x509.CertPool + + // InsecureSkipVerify controls whether a client verifies the server's + // certificate chain and host name. If InsecureSkipVerify is true, crypto/tls + // accepts any certificate presented by the server and any host name in that + // certificate. In this mode, TLS is susceptible to machine-in-the-middle + // attacks unless custom verification is used. This should be used only for + // testing or in combination with VerifyConnection or VerifyPeerCertificate. + InsecureSkipVerify bool + + // CipherSuites is a list of supported cipher suites for TLS versions up to + // TLS 1.2. If CipherSuites is nil, a default list of secure cipher suites + // is used, with a preference order based on hardware performance. The + // default cipher suites might change over Go versions. Note that TLS 1.3 + // ciphersuites are not configurable. + CipherSuites []uint16 + + // PreferServerCipherSuites controls whether the server selects the + // client's most preferred ciphersuite, or the server's most preferred + // ciphersuite. If true then the server's preference, as expressed in + // the order of elements in CipherSuites, is used. + PreferServerCipherSuites bool + + // SessionTicketsDisabled may be set to true to disable session ticket and + // PSK (resumption) support. Note that on clients, session ticket support is + // also disabled if ClientSessionCache is nil. + SessionTicketsDisabled bool + + // SessionTicketKey is used by TLS servers to provide session resumption. + // See RFC 5077 and the PSK mode of RFC 8446. If zero, it will be filled + // with random data before the first server handshake. + // + // Deprecated: if this field is left at zero, session ticket keys will be + // automatically rotated every day and dropped after seven days. For + // customizing the rotation schedule or synchronizing servers that are + // terminating connections for the same host, use SetSessionTicketKeys. + SessionTicketKey [32]byte + + // ClientSessionCache is a cache of ClientSessionState entries for TLS + // session resumption. It is only used by clients. + ClientSessionCache ClientSessionCache + + // MinVersion contains the minimum TLS version that is acceptable. + // If zero, TLS 1.0 is currently taken as the minimum. + MinVersion uint16 + + // MaxVersion contains the maximum TLS version that is acceptable. + // If zero, the maximum version supported by this package is used, + // which is currently TLS 1.3. + MaxVersion uint16 + + // CurvePreferences contains the elliptic curves that will be used in + // an ECDHE handshake, in preference order. If empty, the default will + // be used. The client will use the first preference as the type for + // its key share in TLS 1.3. This may change in the future. + CurvePreferences []CurveID + + // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. + // When true, the largest possible TLS record size is always used. When + // false, the size of TLS records may be adjusted in an attempt to + // improve latency. + DynamicRecordSizingDisabled bool + + // Renegotiation controls what types of renegotiation are supported. + // The default, none, is correct for the vast majority of applications. + Renegotiation RenegotiationSupport + + // KeyLogWriter optionally specifies a destination for TLS master secrets + // in NSS key log format that can be used to allow external programs + // such as Wireshark to decrypt TLS connections. + // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. + // Use of KeyLogWriter compromises security and should only be + // used for debugging. + KeyLogWriter io.Writer + + // mutex protects sessionTicketKeys and autoSessionTicketKeys. + mutex sync.RWMutex + // sessionTicketKeys contains zero or more ticket keys. If set, it means the + // the keys were set with SessionTicketKey or SetSessionTicketKeys. The + // first key is used for new tickets and any subsequent keys can be used to + // decrypt old tickets. The slice contents are not protected by the mutex + // and are immutable. + sessionTicketKeys []ticketKey + // autoSessionTicketKeys is like sessionTicketKeys but is owned by the + // auto-rotation logic. See Config.ticketKeys. + autoSessionTicketKeys []ticketKey +} + +const ( + // ticketKeyNameLen is the number of bytes of identifier that is prepended to + // an encrypted session ticket in order to identify the key used to encrypt it. + ticketKeyNameLen = 16 + + // ticketKeyLifetime is how long a ticket key remains valid and can be used to + // resume a client connection. + ticketKeyLifetime = 7 * 24 * time.Hour // 7 days + + // ticketKeyRotation is how often the server should rotate the session ticket key + // that is used for new tickets. + ticketKeyRotation = 24 * time.Hour +) + +// ticketKey is the internal representation of a session ticket key. +type ticketKey struct { + // keyName is an opaque byte string that serves to identify the session + // ticket key. It's exposed as plaintext in every session ticket. + keyName [ticketKeyNameLen]byte + aesKey [16]byte + hmacKey [16]byte + // created is the time at which this ticket key was created. See Config.ticketKeys. + created time.Time +} + +// ticketKeyFromBytes converts from the external representation of a session +// ticket key to a ticketKey. Externally, session ticket keys are 32 random +// bytes and this function expands that into sufficient name and key material. +func (c *Config) ticketKeyFromBytes(b [32]byte) (key ticketKey) { + hashed := sha512.Sum512(b[:]) + copy(key.keyName[:], hashed[:ticketKeyNameLen]) + copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16]) + copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32]) + key.created = c.time() + return key +} + +// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session +// ticket, and the lifetime we set for tickets we send. +const maxSessionTicketLifetime = 7 * 24 * time.Hour + +// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a Config that is +// being used concurrently by a TLS client or server. +func (c *Config) Clone() *Config { + if c == nil { + return nil + } + c.mutex.RLock() + defer c.mutex.RUnlock() + return &Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + GetClientCertificate: c.GetClientCertificate, + GetConfigForClient: c.GetConfigForClient, + VerifyPeerCertificate: c.VerifyPeerCertificate, + VerifyConnection: c.VerifyConnection, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + KeyLogWriter: c.KeyLogWriter, + sessionTicketKeys: c.sessionTicketKeys, + autoSessionTicketKeys: c.autoSessionTicketKeys, + } +} + +// deprecatedSessionTicketKey is set as the prefix of SessionTicketKey if it was +// randomized for backwards compatibility but is not in use. +var deprecatedSessionTicketKey = []byte("DEPRECATED") + +// initLegacySessionTicketKeyRLocked ensures the legacy SessionTicketKey field is +// randomized if empty, and that sessionTicketKeys is populated from it otherwise. +func (c *Config) initLegacySessionTicketKeyRLocked() { + // Don't write if SessionTicketKey is already defined as our deprecated string, + // or if it is defined by the user but sessionTicketKeys is already set. + if c.SessionTicketKey != [32]byte{} && + (bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) || len(c.sessionTicketKeys) > 0) { + return + } + + // We need to write some data, so get an exclusive lock and re-check any conditions. + c.mutex.RUnlock() + defer c.mutex.RLock() + c.mutex.Lock() + defer c.mutex.Unlock() + if c.SessionTicketKey == [32]byte{} { + if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { + panic(fmt.Sprintf("tls: unable to generate random session ticket key: %v", err)) + } + // Write the deprecated prefix at the beginning so we know we created + // it. This key with the DEPRECATED prefix isn't used as an actual + // session ticket key, and is only randomized in case the application + // reuses it for some reason. + copy(c.SessionTicketKey[:], deprecatedSessionTicketKey) + } else if !bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) && len(c.sessionTicketKeys) == 0 { + c.sessionTicketKeys = []ticketKey{c.ticketKeyFromBytes(c.SessionTicketKey)} + } + +} + +// ticketKeys returns the ticketKeys for this connection. +// If configForClient has explicitly set keys, those will +// be returned. Otherwise, the keys on c will be used and +// may be rotated if auto-managed. +// During rotation, any expired session ticket keys are deleted from +// c.sessionTicketKeys. If the session ticket key that is currently +// encrypting tickets (ie. the first ticketKey in c.sessionTicketKeys) +// is not fresh, then a new session ticket key will be +// created and prepended to c.sessionTicketKeys. +func (c *Config) ticketKeys(configForClient *Config) []ticketKey { + // If the ConfigForClient callback returned a Config with explicitly set + // keys, use those, otherwise just use the original Config. + if configForClient != nil { + configForClient.mutex.RLock() + if configForClient.SessionTicketsDisabled { + return nil + } + configForClient.initLegacySessionTicketKeyRLocked() + if len(configForClient.sessionTicketKeys) != 0 { + ret := configForClient.sessionTicketKeys + configForClient.mutex.RUnlock() + return ret + } + configForClient.mutex.RUnlock() + } + + c.mutex.RLock() + defer c.mutex.RUnlock() + if c.SessionTicketsDisabled { + return nil + } + c.initLegacySessionTicketKeyRLocked() + if len(c.sessionTicketKeys) != 0 { + return c.sessionTicketKeys + } + // Fast path for the common case where the key is fresh enough. + if len(c.autoSessionTicketKeys) > 0 && c.time().Sub(c.autoSessionTicketKeys[0].created) < ticketKeyRotation { + return c.autoSessionTicketKeys + } + + // autoSessionTicketKeys are managed by auto-rotation. + c.mutex.RUnlock() + defer c.mutex.RLock() + c.mutex.Lock() + defer c.mutex.Unlock() + // Re-check the condition in case it changed since obtaining the new lock. + if len(c.autoSessionTicketKeys) == 0 || c.time().Sub(c.autoSessionTicketKeys[0].created) >= ticketKeyRotation { + var newKey [32]byte + if _, err := io.ReadFull(c.rand(), newKey[:]); err != nil { + panic(fmt.Sprintf("unable to generate random session ticket key: %v", err)) + } + valid := make([]ticketKey, 0, len(c.autoSessionTicketKeys)+1) + valid = append(valid, c.ticketKeyFromBytes(newKey)) + for _, k := range c.autoSessionTicketKeys { + // While rotating the current key, also remove any expired ones. + if c.time().Sub(k.created) < ticketKeyLifetime { + valid = append(valid, k) + } + } + c.autoSessionTicketKeys = valid + } + return c.autoSessionTicketKeys +} + +// SetSessionTicketKeys updates the session ticket keys for a server. +// +// The first key will be used when creating new tickets, while all keys can be +// used for decrypting tickets. It is safe to call this function while the +// server is running in order to rotate the session ticket keys. The function +// will panic if keys is empty. +// +// Calling this function will turn off automatic session ticket key rotation. +// +// If multiple servers are terminating connections for the same host they should +// all have the same session ticket keys. If the session ticket keys leaks, +// previously recorded and future TLS connections using those keys might be +// compromised. +func (c *Config) SetSessionTicketKeys(keys [][32]byte) { + if len(keys) == 0 { + panic("tls: keys must have at least one key") + } + + newKeys := make([]ticketKey, len(keys)) + for i, bytes := range keys { + newKeys[i] = c.ticketKeyFromBytes(bytes) + } + + c.mutex.Lock() + c.sessionTicketKeys = newKeys + c.mutex.Unlock() +} + +func (c *Config) rand() io.Reader { + r := c.Rand + if r == nil { + return rand.Reader + } + return r +} + +func (c *Config) time() time.Time { + t := c.Time + if t == nil { + t = time.Now + } + return t() +} + +func (c *Config) cipherSuites() []uint16 { + s := c.CipherSuites + if s == nil { + s = defaultCipherSuites() + } + return s +} + +var supportedVersions = []uint16{ + VersionTLS13, + VersionTLS12, + VersionTLS11, + VersionTLS10, +} + +func (c *Config) supportedVersions() []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if c != nil && c.MinVersion != 0 && v < c.MinVersion { + continue + } + if c != nil && c.MaxVersion != 0 && v > c.MaxVersion { + continue + } + versions = append(versions, v) + } + return versions +} + +func (c *Config) maxSupportedVersion() uint16 { + supportedVersions := c.supportedVersions() + if len(supportedVersions) == 0 { + return 0 + } + return supportedVersions[0] +} + +// supportedVersionsFromMax returns a list of supported versions derived from a +// legacy maximum version value. Note that only versions supported by this +// library are returned. Any newer peer will use supportedVersions anyway. +func supportedVersionsFromMax(maxVersion uint16) []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if v > maxVersion { + continue + } + versions = append(versions, v) + } + return versions +} + +var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} + +func (c *Config) curvePreferences() []CurveID { + if c == nil || len(c.CurvePreferences) == 0 { + return defaultCurvePreferences + } + return c.CurvePreferences +} + +func (c *Config) supportsCurve(curve CurveID) bool { + for _, cc := range c.curvePreferences() { + if cc == curve { + return true + } + } + return false +} + +// mutualVersion returns the protocol version to use given the advertised +// versions of the peer. Priority is given to the peer preference order. +func (c *Config) mutualVersion(peerVersions []uint16) (uint16, bool) { + supportedVersions := c.supportedVersions() + for _, peerVersion := range peerVersions { + for _, v := range supportedVersions { + if v == peerVersion { + return v, true + } + } + } + return 0, false +} + +var errNoCertificates = errors.New("tls: no certificates configured") + +// getCertificate returns the best certificate for the given ClientHelloInfo, +// defaulting to the first element of c.Certificates. +func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { + if c.GetCertificate != nil && + (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { + cert, err := c.GetCertificate(clientHello) + if cert != nil || err != nil { + return cert, err + } + } + + if len(c.Certificates) == 0 { + return nil, errNoCertificates + } + + if len(c.Certificates) == 1 { + // There's only one choice, so no point doing any work. + return &c.Certificates[0], nil + } + + if c.NameToCertificate != nil { + name := strings.ToLower(clientHello.ServerName) + if cert, ok := c.NameToCertificate[name]; ok { + return cert, nil + } + if len(name) > 0 { + labels := strings.Split(name, ".") + labels[0] = "*" + wildcardName := strings.Join(labels, ".") + if cert, ok := c.NameToCertificate[wildcardName]; ok { + return cert, nil + } + } + } + + for _, cert := range c.Certificates { + if err := clientHello.SupportsCertificate(&cert); err == nil { + return &cert, nil + } + } + + // If nothing matches, return the first certificate. + return &c.Certificates[0], nil +} + +// SupportsCertificate returns nil if the provided certificate is supported by +// the client that sent the ClientHello. Otherwise, it returns an error +// describing the reason for the incompatibility. +// +// If this ClientHelloInfo was passed to a GetConfigForClient or GetCertificate +// callback, this method will take into account the associated Config. Note that +// if GetConfigForClient returns a different Config, the change can't be +// accounted for by this method. +// +// This function will call x509.ParseCertificate unless c.Leaf is set, which can +// incur a significant performance cost. +func (chi *ClientHelloInfo) SupportsCertificate(c *Certificate) error { + // Note we don't currently support certificate_authorities nor + // signature_algorithms_cert, and don't check the algorithms of the + // signatures on the chain (which anyway are a SHOULD, see RFC 8446, + // Section 4.4.2.2). + + config := chi.config + if config == nil { + config = &Config{} + } + vers, ok := config.mutualVersion(chi.SupportedVersions) + if !ok { + return errors.New("no mutually supported protocol versions") + } + + // If the client specified the name they are trying to connect to, the + // certificate needs to be valid for it. + if chi.ServerName != "" { + x509Cert, err := c.leaf() + if err != nil { + return fmt.Errorf("failed to parse certificate: %w", err) + } + if err := x509Cert.VerifyHostname(chi.ServerName); err != nil { + return fmt.Errorf("certificate is not valid for requested server name: %w", err) + } + } + + // supportsRSAFallback returns nil if the certificate and connection support + // the static RSA key exchange, and unsupported otherwise. The logic for + // supporting static RSA is completely disjoint from the logic for + // supporting signed key exchanges, so we just check it as a fallback. + supportsRSAFallback := func(unsupported error) error { + // TLS 1.3 dropped support for the static RSA key exchange. + if vers == VersionTLS13 { + return unsupported + } + // The static RSA key exchange works by decrypting a challenge with the + // RSA private key, not by signing, so check the PrivateKey implements + // crypto.Decrypter, like *rsa.PrivateKey does. + if priv, ok := c.PrivateKey.(crypto.Decrypter); ok { + if _, ok := priv.Public().(*rsa.PublicKey); !ok { + return unsupported + } + } else { + return unsupported + } + // Finally, there needs to be a mutual cipher suite that uses the static + // RSA key exchange instead of ECDHE. + rsaCipherSuite := selectCipherSuite(chi.CipherSuites, config.cipherSuites(), func(c *cipherSuite) bool { + if c.flags&suiteECDHE != 0 { + return false + } + if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true + }) + if rsaCipherSuite == nil { + return unsupported + } + return nil + } + + // If the client sent the signature_algorithms extension, ensure it supports + // schemes we can use with this certificate and TLS version. + if len(chi.SignatureSchemes) > 0 { + if _, err := selectSignatureScheme(vers, c, chi.SignatureSchemes); err != nil { + return supportsRSAFallback(err) + } + } + + // In TLS 1.3 we are done because supported_groups is only relevant to the + // ECDHE computation, point format negotiation is removed, cipher suites are + // only relevant to the AEAD choice, and static RSA does not exist. + if vers == VersionTLS13 { + return nil + } + + // The only signed key exchange we support is ECDHE. + if !supportsECDHE(config, chi.SupportedCurves, chi.SupportedPoints) { + return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange")) + } + + var ecdsaCipherSuite bool + if priv, ok := c.PrivateKey.(crypto.Signer); ok { + switch pub := priv.Public().(type) { + case *ecdsa.PublicKey: + var curve CurveID + switch pub.Curve { + case elliptic.P256(): + curve = CurveP256 + case elliptic.P384(): + curve = CurveP384 + case elliptic.P521(): + curve = CurveP521 + default: + return supportsRSAFallback(unsupportedCertificateError(c)) + } + var curveOk bool + for _, c := range chi.SupportedCurves { + if c == curve && config.supportsCurve(c) { + curveOk = true + break + } + } + if !curveOk { + return errors.New("client doesn't support certificate curve") + } + ecdsaCipherSuite = true + case ed25519.PublicKey: + if vers < VersionTLS12 || len(chi.SignatureSchemes) == 0 { + return errors.New("connection doesn't support Ed25519") + } + ecdsaCipherSuite = true + case *rsa.PublicKey: + default: + return supportsRSAFallback(unsupportedCertificateError(c)) + } + } else { + return supportsRSAFallback(unsupportedCertificateError(c)) + } + + // Make sure that there is a mutually supported cipher suite that works with + // this certificate. Cipher suite selection will then apply the logic in + // reverse to pick it. See also serverHandshakeState.cipherSuiteOk. + cipherSuite := selectCipherSuite(chi.CipherSuites, config.cipherSuites(), func(c *cipherSuite) bool { + if c.flags&suiteECDHE == 0 { + return false + } + if c.flags&suiteECSign != 0 { + if !ecdsaCipherSuite { + return false + } + } else { + if ecdsaCipherSuite { + return false + } + } + if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true + }) + if cipherSuite == nil { + return supportsRSAFallback(errors.New("client doesn't support any cipher suites compatible with the certificate")) + } + + return nil +} + +// SupportsCertificate returns nil if the provided certificate is supported by +// the server that sent the CertificateRequest. Otherwise, it returns an error +// describing the reason for the incompatibility. +func (cri *CertificateRequestInfo) SupportsCertificate(c *Certificate) error { + if _, err := selectSignatureScheme(cri.Version, c, cri.SignatureSchemes); err != nil { + return err + } + + if len(cri.AcceptableCAs) == 0 { + return nil + } + + for j, cert := range c.Certificate { + x509Cert := c.Leaf + // Parse the certificate if this isn't the leaf node, or if + // chain.Leaf was nil. + if j != 0 || x509Cert == nil { + var err error + if x509Cert, err = x509.ParseCertificate(cert); err != nil { + return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err) + } + } + + for _, ca := range cri.AcceptableCAs { + if bytes.Equal(x509Cert.RawIssuer, ca) { + return nil + } + } + } + return errors.New("chain is not signed by an acceptable CA") +} + +// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate +// from the CommonName and SubjectAlternateName fields of each of the leaf +// certificates. +// +// Deprecated: NameToCertificate only allows associating a single certificate +// with a given name. Leave that field nil to let the library select the first +// compatible chain from Certificates. +func (c *Config) BuildNameToCertificate() { + c.NameToCertificate = make(map[string]*Certificate) + for i := range c.Certificates { + cert := &c.Certificates[i] + x509Cert, err := cert.leaf() + if err != nil { + continue + } + // If SANs are *not* present, some clients will consider the certificate + // valid for the name in the Common Name. + if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 { + c.NameToCertificate[x509Cert.Subject.CommonName] = cert + } + for _, san := range x509Cert.DNSNames { + c.NameToCertificate[san] = cert + } + } +} + +const ( + keyLogLabelTLS12 = "CLIENT_RANDOM" + keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET" + keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET" + keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0" + keyLogLabelServerTraffic = "SERVER_TRAFFIC_SECRET_0" +) + +func (c *Config) writeKeyLog(label string, clientRandom, secret []byte) error { + if c.KeyLogWriter == nil { + return nil + } + + logLine := []byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)) + + writerMutex.Lock() + _, err := c.KeyLogWriter.Write(logLine) + writerMutex.Unlock() + + return err +} + +// writerMutex protects all KeyLogWriters globally. It is rarely enabled, +// and is only for debugging, so a global mutex saves space. +var writerMutex sync.Mutex + +// A Certificate is a chain of one or more certificates, leaf first. +type Certificate struct { + Certificate [][]byte + // PrivateKey contains the private key corresponding to the public key in + // Leaf. This must implement crypto.Signer with an RSA, ECDSA or Ed25519 PublicKey. + // For a server up to TLS 1.2, it can also implement crypto.Decrypter with + // an RSA PublicKey. + PrivateKey crypto.PrivateKey + // SupportedSignatureAlgorithms is an optional list restricting what + // signature algorithms the PrivateKey can be used for. + SupportedSignatureAlgorithms []SignatureScheme + // OCSPStaple contains an optional OCSP response which will be served + // to clients that request it. + OCSPStaple []byte + // SignedCertificateTimestamps contains an optional list of Signed + // Certificate Timestamps which will be served to clients that request it. + SignedCertificateTimestamps [][]byte + // Leaf is the parsed form of the leaf certificate, which may be initialized + // using x509.ParseCertificate to reduce per-handshake processing. If nil, + // the leaf certificate will be parsed as needed. + Leaf *x509.Certificate +} + +// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing +// the corresponding c.Certificate[0]. +func (c *Certificate) leaf() (*x509.Certificate, error) { + if c.Leaf != nil { + return c.Leaf, nil + } + return x509.ParseCertificate(c.Certificate[0]) +} + +type handshakeMessage interface { + marshal() []byte + unmarshal([]byte) bool +} + +// lruSessionCache is a ClientSessionCache implementation that uses an LRU +// caching strategy. +type lruSessionCache struct { + sync.Mutex + + m map[string]*list.Element + q *list.List + capacity int +} + +type lruSessionCacheEntry struct { + sessionKey string + state *ClientSessionState +} + +// NewLRUClientSessionCache returns a ClientSessionCache with the given +// capacity that uses an LRU strategy. If capacity is < 1, a default capacity +// is used instead. +func NewLRUClientSessionCache(capacity int) ClientSessionCache { + const defaultSessionCacheCapacity = 64 + + if capacity < 1 { + capacity = defaultSessionCacheCapacity + } + return &lruSessionCache{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: capacity, + } +} + +// Put adds the provided (sessionKey, cs) pair to the cache. If cs is nil, the entry +// corresponding to sessionKey is removed from the cache instead. +func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + if cs == nil { + c.q.Remove(elem) + delete(c.m, sessionKey) + } else { + entry := elem.Value.(*lruSessionCacheEntry) + entry.state = cs + c.q.MoveToFront(elem) + } + return + } + + if c.q.Len() < c.capacity { + entry := &lruSessionCacheEntry{sessionKey, cs} + c.m[sessionKey] = c.q.PushFront(entry) + return + } + + elem := c.q.Back() + entry := elem.Value.(*lruSessionCacheEntry) + delete(c.m, entry.sessionKey) + entry.sessionKey = sessionKey + entry.state = cs + c.q.MoveToFront(elem) + c.m[sessionKey] = elem +} + +// Get returns the ClientSessionState value associated with a given key. It +// returns (nil, false) if no value is found. +func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + c.q.MoveToFront(elem) + return elem.Value.(*lruSessionCacheEntry).state, true + } + return nil, false +} + +var emptyConfig Config + +func defaultConfig() *Config { + return &emptyConfig +} + +var ( + once sync.Once + varDefaultCipherSuites []uint16 + varDefaultCipherSuitesTLS13 []uint16 +) + +func defaultCipherSuites() []uint16 { + once.Do(initDefaultCipherSuites) + return varDefaultCipherSuites +} + +func defaultCipherSuitesTLS13() []uint16 { + once.Do(initDefaultCipherSuites) + return varDefaultCipherSuitesTLS13 +} + +var ( + hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) + + hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 || + runtime.GOARCH == "arm64" && hasGCMAsmARM64 || + runtime.GOARCH == "s390x" && hasGCMAsmS390X +) + +func initDefaultCipherSuites() { + var topCipherSuites []uint16 + + if hasAESGCMHardwareSupport { + // If AES-GCM hardware is provided then prioritise AES-GCM + // cipher suites. + topCipherSuites = []uint16{ + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + } + varDefaultCipherSuitesTLS13 = []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_256_GCM_SHA384, + } + } else { + // Without AES-GCM hardware, we put the ChaCha20-Poly1305 + // cipher suites first. + topCipherSuites = []uint16{ + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + } + varDefaultCipherSuitesTLS13 = []uint16{ + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + } + } + + varDefaultCipherSuites = make([]uint16, 0, len(cipherSuites)) + varDefaultCipherSuites = append(varDefaultCipherSuites, topCipherSuites...) + +NextCipherSuite: + for _, suite := range cipherSuites { + if suite.flags&suiteDefaultOff != 0 { + continue + } + for _, existing := range varDefaultCipherSuites { + if existing == suite.id { + continue NextCipherSuite + } + } + varDefaultCipherSuites = append(varDefaultCipherSuites, suite.id) + } +} + +func unexpectedMessageError(wanted, got interface{}) error { + return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) +} + +func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool { + for _, s := range supportedSignatureAlgorithms { + if s == sigAlg { + return true + } + } + return false +} + +var aesgcmCiphers = map[uint16]bool{ + // 1.2 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true, + // 1.3 + TLS_AES_128_GCM_SHA256: true, + TLS_AES_256_GCM_SHA384: true, +} + +var nonAESGCMAEADCiphers = map[uint16]bool{ + // 1.2 + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true, + // 1.3 + TLS_CHACHA20_POLY1305_SHA256: true, +} + +// aesgcmPreferred returns whether the first valid cipher in the preference list +// is an AES-GCM cipher, implying the peer has hardware support for it. +func aesgcmPreferred(ciphers []uint16) bool { + for _, cID := range ciphers { + c := cipherSuiteByID(cID) + if c == nil { + c13 := cipherSuiteTLS13ByID(cID) + if c13 == nil { + continue + } + return aesgcmCiphers[cID] + } + return aesgcmCiphers[cID] + } + return false +} + +// deprioritizeAES reorders cipher preference lists by rearranging +// adjacent AEAD ciphers such that AES-GCM based ciphers are moved +// after other AEAD ciphers. It returns a fresh slice. +func deprioritizeAES(ciphers []uint16) []uint16 { + reordered := make([]uint16, len(ciphers)) + copy(reordered, ciphers) + sort.SliceStable(reordered, func(i, j int) bool { + return nonAESGCMAEADCiphers[reordered[i]] && aesgcmCiphers[reordered[j]] + }) + return reordered +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/common_string.go b/vendor/github.com/lesismal/llib/std/crypto/tls/common_string.go new file mode 100644 index 0000000..2381088 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/common_string.go @@ -0,0 +1,116 @@ +// Code generated by "stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go"; DO NOT EDIT. + +package tls + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[PKCS1WithSHA256-1025] + _ = x[PKCS1WithSHA384-1281] + _ = x[PKCS1WithSHA512-1537] + _ = x[PSSWithSHA256-2052] + _ = x[PSSWithSHA384-2053] + _ = x[PSSWithSHA512-2054] + _ = x[ECDSAWithP256AndSHA256-1027] + _ = x[ECDSAWithP384AndSHA384-1283] + _ = x[ECDSAWithP521AndSHA512-1539] + _ = x[Ed25519-2055] + _ = x[PKCS1WithSHA1-513] + _ = x[ECDSAWithSHA1-515] +} + +const ( + _SignatureScheme_name_0 = "PKCS1WithSHA1" + _SignatureScheme_name_1 = "ECDSAWithSHA1" + _SignatureScheme_name_2 = "PKCS1WithSHA256" + _SignatureScheme_name_3 = "ECDSAWithP256AndSHA256" + _SignatureScheme_name_4 = "PKCS1WithSHA384" + _SignatureScheme_name_5 = "ECDSAWithP384AndSHA384" + _SignatureScheme_name_6 = "PKCS1WithSHA512" + _SignatureScheme_name_7 = "ECDSAWithP521AndSHA512" + _SignatureScheme_name_8 = "PSSWithSHA256PSSWithSHA384PSSWithSHA512Ed25519" +) + +var ( + _SignatureScheme_index_8 = [...]uint8{0, 13, 26, 39, 46} +) + +func (i SignatureScheme) String() string { + switch { + case i == 513: + return _SignatureScheme_name_0 + case i == 515: + return _SignatureScheme_name_1 + case i == 1025: + return _SignatureScheme_name_2 + case i == 1027: + return _SignatureScheme_name_3 + case i == 1281: + return _SignatureScheme_name_4 + case i == 1283: + return _SignatureScheme_name_5 + case i == 1537: + return _SignatureScheme_name_6 + case i == 1539: + return _SignatureScheme_name_7 + case 2052 <= i && i <= 2055: + i -= 2052 + return _SignatureScheme_name_8[_SignatureScheme_index_8[i]:_SignatureScheme_index_8[i+1]] + default: + return "SignatureScheme(" + strconv.FormatInt(int64(i), 10) + ")" + } +} +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[CurveP256-23] + _ = x[CurveP384-24] + _ = x[CurveP521-25] + _ = x[X25519-29] +} + +const ( + _CurveID_name_0 = "CurveP256CurveP384CurveP521" + _CurveID_name_1 = "X25519" +) + +var ( + _CurveID_index_0 = [...]uint8{0, 9, 18, 27} +) + +func (i CurveID) String() string { + switch { + case 23 <= i && i <= 25: + i -= 23 + return _CurveID_name_0[_CurveID_index_0[i]:_CurveID_index_0[i+1]] + case i == 29: + return _CurveID_name_1 + default: + return "CurveID(" + strconv.FormatInt(int64(i), 10) + ")" + } +} +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[NoClientCert-0] + _ = x[RequestClientCert-1] + _ = x[RequireAnyClientCert-2] + _ = x[VerifyClientCertIfGiven-3] + _ = x[RequireAndVerifyClientCert-4] +} + +const _ClientAuthType_name = "NoClientCertRequestClientCertRequireAnyClientCertVerifyClientCertIfGivenRequireAndVerifyClientCert" + +var _ClientAuthType_index = [...]uint8{0, 12, 29, 49, 72, 98} + +func (i ClientAuthType) String() string { + if i < 0 || i >= ClientAuthType(len(_ClientAuthType_index)-1) { + return "ClientAuthType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _ClientAuthType_name[_ClientAuthType_index[i]:_ClientAuthType_index[i+1]] +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/conn.go b/vendor/github.com/lesismal/llib/std/crypto/tls/conn.go new file mode 100644 index 0000000..65472fa --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/conn.go @@ -0,0 +1,1775 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TLS low level connection and record layer + +package tls + +import ( + "bytes" + "crypto/cipher" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +var ( + defaultReadBufferSize = 4096 + errDataNotEnough = errors.New("data not enough") +) + +type Allocator interface { + Malloc(size int) []byte + Realloc(buf []byte, size int) []byte + Free(buf []byte) +} + +type NativeAllocator struct{} + +// Malloc . +func (a *NativeAllocator) Malloc(size int) []byte { + return make([]byte, size) +} + +// Realloc . +func (a *NativeAllocator) Realloc(buf []byte, size int) []byte { + if size <= cap(buf) { + return buf[:size] + } + newBuf := make([]byte, size) + copy(newBuf, buf) + return newBuf +} + +// Free . +func (a *NativeAllocator) Free(buf []byte) { +} + +// A Conn represents a secured connection. +// It implements the net.Conn interface. +type Conn struct { + // constant + conn net.Conn + handshakeFn func() error // (*Conn).clientHandshake or serverHandshake + + // handshakeStatus is 1 if the connection is currently transferring + // application data (i.e. is not currently processing a handshake). + // This field is only to be accessed with sync/atomic. + handshakeStatus uint32 + // constant after handshake; protected by handshakeMutex + // handshakeMutex sync.Mutex + handshakeErr error // error resulting from handshake + vers uint16 // TLS version + config *Config // configuration passed to constructor + // handshakes counts the number of handshakes performed on the + // connection so far. If renegotiation is disabled then this is either + // zero or one. + handshakes int + cipherSuite uint16 + ocspResponse []byte // stapled OCSP response + scts [][]byte // signed certificate timestamps from server + peerCertificates []*x509.Certificate + // verifiedChains contains the certificate chains that we built, as + // opposed to the ones presented by the server. + verifiedChains [][]*x509.Certificate + // serverName contains the server name indicated by the client, if any. + serverName string + + // ekm is a closure for exporting keying material. + ekm func(label string, context []byte, length int) ([]byte, error) + // resumptionSecret is the resumption_master_secret for handling + // NewSessionTicket messages. nil if config.SessionTicketsDisabled. + resumptionSecret []byte + + // ticketKeys is the set of active session ticket keys for this + // connection. The first one is used to encrypt new tickets and + // all are tried to decrypt tickets. + ticketKeys []ticketKey + + // clientFinished and serverFinished contain the Finished message sent + // by the client or server in the most recent handshake. This is + // retained to support the renegotiation extension and tls-unique + // channel-binding. + clientFinished [12]byte + serverFinished [12]byte + + // clientProtocol is the negotiated ALPN protocol. + clientProtocol string + + // closeNotifyErr is any error from sending the alertCloseNotify record. + closeNotifyErr error + + closeMux sync.Mutex + closed bool + // secureRenegotiation is true if the server echoed the secure + // renegotiation extension. (This is meaningless as a server because + // renegotiation is not supported in that case.) + secureRenegotiation bool + + // clientFinishedIsFirst is true if the client sent the first Finished + // message during the most recent handshake. This is recorded because + // the first transmitted Finished message is the tls-unique + // channel-binding value. + clientFinishedIsFirst bool + + haveVers bool // version has been negotiated + didResume bool // whether this connection was a session resumption + + // closeNotifySent is true if the Conn attempted to send an + // alertCloseNotify record. + closeNotifySent bool + + isClient bool + isNonBlock bool + + // input/output + buffering bool // whether records are buffered in sendBuf + sendBuf []byte // a buffer of records waiting to be sent + in, out halfConn + rawInputOff int + rawInput []byte // bytes.Buffer // raw input, starting with a record header + input bytes.Reader // application data waiting to be read, from rawInput.Next + handOff int + hand []byte // bytes.Buffer // handshake data waiting to be read + + // bytesSent counts the bytes of application data sent. + // packetsSent counts packets. + bytesSent int64 + packetsSent int64 + + // retryCount counts the number of consecutive non-advancing records + // received by Conn.readRecord. That is, records that neither advance the + // handshake, nor deliver application data. Protected by in.Mutex. + retryCount int + + // activeCall is an atomic int32; the low bit is whether Close has + // been called. the rest of the bits are the number of goroutines + // in Conn.Write. + // activeCall int32 + + tmp [16]byte + + handshakeStatusAsync uint32 + clientHello *clientHelloMsg + serverHello *serverHelloMsg + hs *serverHandshakeState + hs13 *serverHandshakeStateTLS13 + certMsg *certificateMsgTLS13 + certMsgVerified []bool + + allocator Allocator + session interface{} +} + +// Access to net.Conn methods. +// Cannot just embed net.Conn because that would +// export the struct field too. + +// Conn returns conn +func (c *Conn) Conn() net.Conn { + return c.conn +} + +// ResetConn resets conn +func (c *Conn) ResetConn(conn net.Conn, nonBlock bool, v ...interface{}) { + c.conn = conn + c.isNonBlock = nonBlock + if len(v) > 0 { + if allocator, ok := v[0].(Allocator); ok { + c.allocator = allocator + } + } + if c.allocator == nil { + c.allocator = &NativeAllocator{} + } +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated with the connection. +// A zero value for t means Read and Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying connection. +// A zero value for t means Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +// A halfConn represents one direction of the record layer +// connection, either sending or receiving. +type halfConn struct { + // sync.Mutex + + err error // first permanent error + version uint16 // protocol version + cipher interface{} // cipher algorithm + mac hash.Hash + seq [8]byte // 64-bit sequence number + + scratchBuf [13]byte // to avoid allocs; interface method args escape + + nextCipher interface{} // next encryption state + nextMac hash.Hash // next MAC algorithm + + trafficSecret []byte // current TLS 1.3 traffic secret +} + +type permanentError struct { + err net.Error +} + +func (e *permanentError) Error() string { return e.err.Error() } +func (e *permanentError) Unwrap() error { return e.err } +func (e *permanentError) Timeout() bool { return e.err.Timeout() } +func (e *permanentError) Temporary() bool { return false } + +func (hc *halfConn) setErrorLocked(err error) error { + if e, ok := err.(net.Error); ok { + hc.err = &permanentError{err: e} + } else { + hc.err = err + } + return hc.err +} + +// prepareCipherSpec sets the encryption and MAC states +// that a subsequent changeCipherSpec will use. +func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) { + hc.version = version + hc.nextCipher = cipher + hc.nextMac = mac +} + +// changeCipherSpec changes the encryption and MAC states +// to the ones previously passed to prepareCipherSpec. +func (hc *halfConn) changeCipherSpec() error { + if hc.nextCipher == nil || hc.version == VersionTLS13 { + return alertInternalError + } + hc.cipher = hc.nextCipher + hc.mac = hc.nextMac + hc.nextCipher = nil + hc.nextMac = nil + for i := range hc.seq { + hc.seq[i] = 0 + } + return nil +} + +func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) { + hc.trafficSecret = secret + key, iv := suite.trafficKey(secret) + hc.cipher = suite.aead(key, iv) + for i := range hc.seq { + hc.seq[i] = 0 + } +} + +// incSeq increments the sequence number. +func (hc *halfConn) incSeq() { + for i := 7; i >= 0; i-- { + hc.seq[i]++ + if hc.seq[i] != 0 { + return + } + } + + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. + panic("TLS: sequence number wraparound") +} + +// explicitNonceLen returns the number of bytes of explicit nonce or IV included +// in each record. Explicit nonces are present only in CBC modes after TLS 1.0 +// and in certain AEAD modes in TLS 1.2. +func (hc *halfConn) explicitNonceLen() int { + if hc.cipher == nil { + return 0 + } + + switch c := hc.cipher.(type) { + case cipher.Stream: + return 0 + case aead: + return c.explicitNonceLen() + case cbcMode: + // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack. + if hc.version >= VersionTLS11 { + return c.BlockSize() + } + return 0 + default: + // panic("unknown cipher type") + } + return -1 +} + +// extractPadding returns, in constant time, the length of the padding to remove +// from the end of payload. It also returns a byte which is equal to 255 if the +// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. +func extractPadding(payload []byte) (toRemove int, good byte) { + if len(payload) < 1 { + return 0, 0 + } + + paddingLen := payload[len(payload)-1] + t := uint(len(payload)-1) - uint(paddingLen) + // if len(payload) >= (paddingLen - 1) then the MSB of t is zero + good = byte(int32(^t) >> 31) + + // The maximum possible padding length plus the actual length field + toCheck := 256 + // The length of the padded data is public, so we can use an if here + if toCheck > len(payload) { + toCheck = len(payload) + } + + for i := 0; i < toCheck; i++ { + t := uint(paddingLen) - uint(i) + // if i <= paddingLen then the MSB of t is zero + mask := byte(int32(^t) >> 31) + b := payload[len(payload)-1-i] + good &^= mask&paddingLen ^ mask&b + } + + // We AND together the bits of good and replicate the result across + // all the bits. + good &= good << 4 + good &= good << 2 + good &= good << 1 + good = uint8(int8(good) >> 7) + + // Zero the padding length on error. This ensures any unchecked bytes + // are included in the MAC. Otherwise, an attacker that could + // distinguish MAC failures from padding failures could mount an attack + // similar to POODLE in SSL 3.0: given a good ciphertext that uses a + // full block's worth of padding, replace the final block with another + // block. If the MAC check passed but the padding check failed, the + // last byte of that block decrypted to the block size. + // + // See also macAndPaddingGood logic below. + paddingLen &= good + + toRemove = int(paddingLen) + 1 + return +} + +func roundUp(a, b int) int { + return a + (b-a%b)%b +} + +// cbcMode is an interface for block ciphers using cipher block chaining. +type cbcMode interface { + cipher.BlockMode + SetIV([]byte) +} + +// decrypt authenticates and decrypts the record if protection is active at +// this stage. The returned plaintext might overlap with the input. +func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { + var plaintext []byte + typ := recordType(record[0]) + payload := record[recordHeaderLen:] + + // In TLS 1.3, change_cipher_spec messages are to be ignored without being + // decrypted. See RFC 8446, Appendix D.4. + if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec { + return payload, typ, nil + } + + paddingGood := byte(255) + paddingLen := 0 + + explicitNonceLen := hc.explicitNonceLen() + if explicitNonceLen < 0 { + return nil, 0, errors.New("unknown cipher type") + } + + if hc.cipher != nil { + switch c := hc.cipher.(type) { + case cipher.Stream: + c.XORKeyStream(payload, payload) + case aead: + if len(payload) < explicitNonceLen { + return nil, 0, alertBadRecordMAC + } + nonce := payload[:explicitNonceLen] + if len(nonce) == 0 { + nonce = hc.seq[:] + } + payload = payload[explicitNonceLen:] + + var additionalData []byte + if hc.version == VersionTLS13 { + additionalData = record[:recordHeaderLen] + } else { + additionalData = append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:3]...) + n := len(payload) - c.Overhead() + additionalData = append(additionalData, byte(n>>8), byte(n)) + } + + var err error + plaintext, err = c.Open(payload[:0], nonce, payload, additionalData) + if err != nil { + return nil, 0, alertBadRecordMAC + } + case cbcMode: + blockSize := c.BlockSize() + minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) + if len(payload)%blockSize != 0 || len(payload) < minPayload { + return nil, 0, alertBadRecordMAC + } + + if explicitNonceLen > 0 { + c.SetIV(payload[:explicitNonceLen]) + payload = payload[explicitNonceLen:] + } + c.CryptBlocks(payload, payload) + + // In a limited attempt to protect against CBC padding oracles like + // Lucky13, the data past paddingLen (which is secret) is passed to + // the MAC function as extra data, to be fed into the HMAC after + // computing the digest. This makes the MAC roughly constant time as + // long as the digest computation is constant time and does not + // affect the subsequent write, modulo cache effects. + paddingLen, paddingGood = extractPadding(payload) + default: + panic("unknown cipher type") + } + + if hc.version == VersionTLS13 { + if typ != recordTypeApplicationData { + return nil, 0, alertUnexpectedMessage + } + if len(plaintext) > maxPlaintext+1 { + return nil, 0, alertRecordOverflow + } + // Remove padding and find the ContentType scanning from the end. + for i := len(plaintext) - 1; i >= 0; i-- { + if plaintext[i] != 0 { + typ = recordType(plaintext[i]) + plaintext = plaintext[:i] + break + } + if i == 0 { + return nil, 0, alertUnexpectedMessage + } + } + } + } else { + plaintext = payload + } + + if hc.mac != nil { + macSize := hc.mac.Size() + if len(payload) < macSize { + return nil, 0, alertBadRecordMAC + } + + n := len(payload) - macSize - paddingLen + n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } + record[3] = byte(n >> 8) + record[4] = byte(n) + remoteMAC := payload[n : n+macSize] + localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) + + // This is equivalent to checking the MACs and paddingGood + // separately, but in constant-time to prevent distinguishing + // padding failures from MAC failures. Depending on what value + // of paddingLen was returned on bad padding, distinguishing + // bad MAC from bad padding can lead to an attack. + // + // See also the logic at the end of extractPadding. + macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood) + if macAndPaddingGood != 1 { + return nil, 0, alertBadRecordMAC + } + + plaintext = payload[:n] + } + + hc.incSeq() + return plaintext, typ, nil +} + +// sliceForAppend extends the input slice by n bytes. head is the full extended +// slice, while tail is the appended part. If the original slice has sufficient +// capacity no allocation is performed. +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return +} + +// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and +// appends it to record, which must already contain the record header. +func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { + if hc.cipher == nil { + return append(record, payload...), nil + } + + var explicitNonce []byte + if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 { + record, explicitNonce = sliceForAppend(record, explicitNonceLen) + if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 { + // The AES-GCM construction in TLS has an explicit nonce so that the + // nonce can be random. However, the nonce is only 8 bytes which is + // too small for a secure, random nonce. Therefore we use the + // sequence number as the nonce. The 3DES-CBC construction also has + // an 8 bytes nonce but its nonces must be unpredictable (see RFC + // 5246, Appendix F.3), forcing us to use randomness. That's not + // 3DES' biggest problem anyway because the birthday bound on block + // collision is reached first due to its similarly small block size + // (see the Sweet32 attack). + copy(explicitNonce, hc.seq[:]) + } else { + if _, err := io.ReadFull(rand, explicitNonce); err != nil { + return nil, err + } + } + } + + var dst []byte + switch c := hc.cipher.(type) { + case cipher.Stream: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) + record, dst = sliceForAppend(record, len(payload)+len(mac)) + c.XORKeyStream(dst[:len(payload)], payload) + c.XORKeyStream(dst[len(payload):], mac) + case aead: + nonce := explicitNonce + if len(nonce) == 0 { + nonce = hc.seq[:] + } + + if hc.version == VersionTLS13 { + record = append(record, payload...) + + // Encrypt the actual ContentType and replace the plaintext one. + record = append(record, record[0]) + record[0] = byte(recordTypeApplicationData) + + n := len(payload) + 1 + c.Overhead() + record[3] = byte(n >> 8) + record[4] = byte(n) + + record = c.Seal(record[:recordHeaderLen], + nonce, record[recordHeaderLen:], record[:recordHeaderLen]) + } else { + additionalData := append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:recordHeaderLen]...) + record = c.Seal(record, nonce, payload, additionalData) + } + case cbcMode: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) + blockSize := c.BlockSize() + plaintextLen := len(payload) + len(mac) + paddingLen := blockSize - plaintextLen%blockSize + record, dst = sliceForAppend(record, plaintextLen+paddingLen) + copy(dst, payload) + copy(dst[len(payload):], mac) + for i := plaintextLen; i < len(dst); i++ { + dst[i] = byte(paddingLen - 1) + } + if len(explicitNonce) > 0 { + c.SetIV(explicitNonce) + } + c.CryptBlocks(dst, dst) + default: + panic("unknown cipher type") + } + + // Update length to include nonce, MAC and any block padding needed. + n := len(record) - recordHeaderLen + record[3] = byte(n >> 8) + record[4] = byte(n) + hc.incSeq() + + return record, nil +} + +// RecordHeaderError is returned when a TLS record header is invalid. +type RecordHeaderError struct { + // Msg contains a human readable string that describes the error. + Msg string + // RecordHeader contains the five bytes of TLS record header that + // triggered the error. + RecordHeader [5]byte + // Conn provides the underlying net.Conn in the case that a client + // sent an initial handshake that didn't look like TLS. + // It is nil if there's already been a handshake or a TLS alert has + // been written to the connection. + Conn net.Conn +} + +func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } + +func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) { + err.Msg = msg + err.Conn = conn + copy(err.RecordHeader[:], c.rawInput[c.rawInputOff:]) + return err +} + +func (c *Conn) readRecord() error { + return c.readRecordOrCCS(false) +} + +func (c *Conn) readChangeCipherSpec() error { + return c.readRecordOrCCS(true) +} + +func (c *Conn) ResetOrFreeBuffer() { + c.closeMux.Lock() + defer c.closeMux.Unlock() + if c.closed { + return + } + + remain := len(c.rawInput) - c.rawInputOff + switch remain { + case 0: + if c.rawInput != nil { + c.allocator.Free(c.rawInput) + c.rawInput = nil + } + default: + copy(c.rawInput, c.rawInput[c.rawInputOff:]) + c.rawInput = c.rawInput[:remain] + c.rawInputOff = 0 + } +} + +// readRecordOrCCS reads one or more TLS records from the connection and +// updates the record layer state. Some invariants: +// * c.in must be locked +// * c.input must be empty +// During the handshake one and only one of the following will happen: +// - c.hand grows +// - c.in.changeCipherSpec is called +// - an error is returned +// After the handshake one and only one of the following will happen: +// - c.hand grows +// - c.input is set +// - an error is returned +func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { + if c.in.err != nil { + return c.in.err + } + handshakeComplete := c.handshakeComplete() + + if c.isNonBlock { + if len(c.rawInput)-c.rawInputOff < recordHeaderLen { + return errDataNotEnough + } + } else { + // This function modifies c.rawInput, which owns the c.input memory. + if c.input.Len() != 0 { + return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data")) + } + c.input.Reset(nil) + + // Read header, payload. + if err := c.readFromUntil(c.conn, recordHeaderLen, 0); err != nil { + // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify + // is an error, but popular web sites seem to do this, so we accept it + // if and only if at the record boundary. + if err == io.ErrUnexpectedEOF && len(c.rawInput)-c.rawInputOff == 0 { + err = io.EOF + } + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return err + } + } + + hdr := c.rawInput[c.rawInputOff : c.rawInputOff+recordHeaderLen] + typ := recordType(hdr[0]) + + // No valid TLS record has a type of 0x80, however SSLv2 handshakes + // start with a uint16 length where the MSB is set and the first record + // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests + // an SSLv2 client. + if !handshakeComplete && typ == 0x80 { + c.sendAlert(alertProtocolVersion) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) + } + + vers := uint16(hdr[1])<<8 | uint16(hdr[2]) + n := int(hdr[3])<<8 | int(hdr[4]) + if c.haveVers && c.vers != VersionTLS13 && vers != c.vers { + c.sendAlert(alertProtocolVersion) + msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } + if !c.haveVers { + // First message, be extra suspicious: this might not be a TLS + // client. Bail out before reading a full 'body', if possible. + // The current max version is 3.3 so if the version is >= 16.0, + // it's probably not real. + if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 { + return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) + } + } + if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext { + c.sendAlert(alertRecordOverflow) + msg := fmt.Sprintf("oversized record received with length %d", n) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } + + if c.isNonBlock { + if len(c.rawInput)-c.rawInputOff < recordHeaderLen+n { + return errDataNotEnough + } + } else { + if err := c.readFromUntil(c.conn, n, recordHeaderLen); err != nil { + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return err + } + } + + // Process message. + record := c.rawInput[c.rawInputOff : c.rawInputOff+recordHeaderLen+n] + c.rawInputOff += (recordHeaderLen + n) + data, typ, err := c.in.decrypt(record) + if err != nil { + return c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + if len(data) > maxPlaintext { + return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) + } + if len(c.rawInput) == c.rawInputOff { + c.allocator.Free(c.rawInput) + c.rawInput = nil + c.rawInputOff = 0 + } + + // Application Data messages are always protected. + if c.in.cipher == nil && typ == recordTypeApplicationData { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 { + // This is a state-advancing message: reset the retry count. + c.retryCount = 0 + } + + // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3. + if c.vers == VersionTLS13 && typ != recordTypeHandshake && len(c.hand)-c.handOff > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + switch typ { + default: + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + + case recordTypeAlert: + if len(data) != 2 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if alert(data[1]) == alertCloseNotify { + return c.in.setErrorLocked(io.EOF) + } + if c.vers == VersionTLS13 { + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + } + switch data[0] { + case alertLevelWarning: + // Drop the record on the floor and retry. + return c.retryReadRecord(expectChangeCipherSpec) + case alertLevelError: + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + default: + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + case recordTypeChangeCipherSpec: + if len(data) != 1 || data[0] != 1 { + return c.in.setErrorLocked(c.sendAlert(alertDecodeError)) + } + // Handshake messages are not allowed to fragment across the CCS. + if len(c.hand)-c.handOff > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // In TLS 1.3, change_cipher_spec records are ignored until the + // Finished. See RFC 8446, Appendix D.4. Note that according to Section + // 5, a server can send a ChangeCipherSpec before its ServerHello, when + // c.vers is still unset. That's not useful though and suspicious if the + // server then selects a lower protocol version, so don't allow that. + if c.vers == VersionTLS13 { + return c.retryReadRecord(expectChangeCipherSpec) + } + if !expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if err := c.in.changeCipherSpec(); err != nil { + return c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + + case recordTypeApplicationData: + if !handshakeComplete || expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // Some OpenSSL servers send empty records in order to randomize the + // CBC IV. Ignore a limited number of empty records. + if len(data) == 0 { + return c.retryReadRecord(expectChangeCipherSpec) + } + // Note that data is owned by c.rawInput, following the Next call above, + // to avoid copying the plaintext. This is safe because c.rawInput is + // not read from or written to until c.input is drained. + c.input.Reset(data) + + case recordTypeHandshake: + if len(data) == 0 || expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if c.hand == nil { + c.hand = c.allocator.Malloc(len(data))[0:0] + c.handOff = 0 + } + c.hand = append(c.hand, data...) + } + + return nil +} + +// retryReadRecord recurses into readRecordOrCCS to drop a non-advancing record, like +// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. +func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many ignored records")) + } + return c.readRecordOrCCS(expectChangeCipherSpec) +} + +// atLeastReader reads from R, stopping with EOF once at least N bytes have been +// read. It is different from an io.LimitedReader in that it doesn't cut short +// the last Read call, and in that it considers an early EOF an error. +type atLeastReader struct { + R io.Reader + N int64 +} + +func (r *atLeastReader) Read(p []byte) (int, error) { + if r.N <= 0 { + return 0, io.EOF + } + n, err := r.R.Read(p) + r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809 + if r.N > 0 && err == io.EOF { + return n, io.ErrUnexpectedEOF + } + if r.N <= 0 && err == nil { + return n, io.EOF + } + return n, err +} + +// readFromUntil reads from r into c.rawInput until c.rawInput contains +// at least n bytes or else returns an error. +func (c *Conn) readFromUntil(r io.Reader, n int, from int) error { + // if len(c.rawInput)-c.rawInputOff >= n { + // return nil + // } + // if len(c.rawInput) == c.rawInputOff { + // c.rawInput = c.rawInput[0:0] + // c.rawInputOff = 0 + // } + + needs := from + n - cap(c.rawInput) + // There might be extra input waiting on the wire. Make a best effort + // attempt to fetch it so that it can be used in (*Conn).Read to + // "predict" closeNotify alerts. + if needs > 0 { + buf := c.allocator.Malloc(needs) + c.rawInput = append(c.rawInput[:cap(c.rawInput)], buf...) + c.allocator.Free(buf) + } + c.rawInput = c.rawInput[:from+n] + _, err := io.ReadFull(r, c.rawInput[from:]) + return err +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlertLocked(err alert) error { + switch err { + case alertNoRenegotiation, alertCloseNotify: + c.tmp[0] = alertLevelWarning + default: + c.tmp[0] = alertLevelError + } + c.tmp[1] = byte(err) + + _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) + if err == alertCloseNotify { + // closeNotify is a special case in that it isn't an error. + return writeErr + } + + return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlert(err alert) error { + // c.out.Lock() + // defer c.out.Unlock() + return c.sendAlertLocked(err) +} + +const ( + // tcpMSSEstimate is a conservative estimate of the TCP maximum segment + // size (MSS). A constant is used, rather than querying the kernel for + // the actual MSS, to avoid complexity. The value here is the IPv6 + // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 + // bytes) and a TCP header with timestamps (32 bytes). + tcpMSSEstimate = 1208 + + // recordSizeBoostThreshold is the number of bytes of application data + // sent after which the TLS record size will be increased to the + // maximum. + recordSizeBoostThreshold = 128 * 1024 +) + +// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the +// next application data record. There is the following trade-off: +// +// - For latency-sensitive applications, such as web browsing, each TLS +// record should fit in one TCP segment. +// - For throughput-sensitive applications, such as large file transfers, +// larger TLS records better amortize framing and encryption overheads. +// +// A simple heuristic that works well in practice is to use small records for +// the first 1MB of data, then use larger records for subsequent data, and +// reset back to smaller records after the connection becomes idle. See "High +// Performance Web Networking", Chapter 4, or: +// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/ +// +// In the interests of simplicity and determinism, this code does not attempt +// to reset the record size once the connection is idle, however. +func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { + if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { + return maxPlaintext + } + + if c.bytesSent >= recordSizeBoostThreshold { + return maxPlaintext + } + + // Subtract TLS overheads to get the maximum payload size. + payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen() + if c.out.cipher != nil { + switch ciph := c.out.cipher.(type) { + case cipher.Stream: + payloadBytes -= c.out.mac.Size() + case cipher.AEAD: + payloadBytes -= ciph.Overhead() + case cbcMode: + blockSize := ciph.BlockSize() + // The payload must fit in a multiple of blockSize, with + // room for at least one padding byte. + payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 + // The MAC is appended before padding so affects the + // payload size directly. + payloadBytes -= c.out.mac.Size() + default: + panic("unknown cipher type") + } + } + if c.vers == VersionTLS13 { + payloadBytes-- // encrypted ContentType + } + + // Allow packet growth in arithmetic progression up to max. + pkt := c.packetsSent + c.packetsSent++ + if pkt > 1000 { + return maxPlaintext // avoid overflow in multiply below + } + + n := payloadBytes * int(pkt+1) + if n > maxPlaintext { + n = maxPlaintext + } + return n +} + +func (c *Conn) write(data []byte) (int, error) { + if c.buffering { + if len(c.sendBuf) == 0 { + c.sendBuf = data + copy(c.sendBuf, data) + } else { + c.sendBuf = append(c.sendBuf, data...) + c.allocator.Free(data) + } + return len(data), nil + } + + n, err := c.conn.Write(data) + c.bytesSent += int64(n) + return n, err +} + +func (c *Conn) flush() (int, error) { + c.buffering = false + + if len(c.sendBuf) == 0 { + return 0, nil + } + + n, err := c.conn.Write(c.sendBuf) + c.bytesSent += int64(n) + c.sendBuf = nil + return n, err +} + +// outBufPool pools the record-sized scratch buffers used by writeRecordLocked. +var outBufPool = sync.Pool{ + New: func() interface{} { + return new([]byte) + }, +} + +// writeRecordLocked writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { + // outBufPtr := outBufPool.Get().(*[]byte) + // outBuf := *outBufPtr + // defer func() { + // // You might be tempted to simplify this by just passing &outBuf to Put, + // // but that would make the local copy of the outBuf slice header escape + // // to the heap, causing an allocation. Instead, we keep around the + // // pointer to the slice header returned by Get, which is already on the + // // heap, and overwrite and return that. + // *outBufPtr = outBuf + // outBufPool.Put(outBufPtr) + // }() + + var n int + var outBuf []byte + // var outBuf = c.allocator.Malloc(recordHeaderLen + len(data))[0:0] + for len(data) > 0 { + m := len(data) + if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { + m = maxPayload + } + + outBuf = c.allocator.Malloc(recordHeaderLen) + // _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) + outBuf[0] = byte(typ) + vers := c.vers + if vers == 0 { + // Some TLS servers fail if the record version is + // greater than TLS 1.0 for the initial ClientHello. + vers = VersionTLS10 + } else if vers == VersionTLS13 { + // TLS 1.3 froze the record layer version to 1.2. + // See RFC 8446, Section 5.1. + vers = VersionTLS12 + } + outBuf[1] = byte(vers >> 8) + outBuf[2] = byte(vers) + outBuf[3] = byte(m >> 8) + outBuf[4] = byte(m) + + var err error + outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand()) + if err != nil { + return n, err + } + if _, err := c.write(outBuf); err != nil { + return n, err + } + n += m + data = data[m:] + } + + if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 { + if err := c.out.changeCipherSpec(); err != nil { + return n, c.sendAlertLocked(err.(alert)) + } + } + + return n, nil +} + +// writeRecord writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { + // c.out.Lock() + // defer c.out.Unlock() + + return c.writeRecordLocked(typ, data) +} + +// readHandshake reads the next handshake message from +// the record layer. +func (c *Conn) readHandshake() (interface{}, error) { + if c.isNonBlock { + if len(c.hand)-c.handOff < 4 { + if err := c.readRecord(); err != nil { + return nil, err + } + } + } else { + for len(c.hand)-c.handOff < 4 { + if err := c.readRecord(); err != nil { + return nil, err + } + } + } + + data := c.hand[c.handOff:] + n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if n > maxHandshake { + c.sendAlertLocked(alertInternalError) + return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) + } + if c.isNonBlock { + if len(c.hand)-c.handOff < 4+n { + if err := c.readRecord(); err != nil { + return nil, err + } + } + } else { + for len(c.hand)-c.handOff < 4+n { + if err := c.readRecord(); err != nil { + return nil, err + } + } + } + data = c.hand[c.handOff : c.handOff+4+n] + if c.handOff+4+n == len(c.hand) { + c.hand = c.hand[0:0] + c.handOff = 0 + } else { + c.handOff += (4 + n) + } + + var m handshakeMessage + switch data[0] { + case typeHelloRequest: + m = new(helloRequestMsg) + case typeClientHello: + m = new(clientHelloMsg) + case typeServerHello: + m = new(serverHelloMsg) + case typeNewSessionTicket: + if c.vers == VersionTLS13 { + m = new(newSessionTicketMsgTLS13) + } else { + m = new(newSessionTicketMsg) + } + case typeCertificate: + if c.vers == VersionTLS13 { + m = new(certificateMsgTLS13) + } else { + m = new(certificateMsg) + } + case typeCertificateRequest: + if c.vers == VersionTLS13 { + m = new(certificateRequestMsgTLS13) + } else { + m = &certificateRequestMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } + } + case typeCertificateStatus: + m = new(certificateStatusMsg) + case typeServerKeyExchange: + m = new(serverKeyExchangeMsg) + case typeServerHelloDone: + m = new(serverHelloDoneMsg) + case typeClientKeyExchange: + m = new(clientKeyExchangeMsg) + case typeCertificateVerify: + m = &certificateVerifyMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } + case typeFinished: + m = new(finishedMsg) + case typeEncryptedExtensions: + m = new(encryptedExtensionsMsg) + case typeEndOfEarlyData: + m = new(endOfEarlyDataMsg) + case typeKeyUpdate: + m = new(keyUpdateMsg) + default: + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + // The handshake message unmarshalers + // expect to be able to keep references to data, + // so pass in a fresh copy that won't be overwritten. + data = append([]byte(nil), data...) + + if !m.unmarshal(data) { + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + return m, nil +} + +var ( + errShutdown = errors.New("tls: protocol is shutdown") +) + +// Write writes data to the connection. +// +// As Write calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Write is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. +func (c *Conn) Write(b []byte) (int, error) { + defer c.allocator.Free(b) + + if len(b) == 0 { + return 0, nil + } + + c.closeMux.Lock() + defer c.closeMux.Unlock() + + if c.closed { + return 0, net.ErrClosed + } + + if err := c.Handshake(); err != nil { + return 0, err + } + + // c.out.Lock() + // defer c.out.Unlock() + + if err := c.out.err; err != nil { + return 0, err + } + + if !c.handshakeComplete() { + return 0, alertInternalError + } + + if c.closeNotifySent { + return 0, errShutdown + } + + // TLS 1.0 is susceptible to a chosen-plaintext + // attack when using block mode ciphers due to predictable IVs. + // This can be prevented by splitting each Application Data + // record into two records, effectively randomizing the IV. + // + // https://www.openssl.org/~bodo/tls-cbc.txt + // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 + // https://www.imperialviolet.org/2012/01/15/beastfollowup.html + + var m int + if len(b) > 1 && c.vers == VersionTLS10 { + if _, ok := c.out.cipher.(cipher.BlockMode); ok { + n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) + if err != nil { + return n, c.out.setErrorLocked(err) + } + m, b = 1, b[1:] + } + } + + n, err := c.writeRecordLocked(recordTypeApplicationData, b) + return n + m, c.out.setErrorLocked(err) +} + +// handleRenegotiation processes a HelloRequest handshake message. +func (c *Conn) handleRenegotiation() error { + if c.vers == VersionTLS13 { + return errors.New("tls: internal error: unexpected renegotiation") + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + helloReq, ok := msg.(*helloRequestMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(helloReq, msg) + } + + if !c.isClient { + return c.sendAlert(alertNoRenegotiation) + } + + switch c.config.Renegotiation { + case RenegotiateNever: + return c.sendAlert(alertNoRenegotiation) + case RenegotiateOnceAsClient: + if c.handshakes > 1 { + return c.sendAlert(alertNoRenegotiation) + } + case RenegotiateFreelyAsClient: + // Ok. + default: + c.sendAlert(alertInternalError) + return errors.New("tls: unknown Renegotiation value") + } + + // c.handshakeMutex.Lock() + // defer c.handshakeMutex.Unlock() + + atomic.StoreUint32(&c.handshakeStatus, 0) + if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { + c.handshakes++ + } + return c.handshakeErr +} + +// handlePostHandshakeMessage processes a handshake message arrived after the +// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation. +func (c *Conn) handlePostHandshakeMessage() error { + if c.vers != VersionTLS13 { + return c.handleRenegotiation() + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) + } + + switch msg := msg.(type) { + case *newSessionTicketMsgTLS13: + return c.handleNewSessionTicket(msg) + case *keyUpdateMsg: + return c.handleKeyUpdate(msg) + default: + c.sendAlert(alertUnexpectedMessage) + return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) + } +} + +func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) + if cipherSuite == nil { + return c.in.setErrorLocked(c.sendAlert(alertInternalError)) + } + + newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) + c.in.setTrafficSecret(cipherSuite, newSecret) + + if keyUpdate.updateRequested { + // c.out.Lock() + // defer c.out.Unlock() + + msg := &keyUpdateMsg{} + _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) + if err != nil { + // Surface the error at the next write. + c.out.setErrorLocked(err) + return nil + } + + newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) + c.out.setTrafficSecret(cipherSuite, newSecret) + } + + return nil +} + +// Append . +func (c *Conn) Append(b []byte) (int, error) { + c.closeMux.Lock() + defer c.closeMux.Unlock() + + if c.closed { + return 0, net.ErrClosed + } + + // c.in.Lock() + // defer c.in.Unlock() + + if len(b) > 0 { + if cap(c.rawInput) == 0 { + needs := len(b) + if needs < bytes.MinRead { + needs = bytes.MinRead + } + c.rawInput = c.allocator.Malloc(needs)[0:0] + c.rawInputOff = 0 + } else if len(c.rawInput) == c.rawInputOff { + c.rawInput = c.rawInput[0:0] + c.rawInputOff = 0 + } + c.rawInput = append(c.rawInput, b...) + } + return 0, nil +} + +// Read reads data from the connection. +// +// As Read calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Read is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. +func (c *Conn) Read(b []byte) (int, error) { + c.closeMux.Lock() + defer c.closeMux.Unlock() + if c.closed { + return 0, net.ErrClosed + } + + if err := c.Handshake(); err != nil { + if c.isNonBlock && err == errDataNotEnough { + return 0, nil + } + return 0, err + } + if len(b) == 0 { + // Put this after Handshake, in case people were calling + // Read(nil) for the side effect of the Handshake. + return 0, nil + } + + // c.in.Lock() + // defer c.in.Unlock() + + for c.input.Len() == 0 { + if err := c.readRecord(); err != nil { + if c.isNonBlock && err == errDataNotEnough { + return 0, nil + } + return 0, err + } + for len(c.hand)-c.handOff > 0 { + if err := c.handlePostHandshakeMessage(); err != nil { + if c.isNonBlock && err == errDataNotEnough { + return 0, nil + } + return 0, err + } + } + } + + n, _ := c.input.Read(b) + + // If a close-notify alert is waiting, read it so that we can return (n, + // EOF) instead of (n, nil), to signal to the HTTP response reading + // goroutine that the connection is now closed. This eliminates a race + // where the HTTP response reading goroutine would otherwise not observe + // the EOF until its next read, by which time a client goroutine might + // have already tried to reuse the HTTP connection for a new request. + // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 + if n != 0 && c.input.Len() == 0 && len(c.rawInput)-c.rawInputOff > 0 && + recordType(c.rawInput[c.rawInputOff]) == recordTypeAlert { + if err := c.readRecord(); err != nil { + if c.isNonBlock && err == errDataNotEnough { + return 0, nil + } + return n, err // will be io.EOF on closeNotify + } + } + + return n, nil +} + +// Append . +func (c *Conn) AppendAndRead(bufAppend []byte, bufRead []byte) (int, int, error) { + c.closeMux.Lock() + defer c.closeMux.Unlock() + + if c.closed { + return 0, 0, net.ErrClosed + } + + // c.in.Lock() + // defer c.in.Unlock() + + if len(bufAppend) > 0 { + if cap(c.rawInput) == 0 { + needs := len(bufAppend) + if needs < bytes.MinRead { + needs = bytes.MinRead + } + c.rawInput = c.allocator.Malloc(needs)[0:0] + c.rawInputOff = 0 + } else if len(c.rawInput) == c.rawInputOff { + c.rawInput = c.rawInput[0:0] + c.rawInputOff = 0 + } + c.rawInput = append(c.rawInput, bufAppend...) + } + + if err := c.Handshake(); err != nil { + if c.isNonBlock && err == errDataNotEnough { + return len(bufAppend), 0, nil + } + return len(bufAppend), 0, err + } + + if len(bufRead) == 0 { + // Put this after Handshake, in case people were calling + // Read(nil) for the side effect of the Handshake. + return len(bufAppend), 0, nil + } + + for c.input.Len() == 0 { + if err := c.readRecord(); err != nil { + if c.isNonBlock && err == errDataNotEnough { + return len(bufAppend), 0, nil + } + return len(bufAppend), 0, err + } + for len(c.hand)-c.handOff > 0 { + if err := c.handlePostHandshakeMessage(); err != nil { + if c.isNonBlock && err == errDataNotEnough { + return len(bufAppend), 0, nil + } + return len(bufAppend), 0, err + } + } + } + + n, _ := c.input.Read(bufRead) + + // If a close-notify alert is waiting, read it so that we can return (n, + // EOF) instead of (n, nil), to signal to the HTTP response reading + // goroutine that the connection is now closed. This eliminates a race + // where the HTTP response reading goroutine would otherwise not observe + // the EOF until its next read, by which time a client goroutine might + // have already tried to reuse the HTTP connection for a new request. + // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 + if n != 0 && c.input.Len() == 0 && len(c.rawInput)-c.rawInputOff > 0 && + recordType(c.rawInput[c.rawInputOff]) == recordTypeAlert { + if err := c.readRecord(); err != nil { + if c.isNonBlock && err == errDataNotEnough { + return len(bufAppend), 0, nil + } + return len(bufAppend), n, err // will be io.EOF on closeNotify + } + } + + return len(bufAppend), n, nil +} + +func (c *Conn) release() { + if cap(c.hand) > 0 { + c.allocator.Free(c.hand) + } + if cap(c.rawInput) > 0 { + c.allocator.Free(c.rawInput) + } +} + +// Close closes the connection. +func (c *Conn) Close() error { + c.closeMux.Lock() + closed := c.closed + c.closed = true + c.closeMux.Unlock() + + if closed { + return net.ErrClosed + } + + c.release() + + var alertErr error + if c.handshakeComplete() { + if err := c.closeNotify(); err != nil { + alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) + } + } + + if err := c.conn.Close(); err != nil { + return err + } + return alertErr +} + +var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") + +// CloseWrite shuts down the writing side of the connection. It should only be +// called once the handshake has completed and does not call CloseWrite on the +// underlying connection. Most callers should just use Close. +func (c *Conn) CloseWrite() error { + if !c.handshakeComplete() { + return errEarlyCloseWrite + } + + return c.closeNotify() +} + +func (c *Conn) closeNotify() error { + // c.out.Lock() + // defer c.out.Unlock() + + if !c.closeNotifySent { + // Set a Write Deadline to prevent possibly blocking forever. + c.SetWriteDeadline(time.Now().Add(time.Second * 5)) + c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) + c.closeNotifySent = true + // Any subsequent writes will fail. + c.SetWriteDeadline(time.Now()) + } + return c.closeNotifyErr +} + +// Handshake runs the client or server handshake +// protocol if it has not yet been run. +// +// Most uses of this package need not call Handshake explicitly: the +// first Read or Write will call it automatically. +// +// For control over canceling or setting a timeout on a handshake, use +// the Dialer's DialContext method. +func (c *Conn) Handshake() error { + // c.handshakeMutex.Lock() + // defer c.handshakeMutex.Unlock() + + if err := c.handshakeErr; err != nil { + return err + } + if c.handshakeComplete() { + return nil + } + + // c.in.Lock() + // defer c.in.Unlock() + + c.handshakeErr = c.handshakeFn() + if c.isNonBlock && c.handshakeErr == errDataNotEnough { + c.handshakeErr = nil + return errDataNotEnough + } + if c.handshakeErr == nil { + c.handshakes++ + } else { + // If an error occurred during the handshake try to flush the + // alert that might be left in the buffer. + c.flush() + } + + if c.handshakeErr == nil && !c.handshakeComplete() { + c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") + } + + return c.handshakeErr +} + +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() ConnectionState { + // c.handshakeMutex.Lock() + // defer c.handshakeMutex.Unlock() + return c.connectionStateLocked() +} + +func (c *Conn) connectionStateLocked() ConnectionState { + var state ConnectionState + state.HandshakeComplete = c.handshakeComplete() + state.Version = c.vers + state.NegotiatedProtocol = c.clientProtocol + state.DidResume = c.didResume + state.NegotiatedProtocolIsMutual = true + state.ServerName = c.serverName + state.CipherSuite = c.cipherSuite + state.PeerCertificates = c.peerCertificates + state.VerifiedChains = c.verifiedChains + state.SignedCertificateTimestamps = c.scts + state.OCSPResponse = c.ocspResponse + if !c.didResume && c.vers != VersionTLS13 { + if c.clientFinishedIsFirst { + state.TLSUnique = c.clientFinished[:] + } else { + state.TLSUnique = c.serverFinished[:] + } + } + if c.config.Renegotiation != RenegotiateNever { + state.ekm = noExportedKeyingMaterial + } else { + state.ekm = c.ekm + } + return state +} + +// OCSPResponse returns the stapled OCSP response from the TLS server, if +// any. (Only valid for client connections.) +func (c *Conn) OCSPResponse() []byte { + // c.handshakeMutex.Lock() + // defer c.handshakeMutex.Unlock() + + return c.ocspResponse +} + +// VerifyHostname checks that the peer certificate chain is valid for +// connecting to host. If so, it returns nil; if not, it returns an error +// describing the problem. +func (c *Conn) VerifyHostname(host string) error { + // c.handshakeMutex.Lock() + // defer c.handshakeMutex.Unlock() + if !c.isClient { + return errors.New("tls: VerifyHostname called on TLS server connection") + } + if !c.handshakeComplete() { + return errors.New("tls: handshake has not yet been performed") + } + if len(c.verifiedChains) == 0 { + return errors.New("tls: handshake did not verify certificate chain") + } + return c.peerCertificates[0].VerifyHostname(host) +} + +func (c *Conn) handshakeComplete() bool { + return atomic.LoadUint32(&c.handshakeStatus) == 1 +} + +// Session returns user session +func (c *Conn) Session() interface{} { + return c.session +} + +// SetSession sets user session +func (c *Conn) SetSession(session interface{}) { + c.session = session +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/conn_test.go b/vendor/github.com/lesismal/llib/std/crypto/tls/conn_test.go new file mode 100644 index 0000000..78935b1 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/conn_test.go @@ -0,0 +1,287 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "io" + "net" + "testing" +) + +func TestRoundUp(t *testing.T) { + if roundUp(0, 16) != 0 || + roundUp(1, 16) != 16 || + roundUp(15, 16) != 16 || + roundUp(16, 16) != 16 || + roundUp(17, 16) != 32 { + t.Error("roundUp broken") + } +} + +// will be initialized with {0, 255, 255, ..., 255} +var padding255Bad = [256]byte{} + +// will be initialized with {255, 255, 255, ..., 255} +var padding255Good = [256]byte{255} + +var paddingTests = []struct { + in []byte + good bool + expectedLen int +}{ + {[]byte{1, 2, 3, 4, 0}, true, 4}, + {[]byte{1, 2, 3, 4, 0, 1}, false, 0}, + {[]byte{1, 2, 3, 4, 99, 99}, false, 0}, + {[]byte{1, 2, 3, 4, 1, 1}, true, 4}, + {[]byte{1, 2, 3, 2, 2, 2}, true, 3}, + {[]byte{1, 2, 3, 3, 3, 3}, true, 2}, + {[]byte{1, 2, 3, 4, 3, 3}, false, 0}, + {[]byte{1, 4, 4, 4, 4, 4}, true, 1}, + {[]byte{5, 5, 5, 5, 5, 5}, true, 0}, + {[]byte{6, 6, 6, 6, 6, 6}, false, 0}, + {padding255Bad[:], false, 0}, + {padding255Good[:], true, 0}, +} + +func TestRemovePadding(t *testing.T) { + for i := 1; i < len(padding255Bad); i++ { + padding255Bad[i] = 255 + padding255Good[i] = 255 + } + for i, test := range paddingTests { + paddingLen, good := extractPadding(test.in) + expectedGood := byte(255) + if !test.good { + expectedGood = 0 + } + if good != expectedGood { + t.Errorf("#%d: wrong validity, want:%d got:%d", i, expectedGood, good) + } + if good == 255 && len(test.in)-paddingLen != test.expectedLen { + t.Errorf("#%d: got %d, want %d", i, len(test.in)-paddingLen, test.expectedLen) + } + } +} + +var certExampleCom = `308201713082011ba003020102021005a75ddf21014d5f417083b7a010ba2e300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343135335a170d3137303831373231343135335a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100b37f0fdd67e715bf532046ac34acbd8fdc4dabe2b598588f3f58b1f12e6219a16cbfe54d2b4b665396013589262360b6721efa27d546854f17cc9aeec6751db10203010001a34d304b300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030160603551d11040f300d820b6578616d706c652e636f6d300d06092a864886f70d01010b050003410059fc487866d3d855503c8e064ca32aac5e9babcece89ec597f8b2b24c17867f4a5d3b4ece06e795bfc5448ccbd2ffca1b3433171ebf3557a4737b020565350a0` + +var certWildcardExampleCom = `308201743082011ea003020102021100a7aa6297c9416a4633af8bec2958c607300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343231395a170d3137303831373231343231395a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100b105afc859a711ee864114e7d2d46c2dcbe392d3506249f6c2285b0eb342cc4bf2d803677c61c0abde443f084745c1a6d62080e5664ef2cc8f50ad8a0ab8870b0203010001a34f304d300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030180603551d110411300f820d2a2e6578616d706c652e636f6d300d06092a864886f70d01010b0500034100af26088584d266e3f6566360cf862c7fecc441484b098b107439543144a2b93f20781988281e108c6d7656934e56950e1e5f2bcf38796b814ccb729445856c34` + +var certFooExampleCom = `308201753082011fa00302010202101bbdb6070b0aeffc49008cde74deef29300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343234345a170d3137303831373231343234345a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100f00ac69d8ca2829f26216c7b50f1d4bbabad58d447706476cd89a2f3e1859943748aa42c15eedc93ac7c49e40d3b05ed645cb6b81c4efba60d961f44211a54eb0203010001a351304f300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000301a0603551d1104133011820f666f6f2e6578616d706c652e636f6d300d06092a864886f70d01010b0500034100a0957fca6d1e0f1ef4b247348c7a8ca092c29c9c0ecc1898ea6b8065d23af6d922a410dd2335a0ea15edd1394cef9f62c9e876a21e35250a0b4fe1ddceba0f36` + +func TestCertificateSelection(t *testing.T) { + config := Config{ + Certificates: []Certificate{ + { + Certificate: [][]byte{fromHex(certExampleCom)}, + }, + { + Certificate: [][]byte{fromHex(certWildcardExampleCom)}, + }, + { + Certificate: [][]byte{fromHex(certFooExampleCom)}, + }, + }, + } + + config.BuildNameToCertificate() + + pointerToIndex := func(c *Certificate) int { + for i := range config.Certificates { + if c == &config.Certificates[i] { + return i + } + } + return -1 + } + + certificateForName := func(name string) *Certificate { + clientHello := &ClientHelloInfo{ + ServerName: name, + } + if cert, err := config.getCertificate(clientHello); err != nil { + t.Errorf("unable to get certificate for name '%s': %s", name, err) + return nil + } else { + return cert + } + } + + if n := pointerToIndex(certificateForName("example.com")); n != 0 { + t.Errorf("example.com returned certificate %d, not 0", n) + } + if n := pointerToIndex(certificateForName("bar.example.com")); n != 1 { + t.Errorf("bar.example.com returned certificate %d, not 1", n) + } + if n := pointerToIndex(certificateForName("foo.example.com")); n != 2 { + t.Errorf("foo.example.com returned certificate %d, not 2", n) + } + if n := pointerToIndex(certificateForName("foo.bar.example.com")); n != 0 { + t.Errorf("foo.bar.example.com returned certificate %d, not 0", n) + } +} + +// Run with multiple crypto configs to test the logic for computing TLS record overheads. +func runDynamicRecordSizingTest(t *testing.T, config *Config) { + clientConn, serverConn := localPipe(t) + + serverConfig := config.Clone() + serverConfig.DynamicRecordSizingDisabled = false + tlsConn := Server(serverConn, serverConfig) + + handshakeDone := make(chan struct{}) + recordSizesChan := make(chan []int, 1) + defer func() { <-recordSizesChan }() // wait for the goroutine to exit + go func() { + // This goroutine performs a TLS handshake over clientConn and + // then reads TLS records until EOF. It writes a slice that + // contains all the record sizes to recordSizesChan. + defer close(recordSizesChan) + defer clientConn.Close() + + tlsConn := Client(clientConn, config) + if err := tlsConn.Handshake(); err != nil { + t.Errorf("Error from client handshake: %v", err) + return + } + close(handshakeDone) + + var recordHeader [recordHeaderLen]byte + var record []byte + var recordSizes []int + + for { + n, err := io.ReadFull(clientConn, recordHeader[:]) + if err == io.EOF { + break + } + if err != nil || n != len(recordHeader) { + t.Errorf("io.ReadFull = %d, %v", n, err) + return + } + + length := int(recordHeader[3])<<8 | int(recordHeader[4]) + if len(record) < length { + record = make([]byte, length) + } + + n, err = io.ReadFull(clientConn, record[:length]) + if err != nil || n != length { + t.Errorf("io.ReadFull = %d, %v", n, err) + return + } + + recordSizes = append(recordSizes, recordHeaderLen+length) + } + + recordSizesChan <- recordSizes + }() + + if err := tlsConn.Handshake(); err != nil { + t.Fatalf("Error from server handshake: %s", err) + } + <-handshakeDone + + // The server writes these plaintexts in order. + plaintext := bytes.Join([][]byte{ + bytes.Repeat([]byte("x"), recordSizeBoostThreshold), + bytes.Repeat([]byte("y"), maxPlaintext*2), + bytes.Repeat([]byte("z"), maxPlaintext), + }, nil) + + if _, err := tlsConn.Write(plaintext); err != nil { + t.Fatalf("Error from server write: %s", err) + } + if err := tlsConn.Close(); err != nil { + t.Fatalf("Error from server close: %s", err) + } + + recordSizes := <-recordSizesChan + if recordSizes == nil { + t.Fatalf("Client encountered an error") + } + + // Drop the size of the second to last record, which is likely to be + // truncated, and the last record, which is a close_notify alert. + recordSizes = recordSizes[:len(recordSizes)-2] + + // recordSizes should contain a series of records smaller than + // tcpMSSEstimate followed by some larger than maxPlaintext. + seenLargeRecord := false + for i, size := range recordSizes { + if !seenLargeRecord { + if size > (i+1)*tcpMSSEstimate { + t.Fatalf("Record #%d has size %d, which is too large too soon", i, size) + } + if size >= maxPlaintext { + seenLargeRecord = true + } + } else if size <= maxPlaintext { + t.Fatalf("Record #%d has size %d but should be full sized", i, size) + } + } + + if !seenLargeRecord { + t.Fatalf("No large records observed") + } +} + +func TestDynamicRecordSizingWithStreamCipher(t *testing.T) { + config := testConfig.Clone() + config.MaxVersion = VersionTLS12 + config.CipherSuites = []uint16{TLS_RSA_WITH_RC4_128_SHA} + runDynamicRecordSizingTest(t, config) +} + +func TestDynamicRecordSizingWithCBC(t *testing.T) { + config := testConfig.Clone() + config.MaxVersion = VersionTLS12 + config.CipherSuites = []uint16{TLS_RSA_WITH_AES_256_CBC_SHA} + runDynamicRecordSizingTest(t, config) +} + +func TestDynamicRecordSizingWithAEAD(t *testing.T) { + config := testConfig.Clone() + config.MaxVersion = VersionTLS12 + config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256} + runDynamicRecordSizingTest(t, config) +} + +func TestDynamicRecordSizingWithTLSv13(t *testing.T) { + config := testConfig.Clone() + runDynamicRecordSizingTest(t, config) +} + +// hairpinConn is a net.Conn that makes a “hairpin” call when closed, back into +// the tls.Conn which is calling it. +type hairpinConn struct { + net.Conn + tlsConn *Conn +} + +func (conn *hairpinConn) Close() error { + conn.tlsConn.ConnectionState() + return nil +} + +func TestHairpinInClose(t *testing.T) { + // This tests that the underlying net.Conn can call back into the + // tls.Conn when being closed without deadlocking. + client, server := localPipe(t) + defer server.Close() + defer client.Close() + + conn := &hairpinConn{client, nil} + tlsConn := Server(conn, &Config{ + GetCertificate: func(*ClientHelloInfo) (*Certificate, error) { + panic("unreachable") + }, + }) + conn.tlsConn = tlsConn + + // This call should not deadlock. + tlsConn.Close() +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/example_test.go b/vendor/github.com/lesismal/llib/std/crypto/tls/example_test.go new file mode 100644 index 0000000..6389fd7 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/example_test.go @@ -0,0 +1,232 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls_test + +import ( + "crypto/tls" + "crypto/x509" + "log" + "net/http" + "net/http/httptest" + "os" + "time" +) + +// zeroSource is an io.Reader that returns an unlimited number of zero bytes. +type zeroSource struct{} + +func (zeroSource) Read(b []byte) (n int, err error) { + for i := range b { + b[i] = 0 + } + + return len(b), nil +} + +func ExampleDial() { + // Connecting with a custom root-certificate set. + + const rootPEM = ` +-- GlobalSign Root R2, valid until Dec 15, 2021 +-----BEGIN CERTIFICATE----- +MIIDujCCAqKgAwIBAgILBAAAAAABD4Ym5g0wDQYJKoZIhvcNAQEFBQAwTDEgMB4G +A1UECxMXR2xvYmFsU2lnbiBSb290IENBIC0gUjIxEzARBgNVBAoTCkdsb2JhbFNp +Z24xEzARBgNVBAMTCkdsb2JhbFNpZ24wHhcNMDYxMjE1MDgwMDAwWhcNMjExMjE1 +MDgwMDAwWjBMMSAwHgYDVQQLExdHbG9iYWxTaWduIFJvb3QgQ0EgLSBSMjETMBEG +A1UEChMKR2xvYmFsU2lnbjETMBEGA1UEAxMKR2xvYmFsU2lnbjCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAKbPJA6+Lm8omUVCxKs+IVSbC9N/hHD6ErPL +v4dfxn+G07IwXNb9rfF73OX4YJYJkhD10FPe+3t+c4isUoh7SqbKSaZeqKeMWhG8 +eoLrvozps6yWJQeXSpkqBy+0Hne/ig+1AnwblrjFuTosvNYSuetZfeLQBoZfXklq +tTleiDTsvHgMCJiEbKjNS7SgfQx5TfC4LcshytVsW33hoCmEofnTlEnLJGKRILzd +C9XZzPnqJworc5HGnRusyMvo4KD0L5CLTfuwNhv2GXqF4G3yYROIXJ/gkwpRl4pa +zq+r1feqCapgvdzZX99yqWATXgAByUr6P6TqBwMhAo6CygPCm48CAwEAAaOBnDCB +mTAOBgNVHQ8BAf8EBAMCAQYwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUm+IH +V2ccHsBqBt5ZtJot39wZhi4wNgYDVR0fBC8wLTAroCmgJ4YlaHR0cDovL2NybC5n +bG9iYWxzaWduLm5ldC9yb290LXIyLmNybDAfBgNVHSMEGDAWgBSb4gdXZxwewGoG +3lm0mi3f3BmGLjANBgkqhkiG9w0BAQUFAAOCAQEAmYFThxxol4aR7OBKuEQLq4Gs +J0/WwbgcQ3izDJr86iw8bmEbTUsp9Z8FHSbBuOmDAGJFtqkIk7mpM0sYmsL4h4hO +291xNBrBVNpGP+DTKqttVCL1OmLNIG+6KYnX3ZHu01yiPqFbQfXf5WRDLenVOavS +ot+3i9DAgBkcRcAtjOj4LaR0VknFBbVPFd5uRHg5h6h+u/N5GJG79G+dwfCMNYxd +AfvDbbnvRG15RjF+Cv6pgsH/76tuIMRQyV+dTZsXjAzlAcmgQWpzU/qlULRuJQ/7 +TBj0/VLZjmmx6BEP3ojY+x1J96relc8geMJgEtslQIxq/H5COEBkEveegeGTLg== +-----END CERTIFICATE-----` + + // First, create the set of root certificates. For this example we only + // have one. It's also possible to omit this in order to use the + // default root set of the current operating system. + roots := x509.NewCertPool() + ok := roots.AppendCertsFromPEM([]byte(rootPEM)) + if !ok { + panic("failed to parse root certificate") + } + + conn, err := tls.Dial("tcp", "mail.google.com:443", &tls.Config{ + RootCAs: roots, + }) + if err != nil { + panic("failed to connect: " + err.Error()) + } + conn.Close() +} + +func ExampleConfig_keyLogWriter() { + // Debugging TLS applications by decrypting a network traffic capture. + + // WARNING: Use of KeyLogWriter compromises security and should only be + // used for debugging. + + // Dummy test HTTP server for the example with insecure random so output is + // reproducible. + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + server.TLS = &tls.Config{ + Rand: zeroSource{}, // for example only; don't do this. + } + server.StartTLS() + defer server.Close() + + // Typically the log would go to an open file: + // w, err := os.OpenFile("tls-secrets.txt", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + w := os.Stdout + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + KeyLogWriter: w, + + Rand: zeroSource{}, // for reproducible output; don't do this. + InsecureSkipVerify: true, // test server certificate is not trusted. + }, + }, + } + resp, err := client.Get(server.URL) + if err != nil { + log.Fatalf("Failed to get URL: %v", err) + } + resp.Body.Close() + + // The resulting file can be used with Wireshark to decrypt the TLS + // connection by setting (Pre)-Master-Secret log filename in SSL Protocol + // preferences. +} + +func ExampleLoadX509KeyPair() { + cert, err := tls.LoadX509KeyPair("testdata/example-cert.pem", "testdata/example-key.pem") + if err != nil { + log.Fatal(err) + } + cfg := &tls.Config{Certificates: []tls.Certificate{cert}} + listener, err := tls.Listen("tcp", ":2000", cfg) + if err != nil { + log.Fatal(err) + } + _ = listener +} + +func ExampleX509KeyPair() { + certPem := []byte(`-----BEGIN CERTIFICATE----- +MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw +DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow +EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d +7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B +5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr +BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 +NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l +Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc +6MF9+Yw1Yy0t +-----END CERTIFICATE-----`) + keyPem := []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 +AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q +EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== +-----END EC PRIVATE KEY-----`) + cert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + log.Fatal(err) + } + cfg := &tls.Config{Certificates: []tls.Certificate{cert}} + listener, err := tls.Listen("tcp", ":2000", cfg) + if err != nil { + log.Fatal(err) + } + _ = listener +} + +func ExampleX509KeyPair_httpServer() { + certPem := []byte(`-----BEGIN CERTIFICATE----- +MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw +DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow +EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d +7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B +5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr +BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 +NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l +Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc +6MF9+Yw1Yy0t +-----END CERTIFICATE-----`) + keyPem := []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 +AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q +EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== +-----END EC PRIVATE KEY-----`) + cert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + log.Fatal(err) + } + cfg := &tls.Config{Certificates: []tls.Certificate{cert}} + srv := &http.Server{ + TLSConfig: cfg, + ReadTimeout: time.Minute, + WriteTimeout: time.Minute, + } + log.Fatal(srv.ListenAndServeTLS("", "")) +} + +func ExampleConfig_verifyConnection() { + // VerifyConnection can be used to replace and customize connection + // verification. This example shows a VerifyConnection implementation that + // will be approximately equivalent to what crypto/tls does normally to + // verify the peer's certificate. + + // Client side configuration. + _ = &tls.Config{ + // Set InsecureSkipVerify to skip the default validation we are + // replacing. This will not disable VerifyConnection. + InsecureSkipVerify: true, + VerifyConnection: func(cs tls.ConnectionState) error { + opts := x509.VerifyOptions{ + DNSName: cs.ServerName, + Intermediates: x509.NewCertPool(), + } + for _, cert := range cs.PeerCertificates[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := cs.PeerCertificates[0].Verify(opts) + return err + }, + } + + // Server side configuration. + _ = &tls.Config{ + // Require client certificates (or VerifyConnection will run anyway and + // panic accessing cs.PeerCertificates[0]) but don't verify them with the + // default verifier. This will not disable VerifyConnection. + ClientAuth: tls.RequireAnyClientCert, + VerifyConnection: func(cs tls.ConnectionState) error { + opts := x509.VerifyOptions{ + DNSName: cs.ServerName, + Intermediates: x509.NewCertPool(), + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + for _, cert := range cs.PeerCertificates[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := cs.PeerCertificates[0].Verify(opts) + return err + }, + } + + // Note that when certificates are not handled by the default verifier + // ConnectionState.VerifiedChains will be nil. +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/generate_cert.go b/vendor/github.com/lesismal/llib/std/crypto/tls/generate_cert.go new file mode 100644 index 0000000..1857185 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/generate_cert.go @@ -0,0 +1,172 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +// Generate a self-signed X.509 certificate for a TLS server. Outputs to +// 'cert.pem' and 'key.pem' and will overwrite existing files. + +package main + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "flag" + "log" + "math/big" + "net" + "os" + "strings" + "time" +) + +var ( + host = flag.String("host", "", "Comma-separated hostnames and IPs to generate a certificate for") + validFrom = flag.String("start-date", "", "Creation date formatted as Jan 1 15:04:05 2011") + validFor = flag.Duration("duration", 365*24*time.Hour, "Duration that certificate is valid for") + isCA = flag.Bool("ca", false, "whether this cert should be its own Certificate Authority") + rsaBits = flag.Int("rsa-bits", 2048, "Size of RSA key to generate. Ignored if --ecdsa-curve is set") + ecdsaCurve = flag.String("ecdsa-curve", "", "ECDSA curve to use to generate a key. Valid values are P224, P256 (recommended), P384, P521") + ed25519Key = flag.Bool("ed25519", false, "Generate an Ed25519 key") +) + +func publicKey(priv interface{}) interface{} { + switch k := priv.(type) { + case *rsa.PrivateKey: + return &k.PublicKey + case *ecdsa.PrivateKey: + return &k.PublicKey + case ed25519.PrivateKey: + return k.Public().(ed25519.PublicKey) + default: + return nil + } +} + +func main() { + flag.Parse() + + if len(*host) == 0 { + log.Fatalf("Missing required --host parameter") + } + + var priv interface{} + var err error + switch *ecdsaCurve { + case "": + if *ed25519Key { + _, priv, err = ed25519.GenerateKey(rand.Reader) + } else { + priv, err = rsa.GenerateKey(rand.Reader, *rsaBits) + } + case "P224": + priv, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + case "P256": + priv, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + case "P384": + priv, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + case "P521": + priv, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + default: + log.Fatalf("Unrecognized elliptic curve: %q", *ecdsaCurve) + } + if err != nil { + log.Fatalf("Failed to generate private key: %v", err) + } + + // ECDSA, ED25519 and RSA subject keys should have the DigitalSignature + // KeyUsage bits set in the x509.Certificate template + keyUsage := x509.KeyUsageDigitalSignature + // Only RSA subject keys should have the KeyEncipherment KeyUsage bits set. In + // the context of TLS this KeyUsage is particular to RSA key exchange and + // authentication. + if _, isRSA := priv.(*rsa.PrivateKey); isRSA { + keyUsage |= x509.KeyUsageKeyEncipherment + } + + var notBefore time.Time + if len(*validFrom) == 0 { + notBefore = time.Now() + } else { + notBefore, err = time.Parse("Jan 2 15:04:05 2006", *validFrom) + if err != nil { + log.Fatalf("Failed to parse creation date: %v", err) + } + } + + notAfter := notBefore.Add(*validFor) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + log.Fatalf("Failed to generate serial number: %v", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Acme Co"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: keyUsage, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + hosts := strings.Split(*host, ",") + for _, h := range hosts { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + + if *isCA { + template.IsCA = true + template.KeyUsage |= x509.KeyUsageCertSign + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv) + if err != nil { + log.Fatalf("Failed to create certificate: %v", err) + } + + certOut, err := os.Create("cert.pem") + if err != nil { + log.Fatalf("Failed to open cert.pem for writing: %v", err) + } + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + log.Fatalf("Failed to write data to cert.pem: %v", err) + } + if err := certOut.Close(); err != nil { + log.Fatalf("Error closing cert.pem: %v", err) + } + log.Print("wrote cert.pem\n") + + keyOut, err := os.OpenFile("key.pem", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + log.Fatalf("Failed to open key.pem for writing: %v", err) + return + } + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + log.Fatalf("Unable to marshal private key: %v", err) + } + if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { + log.Fatalf("Failed to write data to key.pem: %v", err) + } + if err := keyOut.Close(); err != nil { + log.Fatalf("Error closing key.pem: %v", err) + } + log.Print("wrote key.pem\n") +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_client.go b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_client.go new file mode 100644 index 0000000..e684b21 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_client.go @@ -0,0 +1,1002 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "net" + "strings" + "sync/atomic" + "time" +) + +type clientHandshakeState struct { + c *Conn + serverHello *serverHelloMsg + hello *clientHelloMsg + suite *cipherSuite + finishedHash finishedHash + masterSecret []byte + session *ClientSessionState +} + +func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { + config := c.config + if len(config.ServerName) == 0 && !config.InsecureSkipVerify { + return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") + } + + nextProtosLength := 0 + for _, proto := range config.NextProtos { + if l := len(proto); l == 0 || l > 255 { + return nil, nil, errors.New("tls: invalid NextProtos value") + } else { + nextProtosLength += 1 + l + } + } + if nextProtosLength > 0xffff { + return nil, nil, errors.New("tls: NextProtos values too large") + } + + supportedVersions := config.supportedVersions() + if len(supportedVersions) == 0 { + return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") + } + + clientHelloVersion := config.maxSupportedVersion() + // The version at the beginning of the ClientHello was capped at TLS 1.2 + // for compatibility reasons. The supported_versions extension is used + // to negotiate versions now. See RFC 8446, Section 4.2.1. + if clientHelloVersion > VersionTLS12 { + clientHelloVersion = VersionTLS12 + } + + hello := &clientHelloMsg{ + vers: clientHelloVersion, + compressionMethods: []uint8{compressionNone}, + random: make([]byte, 32), + sessionId: make([]byte, 32), + ocspStapling: true, + scts: true, + serverName: hostnameInSNI(config.ServerName), + supportedCurves: config.curvePreferences(), + supportedPoints: []uint8{pointFormatUncompressed}, + secureRenegotiationSupported: true, + alpnProtocols: config.NextProtos, + supportedVersions: supportedVersions, + } + + if c.handshakes > 0 { + hello.secureRenegotiation = c.clientFinished[:] + } + + possibleCipherSuites := config.cipherSuites() + hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites)) + + for _, suiteId := range possibleCipherSuites { + for _, suite := range cipherSuites { + if suite.id != suiteId { + continue + } + // Don't advertise TLS 1.2-only cipher suites unless + // we're attempting TLS 1.2. + if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { + break + } + hello.cipherSuites = append(hello.cipherSuites, suiteId) + break + } + } + + _, err := io.ReadFull(config.rand(), hello.random) + if err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + + // A random session ID is used to detect when the server accepted a ticket + // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as + // a compatibility measure (see RFC 8446, Section 4.1.2). + if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + + if hello.vers >= VersionTLS12 { + hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + + var params ecdheParameters + if hello.supportedVersions[0] == VersionTLS13 { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13()...) + + curveID := config.curvePreferences()[0] + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err = generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, nil, err + } + hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + } + + return hello, params, nil +} + +func (c *Conn) clientHandshake() (err error) { + if c.config == nil { + c.config = defaultConfig() + } + + // This may be a renegotiation handshake, in which case some fields + // need to be reset. + c.didResume = false + + hello, ecdheParams, err := c.makeClientHello() + if err != nil { + return err + } + c.serverName = hello.serverName + + cacheKey, session, earlySecret, binderKey := c.loadSession(hello) + if cacheKey != "" && session != nil { + defer func() { + // If we got a handshake failure when resuming a session, throw away + // the session ticket. See RFC 5077, Section 3.2. + // + // RFC 8446 makes no mention of dropping tickets on failure, but it + // does require servers to abort on invalid binders, so we need to + // delete tickets to recover from a corrupted PSK. + if err != nil { + c.config.ClientSessionCache.Put(cacheKey, nil) + } + }() + } + + if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + + if err := c.pickTLSVersion(serverHello); err != nil { + return err + } + + // If we are negotiating a protocol version that's lower than what we + // support, check for the server downgrade canaries. + // See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion() + tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 + tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 + if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) || + maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") + } + + if c.vers == VersionTLS13 { + hs := &clientHandshakeStateTLS13{ + c: c, + serverHello: serverHello, + hello: hello, + ecdheParams: ecdheParams, + session: session, + earlySecret: earlySecret, + binderKey: binderKey, + } + + // In TLS 1.3, session tickets are delivered after the handshake. + return hs.handshake() + } + + hs := &clientHandshakeState{ + c: c, + serverHello: serverHello, + hello: hello, + session: session, + } + + if err := hs.handshake(); err != nil { + return err + } + + // If we had a successful handshake and hs.session is different from + // the one already cached - cache a new one. + if cacheKey != "" && hs.session != nil && session != hs.session { + c.config.ClientSessionCache.Put(cacheKey, hs.session) + } + + return nil +} + +func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, + session *ClientSessionState, earlySecret, binderKey []byte) { + if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { + return "", nil, nil, nil + } + + hello.ticketSupported = true + + if hello.supportedVersions[0] == VersionTLS13 { + // Require DHE on resumption as it guarantees forward secrecy against + // compromise of the session ticket key. See RFC 8446, Section 4.2.9. + hello.pskModes = []uint8{pskModeDHE} + } + + // Session resumption is not allowed if renegotiating because + // renegotiation is primarily used to allow a client to send a client + // certificate, which would be skipped if session resumption occurred. + if c.handshakes != 0 { + return "", nil, nil, nil + } + + // Try to resume a previously negotiated TLS session, if available. + cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + session, ok := c.config.ClientSessionCache.Get(cacheKey) + if !ok || session == nil { + return cacheKey, nil, nil, nil + } + + // Check that version used for the previous session is still valid. + versOk := false + for _, v := range hello.supportedVersions { + if v == session.vers { + versOk = true + break + } + } + if !versOk { + return cacheKey, nil, nil, nil + } + + // Check that the cached server certificate is not expired, and that it's + // valid for the ServerName. This should be ensured by the cache key, but + // protect the application from a faulty ClientSessionCache implementation. + if !c.config.InsecureSkipVerify { + if len(session.verifiedChains) == 0 { + // The original connection had InsecureSkipVerify, while this doesn't. + return cacheKey, nil, nil, nil + } + serverCert := session.serverCertificates[0] + if c.config.time().After(serverCert.NotAfter) { + // Expired certificate, delete the entry. + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { + return cacheKey, nil, nil, nil + } + } + + if session.vers != VersionTLS13 { + // In TLS 1.2 the cipher suite must match the resumed session. Ensure we + // are still offering it. + if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { + return cacheKey, nil, nil, nil + } + + hello.sessionTicket = session.sessionTicket + return + } + + // Check that the session ticket is not expired. + if c.config.time().After(session.useBy) { + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + + // In TLS 1.3 the KDF hash must match the resumed session. Ensure we + // offer at least one cipher suite with that hash. + cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) + if cipherSuite == nil { + return cacheKey, nil, nil, nil + } + cipherSuiteOk := false + for _, offeredID := range hello.cipherSuites { + offeredSuite := cipherSuiteTLS13ByID(offeredID) + if offeredSuite != nil && offeredSuite.hash == cipherSuite.hash { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return cacheKey, nil, nil, nil + } + + // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. + ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond) + identity := pskIdentity{ + label: session.sessionTicket, + obfuscatedTicketAge: ticketAge + session.ageAdd, + } + hello.pskIdentities = []pskIdentity{identity} + hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())} + + // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. + psk := cipherSuite.expandLabel(session.masterSecret, "resumption", + session.nonce, cipherSuite.hash.Size()) + earlySecret = cipherSuite.extract(psk, nil) + binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) + transcript := cipherSuite.hash.New() + transcript.Write(hello.marshalWithoutBinders()) + pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} + hello.updateBinders(pskBinders) + + return +} + +func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error { + peerVersion := serverHello.vers + if serverHello.supportedVersion != 0 { + peerVersion = serverHello.supportedVersion + } + + vers, ok := c.config.mutualVersion([]uint16{peerVersion}) + if !ok { + c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion) + } + + c.vers = vers + c.haveVers = true + c.in.version = vers + c.out.version = vers + + return nil +} + +// Does the handshake, either a full one or resumes old session. Requires hs.c, +// hs.hello, hs.serverHello, and, optionally, hs.session to be set. +func (hs *clientHandshakeState) handshake() error { + c := hs.c + + isResume, err := hs.processServerHello() + if err != nil { + return err + } + + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + + // No signatures of the handshake are needed in a resumption. + // Otherwise, in a full handshake, if we don't have any certificates + // configured then we will never send a CertificateVerify message and + // thus no signatures are needed in that case either. + if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) { + hs.finishedHash.discardHandshakeBuffer() + } + + hs.finishedHash.Write(hs.hello.marshal()) + hs.finishedHash.Write(hs.serverHello.marshal()) + + c.buffering = true + c.didResume = isResume + if isResume { + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = false + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } else { + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.clientFinishedIsFirst = true + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + } + + c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +func (hs *clientHandshakeState) pickCipherSuite() error { + if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil { + hs.c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server chose an unconfigured cipher suite") + } + + hs.c.cipherSuite = hs.suite.id + return nil +} + +func (hs *clientHandshakeState) doFullHandshake() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + certMsg, ok := msg.(*certificateMsg) + if !ok || len(certMsg.certificates) == 0 { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + msg, err = c.readHandshake() + if err != nil { + return err + } + + cs, ok := msg.(*certificateStatusMsg) + if ok { + // RFC4366 on Certificate Status Request: + // The server MAY return a "certificate_status" message. + + if !hs.serverHello.ocspStapling { + // If a server returns a "CertificateStatus" message, then the + // server MUST have included an extension of type "status_request" + // with empty "extension_data" in the extended server hello. + + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received unexpected CertificateStatus message") + } + hs.finishedHash.Write(cs.marshal()) + + c.ocspResponse = cs.response + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + if c.handshakes == 0 { + // If this is the first handshake on a connection, process and + // (optionally) verify the server's certificates. + if err := c.verifyServerCertificate(certMsg.certificates); err != nil { + return err + } + } else { + // This is a renegotiation handshake. We require that the + // server's identity (i.e. leaf certificate) is unchanged and + // thus any previous trust decision is still valid. + // + // See https://mitls.org/pages/attacks/3SHAKE for the + // motivation behind this requirement. + if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: server's identity changed during renegotiation") + } + } + + keyAgreement := hs.suite.ka(c.vers) + + skx, ok := msg.(*serverKeyExchangeMsg) + if ok { + hs.finishedHash.Write(skx.marshal()) + err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) + if err != nil { + c.sendAlert(alertUnexpectedMessage) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + var chainToSend *Certificate + var certRequested bool + certReq, ok := msg.(*certificateRequestMsg) + if ok { + certRequested = true + hs.finishedHash.Write(certReq.marshal()) + + cri := certificateRequestInfoFromMsg(c.vers, certReq) + if chainToSend, err = c.getClientCertificate(cri); err != nil { + c.sendAlert(alertInternalError) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + shd, ok := msg.(*serverHelloDoneMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(shd, msg) + } + hs.finishedHash.Write(shd.marshal()) + + // If the server requested a certificate then we have to send a + // Certificate message, even if it's empty because we don't have a + // certificate to send. + if certRequested { + certMsg = new(certificateMsg) + certMsg.certificates = chainToSend.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + } + + preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0]) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + if ckx != nil { + hs.finishedHash.Write(ckx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { + return err + } + } + + if chainToSend != nil && len(chainToSend.Certificate) > 0 { + certVerify := &certificateVerifyMsg{} + + key, ok := chainToSend.PrivateKey.(crypto.Signer) + if !ok { + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) + } + + var sigType uint8 + var sigHash crypto.Hash + if c.vers >= VersionTLS12 { + signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + certVerify.hasSignatureAlgorithm = true + certVerify.signatureAlgorithm = signatureAlgorithm + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public()) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + } + + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.finishedHash.Write(certVerify.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { + return err + } + } + + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) + if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.hello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to write to key log: " + err.Error()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *clientHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + var clientCipher, serverCipher interface{} + var clientHash, serverHash hash.Hash + if hs.suite.cipher != nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) + clientHash = hs.suite.mac(clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) + serverHash = hs.suite.mac(serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) + c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) + return nil +} + +func (hs *clientHandshakeState) serverResumedSession() bool { + // If the server responded with the same sessionId then it means the + // sessionTicket is being used to resume a TLS session. + return hs.session != nil && hs.hello.sessionId != nil && + bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) +} + +func (hs *clientHandshakeState) processServerHello() (bool, error) { + c := hs.c + + if err := hs.pickCipherSuite(); err != nil { + return false, err + } + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertUnexpectedMessage) + return false, errors.New("tls: server selected unsupported compression format") + } + + if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported { + c.secureRenegotiation = true + if len(hs.serverHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: initial handshake had non-empty renegotiation extension") + } + } + + if c.handshakes > 0 && c.secureRenegotiation { + var expectedSecureRenegotiation [24]byte + copy(expectedSecureRenegotiation[:], c.clientFinished[:]) + copy(expectedSecureRenegotiation[12:], c.serverFinished[:]) + if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: incorrect renegotiation extension contents") + } + } + + if hs.serverHello.alpnProtocol != "" { + if len(hs.hello.alpnProtocols) == 0 { + c.sendAlert(alertUnsupportedExtension) + return false, errors.New("tls: server advertised unrequested ALPN extension") + } + if mutualProtocol([]string{hs.serverHello.alpnProtocol}, hs.hello.alpnProtocols) == "" { + c.sendAlert(alertUnsupportedExtension) + return false, errors.New("tls: server selected unadvertised ALPN protocol") + } + c.clientProtocol = hs.serverHello.alpnProtocol + } + + c.scts = hs.serverHello.scts + + if !hs.serverResumedSession() { + return false, nil + } + + if hs.session.vers != c.vers { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different version") + } + + if hs.session.cipherSuite != hs.suite.id { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different cipher suite") + } + + // Restore masterSecret, peerCerts, and ocspResponse from previous state + hs.masterSecret = hs.session.masterSecret + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + c.ocspResponse = hs.session.ocspResponse + // Let the ServerHello SCTs override the session SCTs from the original + // connection, if any are provided + if len(c.scts) == 0 && len(hs.session.scts) != 0 { + c.scts = hs.session.scts + } + + return true, nil +} + +func (hs *clientHandshakeState) readFinished(out []byte) error { + c := hs.c + + if err := c.readChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + serverFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverFinished, msg) + } + + verify := hs.finishedHash.serverSum(hs.masterSecret) + if len(verify) != len(serverFinished.verifyData) || + subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server's Finished message was incorrect") + } + hs.finishedHash.Write(serverFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *clientHandshakeState) readSessionTicket() error { + if !hs.serverHello.ticketSupported { + return nil + } + + c := hs.c + msg, err := c.readHandshake() + if err != nil { + return err + } + sessionTicketMsg, ok := msg.(*newSessionTicketMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(sessionTicketMsg, msg) + } + hs.finishedHash.Write(sessionTicketMsg.marshal()) + + hs.session = &ClientSessionState{ + sessionTicket: sessionTicketMsg.ticket, + vers: c.vers, + cipherSuite: hs.suite.id, + masterSecret: hs.masterSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + receivedAt: c.config.time(), + ocspResponse: c.ocspResponse, + scts: c.scts, + } + + return nil +} + +func (hs *clientHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + copy(out, finished.verifyData) + return nil +} + +// verifyServerCertificate parses and verifies the provided chain, setting +// c.verifiedChains and c.peerCertificates or sending the appropriate alert. +func (c *Conn) verifyServerCertificate(certificates [][]byte) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + if !c.config.InsecureSkipVerify { + opts := x509.VerifyOptions{ + Roots: c.config.RootCAs, + CurrentTime: c.config.time(), + DNSName: c.config.ServerName, + Intermediates: x509.NewCertPool(), + } + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + var err error + c.verifiedChains, err = certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + switch certs[0].PublicKey.(type) { + case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: + break + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) + } + + c.peerCertificates = certs + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + return nil +} + +// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS +// <= 1.2 CertificateRequest, making an effort to fill in missing information. +func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { + cri := &CertificateRequestInfo{ + AcceptableCAs: certReq.certificateAuthorities, + Version: vers, + } + + var rsaAvail, ecAvail bool + for _, certType := range certReq.certificateTypes { + switch certType { + case certTypeRSASign: + rsaAvail = true + case certTypeECDSASign: + ecAvail = true + } + } + + if !certReq.hasSignatureAlgorithm { + // Prior to TLS 1.2, signature schemes did not exist. In this case we + // make up a list based on the acceptable certificate types, to help + // GetClientCertificate and SupportsCertificate select the right certificate. + // The hash part of the SignatureScheme is a lie here, because + // TLS 1.0 and 1.1 always use MD5+SHA1 for RSA and SHA1 for ECDSA. + switch { + case rsaAvail && ecAvail: + cri.SignatureSchemes = []SignatureScheme{ + ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, + PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, + } + case rsaAvail: + cri.SignatureSchemes = []SignatureScheme{ + PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, + } + case ecAvail: + cri.SignatureSchemes = []SignatureScheme{ + ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, + } + } + return cri + } + + // Filter the signature schemes based on the certificate types. + // See RFC 5246, Section 7.4.4 (where it calls this "somewhat complicated"). + cri.SignatureSchemes = make([]SignatureScheme, 0, len(certReq.supportedSignatureAlgorithms)) + for _, sigScheme := range certReq.supportedSignatureAlgorithms { + sigType, _, err := typeAndHashFromSignatureScheme(sigScheme) + if err != nil { + continue + } + switch sigType { + case signatureECDSA, signatureEd25519: + if ecAvail { + cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) + } + case signatureRSAPSS, signaturePKCS1v15: + if rsaAvail { + cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) + } + } + } + + return cri +} + +func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error) { + if c.config.GetClientCertificate != nil { + return c.config.GetClientCertificate(cri) + } + + for _, chain := range c.config.Certificates { + if err := cri.SupportsCertificate(&chain); err != nil { + continue + } + return &chain, nil + } + + // No acceptable certificate found. Don't send a certificate. + return new(Certificate), nil +} + +// clientSessionCacheKey returns a key used to cache sessionTickets that could +// be used to resume previously negotiated TLS sessions with a server. +func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { + if len(config.ServerName) > 0 { + return config.ServerName + } + return serverAddr.String() +} + +// mutualProtocol finds the mutual ALPN protocol given list of possible +// protocols and a list of the preference order. +func mutualProtocol(protos, preferenceProtos []string) string { + for _, s := range preferenceProtos { + for _, c := range protos { + if s == c { + return s + } + } + } + return "" +} + +// hostnameInSNI converts name into an appropriate hostname for SNI. +// Literal IP addresses and absolute FQDNs are not permitted as SNI values. +// See RFC 6066, Section 3. +func hostnameInSNI(name string) string { + host := name + if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { + host = host[1 : len(host)-1] + } + if i := strings.LastIndex(host, "%"); i > 0 { + host = host[:i] + } + if net.ParseIP(host) != nil { + return "" + } + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return name +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_client_test.go b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_client_test.go new file mode 100644 index 0000000..12b0254 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_client_test.go @@ -0,0 +1,2513 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/binary" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" + "net" + "os" + "os/exec" + "path/filepath" + "reflect" + "strconv" + "strings" + "testing" + "time" +) + +// Note: see comment in handshake_test.go for details of how the reference +// tests work. + +// opensslInputEvent enumerates possible inputs that can be sent to an `openssl +// s_client` process. +type opensslInputEvent int + +const ( + // opensslRenegotiate causes OpenSSL to request a renegotiation of the + // connection. + opensslRenegotiate opensslInputEvent = iota + + // opensslSendBanner causes OpenSSL to send the contents of + // opensslSentinel on the connection. + opensslSendSentinel + + // opensslKeyUpdate causes OpenSSL to send send a key update message to the + // client and request one back. + opensslKeyUpdate +) + +const opensslSentinel = "SENTINEL\n" + +type opensslInput chan opensslInputEvent + +func (i opensslInput) Read(buf []byte) (n int, err error) { + for event := range i { + switch event { + case opensslRenegotiate: + return copy(buf, []byte("R\n")), nil + case opensslKeyUpdate: + return copy(buf, []byte("K\n")), nil + case opensslSendSentinel: + return copy(buf, []byte(opensslSentinel)), nil + default: + panic("unknown event") + } + } + + return 0, io.EOF +} + +// opensslOutputSink is an io.Writer that receives the stdout and stderr from an +// `openssl` process and sends a value to handshakeComplete or readKeyUpdate +// when certain messages are seen. +type opensslOutputSink struct { + handshakeComplete chan struct{} + readKeyUpdate chan struct{} + all []byte + line []byte +} + +func newOpensslOutputSink() *opensslOutputSink { + return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil} +} + +// opensslEndOfHandshake is a message that the “openssl s_server” tool will +// print when a handshake completes if run with “-state”. +const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished" + +// opensslReadKeyUpdate is a message that the “openssl s_server” tool will +// print when a KeyUpdate message is received if run with “-state”. +const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update" + +func (o *opensslOutputSink) Write(data []byte) (n int, err error) { + o.line = append(o.line, data...) + o.all = append(o.all, data...) + + for { + i := bytes.IndexByte(o.line, '\n') + if i < 0 { + break + } + + if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) { + o.handshakeComplete <- struct{}{} + } + if bytes.Equal([]byte(opensslReadKeyUpdate), o.line[:i]) { + o.readKeyUpdate <- struct{}{} + } + o.line = o.line[i+1:] + } + + return len(data), nil +} + +func (o *opensslOutputSink) String() string { + return string(o.all) +} + +// clientTest represents a test of the TLS client handshake against a reference +// implementation. +type clientTest struct { + // name is a freeform string identifying the test and the file in which + // the expected results will be stored. + name string + // args, if not empty, contains a series of arguments for the + // command to run for the reference server. + args []string + // config, if not nil, contains a custom Config to use for this test. + config *Config + // cert, if not empty, contains a DER-encoded certificate for the + // reference server. + cert []byte + // key, if not nil, contains either a *rsa.PrivateKey, ed25519.PrivateKey or + // *ecdsa.PrivateKey which is the private key for the reference server. + key interface{} + // extensions, if not nil, contains a list of extension data to be returned + // from the ServerHello. The data should be in standard TLS format with + // a 2-byte uint16 type, 2-byte data length, followed by the extension data. + extensions [][]byte + // validate, if not nil, is a function that will be called with the + // ConnectionState of the resulting connection. It returns a non-nil + // error if the ConnectionState is unacceptable. + validate func(ConnectionState) error + // numRenegotiations is the number of times that the connection will be + // renegotiated. + numRenegotiations int + // renegotiationExpectedToFail, if not zero, is the number of the + // renegotiation attempt that is expected to fail. + renegotiationExpectedToFail int + // checkRenegotiationError, if not nil, is called with any error + // arising from renegotiation. It can map expected errors to nil to + // ignore them. + checkRenegotiationError func(renegotiationNum int, err error) error + // sendKeyUpdate will cause the server to send a KeyUpdate message. + sendKeyUpdate bool +} + +var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"} + +// connFromCommand starts the reference server process, connects to it and +// returns a recordingConn for the connection. The stdin return value is an +// opensslInput for the stdin of the child process. It must be closed before +// Waiting for child. +func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) { + cert := testRSACertificate + if len(test.cert) > 0 { + cert = test.cert + } + certPath := tempFile(string(cert)) + defer os.Remove(certPath) + + var key interface{} = testRSAPrivateKey + if test.key != nil { + key = test.key + } + derBytes, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + panic(err) + } + + var pemOut bytes.Buffer + pem.Encode(&pemOut, &pem.Block{Type: "PRIVATE KEY", Bytes: derBytes}) + + keyPath := tempFile(pemOut.String()) + defer os.Remove(keyPath) + + var command []string + command = append(command, serverCommand...) + command = append(command, test.args...) + command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath) + // serverPort contains the port that OpenSSL will listen on. OpenSSL + // can't take "0" as an argument here so we have to pick a number and + // hope that it's not in use on the machine. Since this only occurs + // when -update is given and thus when there's a human watching the + // test, this isn't too bad. + const serverPort = 24323 + command = append(command, "-accept", strconv.Itoa(serverPort)) + + if len(test.extensions) > 0 { + var serverInfo bytes.Buffer + for _, ext := range test.extensions { + pem.Encode(&serverInfo, &pem.Block{ + Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)), + Bytes: ext, + }) + } + serverInfoPath := tempFile(serverInfo.String()) + defer os.Remove(serverInfoPath) + command = append(command, "-serverinfo", serverInfoPath) + } + + if test.numRenegotiations > 0 || test.sendKeyUpdate { + found := false + for _, flag := range command[1:] { + if flag == "-state" { + found = true + break + } + } + + if !found { + panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate") + } + } + + cmd := exec.Command(command[0], command[1:]...) + stdin = opensslInput(make(chan opensslInputEvent)) + cmd.Stdin = stdin + out := newOpensslOutputSink() + cmd.Stdout = out + cmd.Stderr = out + if err := cmd.Start(); err != nil { + return nil, nil, nil, nil, err + } + + // OpenSSL does print an "ACCEPT" banner, but it does so *before* + // opening the listening socket, so we can't use that to wait until it + // has started listening. Thus we are forced to poll until we get a + // connection. + var tcpConn net.Conn + for i := uint(0); i < 5; i++ { + tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: serverPort, + }) + if err == nil { + break + } + time.Sleep((1 << i) * 5 * time.Millisecond) + } + if err != nil { + close(stdin) + cmd.Process.Kill() + err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out) + return nil, nil, nil, nil, err + } + + record := &recordingConn{ + Conn: tcpConn, + } + + return record, cmd, stdin, out, nil +} + +func (test *clientTest) dataPath() string { + return filepath.Join("testdata", "Client-"+test.name) +} + +func (test *clientTest) loadData() (flows [][]byte, err error) { + in, err := os.Open(test.dataPath()) + if err != nil { + return nil, err + } + defer in.Close() + return parseTestData(in) +} + +func (test *clientTest) run(t *testing.T, write bool) { + var clientConn, serverConn net.Conn + var recordingConn *recordingConn + var childProcess *exec.Cmd + var stdin opensslInput + var stdout *opensslOutputSink + + if write { + var err error + recordingConn, childProcess, stdin, stdout, err = test.connFromCommand() + if err != nil { + t.Fatalf("Failed to start subcommand: %s", err) + } + clientConn = recordingConn + defer func() { + if t.Failed() { + t.Logf("OpenSSL output:\n\n%s", stdout.all) + } + }() + } else { + clientConn, serverConn = localPipe(t) + } + + doneChan := make(chan bool) + defer func() { + clientConn.Close() + <-doneChan + }() + go func() { + defer close(doneChan) + + config := test.config + if config == nil { + config = testConfig + } + client := Client(clientConn, config) + defer client.Close() + + if _, err := client.Write([]byte("hello\n")); err != nil { + t.Errorf("Client.Write failed: %s", err) + return + } + + for i := 1; i <= test.numRenegotiations; i++ { + // The initial handshake will generate a + // handshakeComplete signal which needs to be quashed. + if i == 1 && write { + <-stdout.handshakeComplete + } + + // OpenSSL will try to interleave application data and + // a renegotiation if we send both concurrently. + // Therefore: ask OpensSSL to start a renegotiation, run + // a goroutine to call client.Read and thus process the + // renegotiation request, watch for OpenSSL's stdout to + // indicate that the handshake is complete and, + // finally, have OpenSSL write something to cause + // client.Read to complete. + if write { + stdin <- opensslRenegotiate + } + + signalChan := make(chan struct{}) + + go func() { + defer close(signalChan) + + buf := make([]byte, 256) + n, err := client.Read(buf) + + if test.checkRenegotiationError != nil { + newErr := test.checkRenegotiationError(i, err) + if err != nil && newErr == nil { + return + } + err = newErr + } + + if err != nil { + t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err) + return + } + + buf = buf[:n] + if !bytes.Equal([]byte(opensslSentinel), buf) { + t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) + } + + if expected := i + 1; client.handshakes != expected { + t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes) + } + }() + + if write && test.renegotiationExpectedToFail != i { + <-stdout.handshakeComplete + stdin <- opensslSendSentinel + } + <-signalChan + } + + if test.sendKeyUpdate { + if write { + <-stdout.handshakeComplete + stdin <- opensslKeyUpdate + } + + doneRead := make(chan struct{}) + + go func() { + defer close(doneRead) + + buf := make([]byte, 256) + n, err := client.Read(buf) + + if err != nil { + t.Errorf("Client.Read failed after KeyUpdate: %s", err) + return + } + + buf = buf[:n] + if !bytes.Equal([]byte(opensslSentinel), buf) { + t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) + } + }() + + if write { + // There's no real reason to wait for the client KeyUpdate to + // send data with the new server keys, except that s_server + // drops writes if they are sent at the wrong time. + <-stdout.readKeyUpdate + stdin <- opensslSendSentinel + } + <-doneRead + + if _, err := client.Write([]byte("hello again\n")); err != nil { + t.Errorf("Client.Write failed: %s", err) + return + } + } + + if test.validate != nil { + if err := test.validate(client.ConnectionState()); err != nil { + t.Errorf("validate callback returned error: %s", err) + } + } + + // If the server sent us an alert after our last flight, give it a + // chance to arrive. + if write && test.renegotiationExpectedToFail == 0 { + if err := peekError(client); err != nil { + t.Errorf("final Read returned an error: %s", err) + } + } + }() + + if !write { + flows, err := test.loadData() + if err != nil { + t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err) + } + for i, b := range flows { + if i%2 == 1 { + if *fast { + serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + } else { + serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute)) + } + serverConn.Write(b) + continue + } + bb := make([]byte, len(b)) + if *fast { + serverConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + } else { + serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute)) + } + _, err := io.ReadFull(serverConn, bb) + if err != nil { + t.Fatalf("%s, flow %d: %s", test.name, i+1, err) + } + if !bytes.Equal(b, bb) { + t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b) + } + } + } + + <-doneChan + if !write { + serverConn.Close() + } + + if write { + path := test.dataPath() + out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + t.Fatalf("Failed to create output file: %s", err) + } + defer out.Close() + recordingConn.Close() + close(stdin) + childProcess.Process.Kill() + childProcess.Wait() + if len(recordingConn.flows) < 3 { + t.Fatalf("Client connection didn't work") + } + recordingConn.WriteTo(out) + t.Logf("Wrote %s\n", path) + } +} + +// peekError does a read with a short timeout to check if the next read would +// cause an error, for example if there is an alert waiting on the wire. +func peekError(conn net.Conn) error { + conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + if n, err := conn.Read(make([]byte, 1)); n != 0 { + return errors.New("unexpectedly read data") + } else if err != nil { + if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + return err + } + } + return nil +} + +func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) { + // Make a deep copy of the template before going parallel. + test := *template + if template.config != nil { + test.config = template.config.Clone() + } + test.name = version + "-" + test.name + test.args = append([]string{option}, test.args...) + + runTestAndUpdateIfNeeded(t, version, test.run, false) +} + +func runClientTestTLS10(t *testing.T, template *clientTest) { + runClientTestForVersion(t, template, "TLSv10", "-tls1") +} + +func runClientTestTLS11(t *testing.T, template *clientTest) { + runClientTestForVersion(t, template, "TLSv11", "-tls1_1") +} + +func runClientTestTLS12(t *testing.T, template *clientTest) { + runClientTestForVersion(t, template, "TLSv12", "-tls1_2") +} + +func runClientTestTLS13(t *testing.T, template *clientTest) { + runClientTestForVersion(t, template, "TLSv13", "-tls1_3") +} + +func TestHandshakeClientRSARC4(t *testing.T) { + test := &clientTest{ + name: "RSA-RC4", + args: []string{"-cipher", "RC4-SHA"}, + } + runClientTestTLS10(t, test) + runClientTestTLS11(t, test) + runClientTestTLS12(t, test) +} + +func TestHandshakeClientRSAAES128GCM(t *testing.T) { + test := &clientTest{ + name: "AES128-GCM-SHA256", + args: []string{"-cipher", "AES128-GCM-SHA256"}, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientRSAAES256GCM(t *testing.T) { + test := &clientTest{ + name: "AES256-GCM-SHA384", + args: []string{"-cipher", "AES256-GCM-SHA384"}, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHERSAAES(t *testing.T) { + test := &clientTest{ + name: "ECDHE-RSA-AES", + args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"}, + } + runClientTestTLS10(t, test) + runClientTestTLS11(t, test) + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHEECDSAAES(t *testing.T) { + test := &clientTest{ + name: "ECDHE-ECDSA-AES", + args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"}, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + runClientTestTLS10(t, test) + runClientTestTLS11(t, test) + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) { + test := &clientTest{ + name: "ECDHE-ECDSA-AES-GCM", + args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"}, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientAES256GCMSHA384(t *testing.T) { + test := &clientTest{ + name: "ECDHE-ECDSA-AES256-GCM-SHA384", + args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"}, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientAES128CBCSHA256(t *testing.T) { + test := &clientTest{ + name: "AES128-SHA256", + args: []string{"-cipher", "AES128-SHA256"}, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) { + test := &clientTest{ + name: "ECDHE-RSA-AES128-SHA256", + args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"}, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) { + test := &clientTest{ + name: "ECDHE-ECDSA-AES128-SHA256", + args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"}, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientX25519(t *testing.T) { + config := testConfig.Clone() + config.CurvePreferences = []CurveID{X25519} + + test := &clientTest{ + name: "X25519-ECDHE", + args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"}, + config: config, + } + + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) +} + +func TestHandshakeClientP256(t *testing.T) { + config := testConfig.Clone() + config.CurvePreferences = []CurveID{CurveP256} + + test := &clientTest{ + name: "P256-ECDHE", + args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"}, + config: config, + } + + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) +} + +func TestHandshakeClientHelloRetryRequest(t *testing.T) { + config := testConfig.Clone() + config.CurvePreferences = []CurveID{X25519, CurveP256} + + test := &clientTest{ + name: "HelloRetryRequest", + args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"}, + config: config, + } + + runClientTestTLS13(t, test) +} + +func TestHandshakeClientECDHERSAChaCha20(t *testing.T) { + config := testConfig.Clone() + config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305} + + test := &clientTest{ + name: "ECDHE-RSA-CHACHA20-POLY1305", + args: []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"}, + config: config, + } + + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) { + config := testConfig.Clone() + config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305} + + test := &clientTest{ + name: "ECDHE-ECDSA-CHACHA20-POLY1305", + args: []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"}, + config: config, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + + runClientTestTLS12(t, test) +} + +func TestHandshakeClientAES128SHA256(t *testing.T) { + test := &clientTest{ + name: "AES128-SHA256", + args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"}, + } + runClientTestTLS13(t, test) +} +func TestHandshakeClientAES256SHA384(t *testing.T) { + test := &clientTest{ + name: "AES256-SHA384", + args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"}, + } + runClientTestTLS13(t, test) +} +func TestHandshakeClientCHACHA20SHA256(t *testing.T) { + test := &clientTest{ + name: "CHACHA20-SHA256", + args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"}, + } + runClientTestTLS13(t, test) +} + +func TestHandshakeClientECDSATLS13(t *testing.T) { + test := &clientTest{ + name: "ECDSA", + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + runClientTestTLS13(t, test) +} + +func TestHandshakeClientEd25519(t *testing.T) { + test := &clientTest{ + name: "Ed25519", + cert: testEd25519Certificate, + key: testEd25519PrivateKey, + } + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) + + config := testConfig.Clone() + cert, _ := X509KeyPair([]byte(clientEd25519CertificatePEM), []byte(clientEd25519KeyPEM)) + config.Certificates = []Certificate{cert} + + test = &clientTest{ + name: "ClientCert-Ed25519", + args: []string{"-Verify", "1"}, + config: config, + } + + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) +} + +func TestHandshakeClientCertRSA(t *testing.T) { + config := testConfig.Clone() + cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) + config.Certificates = []Certificate{cert} + + test := &clientTest{ + name: "ClientCert-RSA-RSA", + args: []string{"-cipher", "AES128", "-Verify", "1"}, + config: config, + } + + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) + + test = &clientTest{ + name: "ClientCert-RSA-ECDSA", + args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"}, + config: config, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) + + test = &clientTest{ + name: "ClientCert-RSA-AES256-GCM-SHA384", + args: []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"}, + config: config, + cert: testRSACertificate, + key: testRSAPrivateKey, + } + + runClientTestTLS12(t, test) +} + +func TestHandshakeClientCertECDSA(t *testing.T) { + config := testConfig.Clone() + cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM)) + config.Certificates = []Certificate{cert} + + test := &clientTest{ + name: "ClientCert-ECDSA-RSA", + args: []string{"-cipher", "AES128", "-Verify", "1"}, + config: config, + } + + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) + + test = &clientTest{ + name: "ClientCert-ECDSA-ECDSA", + args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"}, + config: config, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) +} + +// TestHandshakeClientCertRSAPSS tests rsa_pss_rsae_sha256 signatures from both +// client and server certificates. It also serves from both sides a certificate +// signed itself with RSA-PSS, mostly to check that crypto/x509 chain validation +// works. +func TestHandshakeClientCertRSAPSS(t *testing.T) { + cert, err := x509.ParseCertificate(testRSAPSSCertificate) + if err != nil { + panic(err) + } + rootCAs := x509.NewCertPool() + rootCAs.AddCert(cert) + + config := testConfig.Clone() + // Use GetClientCertificate to bypass the client certificate selection logic. + config.GetClientCertificate = func(*CertificateRequestInfo) (*Certificate, error) { + return &Certificate{ + Certificate: [][]byte{testRSAPSSCertificate}, + PrivateKey: testRSAPrivateKey, + }, nil + } + config.RootCAs = rootCAs + + test := &clientTest{ + name: "ClientCert-RSA-RSAPSS", + args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs", + "rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"}, + config: config, + cert: testRSAPSSCertificate, + key: testRSAPrivateKey, + } + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) +} + +func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) { + config := testConfig.Clone() + cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) + config.Certificates = []Certificate{cert} + + test := &clientTest{ + name: "ClientCert-RSA-RSAPKCS1v15", + args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs", + "rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"}, + config: config, + } + + runClientTestTLS12(t, test) +} + +func TestClientKeyUpdate(t *testing.T) { + test := &clientTest{ + name: "KeyUpdate", + args: []string{"-state"}, + sendKeyUpdate: true, + } + runClientTestTLS13(t, test) +} + +func TestResumption(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) }) +} + +func testResumption(t *testing.T, version uint16) { + if testing.Short() { + t.Skip("skipping in -short mode") + } + serverConfig := &Config{ + MaxVersion: version, + CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, + Certificates: testConfig.Certificates, + } + + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(issuer) + + clientConfig := &Config{ + MaxVersion: version, + CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + ClientSessionCache: NewLRUClientSessionCache(32), + RootCAs: rootCAs, + ServerName: "example.golang", + } + + testResumeState := func(test string, didResume bool) { + _, hs, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("%s: handshake failed: %s", test, err) + } + if hs.DidResume != didResume { + t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume) + } + if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) { + t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains) + } + if got, want := hs.ServerName, clientConfig.ServerName; got != want { + t.Errorf("%s: server name %s, want %s", test, got, want) + } + } + + getTicket := func() []byte { + return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket + } + deleteTicket := func() { + ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey + clientConfig.ClientSessionCache.Put(ticketKey, nil) + } + corruptTicket := func() { + clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.masterSecret[0] ^= 0xff + } + randomKey := func() [32]byte { + var k [32]byte + if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil { + t.Fatalf("Failed to read new SessionTicketKey: %s", err) + } + return k + } + + testResumeState("Handshake", false) + ticket := getTicket() + testResumeState("Resume", true) + if !bytes.Equal(ticket, getTicket()) && version != VersionTLS13 { + t.Fatal("first ticket doesn't match ticket after resumption") + } + if bytes.Equal(ticket, getTicket()) && version == VersionTLS13 { + t.Fatal("ticket didn't change after resumption") + } + + // An old session ticket can resume, but the server will provide a ticket encrypted with a fresh key. + serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) } + testResumeState("ResumeWithOldTicket", true) + if bytes.Equal(ticket[:ticketKeyNameLen], getTicket()[:ticketKeyNameLen]) { + t.Fatal("old first ticket matches the fresh one") + } + + // Now the session tickey key is expired, so a full handshake should occur. + serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) } + testResumeState("ResumeWithExpiredTicket", false) + if bytes.Equal(ticket, getTicket()) { + t.Fatal("expired first ticket matches the fresh one") + } + + serverConfig.Time = func() time.Time { return time.Now() } // reset the time back + key1 := randomKey() + serverConfig.SetSessionTicketKeys([][32]byte{key1}) + + testResumeState("InvalidSessionTicketKey", false) + testResumeState("ResumeAfterInvalidSessionTicketKey", true) + + key2 := randomKey() + serverConfig.SetSessionTicketKeys([][32]byte{key2, key1}) + ticket = getTicket() + testResumeState("KeyChange", true) + if bytes.Equal(ticket, getTicket()) { + t.Fatal("new ticket wasn't included while resuming") + } + testResumeState("KeyChangeFinish", true) + + // Age the session ticket a bit, but not yet expired. + serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) } + testResumeState("OldSessionTicket", true) + ticket = getTicket() + // Expire the session ticket, which would force a full handshake. + serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) } + testResumeState("ExpiredSessionTicket", false) + if bytes.Equal(ticket, getTicket()) { + t.Fatal("new ticket wasn't provided after old ticket expired") + } + + // Age the session ticket a bit at a time, but don't expire it. + d := 0 * time.Hour + for i := 0; i < 13; i++ { + d += 12 * time.Hour + serverConfig.Time = func() time.Time { return time.Now().Add(d) } + testResumeState("OldSessionTicket", true) + } + // Expire it (now a little more than 7 days) and make sure a full + // handshake occurs for TLS 1.2. Resumption should still occur for + // TLS 1.3 since the client should be using a fresh ticket sent over + // by the server. + d += 12 * time.Hour + serverConfig.Time = func() time.Time { return time.Now().Add(d) } + if version == VersionTLS13 { + testResumeState("ExpiredSessionTicket", true) + } else { + testResumeState("ExpiredSessionTicket", false) + } + if bytes.Equal(ticket, getTicket()) { + t.Fatal("new ticket wasn't provided after old ticket expired") + } + + // Reset serverConfig to ensure that calling SetSessionTicketKeys + // before the serverConfig is used works. + serverConfig = &Config{ + MaxVersion: version, + CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, + Certificates: testConfig.Certificates, + } + serverConfig.SetSessionTicketKeys([][32]byte{key2}) + + testResumeState("FreshConfig", true) + + // In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF + // hash matches. Also, Config.CipherSuites does not apply to TLS 1.3. + if version != VersionTLS13 { + clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA} + testResumeState("DifferentCipherSuite", false) + testResumeState("DifferentCipherSuiteRecovers", true) + } + + deleteTicket() + testResumeState("WithoutSessionTicket", false) + + // Session resumption should work when using client certificates + deleteTicket() + serverConfig.ClientCAs = rootCAs + serverConfig.ClientAuth = RequireAndVerifyClientCert + clientConfig.Certificates = serverConfig.Certificates + testResumeState("InitialHandshake", false) + testResumeState("WithClientCertificates", true) + serverConfig.ClientAuth = NoClientCert + + // Tickets should be removed from the session cache on TLS handshake + // failure, and the client should recover from a corrupted PSK + testResumeState("FetchTicketToCorrupt", false) + corruptTicket() + _, _, err = testHandshake(t, clientConfig, serverConfig) + if err == nil { + t.Fatalf("handshake did not fail with a corrupted client secret") + } + testResumeState("AfterHandshakeFailure", false) + + clientConfig.ClientSessionCache = nil + testResumeState("WithoutSessionCache", false) +} + +func TestLRUClientSessionCache(t *testing.T) { + // Initialize cache of capacity 4. + cache := NewLRUClientSessionCache(4) + cs := make([]ClientSessionState, 6) + keys := []string{"0", "1", "2", "3", "4", "5", "6"} + + // Add 4 entries to the cache and look them up. + for i := 0; i < 4; i++ { + cache.Put(keys[i], &cs[i]) + } + for i := 0; i < 4; i++ { + if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { + t.Fatalf("session cache failed lookup for added key: %s", keys[i]) + } + } + + // Add 2 more entries to the cache. First 2 should be evicted. + for i := 4; i < 6; i++ { + cache.Put(keys[i], &cs[i]) + } + for i := 0; i < 2; i++ { + if s, ok := cache.Get(keys[i]); ok || s != nil { + t.Fatalf("session cache should have evicted key: %s", keys[i]) + } + } + + // Touch entry 2. LRU should evict 3 next. + cache.Get(keys[2]) + cache.Put(keys[0], &cs[0]) + if s, ok := cache.Get(keys[3]); ok || s != nil { + t.Fatalf("session cache should have evicted key 3") + } + + // Update entry 0 in place. + cache.Put(keys[0], &cs[3]) + if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] { + t.Fatalf("session cache failed update for key 0") + } + + // Calling Put with a nil entry deletes the key. + cache.Put(keys[0], nil) + if _, ok := cache.Get(keys[0]); ok { + t.Fatalf("session cache failed to delete key 0") + } + + // Delete entry 2. LRU should keep 4 and 5 + cache.Put(keys[2], nil) + if _, ok := cache.Get(keys[2]); ok { + t.Fatalf("session cache failed to delete key 4") + } + for i := 4; i < 6; i++ { + if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { + t.Fatalf("session cache should not have deleted key: %s", keys[i]) + } + } +} + +func TestKeyLogTLS12(t *testing.T) { + var serverBuf, clientBuf bytes.Buffer + + clientConfig := testConfig.Clone() + clientConfig.KeyLogWriter = &clientBuf + clientConfig.MaxVersion = VersionTLS12 + + serverConfig := testConfig.Clone() + serverConfig.KeyLogWriter = &serverBuf + serverConfig.MaxVersion = VersionTLS12 + + c, s := localPipe(t) + done := make(chan bool) + + go func() { + defer close(done) + + if err := Server(s, serverConfig).Handshake(); err != nil { + t.Errorf("server: %s", err) + return + } + s.Close() + }() + + if err := Client(c, clientConfig).Handshake(); err != nil { + t.Fatalf("client: %s", err) + } + + c.Close() + <-done + + checkKeylogLine := func(side, loggedLine string) { + if len(loggedLine) == 0 { + t.Fatalf("%s: no keylog line was produced", side) + } + const expectedLen = 13 /* "CLIENT_RANDOM" */ + + 1 /* space */ + + 32*2 /* hex client nonce */ + + 1 /* space */ + + 48*2 /* hex master secret */ + + 1 /* new line */ + if len(loggedLine) != expectedLen { + t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine) + } + if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") { + t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine) + } + } + + checkKeylogLine("client", clientBuf.String()) + checkKeylogLine("server", serverBuf.String()) +} + +func TestKeyLogTLS13(t *testing.T) { + var serverBuf, clientBuf bytes.Buffer + + clientConfig := testConfig.Clone() + clientConfig.KeyLogWriter = &clientBuf + + serverConfig := testConfig.Clone() + serverConfig.KeyLogWriter = &serverBuf + + c, s := localPipe(t) + done := make(chan bool) + + go func() { + defer close(done) + + if err := Server(s, serverConfig).Handshake(); err != nil { + t.Errorf("server: %s", err) + return + } + s.Close() + }() + + if err := Client(c, clientConfig).Handshake(); err != nil { + t.Fatalf("client: %s", err) + } + + c.Close() + <-done + + checkKeylogLines := func(side, loggedLines string) { + loggedLines = strings.TrimSpace(loggedLines) + lines := strings.Split(loggedLines, "\n") + if len(lines) != 4 { + t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines)) + } + } + + checkKeylogLines("client", clientBuf.String()) + checkKeylogLines("server", serverBuf.String()) +} + +func TestHandshakeClientALPNMatch(t *testing.T) { + config := testConfig.Clone() + config.NextProtos = []string{"proto2", "proto1"} + + test := &clientTest{ + name: "ALPN", + // Note that this needs OpenSSL 1.0.2 because that is the first + // version that supports the -alpn flag. + args: []string{"-alpn", "proto1,proto2"}, + config: config, + validate: func(state ConnectionState) error { + // The server's preferences should override the client. + if state.NegotiatedProtocol != "proto1" { + return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol) + } + return nil + }, + } + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) +} + +// sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443` +const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0=" + +func TestHandshakClientSCTs(t *testing.T) { + config := testConfig.Clone() + + scts, err := base64.StdEncoding.DecodeString(sctsBase64) + if err != nil { + t.Fatal(err) + } + + // Note that this needs OpenSSL 1.0.2 because that is the first + // version that supports the -serverinfo flag. + test := &clientTest{ + name: "SCT", + config: config, + extensions: [][]byte{scts}, + validate: func(state ConnectionState) error { + expectedSCTs := [][]byte{ + scts[8:125], + scts[127:245], + scts[247:], + } + if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) { + return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs)) + } + for i, expected := range expectedSCTs { + if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) { + return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected) + } + } + return nil + }, + } + runClientTestTLS12(t, test) + + // TLS 1.3 moved SCTs to the Certificate extensions and -serverinfo only + // supports ServerHello extensions. +} + +func TestRenegotiationRejected(t *testing.T) { + config := testConfig.Clone() + test := &clientTest{ + name: "RenegotiationRejected", + args: []string{"-state"}, + config: config, + numRenegotiations: 1, + renegotiationExpectedToFail: 1, + checkRenegotiationError: func(renegotiationNum int, err error) error { + if err == nil { + return errors.New("expected error from renegotiation but got nil") + } + if !strings.Contains(err.Error(), "no renegotiation") { + return fmt.Errorf("expected renegotiation to be rejected but got %q", err) + } + return nil + }, + } + runClientTestTLS12(t, test) +} + +func TestRenegotiateOnce(t *testing.T) { + config := testConfig.Clone() + config.Renegotiation = RenegotiateOnceAsClient + + test := &clientTest{ + name: "RenegotiateOnce", + args: []string{"-state"}, + config: config, + numRenegotiations: 1, + } + + runClientTestTLS12(t, test) +} + +func TestRenegotiateTwice(t *testing.T) { + config := testConfig.Clone() + config.Renegotiation = RenegotiateFreelyAsClient + + test := &clientTest{ + name: "RenegotiateTwice", + args: []string{"-state"}, + config: config, + numRenegotiations: 2, + } + + runClientTestTLS12(t, test) +} + +func TestRenegotiateTwiceRejected(t *testing.T) { + config := testConfig.Clone() + config.Renegotiation = RenegotiateOnceAsClient + + test := &clientTest{ + name: "RenegotiateTwiceRejected", + args: []string{"-state"}, + config: config, + numRenegotiations: 2, + renegotiationExpectedToFail: 2, + checkRenegotiationError: func(renegotiationNum int, err error) error { + if renegotiationNum == 1 { + return err + } + + if err == nil { + return errors.New("expected error from renegotiation but got nil") + } + if !strings.Contains(err.Error(), "no renegotiation") { + return fmt.Errorf("expected renegotiation to be rejected but got %q", err) + } + return nil + }, + } + + runClientTestTLS12(t, test) +} + +func TestHandshakeClientExportKeyingMaterial(t *testing.T) { + test := &clientTest{ + name: "ExportKeyingMaterial", + config: testConfig.Clone(), + validate: func(state ConnectionState) error { + if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil { + return fmt.Errorf("ExportKeyingMaterial failed: %v", err) + } else if len(km) != 42 { + return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42) + } + return nil + }, + } + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) +} + +var hostnameInSNITests = []struct { + in, out string +}{ + // Opaque string + {"", ""}, + {"localhost", "localhost"}, + {"foo, bar, baz and qux", "foo, bar, baz and qux"}, + + // DNS hostname + {"golang.org", "golang.org"}, + {"golang.org.", "golang.org"}, + + // Literal IPv4 address + {"1.2.3.4", ""}, + + // Literal IPv6 address + {"::1", ""}, + {"::1%lo0", ""}, // with zone identifier + {"[::1]", ""}, // as per RFC 5952 we allow the [] style as IPv6 literal + {"[::1%lo0]", ""}, +} + +func TestHostnameInSNI(t *testing.T) { + for _, tt := range hostnameInSNITests { + c, s := localPipe(t) + + go func(host string) { + Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake() + }(tt.in) + + var header [5]byte + if _, err := io.ReadFull(s, header[:]); err != nil { + t.Fatal(err) + } + recordLen := int(header[3])<<8 | int(header[4]) + + record := make([]byte, recordLen) + if _, err := io.ReadFull(s, record[:]); err != nil { + t.Fatal(err) + } + + c.Close() + s.Close() + + var m clientHelloMsg + if !m.unmarshal(record) { + t.Errorf("unmarshaling ClientHello for %q failed", tt.in) + continue + } + if tt.in != tt.out && m.serverName == tt.in { + t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record) + } + if m.serverName != tt.out { + t.Errorf("expected %q not found in ClientHello: %x", tt.out, record) + } + } +} + +func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { + // This checks that the server can't select a cipher suite that the + // client didn't offer. See #13174. + + c, s := localPipe(t) + errChan := make(chan error, 1) + + go func() { + client := Client(c, &Config{ + ServerName: "foo", + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, + }) + errChan <- client.Handshake() + }() + + var header [5]byte + if _, err := io.ReadFull(s, header[:]); err != nil { + t.Fatal(err) + } + recordLen := int(header[3])<<8 | int(header[4]) + + record := make([]byte, recordLen) + if _, err := io.ReadFull(s, record); err != nil { + t.Fatal(err) + } + + // Create a ServerHello that selects a different cipher suite than the + // sole one that the client offered. + serverHello := &serverHelloMsg{ + vers: VersionTLS12, + random: make([]byte, 32), + cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384, + } + serverHelloBytes := serverHello.marshal() + + s.Write([]byte{ + byte(recordTypeHandshake), + byte(VersionTLS12 >> 8), + byte(VersionTLS12 & 0xff), + byte(len(serverHelloBytes) >> 8), + byte(len(serverHelloBytes)), + }) + s.Write(serverHelloBytes) + s.Close() + + if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") { + t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) + } +} + +func TestVerifyConnection(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) }) +} + +func testVerifyConnection(t *testing.T, version uint16) { + checkFields := func(c ConnectionState, called *int, errorType string) error { + if c.Version != version { + return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version) + } + if c.HandshakeComplete { + return fmt.Errorf("%s: got HandshakeComplete, want false", errorType) + } + if c.ServerName != "example.golang" { + return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang") + } + if c.NegotiatedProtocol != "protocol1" { + return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1") + } + if c.CipherSuite == 0 { + return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType) + } + wantDidResume := false + if *called == 2 { // if this is the second time, then it should be a resumption + wantDidResume = true + } + if c.DidResume != wantDidResume { + return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume) + } + return nil + } + + tests := []struct { + name string + configureServer func(*Config, *int) + configureClient func(*Config, *int) + }{ + { + name: "RequireAndVerifyClientCert", + configureServer: func(config *Config, called *int) { + config.ClientAuth = RequireAndVerifyClientCert + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero") + } + return checkFields(c, called, "server") + } + }, + configureClient: func(config *Config, called *int) { + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero") + } + if c.DidResume { + return nil + // The SCTs and OCSP Responce are dropped on resumption. + // See http://golang.org/issue/39075. + } + if len(c.OCSPResponse) == 0 { + return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") + } + if len(c.SignedCertificateTimestamps) == 0 { + return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") + } + return checkFields(c, called, "client") + } + }, + }, + { + name: "InsecureSkipVerify", + configureServer: func(config *Config, called *int) { + config.ClientAuth = RequireAnyClientCert + config.InsecureSkipVerify = true + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l) + } + if c.VerifiedChains != nil { + return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains) + } + return checkFields(c, called, "server") + } + }, + configureClient: func(config *Config, called *int) { + config.InsecureSkipVerify = true + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) + } + if c.VerifiedChains != nil { + return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains) + } + if c.DidResume { + return nil + // The SCTs and OCSP Responce are dropped on resumption. + // See http://golang.org/issue/39075. + } + if len(c.OCSPResponse) == 0 { + return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") + } + if len(c.SignedCertificateTimestamps) == 0 { + return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") + } + return checkFields(c, called, "client") + } + }, + }, + { + name: "NoClientCert", + configureServer: func(config *Config, called *int) { + config.ClientAuth = NoClientCert + config.VerifyConnection = func(c ConnectionState) error { + *called++ + return checkFields(c, called, "server") + } + }, + configureClient: func(config *Config, called *int) { + config.VerifyConnection = func(c ConnectionState) error { + *called++ + return checkFields(c, called, "client") + } + }, + }, + { + name: "RequestClientCert", + configureServer: func(config *Config, called *int) { + config.ClientAuth = RequestClientCert + config.VerifyConnection = func(c ConnectionState) error { + *called++ + return checkFields(c, called, "server") + } + }, + configureClient: func(config *Config, called *int) { + config.Certificates = nil // clear the client cert + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero") + } + if c.DidResume { + return nil + // The SCTs and OCSP Responce are dropped on resumption. + // See http://golang.org/issue/39075. + } + if len(c.OCSPResponse) == 0 { + return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") + } + if len(c.SignedCertificateTimestamps) == 0 { + return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") + } + return checkFields(c, called, "client") + } + }, + }, + } + for _, test := range tests { + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + rootCAs := x509.NewCertPool() + rootCAs.AddCert(issuer) + + var serverCalled, clientCalled int + + serverConfig := &Config{ + MaxVersion: version, + Certificates: []Certificate{testConfig.Certificates[0]}, + ClientCAs: rootCAs, + NextProtos: []string{"protocol1"}, + } + serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} + serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp") + test.configureServer(serverConfig, &serverCalled) + + clientConfig := &Config{ + MaxVersion: version, + ClientSessionCache: NewLRUClientSessionCache(32), + RootCAs: rootCAs, + ServerName: "example.golang", + Certificates: []Certificate{testConfig.Certificates[0]}, + NextProtos: []string{"protocol1"}, + } + test.configureClient(clientConfig, &clientCalled) + + testHandshakeState := func(name string, didResume bool) { + _, hs, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("%s: handshake failed: %s", name, err) + } + if hs.DidResume != didResume { + t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume) + } + wantCalled := 1 + if didResume { + wantCalled = 2 // resumption would mean this is the second time it was called in this test + } + if clientCalled != wantCalled { + t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled) + } + if serverCalled != wantCalled { + t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled) + } + } + testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false) + testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true) + } +} + +func TestVerifyPeerCertificate(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) }) +} + +func testVerifyPeerCertificate(t *testing.T, version uint16) { + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(issuer) + + now := func() time.Time { return time.Unix(1476984729, 0) } + + sentinelErr := errors.New("TestVerifyPeerCertificate") + + verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + if l := len(rawCerts); l != 1 { + return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) + } + if len(validatedChains) == 0 { + return errors.New("got len(validatedChains) = 0, wanted non-zero") + } + *called = true + return nil + } + verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error { + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero") + } + if isClient && len(c.OCSPResponse) == 0 { + return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero") + } + *called = true + return nil + } + + tests := []struct { + configureServer func(*Config, *bool) + configureClient func(*Config, *bool) + validate func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) + }{ + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != nil { + t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) + } + if serverErr != nil { + t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) + } + if !clientCalled { + t.Errorf("test[%d]: client did not call callback", testNo) + } + if !serverCalled { + t.Errorf("test[%d]: server did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return sentinelErr + } + }, + configureClient: func(config *Config, called *bool) { + config.VerifyPeerCertificate = nil + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if serverErr != sentinelErr { + t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + }, + configureClient: func(config *Config, called *bool) { + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return sentinelErr + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != sentinelErr { + t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = true + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + if l := len(rawCerts); l != 1 { + return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) + } + // With InsecureSkipVerify set, this + // callback should still be called but + // validatedChains must be empty. + if l := len(validatedChains); l != 0 { + return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l) + } + *called = true + return nil + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != nil { + t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) + } + if serverErr != nil { + t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) + } + if !clientCalled { + t.Errorf("test[%d]: client did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return verifyConnectionCallback(called, false, c) + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return verifyConnectionCallback(called, true, c) + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != nil { + t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) + } + if serverErr != nil { + t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) + } + if !clientCalled { + t.Errorf("test[%d]: client did not call callback", testNo) + } + if !serverCalled { + t.Errorf("test[%d]: server did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = nil + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if serverErr != sentinelErr { + t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = nil + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != sentinelErr { + t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) + } + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = nil + config.VerifyConnection = nil + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if serverErr != sentinelErr { + t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) + } + if !serverCalled { + t.Errorf("test[%d]: server did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = nil + config.VerifyConnection = nil + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) + } + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != sentinelErr { + t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) + } + if !clientCalled { + t.Errorf("test[%d]: client did not call callback", testNo) + } + }, + }, + } + + for i, test := range tests { + c, s := localPipe(t) + done := make(chan error) + + var clientCalled, serverCalled bool + + go func() { + config := testConfig.Clone() + config.ServerName = "example.golang" + config.ClientAuth = RequireAndVerifyClientCert + config.ClientCAs = rootCAs + config.Time = now + config.MaxVersion = version + config.Certificates = make([]Certificate, 1) + config.Certificates[0].Certificate = [][]byte{testRSACertificate} + config.Certificates[0].PrivateKey = testRSAPrivateKey + config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} + config.Certificates[0].OCSPStaple = []byte("dummy ocsp") + test.configureServer(config, &serverCalled) + + err = Server(s, config).Handshake() + s.Close() + done <- err + }() + + config := testConfig.Clone() + config.ServerName = "example.golang" + config.RootCAs = rootCAs + config.Time = now + config.MaxVersion = version + test.configureClient(config, &clientCalled) + clientErr := Client(c, config).Handshake() + c.Close() + serverErr := <-done + + test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr) + } +} + +// brokenConn wraps a net.Conn and causes all Writes after a certain number to +// fail with brokenConnErr. +type brokenConn struct { + net.Conn + + // breakAfter is the number of successful writes that will be allowed + // before all subsequent writes fail. + breakAfter int + + // numWrites is the number of writes that have been done. + numWrites int +} + +// brokenConnErr is the error that brokenConn returns once exhausted. +var brokenConnErr = errors.New("too many writes to brokenConn") + +func (b *brokenConn) Write(data []byte) (int, error) { + if b.numWrites >= b.breakAfter { + return 0, brokenConnErr + } + + b.numWrites++ + return b.Conn.Write(data) +} + +func TestFailedWrite(t *testing.T) { + // Test that a write error during the handshake is returned. + for _, breakAfter := range []int{0, 1} { + c, s := localPipe(t) + done := make(chan bool) + + go func() { + Server(s, testConfig).Handshake() + s.Close() + done <- true + }() + + brokenC := &brokenConn{Conn: c, breakAfter: breakAfter} + err := Client(brokenC, testConfig).Handshake() + if err != brokenConnErr { + t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err) + } + brokenC.Close() + + <-done + } +} + +// writeCountingConn wraps a net.Conn and counts the number of Write calls. +type writeCountingConn struct { + net.Conn + + // numWrites is the number of writes that have been done. + numWrites int +} + +func (wcc *writeCountingConn) Write(data []byte) (int, error) { + wcc.numWrites++ + return wcc.Conn.Write(data) +} + +func TestBuffering(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) }) +} + +func testBuffering(t *testing.T, version uint16) { + c, s := localPipe(t) + done := make(chan bool) + + clientWCC := &writeCountingConn{Conn: c} + serverWCC := &writeCountingConn{Conn: s} + + go func() { + config := testConfig.Clone() + config.MaxVersion = version + Server(serverWCC, config).Handshake() + serverWCC.Close() + done <- true + }() + + err := Client(clientWCC, testConfig).Handshake() + if err != nil { + t.Fatal(err) + } + clientWCC.Close() + <-done + + var expectedClient, expectedServer int + if version == VersionTLS13 { + expectedClient = 2 + expectedServer = 1 + } else { + expectedClient = 2 + expectedServer = 2 + } + + if n := clientWCC.numWrites; n != expectedClient { + t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n) + } + + if n := serverWCC.numWrites; n != expectedServer { + t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n) + } +} + +func TestAlertFlushing(t *testing.T) { + c, s := localPipe(t) + done := make(chan bool) + + clientWCC := &writeCountingConn{Conn: c} + serverWCC := &writeCountingConn{Conn: s} + + serverConfig := testConfig.Clone() + + // Cause a signature-time error + brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey} + brokenKey.D = big.NewInt(42) + serverConfig.Certificates = []Certificate{{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: &brokenKey, + }} + + go func() { + Server(serverWCC, serverConfig).Handshake() + serverWCC.Close() + done <- true + }() + + err := Client(clientWCC, testConfig).Handshake() + if err == nil { + t.Fatal("client unexpectedly returned no error") + } + + const expectedError = "remote error: tls: internal error" + if e := err.Error(); !strings.Contains(e, expectedError) { + t.Fatalf("expected to find %q in error but error was %q", expectedError, e) + } + clientWCC.Close() + <-done + + if n := serverWCC.numWrites; n != 1 { + t.Errorf("expected server handshake to complete with one write, but saw %d", n) + } +} + +func TestHandshakeRace(t *testing.T) { + if testing.Short() { + t.Skip("skipping in -short mode") + } + t.Parallel() + // This test races a Read and Write to try and complete a handshake in + // order to provide some evidence that there are no races or deadlocks + // in the handshake locking. + for i := 0; i < 32; i++ { + c, s := localPipe(t) + + go func() { + server := Server(s, testConfig) + if err := server.Handshake(); err != nil { + panic(err) + } + + var request [1]byte + if n, err := server.Read(request[:]); err != nil || n != 1 { + panic(err) + } + + server.Write(request[:]) + server.Close() + }() + + startWrite := make(chan struct{}) + startRead := make(chan struct{}) + readDone := make(chan struct{}, 1) + + client := Client(c, testConfig) + go func() { + <-startWrite + var request [1]byte + client.Write(request[:]) + }() + + go func() { + <-startRead + var reply [1]byte + if _, err := io.ReadFull(client, reply[:]); err != nil { + panic(err) + } + c.Close() + readDone <- struct{}{} + }() + + if i&1 == 1 { + startWrite <- struct{}{} + startRead <- struct{}{} + } else { + startRead <- struct{}{} + startWrite <- struct{}{} + } + <-readDone + } +} + +var getClientCertificateTests = []struct { + setup func(*Config, *Config) + expectedClientError string + verify func(*testing.T, int, *ConnectionState) +}{ + { + func(clientConfig, serverConfig *Config) { + // Returning a Certificate with no certificate data + // should result in an empty message being sent to the + // server. + serverConfig.ClientCAs = nil + clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { + if len(cri.SignatureSchemes) == 0 { + panic("empty SignatureSchemes") + } + if len(cri.AcceptableCAs) != 0 { + panic("AcceptableCAs should have been empty") + } + return new(Certificate), nil + } + }, + "", + func(t *testing.T, testNum int, cs *ConnectionState) { + if l := len(cs.PeerCertificates); l != 0 { + t.Errorf("#%d: expected no certificates but got %d", testNum, l) + } + }, + }, + { + func(clientConfig, serverConfig *Config) { + // With TLS 1.1, the SignatureSchemes should be + // synthesised from the supported certificate types. + clientConfig.MaxVersion = VersionTLS11 + clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { + if len(cri.SignatureSchemes) == 0 { + panic("empty SignatureSchemes") + } + return new(Certificate), nil + } + }, + "", + func(t *testing.T, testNum int, cs *ConnectionState) { + if l := len(cs.PeerCertificates); l != 0 { + t.Errorf("#%d: expected no certificates but got %d", testNum, l) + } + }, + }, + { + func(clientConfig, serverConfig *Config) { + // Returning an error should abort the handshake with + // that error. + clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { + return nil, errors.New("GetClientCertificate") + } + }, + "GetClientCertificate", + func(t *testing.T, testNum int, cs *ConnectionState) { + }, + }, + { + func(clientConfig, serverConfig *Config) { + clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { + if len(cri.AcceptableCAs) == 0 { + panic("empty AcceptableCAs") + } + cert := &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + } + return cert, nil + } + }, + "", + func(t *testing.T, testNum int, cs *ConnectionState) { + if len(cs.VerifiedChains) == 0 { + t.Errorf("#%d: expected some verified chains, but found none", testNum) + } + }, + }, +} + +func TestGetClientCertificate(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) }) +} + +func testGetClientCertificate(t *testing.T, version uint16) { + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + + for i, test := range getClientCertificateTests { + serverConfig := testConfig.Clone() + serverConfig.ClientAuth = VerifyClientCertIfGiven + serverConfig.RootCAs = x509.NewCertPool() + serverConfig.RootCAs.AddCert(issuer) + serverConfig.ClientCAs = serverConfig.RootCAs + serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) } + serverConfig.MaxVersion = version + + clientConfig := testConfig.Clone() + clientConfig.MaxVersion = version + + test.setup(clientConfig, serverConfig) + + type serverResult struct { + cs ConnectionState + err error + } + + c, s := localPipe(t) + done := make(chan serverResult) + + go func() { + defer s.Close() + server := Server(s, serverConfig) + err := server.Handshake() + + var cs ConnectionState + if err == nil { + cs = server.ConnectionState() + } + done <- serverResult{cs, err} + }() + + clientErr := Client(c, clientConfig).Handshake() + c.Close() + + result := <-done + + if clientErr != nil { + if len(test.expectedClientError) == 0 { + t.Errorf("#%d: client error: %v", i, clientErr) + } else if got := clientErr.Error(); got != test.expectedClientError { + t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got) + } else { + test.verify(t, i, &result.cs) + } + } else if len(test.expectedClientError) > 0 { + t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError) + } else if err := result.err; err != nil { + t.Errorf("#%d: server error: %v", i, err) + } else { + test.verify(t, i, &result.cs) + } + } +} + +func TestRSAPSSKeyError(t *testing.T) { + // crypto/tls does not support the rsa_pss_pss_* SignatureSchemes. If support for + // public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with + // the rsa_pss_rsae_* SignatureSchemes. Assert that RSASSA-PSS certificates don't + // parse, or that they don't carry *rsa.PublicKey keys. + b, _ := pem.Decode([]byte(` +-----BEGIN CERTIFICATE----- +MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK +MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC +AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3 +MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP +ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z +/a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5 +b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL +QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou +czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT +JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz +AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn +OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME +AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab +sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z +H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1 +KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ +bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD +HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi +RwBA9Xk1KBNF +-----END CERTIFICATE-----`)) + if b == nil { + t.Fatal("Failed to decode certificate") + } + cert, err := x509.ParseCertificate(b.Bytes) + if err != nil { + return + } + if _, ok := cert.PublicKey.(*rsa.PublicKey); ok { + t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms") + } +} + +func TestCloseClientConnectionOnIdleServer(t *testing.T) { + clientConn, serverConn := localPipe(t) + client := Client(clientConn, testConfig.Clone()) + go func() { + var b [1]byte + serverConn.Read(b[:]) + client.Close() + }() + client.SetWriteDeadline(time.Now().Add(time.Minute)) + err := client.Handshake() + if err != nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + t.Errorf("Expected a closed network connection error but got '%s'", err.Error()) + } + } else { + t.Errorf("Error expected, but no error returned") + } +} + +func testDowngradeCanary(t *testing.T, clientVersion, serverVersion uint16) error { + defer func() { testingOnlyForceDowngradeCanary = false }() + testingOnlyForceDowngradeCanary = true + + clientConfig := testConfig.Clone() + clientConfig.MaxVersion = clientVersion + serverConfig := testConfig.Clone() + serverConfig.MaxVersion = serverVersion + _, _, err := testHandshake(t, clientConfig, serverConfig) + return err +} + +func TestDowngradeCanary(t *testing.T) { + if err := testDowngradeCanary(t, VersionTLS13, VersionTLS12); err == nil { + t.Errorf("downgrade from TLS 1.3 to TLS 1.2 was not detected") + } + if testing.Short() { + t.Skip("skipping the rest of the checks in short mode") + } + if err := testDowngradeCanary(t, VersionTLS13, VersionTLS11); err == nil { + t.Errorf("downgrade from TLS 1.3 to TLS 1.1 was not detected") + } + if err := testDowngradeCanary(t, VersionTLS13, VersionTLS10); err == nil { + t.Errorf("downgrade from TLS 1.3 to TLS 1.0 was not detected") + } + if err := testDowngradeCanary(t, VersionTLS12, VersionTLS11); err == nil { + t.Errorf("downgrade from TLS 1.2 to TLS 1.1 was not detected") + } + if err := testDowngradeCanary(t, VersionTLS12, VersionTLS10); err == nil { + t.Errorf("downgrade from TLS 1.2 to TLS 1.0 was not detected") + } + if err := testDowngradeCanary(t, VersionTLS13, VersionTLS13); err != nil { + t.Errorf("server unexpectedly sent downgrade canary for TLS 1.3") + } + if err := testDowngradeCanary(t, VersionTLS12, VersionTLS12); err != nil { + t.Errorf("client didn't ignore expected TLS 1.2 canary") + } + if err := testDowngradeCanary(t, VersionTLS11, VersionTLS11); err != nil { + t.Errorf("client unexpectedly reacted to a canary in TLS 1.1") + } + if err := testDowngradeCanary(t, VersionTLS10, VersionTLS10); err != nil { + t.Errorf("client unexpectedly reacted to a canary in TLS 1.0") + } +} + +func TestResumptionKeepsOCSPAndSCT(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS13) }) +} + +func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) { + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + t.Fatalf("failed to parse test issuer") + } + roots := x509.NewCertPool() + roots.AddCert(issuer) + clientConfig := &Config{ + MaxVersion: ver, + ClientSessionCache: NewLRUClientSessionCache(32), + ServerName: "example.golang", + RootCAs: roots, + } + serverConfig := testConfig.Clone() + serverConfig.MaxVersion = ver + serverConfig.Certificates[0].OCSPStaple = []byte{1, 2, 3} + serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{4, 5, 6}} + + _, ccs, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + // after a new session we expect to see OCSPResponse and + // SignedCertificateTimestamps populated as usual + if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) { + t.Errorf("client ConnectionState contained unexpected OCSPResponse: wanted %v, got %v", + serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse) + } + if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) { + t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps: wanted %v, got %v", + serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) + } + + // if the server doesn't send any SCTs, repopulate the old SCTs + oldSCTs := serverConfig.Certificates[0].SignedCertificateTimestamps + serverConfig.Certificates[0].SignedCertificateTimestamps = nil + _, ccs, err = testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + if !ccs.DidResume { + t.Fatalf("expected session to be resumed") + } + // after a resumed session we also expect to see OCSPResponse + // and SignedCertificateTimestamps populated + if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) { + t.Errorf("client ConnectionState contained unexpected OCSPResponse after resumption: wanted %v, got %v", + serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse) + } + if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, oldSCTs) { + t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v", + oldSCTs, ccs.SignedCertificateTimestamps) + } + + // Only test overriding the SCTs for TLS 1.2, since in 1.3 + // the server won't send the message containing them + if ver == VersionTLS13 { + return + } + + // if the server changes the SCTs it sends, they should override the saved SCTs + serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{7, 8, 9}} + _, ccs, err = testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + if !ccs.DidResume { + t.Fatalf("expected session to be resumed") + } + if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) { + t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v", + serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) + } +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_client_tls13.go b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_client_tls13.go new file mode 100644 index 0000000..daa5d97 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_client_tls13.go @@ -0,0 +1,685 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto" + "crypto/hmac" + "crypto/rsa" + "errors" + "hash" + "sync/atomic" + "time" +) + +type clientHandshakeStateTLS13 struct { + c *Conn + serverHello *serverHelloMsg + hello *clientHelloMsg + ecdheParams ecdheParameters + + session *ClientSessionState + earlySecret []byte + binderKey []byte + + certReq *certificateRequestMsgTLS13 + usingPSK bool + sentDummyCCS bool + suite *cipherSuiteTLS13 + transcript hash.Hash + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 +} + +// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and, +// optionally, hs.session, hs.earlySecret and hs.binderKey to be set. +func (hs *clientHandshakeStateTLS13) handshake() error { + c := hs.c + + // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, + // sections 4.1.2 and 4.1.3. + if c.handshakes > 0 { + c.sendAlert(alertProtocolVersion) + return errors.New("tls: server selected TLS 1.3 in a renegotiation") + } + + // Consistency check on the presence of a keyShare and its parameters. + if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 { + return c.sendAlert(alertInternalError) + } + + if err := hs.checkServerHelloOrHRR(); err != nil { + return err + } + + hs.transcript = hs.suite.hash.New() + hs.transcript.Write(hs.hello.marshal()) + + if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + if err := hs.processHelloRetryRequest(); err != nil { + return err + } + } + + hs.transcript.Write(hs.serverHello.marshal()) + + c.buffering = true + if err := hs.processServerHello(); err != nil { + return err + } + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + if err := hs.establishHandshakeKeys(); err != nil { + return err + } + if err := hs.readServerParameters(); err != nil { + return err + } + if err := hs.readServerCertificate(); err != nil { + return err + } + if err := hs.readServerFinished(); err != nil { + return err + } + if err := hs.sendClientCertificate(); err != nil { + return err + } + if err := hs.sendClientFinished(); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +// checkServerHelloOrHRR does validity checks that apply to both ServerHello and +// HelloRetryRequest messages. It sets hs.suite. +func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error { + c := hs.c + + if hs.serverHello.supportedVersion == 0 { + c.sendAlert(alertMissingExtension) + return errors.New("tls: server selected TLS 1.3 using the legacy version field") + } + + if hs.serverHello.supportedVersion != VersionTLS13 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid version after a HelloRetryRequest") + } + + if hs.serverHello.vers != VersionTLS12 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an incorrect legacy version") + } + + if hs.serverHello.ocspStapling || + hs.serverHello.ticketSupported || + hs.serverHello.secureRenegotiationSupported || + len(hs.serverHello.secureRenegotiation) != 0 || + len(hs.serverHello.alpnProtocol) != 0 || + len(hs.serverHello.scts) != 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3") + } + + if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server did not echo the legacy session ID") + } + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported compression format") + } + + selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite) + if hs.suite != nil && selectedSuite != hs.suite { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server changed cipher suite after a HelloRetryRequest") + } + if selectedSuite == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server chose an unconfigured cipher suite") + } + hs.suite = selectedSuite + c.cipherSuite = hs.suite.id + + return nil +} + +// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility +// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. +func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.sentDummyCCS { + return nil + } + hs.sentDummyCCS = true + + _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return err +} + +// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and +// resends hs.hello, and reads the new ServerHello into hs.serverHello. +func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { + c := hs.c + + // The first ClientHello gets double-hashed into the transcript upon a + // HelloRetryRequest. (The idea is that the server might offload transcript + // storage to the client in the cookie.) See RFC 8446, Section 4.4.1. + chHash := hs.transcript.Sum(nil) + hs.transcript.Reset() + hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + hs.transcript.Write(chHash) + hs.transcript.Write(hs.serverHello.marshal()) + + // The only HelloRetryRequest extensions we support are key_share and + // cookie, and clients must abort the handshake if the HRR would not result + // in any change in the ClientHello. + if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest message") + } + + if hs.serverHello.cookie != nil { + hs.hello.cookie = hs.serverHello.cookie + } + + if hs.serverHello.serverShare.group != 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: received malformed key_share extension") + } + + // If the server sent a key_share extension selecting a group, ensure it's + // a group we advertised but did not send a key share for, and send a key + // share for it this time. + if curveID := hs.serverHello.selectedGroup; curveID != 0 { + curveOK := false + for _, id := range hs.hello.supportedCurves { + if id == curveID { + curveOK = true + break + } + } + if !curveOK { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + if hs.ecdheParams.CurveID() == curveID { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") + } + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), curveID) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.ecdheParams = params + hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + } + + hs.hello.raw = nil + if len(hs.hello.pskIdentities) > 0 { + pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) + if pskSuite == nil { + return c.sendAlert(alertInternalError) + } + if pskSuite.hash == hs.suite.hash { + // Update binders and obfuscated_ticket_age. + ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond) + hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd + + transcript := hs.suite.hash.New() + transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + transcript.Write(chHash) + transcript.Write(hs.serverHello.marshal()) + transcript.Write(hs.hello.marshalWithoutBinders()) + pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} + hs.hello.updateBinders(pskBinders) + } else { + // Server selected a cipher suite incompatible with the PSK. + hs.hello.pskIdentities = nil + hs.hello.pskBinders = nil + } + } + + hs.transcript.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + hs.serverHello = serverHello + + if err := hs.checkServerHelloOrHRR(); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) processServerHello() error { + c := hs.c + + if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: server sent two HelloRetryRequest messages") + } + + if len(hs.serverHello.cookie) != 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent a cookie in a normal ServerHello") + } + + if hs.serverHello.selectedGroup != 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: malformed key_share extension") + } + + if hs.serverHello.serverShare.group == 0 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server did not send a key share") + } + if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + + if !hs.serverHello.selectedIdentityPresent { + return nil + } + + if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid PSK") + } + + if len(hs.hello.pskIdentities) != 1 || hs.session == nil { + return c.sendAlert(alertInternalError) + } + pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) + if pskSuite == nil { + return c.sendAlert(alertInternalError) + } + if pskSuite.hash != hs.suite.hash { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid PSK and cipher suite pair") + } + + hs.usingPSK = true + c.didResume = true + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + c.ocspResponse = hs.session.ocspResponse + c.scts = hs.session.scts + return nil +} + +func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { + c := hs.c + + sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data) + if sharedKey == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server key share") + } + + earlySecret := hs.earlySecret + if !hs.usingPSK { + earlySecret = hs.suite.extract(nil, nil) + } + handshakeSecret := hs.suite.extract(sharedKey, + hs.suite.deriveSecret(earlySecret, "derived", nil)) + + clientSecret := hs.suite.deriveSecret(handshakeSecret, + clientHandshakeTrafficLabel, hs.transcript) + c.out.setTrafficSecret(hs.suite, clientSecret) + serverSecret := hs.suite.deriveSecret(handshakeSecret, + serverHandshakeTrafficLabel, hs.transcript) + c.in.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.masterSecret = hs.suite.extract(nil, + hs.suite.deriveSecret(handshakeSecret, "derived", nil)) + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerParameters() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + encryptedExtensions, ok := msg.(*encryptedExtensionsMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(encryptedExtensions, msg) + } + hs.transcript.Write(encryptedExtensions.marshal()) + + if encryptedExtensions.alpnProtocol != "" { + if len(hs.hello.alpnProtocols) == 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server advertised unrequested ALPN extension") + } + if mutualProtocol([]string{encryptedExtensions.alpnProtocol}, hs.hello.alpnProtocols) == "" { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server selected unadvertised ALPN protocol") + } + c.clientProtocol = encryptedExtensions.alpnProtocol + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerCertificate() error { + c := hs.c + + // Either a PSK or a certificate is always used, but not both. + // See RFC 8446, Section 4.1.1. + if hs.usingPSK { + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + certReq, ok := msg.(*certificateRequestMsgTLS13) + if ok { + hs.transcript.Write(certReq.marshal()) + + hs.certReq = certReq + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + certMsg, ok := msg.(*certificateMsgTLS13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + if len(certMsg.certificate.Certificate) == 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: received empty certificates message") + } + hs.transcript.Write(certMsg.marshal()) + + c.scts = certMsg.certificate.SignedCertificateTimestamps + c.ocspResponse = certMsg.certificate.OCSPStaple + + if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil { + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // See RFC 8446, Section 4.4.3. + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: certificate used with invalid signature algorithm") + } + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: certificate used with invalid signature algorithm") + } + signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, + sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } + + hs.transcript.Write(certVerify.marshal()) + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerFinished() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + finished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(finished, msg) + } + + expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) + if !hmac.Equal(expectedMAC, finished.verifyData) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid server finished hash") + } + + hs.transcript.Write(finished.marshal()) + + // Derive secrets that take context through the server Finished. + + hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, + clientApplicationTrafficLabel, hs.transcript) + serverSecret := hs.suite.deriveSecret(hs.masterSecret, + serverApplicationTrafficLabel, hs.transcript) + c.in.setTrafficSecret(hs.suite, serverSecret) + + err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) + + return nil +} + +func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { + c := hs.c + + if hs.certReq == nil { + return nil + } + + cert, err := c.getClientCertificate(&CertificateRequestInfo{ + AcceptableCAs: hs.certReq.certificateAuthorities, + SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, + Version: c.vers, + }) + if err != nil { + return err + } + + certMsg := new(certificateMsgTLS13) + + certMsg.certificate = *cert + certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 + certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 + + hs.transcript.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + // If we sent an empty certificate message, skip the CertificateVerify. + if len(cert.Certificate) == 0 { + return nil + } + + certVerifyMsg := new(certificateVerifyMsg) + certVerifyMsg.hasSignatureAlgorithm = true + + certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms) + if err != nil { + // getClientCertificate returned a certificate incompatible with the + // CertificateRequestInfo supported signature algorithms. + c.sendAlert(alertHandshakeFailure) + return err + } + + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + + signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to sign handshake: " + err.Error()) + } + certVerifyMsg.signature = sig + + hs.transcript.Write(certVerifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) sendClientFinished() error { + c := hs.c + + finished := &finishedMsg{ + verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), + } + + hs.transcript.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + c.out.setTrafficSecret(hs.suite, hs.trafficSecret) + + if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { + c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, + resumptionLabel, hs.transcript) + } + + return nil +} + +func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { + if !c.isClient { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received new session ticket from a client") + } + + if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { + return nil + } + + // See RFC 8446, Section 4.6.1. + if msg.lifetime == 0 { + return nil + } + lifetime := time.Duration(msg.lifetime) * time.Second + if lifetime > maxSessionTicketLifetime { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: received a session ticket with invalid lifetime") + } + + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) + if cipherSuite == nil || c.resumptionSecret == nil { + return c.sendAlert(alertInternalError) + } + + // Save the resumption_master_secret and nonce instead of deriving the PSK + // to do the least amount of work on NewSessionTicket messages before we + // know if the ticket will be used. Forward secrecy of resumed connections + // is guaranteed by the requirement for pskModeDHE. + session := &ClientSessionState{ + sessionTicket: msg.label, + vers: c.vers, + cipherSuite: c.cipherSuite, + masterSecret: c.resumptionSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + receivedAt: c.config.time(), + nonce: msg.nonce, + useBy: c.config.time().Add(lifetime), + ageAdd: msg.ageAdd, + ocspResponse: c.ocspResponse, + scts: c.scts, + } + + cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + c.config.ClientSessionCache.Put(cacheKey, session) + + return nil +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_messages.go b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_messages.go new file mode 100644 index 0000000..b5f81e4 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_messages.go @@ -0,0 +1,1809 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "fmt" + "strings" + + "golang.org/x/crypto/cryptobyte" +) + +// The marshalingFunction type is an adapter to allow the use of ordinary +// functions as cryptobyte.MarshalingValue. +type marshalingFunction func(b *cryptobyte.Builder) error + +func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { + return f(b) +} + +// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If +// the length of the sequence is not the value specified, it produces an error. +func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { + b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { + if len(v) != n { + return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) + } + b.AddBytes(v) + return nil + })) +} + +// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder. +func addUint64(b *cryptobyte.Builder, v uint64) { + b.AddUint32(uint32(v >> 32)) + b.AddUint32(uint32(v)) +} + +// readUint64 decodes a big-endian, 64-bit value into out and advances over it. +// It reports whether the read was successful. +func readUint64(s *cryptobyte.String, out *uint64) bool { + var hi, lo uint32 + if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) { + return false + } + *out = uint64(hi)<<32 | uint64(lo) + return true +} + +// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) +} + +type clientHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuites []uint16 + compressionMethods []uint8 + serverName string + ocspStapling bool + supportedCurves []CurveID + supportedPoints []uint8 + ticketSupported bool + sessionTicket []uint8 + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocols []string + scts bool + supportedVersions []uint16 + cookie []byte + keyShares []keyShare + earlyData bool + pskModes []uint8 + pskIdentities []pskIdentity + pskBinders [][]byte +} + +func (m *clientHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeClientHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, suite := range m.cipherSuites { + b.AddUint16(suite) + } + }) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.compressionMethods) + }) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.serverName) > 0 { + // RFC 6066, Section 3 + b.AddUint16(extensionServerName) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // name_type = host_name + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.serverName)) + }) + }) + }) + } + if m.ocspStapling { + // RFC 4366, Section 3.6 + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(1) // status_type = ocsp + b.AddUint16(0) // empty responder_id_list + b.AddUint16(0) // empty request_extensions + }) + } + if len(m.supportedCurves) > 0 { + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + b.AddUint16(extensionSupportedCurves) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + b.AddUint16(uint16(curve)) + } + }) + }) + } + if len(m.supportedPoints) > 0 { + // RFC 4492, Section 5.1.2 + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + if m.ticketSupported { + // RFC 5077, Section 3.2 + b.AddUint16(extensionSessionTicket) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionTicket) + }) + } + if len(m.supportedSignatureAlgorithms) > 0 { + // RFC 5246, Section 7.4.1.4.1 + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + // RFC 8446, Section 4.2.3 + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if m.secureRenegotiationSupported { + // RFC 5746, Section 3.2 + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocols) > 0 { + // RFC 7301, Section 3.1 + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, proto := range m.alpnProtocols { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(proto)) + }) + } + }) + }) + } + if m.scts { + // RFC 6962, Section 3.3.1 + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedVersions) > 0 { + // RFC 8446, Section 4.2.1 + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + for _, vers := range m.supportedVersions { + b.AddUint16(vers) + } + }) + }) + } + if len(m.cookie) > 0 { + // RFC 8446, Section 4.2.2 + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if len(m.keyShares) > 0 { + // RFC 8446, Section 4.2.8 + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ks := range m.keyShares { + b.AddUint16(uint16(ks.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ks.data) + }) + } + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + b.AddUint16(extensionEarlyData) + b.AddUint16(0) // empty extension_data + } + if len(m.pskModes) > 0 { + // RFC 8446, Section 4.2.9 + b.AddUint16(extensionPSKModes) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.pskModes) + }) + }) + } + if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension + // RFC 8446, Section 4.2.11 + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(psk.label) + }) + b.AddUint32(psk.obfuscatedTicketAge) + } + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +// marshalWithoutBinders returns the ClientHello through the +// PreSharedKeyExtension.identities field, according to RFC 8446, Section +// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. +func (m *clientHelloMsg) marshalWithoutBinders() []byte { + bindersLen := 2 // uint16 length prefix + for _, binder := range m.pskBinders { + bindersLen += 1 // uint8 length prefix + bindersLen += len(binder) + } + + fullMessage := m.marshal() + return fullMessage[:len(fullMessage)-bindersLen] +} + +// updateBinders updates the m.pskBinders field, if necessary updating the +// cached marshaled representation. The supplied binders must have the same +// length as the current m.pskBinders. +func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { + if len(pskBinders) != len(m.pskBinders) { + panic("tls: internal error: pskBinders length mismatch") + } + for i := range m.pskBinders { + if len(pskBinders[i]) != len(m.pskBinders[i]) { + panic("tls: internal error: pskBinders length mismatch") + } + } + m.pskBinders = pskBinders + if m.raw != nil { + lenWithoutBinders := len(m.marshalWithoutBinders()) + // TODO(filippo): replace with NewFixedBuilder once CL 148882 is imported. + b := cryptobyte.NewBuilder(m.raw[:lenWithoutBinders]) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + if len(b.BytesOrPanic()) != len(m.raw) { + panic("tls: internal error: failed to update binders") + } + } +} + +func (m *clientHelloMsg) unmarshal(data []byte) bool { + *m = clientHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) { + return false + } + + var cipherSuites cryptobyte.String + if !s.ReadUint16LengthPrefixed(&cipherSuites) { + return false + } + m.cipherSuites = []uint16{} + m.secureRenegotiationSupported = false + for !cipherSuites.Empty() { + var suite uint16 + if !cipherSuites.ReadUint16(&suite) { + return false + } + if suite == scsvRenegotiation { + m.secureRenegotiationSupported = true + } + m.cipherSuites = append(m.cipherSuites, suite) + } + + if !readUint8LengthPrefixed(&s, &m.compressionMethods) { + return false + } + + if s.Empty() { + // ClientHello is optionally followed by extension data + return true + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionServerName: + // RFC 6066, Section 3 + var nameList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { + return false + } + for !nameList.Empty() { + var nameType uint8 + var serverName cryptobyte.String + if !nameList.ReadUint8(&nameType) || + !nameList.ReadUint16LengthPrefixed(&serverName) || + serverName.Empty() { + return false + } + if nameType != 0 { + continue + } + if len(m.serverName) != 0 { + // Multiple names of the same name_type are prohibited. + return false + } + m.serverName = string(serverName) + // An SNI value may not include a trailing dot. + if strings.HasSuffix(m.serverName, ".") { + return false + } + } + case extensionStatusRequest: + // RFC 4366, Section 3.6 + var statusType uint8 + var ignored cryptobyte.String + if !extData.ReadUint8(&statusType) || + !extData.ReadUint16LengthPrefixed(&ignored) || + !extData.ReadUint16LengthPrefixed(&ignored) { + return false + } + m.ocspStapling = statusType == statusTypeOCSP + case extensionSupportedCurves: + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + var curves cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() { + return false + } + for !curves.Empty() { + var curve uint16 + if !curves.ReadUint16(&curve) { + return false + } + m.supportedCurves = append(m.supportedCurves, CurveID(curve)) + } + case extensionSupportedPoints: + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { + return false + } + case extensionSessionTicket: + // RFC 5077, Section 3.2 + m.ticketSupported = true + extData.ReadBytes(&m.sessionTicket, len(extData)) + case extensionSignatureAlgorithms: + // RFC 5246, Section 7.4.1.4.1 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + // RFC 8446, Section 4.2.3 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + case extensionRenegotiationInfo: + // RFC 5746, Section 3.2 + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { + return false + } + m.secureRenegotiationSupported = true + case extensionALPN: + // RFC 7301, Section 3.1 + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + for !protoList.Empty() { + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() { + return false + } + m.alpnProtocols = append(m.alpnProtocols, string(proto)) + } + case extensionSCT: + // RFC 6962, Section 3.3.1 + m.scts = true + case extensionSupportedVersions: + // RFC 8446, Section 4.2.1 + var versList cryptobyte.String + if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() { + return false + } + for !versList.Empty() { + var vers uint16 + if !versList.ReadUint16(&vers) { + return false + } + m.supportedVersions = append(m.supportedVersions, vers) + } + case extensionCookie: + // RFC 8446, Section 4.2.2 + if !readUint16LengthPrefixed(&extData, &m.cookie) || + len(m.cookie) == 0 { + return false + } + case extensionKeyShare: + // RFC 8446, Section 4.2.8 + var clientShares cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&clientShares) { + return false + } + for !clientShares.Empty() { + var ks keyShare + if !clientShares.ReadUint16((*uint16)(&ks.group)) || + !readUint16LengthPrefixed(&clientShares, &ks.data) || + len(ks.data) == 0 { + return false + } + m.keyShares = append(m.keyShares, ks) + } + case extensionEarlyData: + // RFC 8446, Section 4.2.10 + m.earlyData = true + case extensionPSKModes: + // RFC 8446, Section 4.2.9 + if !readUint8LengthPrefixed(&extData, &m.pskModes) { + return false + } + case extensionPreSharedKey: + // RFC 8446, Section 4.2.11 + if !extensions.Empty() { + return false // pre_shared_key must be the last extension + } + var identities cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() { + return false + } + for !identities.Empty() { + var psk pskIdentity + if !readUint16LengthPrefixed(&identities, &psk.label) || + !identities.ReadUint32(&psk.obfuscatedTicketAge) || + len(psk.label) == 0 { + return false + } + m.pskIdentities = append(m.pskIdentities, psk) + } + var binders cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() { + return false + } + for !binders.Empty() { + var binder []byte + if !readUint8LengthPrefixed(&binders, &binder) || + len(binder) == 0 { + return false + } + m.pskBinders = append(m.pskBinders, binder) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type serverHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuite uint16 + compressionMethod uint8 + ocspStapling bool + ticketSupported bool + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocol string + scts [][]byte + supportedVersion uint16 + serverShare keyShare + selectedIdentityPresent bool + selectedIdentity uint16 + supportedPoints []uint8 + + // HelloRetryRequest extensions + cookie []byte + selectedGroup CurveID +} + +func (m *serverHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeServerHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16(m.cipherSuite) + b.AddUint8(m.compressionMethod) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.ticketSupported { + b.AddUint16(extensionSessionTicket) + b.AddUint16(0) // empty extension_data + } + if m.secureRenegotiationSupported { + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if len(m.scts) > 0 { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range m.scts { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + if m.supportedVersion != 0 { + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.supportedVersion) + }) + } + if m.serverShare.group != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.serverShare.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.serverShare.data) + }) + }) + } + if m.selectedIdentityPresent { + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.selectedIdentity) + }) + } + + if len(m.cookie) > 0 { + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if m.selectedGroup != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.selectedGroup)) + }) + } + if len(m.supportedPoints) > 0 { + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *serverHelloMsg) unmarshal(data []byte) bool { + *m = serverHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) || + !s.ReadUint16(&m.cipherSuite) || + !s.ReadUint8(&m.compressionMethod) { + return false + } + + if s.Empty() { + // ServerHello is optionally followed by extension data + return true + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSessionTicket: + m.ticketSupported = true + case extensionRenegotiationInfo: + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { + return false + } + m.secureRenegotiationSupported = true + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + case extensionSCT: + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + m.scts = append(m.scts, sct) + } + case extensionSupportedVersions: + if !extData.ReadUint16(&m.supportedVersion) { + return false + } + case extensionCookie: + if !readUint16LengthPrefixed(&extData, &m.cookie) || + len(m.cookie) == 0 { + return false + } + case extensionKeyShare: + // This extension has different formats in SH and HRR, accept either + // and let the handshake logic decide. See RFC 8446, Section 4.2.8. + if len(extData) == 2 { + if !extData.ReadUint16((*uint16)(&m.selectedGroup)) { + return false + } + } else { + if !extData.ReadUint16((*uint16)(&m.serverShare.group)) || + !readUint16LengthPrefixed(&extData, &m.serverShare.data) { + return false + } + } + case extensionPreSharedKey: + m.selectedIdentityPresent = true + if !extData.ReadUint16(&m.selectedIdentity) { + return false + } + case extensionSupportedPoints: + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type encryptedExtensionsMsg struct { + raw []byte + alpnProtocol string +} + +func (m *encryptedExtensionsMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeEncryptedExtensions) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { + *m = encryptedExtensionsMsg{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type endOfEarlyDataMsg struct{} + +func (m *endOfEarlyDataMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeEndOfEarlyData + return x +} + +func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type keyUpdateMsg struct { + raw []byte + updateRequested bool +} + +func (m *keyUpdateMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeKeyUpdate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.updateRequested { + b.AddUint8(1) + } else { + b.AddUint8(0) + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *keyUpdateMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var updateRequested uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&updateRequested) || !s.Empty() { + return false + } + switch updateRequested { + case 0: + m.updateRequested = false + case 1: + m.updateRequested = true + default: + return false + } + return true +} + +type newSessionTicketMsgTLS13 struct { + raw []byte + lifetime uint32 + ageAdd uint32 + nonce []byte + label []byte + maxEarlyData uint32 +} + +func (m *newSessionTicketMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeNewSessionTicket) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.lifetime) + b.AddUint32(m.ageAdd) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.nonce) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.label) + }) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.maxEarlyData > 0 { + b.AddUint16(extensionEarlyData) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.maxEarlyData) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { + *m = newSessionTicketMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint32(&m.lifetime) || + !s.ReadUint32(&m.ageAdd) || + !readUint8LengthPrefixed(&s, &m.nonce) || + !readUint16LengthPrefixed(&s, &m.label) || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionEarlyData: + if !extData.ReadUint32(&m.maxEarlyData) { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateRequestMsgTLS13 struct { + raw []byte + ocspStapling bool + scts bool + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateRequest) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + // certificate_request_context (SHALL be zero length unless used for + // post-handshake authentication) + b.AddUint8(0) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.scts { + // RFC 8446, Section 4.4.2.1 makes no mention of + // signed_certificate_timestamp in CertificateRequest, but + // "Extensions in the Certificate message from the client MUST + // correspond to extensions in the CertificateRequest message + // from the server." and it appears in the table in Section 4.2. + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedSignatureAlgorithms) > 0 { + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.certificateAuthorities) > 0 { + b.AddUint16(extensionCertificateAuthorities) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ca := range m.certificateAuthorities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ca) + }) + } + }) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { + *m = certificateRequestMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context, extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSCT: + m.scts = true + case extensionSignatureAlgorithms: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + case extensionCertificateAuthorities: + var auths cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() { + return false + } + for !auths.Empty() { + var ca []byte + if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 { + return false + } + m.certificateAuthorities = append(m.certificateAuthorities, ca) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateMsg struct { + raw []byte + certificates [][]byte +} + +func (m *certificateMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var i int + for _, slice := range m.certificates { + i += len(slice) + } + + length := 3 + 3*len(m.certificates) + i + x = make([]byte, 4+length) + x[0] = typeCertificate + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + certificateOctets := length - 3 + x[4] = uint8(certificateOctets >> 16) + x[5] = uint8(certificateOctets >> 8) + x[6] = uint8(certificateOctets) + + y := x[7:] + for _, slice := range m.certificates { + y[0] = uint8(len(slice) >> 16) + y[1] = uint8(len(slice) >> 8) + y[2] = uint8(len(slice)) + copy(y[3:], slice) + y = y[3+len(slice):] + } + + m.raw = x + return +} + +func (m *certificateMsg) unmarshal(data []byte) bool { + if len(data) < 7 { + return false + } + + m.raw = data + certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) + if uint32(len(data)) != certsLen+7 { + return false + } + + numCerts := 0 + d := data[7:] + for certsLen > 0 { + if len(d) < 4 { + return false + } + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + if uint32(len(d)) < 3+certLen { + return false + } + d = d[3+certLen:] + certsLen -= 3 + certLen + numCerts++ + } + + m.certificates = make([][]byte, numCerts) + d = data[7:] + for i := 0; i < numCerts; i++ { + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + m.certificates[i] = d[3 : 3+certLen] + d = d[3+certLen:] + } + + return true +} + +type certificateMsgTLS13 struct { + raw []byte + certificate Certificate + ocspStapling bool + scts bool +} + +func (m *certificateMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // certificate_request_context + + certificate := m.certificate + if !m.ocspStapling { + certificate.OCSPStaple = nil + } + if !m.scts { + certificate.SignedCertificateTimestamps = nil + } + marshalCertificate(b, certificate) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for i, cert := range certificate.Certificate { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if i > 0 { + // This library only supports OCSP and SCT for leaf certificates. + return + } + if certificate.OCSPStaple != nil { + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(certificate.OCSPStaple) + }) + }) + } + if certificate.SignedCertificateTimestamps != nil { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range certificate.SignedCertificateTimestamps { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + }) + } + }) +} + +func (m *certificateMsgTLS13) unmarshal(data []byte) bool { + *m = certificateMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !unmarshalCertificate(&s, &m.certificate) || + !s.Empty() { + return false + } + + m.scts = m.certificate.SignedCertificateTimestamps != nil + m.ocspStapling = m.certificate.OCSPStaple != nil + + return true +} + +func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool { + var certList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&certList) { + return false + } + for !certList.Empty() { + var cert []byte + var extensions cryptobyte.String + if !readUint24LengthPrefixed(&certList, &cert) || + !certList.ReadUint16LengthPrefixed(&extensions) { + return false + } + certificate.Certificate = append(certificate.Certificate, cert) + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + if len(certificate.Certificate) > 1 { + // This library only supports OCSP and SCT for leaf certificates. + continue + } + + switch extension { + case extensionStatusRequest: + var statusType uint8 + if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) || + len(certificate.OCSPStaple) == 0 { + return false + } + case extensionSCT: + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + certificate.SignedCertificateTimestamps = append( + certificate.SignedCertificateTimestamps, sct) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + } + return true +} + +type serverKeyExchangeMsg struct { + raw []byte + key []byte +} + +func (m *serverKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.key) + x := make([]byte, length+4) + x[0] = typeServerKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.key) + + m.raw = x + return x +} + +func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + m.key = data[4:] + return true +} + +type certificateStatusMsg struct { + raw []byte + response []byte +} + +func (m *certificateStatusMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateStatus) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.response) + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateStatusMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var statusType uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&s, &m.response) || + len(m.response) == 0 || !s.Empty() { + return false + } + return true +} + +type serverHelloDoneMsg struct{} + +func (m *serverHelloDoneMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeServerHelloDone + return x +} + +func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type clientKeyExchangeMsg struct { + raw []byte + ciphertext []byte +} + +func (m *clientKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.ciphertext) + x := make([]byte, length+4) + x[0] = typeClientKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.ciphertext) + + m.raw = x + return x +} + +func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if l != len(data)-4 { + return false + } + m.ciphertext = data[4:] + return true +} + +type finishedMsg struct { + raw []byte + verifyData []byte +} + +func (m *finishedMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeFinished) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.verifyData) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *finishedMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + return s.Skip(1) && + readUint24LengthPrefixed(&s, &m.verifyData) && + s.Empty() +} + +type certificateRequestMsg struct { + raw []byte + // hasSignatureAlgorithm indicates whether this message includes a list of + // supported signature algorithms. This change was introduced with TLS 1.2. + hasSignatureAlgorithm bool + + certificateTypes []byte + supportedSignatureAlgorithms []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See RFC 4346, Section 7.4.4. + length := 1 + len(m.certificateTypes) + 2 + casLength := 0 + for _, ca := range m.certificateAuthorities { + casLength += 2 + len(ca) + } + length += casLength + + if m.hasSignatureAlgorithm { + length += 2 + 2*len(m.supportedSignatureAlgorithms) + } + + x = make([]byte, 4+length) + x[0] = typeCertificateRequest + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + x[4] = uint8(len(m.certificateTypes)) + + copy(x[5:], m.certificateTypes) + y := x[5+len(m.certificateTypes):] + + if m.hasSignatureAlgorithm { + n := len(m.supportedSignatureAlgorithms) * 2 + y[0] = uint8(n >> 8) + y[1] = uint8(n) + y = y[2:] + for _, sigAlgo := range m.supportedSignatureAlgorithms { + y[0] = uint8(sigAlgo >> 8) + y[1] = uint8(sigAlgo) + y = y[2:] + } + } + + y[0] = uint8(casLength >> 8) + y[1] = uint8(casLength) + y = y[2:] + for _, ca := range m.certificateAuthorities { + y[0] = uint8(len(ca) >> 8) + y[1] = uint8(len(ca)) + y = y[2:] + copy(y, ca) + y = y[len(ca):] + } + + m.raw = x + return +} + +func (m *certificateRequestMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 5 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + numCertTypes := int(data[4]) + data = data[5:] + if numCertTypes == 0 || len(data) <= numCertTypes { + return false + } + + m.certificateTypes = make([]byte, numCertTypes) + if copy(m.certificateTypes, data) != numCertTypes { + return false + } + + data = data[numCertTypes:] + + if m.hasSignatureAlgorithm { + if len(data) < 2 { + return false + } + sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if sigAndHashLen&1 != 0 { + return false + } + if len(data) < int(sigAndHashLen) { + return false + } + numSigAlgos := sigAndHashLen / 2 + m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos) + for i := range m.supportedSignatureAlgorithms { + m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) + data = data[2:] + } + } + + if len(data) < 2 { + return false + } + casLength := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if len(data) < int(casLength) { + return false + } + cas := make([]byte, casLength) + copy(cas, data) + data = data[casLength:] + + m.certificateAuthorities = nil + for len(cas) > 0 { + if len(cas) < 2 { + return false + } + caLen := uint16(cas[0])<<8 | uint16(cas[1]) + cas = cas[2:] + + if len(cas) < int(caLen) { + return false + } + + m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) + cas = cas[caLen:] + } + + return len(data) == 0 +} + +type certificateVerifyMsg struct { + raw []byte + hasSignatureAlgorithm bool // format change introduced in TLS 1.2 + signatureAlgorithm SignatureScheme + signature []byte +} + +func (m *certificateVerifyMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateVerify) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.hasSignatureAlgorithm { + b.AddUint16(uint16(m.signatureAlgorithm)) + } + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.signature) + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateVerifyMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + if !s.Skip(4) { // message type and uint24 length field + return false + } + if m.hasSignatureAlgorithm { + if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) { + return false + } + } + return readUint16LengthPrefixed(&s, &m.signature) && s.Empty() +} + +type newSessionTicketMsg struct { + raw []byte + ticket []byte +} + +func (m *newSessionTicketMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See RFC 5077, Section 3.3. + ticketLen := len(m.ticket) + length := 2 + 4 + ticketLen + x = make([]byte, 4+length) + x[0] = typeNewSessionTicket + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[8] = uint8(ticketLen >> 8) + x[9] = uint8(ticketLen) + copy(x[10:], m.ticket) + + m.raw = x + + return +} + +func (m *newSessionTicketMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 10 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + ticketLen := int(data[8])<<8 + int(data[9]) + if len(data)-10 != ticketLen { + return false + } + + m.ticket = data[10:] + + return true +} + +type helloRequestMsg struct { +} + +func (*helloRequestMsg) marshal() []byte { + return []byte{typeHelloRequest, 0, 0, 0} +} + +func (*helloRequestMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_messages_test.go b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_messages_test.go new file mode 100644 index 0000000..bb8aea8 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_messages_test.go @@ -0,0 +1,465 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "math/rand" + "reflect" + "strings" + "testing" + "testing/quick" + "time" +) + +var tests = []interface{}{ + &clientHelloMsg{}, + &serverHelloMsg{}, + &finishedMsg{}, + + &certificateMsg{}, + &certificateRequestMsg{}, + &certificateVerifyMsg{ + hasSignatureAlgorithm: true, + }, + &certificateStatusMsg{}, + &clientKeyExchangeMsg{}, + &newSessionTicketMsg{}, + &sessionState{}, + &sessionStateTLS13{}, + &encryptedExtensionsMsg{}, + &endOfEarlyDataMsg{}, + &keyUpdateMsg{}, + &newSessionTicketMsgTLS13{}, + &certificateRequestMsgTLS13{}, + &certificateMsgTLS13{}, +} + +func TestMarshalUnmarshal(t *testing.T) { + rand := rand.New(rand.NewSource(time.Now().UnixNano())) + + for i, iface := range tests { + ty := reflect.ValueOf(iface).Type() + + n := 100 + if testing.Short() { + n = 5 + } + for j := 0; j < n; j++ { + v, ok := quick.Value(ty, rand) + if !ok { + t.Errorf("#%d: failed to create value", i) + break + } + + m1 := v.Interface().(handshakeMessage) + marshaled := m1.marshal() + m2 := iface.(handshakeMessage) + if !m2.unmarshal(marshaled) { + t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) + break + } + m2.marshal() // to fill any marshal cache in the message + + if !reflect.DeepEqual(m1, m2) { + t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) + break + } + + if i >= 3 { + // The first three message types (ClientHello, + // ServerHello and Finished) are allowed to + // have parsable prefixes because the extension + // data is optional and the length of the + // Finished varies across versions. + for j := 0; j < len(marshaled); j++ { + if m2.unmarshal(marshaled[0:j]) { + t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) + break + } + } + } + } + } +} + +func TestFuzz(t *testing.T) { + rand := rand.New(rand.NewSource(0)) + for _, iface := range tests { + m := iface.(handshakeMessage) + + for j := 0; j < 1000; j++ { + len := rand.Intn(100) + bytes := randomBytes(len, rand) + // This just looks for crashes due to bounds errors etc. + m.unmarshal(bytes) + } + } +} + +func randomBytes(n int, rand *rand.Rand) []byte { + r := make([]byte, n) + if _, err := rand.Read(r); err != nil { + panic("rand.Read failed: " + err.Error()) + } + return r +} + +func randomString(n int, rand *rand.Rand) string { + b := randomBytes(n, rand) + return string(b) +} + +func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &clientHelloMsg{} + m.vers = uint16(rand.Intn(65536)) + m.random = randomBytes(32, rand) + m.sessionId = randomBytes(rand.Intn(32), rand) + m.cipherSuites = make([]uint16, rand.Intn(63)+1) + for i := 0; i < len(m.cipherSuites); i++ { + cs := uint16(rand.Int31()) + if cs == scsvRenegotiation { + cs += 1 + } + m.cipherSuites[i] = cs + } + m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) + if rand.Intn(10) > 5 { + m.serverName = randomString(rand.Intn(255), rand) + for strings.HasSuffix(m.serverName, ".") { + m.serverName = m.serverName[:len(m.serverName)-1] + } + } + m.ocspStapling = rand.Intn(10) > 5 + m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) + m.supportedCurves = make([]CurveID, rand.Intn(5)+1) + for i := range m.supportedCurves { + m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1) + } + if rand.Intn(10) > 5 { + m.ticketSupported = true + if rand.Intn(10) > 5 { + m.sessionTicket = randomBytes(rand.Intn(300), rand) + } else { + m.sessionTicket = make([]byte, 0) + } + } + if rand.Intn(10) > 5 { + m.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + if rand.Intn(10) > 5 { + m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms + } + for i := 0; i < rand.Intn(5); i++ { + m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) + } + if rand.Intn(10) > 5 { + m.scts = true + } + if rand.Intn(10) > 5 { + m.secureRenegotiationSupported = true + m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) + } + for i := 0; i < rand.Intn(5); i++ { + m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1)) + } + if rand.Intn(10) > 5 { + m.cookie = randomBytes(rand.Intn(500)+1, rand) + } + for i := 0; i < rand.Intn(5); i++ { + var ks keyShare + ks.group = CurveID(rand.Intn(30000) + 1) + ks.data = randomBytes(rand.Intn(200)+1, rand) + m.keyShares = append(m.keyShares, ks) + } + switch rand.Intn(3) { + case 1: + m.pskModes = []uint8{pskModeDHE} + case 2: + m.pskModes = []uint8{pskModeDHE, pskModePlain} + } + for i := 0; i < rand.Intn(5); i++ { + var psk pskIdentity + psk.obfuscatedTicketAge = uint32(rand.Intn(500000)) + psk.label = randomBytes(rand.Intn(500)+1, rand) + m.pskIdentities = append(m.pskIdentities, psk) + m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) + } + if rand.Intn(10) > 5 { + m.earlyData = true + } + + return reflect.ValueOf(m) +} + +func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &serverHelloMsg{} + m.vers = uint16(rand.Intn(65536)) + m.random = randomBytes(32, rand) + m.sessionId = randomBytes(rand.Intn(32), rand) + m.cipherSuite = uint16(rand.Int31()) + m.compressionMethod = uint8(rand.Intn(256)) + m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) + + if rand.Intn(10) > 5 { + m.ocspStapling = true + } + if rand.Intn(10) > 5 { + m.ticketSupported = true + } + if rand.Intn(10) > 5 { + m.alpnProtocol = randomString(rand.Intn(32)+1, rand) + } + + for i := 0; i < rand.Intn(4); i++ { + m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) + } + + if rand.Intn(10) > 5 { + m.secureRenegotiationSupported = true + m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) + } + if rand.Intn(10) > 5 { + m.supportedVersion = uint16(rand.Intn(0xffff) + 1) + } + if rand.Intn(10) > 5 { + m.cookie = randomBytes(rand.Intn(500)+1, rand) + } + if rand.Intn(10) > 5 { + for i := 0; i < rand.Intn(5); i++ { + m.serverShare.group = CurveID(rand.Intn(30000) + 1) + m.serverShare.data = randomBytes(rand.Intn(200)+1, rand) + } + } else if rand.Intn(10) > 5 { + m.selectedGroup = CurveID(rand.Intn(30000) + 1) + } + if rand.Intn(10) > 5 { + m.selectedIdentityPresent = true + m.selectedIdentity = uint16(rand.Intn(0xffff)) + } + + return reflect.ValueOf(m) +} + +func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &encryptedExtensionsMsg{} + + if rand.Intn(10) > 5 { + m.alpnProtocol = randomString(rand.Intn(32)+1, rand) + } + + return reflect.ValueOf(m) +} + +func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateMsg{} + numCerts := rand.Intn(20) + m.certificates = make([][]byte, numCerts) + for i := 0; i < numCerts; i++ { + m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) + } + return reflect.ValueOf(m) +} + +func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateRequestMsg{} + m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) + for i := 0; i < rand.Intn(100); i++ { + m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) + } + return reflect.ValueOf(m) +} + +func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateVerifyMsg{} + m.hasSignatureAlgorithm = true + m.signatureAlgorithm = SignatureScheme(rand.Intn(30000)) + m.signature = randomBytes(rand.Intn(15)+1, rand) + return reflect.ValueOf(m) +} + +func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateStatusMsg{} + m.response = randomBytes(rand.Intn(10)+1, rand) + return reflect.ValueOf(m) +} + +func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &clientKeyExchangeMsg{} + m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) + return reflect.ValueOf(m) +} + +func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &finishedMsg{} + m.verifyData = randomBytes(12, rand) + return reflect.ValueOf(m) +} + +func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &newSessionTicketMsg{} + m.ticket = randomBytes(rand.Intn(4), rand) + return reflect.ValueOf(m) +} + +func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { + s := &sessionState{} + s.vers = uint16(rand.Intn(10000)) + s.cipherSuite = uint16(rand.Intn(10000)) + s.masterSecret = randomBytes(rand.Intn(100)+1, rand) + s.createdAt = uint64(rand.Int63()) + for i := 0; i < rand.Intn(20); i++ { + s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand)) + } + return reflect.ValueOf(s) +} + +func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value { + s := &sessionStateTLS13{} + s.cipherSuite = uint16(rand.Intn(10000)) + s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand) + s.createdAt = uint64(rand.Int63()) + for i := 0; i < rand.Intn(2)+1; i++ { + s.certificate.Certificate = append( + s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) + } + if rand.Intn(10) > 5 { + s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) + } + if rand.Intn(10) > 5 { + for i := 0; i < rand.Intn(2)+1; i++ { + s.certificate.SignedCertificateTimestamps = append( + s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) + } + } + return reflect.ValueOf(s) +} + +func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &endOfEarlyDataMsg{} + return reflect.ValueOf(m) +} + +func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &keyUpdateMsg{} + m.updateRequested = rand.Intn(10) > 5 + return reflect.ValueOf(m) +} + +func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { + m := &newSessionTicketMsgTLS13{} + m.lifetime = uint32(rand.Intn(500000)) + m.ageAdd = uint32(rand.Intn(500000)) + m.nonce = randomBytes(rand.Intn(100), rand) + m.label = randomBytes(rand.Intn(1000), rand) + if rand.Intn(10) > 5 { + m.maxEarlyData = uint32(rand.Intn(500000)) + } + return reflect.ValueOf(m) +} + +func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateRequestMsgTLS13{} + if rand.Intn(10) > 5 { + m.ocspStapling = true + } + if rand.Intn(10) > 5 { + m.scts = true + } + if rand.Intn(10) > 5 { + m.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + if rand.Intn(10) > 5 { + m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms + } + if rand.Intn(10) > 5 { + m.certificateAuthorities = make([][]byte, 3) + for i := 0; i < 3; i++ { + m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand) + } + } + return reflect.ValueOf(m) +} + +func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateMsgTLS13{} + for i := 0; i < rand.Intn(2)+1; i++ { + m.certificate.Certificate = append( + m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) + } + if rand.Intn(10) > 5 { + m.ocspStapling = true + m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) + } + if rand.Intn(10) > 5 { + m.scts = true + for i := 0; i < rand.Intn(2)+1; i++ { + m.certificate.SignedCertificateTimestamps = append( + m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) + } + } + return reflect.ValueOf(m) +} + +func TestRejectEmptySCTList(t *testing.T) { + // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. + + var random [32]byte + sct := []byte{0x42, 0x42, 0x42, 0x42} + serverHello := serverHelloMsg{ + vers: VersionTLS12, + random: random[:], + scts: [][]byte{sct}, + } + serverHelloBytes := serverHello.marshal() + + var serverHelloCopy serverHelloMsg + if !serverHelloCopy.unmarshal(serverHelloBytes) { + t.Fatal("Failed to unmarshal initial message") + } + + // Change serverHelloBytes so that the SCT list is empty + i := bytes.Index(serverHelloBytes, sct) + if i < 0 { + t.Fatal("Cannot find SCT in ServerHello") + } + + var serverHelloEmptySCT []byte + serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) + // Append the extension length and SCT list length for an empty list. + serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) + serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) + + // Update the handshake message length. + serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) + serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) + serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) + + // Update the extensions length + serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) + serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) + + if serverHelloCopy.unmarshal(serverHelloEmptySCT) { + t.Fatal("Unmarshaled ServerHello with empty SCT list") + } +} + +func TestRejectEmptySCT(t *testing.T) { + // Not only must the SCT list be non-empty, but the SCT elements must + // not be zero length. + + var random [32]byte + serverHello := serverHelloMsg{ + vers: VersionTLS12, + random: random[:], + scts: [][]byte{nil}, + } + serverHelloBytes := serverHello.marshal() + + var serverHelloCopy serverHelloMsg + if serverHelloCopy.unmarshal(serverHelloBytes) { + t.Fatal("Unmarshaled ServerHello with zero-length SCT") + } +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_server.go b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_server.go new file mode 100644 index 0000000..0d67d5a --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_server.go @@ -0,0 +1,1106 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "sync/atomic" + "time" +) + +const ( + stateServerHandshakeReadClientHello uint32 = 1 + + // TLS 1.3 + stateServerHandshake13ProcessClientHello uint32 = 101 + stateServerHandshake13CheckForResumption uint32 = 102 + stateServerHandshake13PickCertificate uint32 = 103 + stateServerHandshake13SendServerParameters uint32 = 104 + stateServerHandshake13SendServerCertificate uint32 = 105 + stateServerHandshake13SendServerFinished uint32 = 106 + stateServerHandshake13ReadClientCertificate uint32 = 107 + stateServerHandshake13ReadClientFinished uint32 = 108 + stateServerHandshake13HandshakeDone uint32 = 109 + + // below TLS 1.3 + stateServerHandshakeProcessClientHello uint32 = 201 + stateServerHandshakeCheckForResumption uint32 = 202 + + stateServerHandshakeDoResumeHandshake uint32 = 203 + stateServerHandshakeEstablishKeys uint32 = 204 + stateServerHandshakeSendSessionTicket uint32 = 205 + stateServerHandshakeSendFinished uint32 = 206 + + stateServerHandshakePickCipherSuite2 uint32 = 208 + stateServerHandshakeDoFullHandshake2 uint32 = 209 + stateServerHandshakeDoFullHandshake2ReadHandshake1 uint32 = 210 + stateServerHandshakeDoFullHandshake2HandleCertificateMsg uint32 = 211 + stateServerHandshakeDoFullHandshake2ReadHandshake2 uint32 = 212 + stateServerHandshakeDoFullHandshake2HandleVerifyConnection uint32 = 213 + stateServerHandshakeDoFullHandshake2ReadHandshake3 uint32 = 214 + stateServerHandshakeEstablishKeys2 uint32 = 215 + + stateServerHandshakeReadFinishedReadChangeCipherSpec uint32 = 216 + stateServerHandshakeReadFinishedDone uint32 = 217 + + stateServerHandshakeSendSessionTicket2 uint32 = 218 + stateServerHandshakeSendFinished2 uint32 = 219 + + stateServerHandshakeHandshakeDone uint32 = 220 +) + +// serverHandshakeState contains details of a server handshake in progress. +// It's discarded once the handshake has completed. +type serverHandshakeState struct { + c *Conn + ok bool + // msg interface{} + ka keyAgreement + clientHello *clientHelloMsg + hello *serverHelloMsg + suite *cipherSuite + ecdheOk bool + ecSignOk bool + rsaDecryptOk bool + rsaSignOk bool + sessionState *sessionState + finishedHash finishedHash + masterSecret []byte + cert *Certificate + + err error +} + +// serverHandshake performs a TLS handshake as a server. +func (c *Conn) serverHandshake() error { + var err error + + if c.handshakeStatusAsync < stateServerHandshakeReadClientHello { + c.clientHello, err = c.readClientHello() + if err != nil { + return err + } + c.handshakeStatusAsync = stateServerHandshakeReadClientHello + } + if c.vers == VersionTLS13 { + hs := c.hs13 + if hs == nil { + hs = &serverHandshakeStateTLS13{ + c: c, + clientHello: c.clientHello, + } + c.hs13 = hs + } + return hs.handshake() + } + + hs := c.hs + if hs == nil { + hs = &serverHandshakeState{ + c: c, + clientHello: c.clientHello, + } + c.hs = hs + } + return hs.handshake() +} + +func (hs *serverHandshakeState) handshake() error { + c := hs.c + if c.handshakeStatusAsync >= stateServerHandshakeHandshakeDone { + return nil + } + if hs.err != nil && hs.err != errDataNotEnough { + return hs.err + } + + if err := hs.processClientHello(); err != nil { + hs.err = err + return err + } + + // For an overview of TLS handshaking, see RFC 5246, Section 7.3. + if hs.checkForResumption() { + // The client has included a session ticket and so we do an abbreviated handshake. + c.didResume = true + if err := hs.doResumeHandshake(); err != nil { + hs.err = err + return err + } + if err := hs.establishKeys(); err != nil { + hs.err = err + return err + } + c.buffering = true + if err := hs.sendSessionTicket(); err != nil { + hs.err = err + return err + } + if err := hs.sendFinished(c.serverFinished[:]); err != nil { + hs.err = err + return err + } + if _, err := c.flush(); err != nil { + hs.err = err + return err + } + c.clientFinishedIsFirst = false + if err := hs.readFinished(nil); err != nil { + hs.err = err + return err + } + } else { + // The client didn't include a session ticket, or it wasn't + // valid so we do a full handshake. + if err := hs.pickCipherSuite(); err != nil { + hs.err = err + return err + } + if err := hs.doFullHandshake(); err != nil { + hs.err = err + if err != errDataNotEnough { + } + return err + } + if err := hs.establishKeys(); err != nil { + hs.err = err + return err + } + if err := hs.readFinished(c.clientFinished[:]); err != nil { + hs.err = err + if err != errDataNotEnough { + } + return err + } + c.clientFinishedIsFirst = true + c.buffering = true + if err := hs.sendSessionTicket2(); err != nil { + hs.err = err + return err + } + if err := hs.sendFinished2(nil); err != nil { + hs.err = err + return err + } + if _, err := c.flush(); err != nil { + hs.err = err + return err + } + } + + c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) + atomic.StoreUint32(&c.handshakeStatus, 1) + + c.handshakeStatusAsync = stateServerHandshakeHandshakeDone + + return nil +} + +// readClientHello reads a ClientHello message and selects the protocol version. +func (c *Conn) readClientHello() (*clientHelloMsg, error) { + msg, err := c.readHandshake() + if err != nil { + return nil, err + } + clientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return nil, unexpectedMessageError(clientHello, msg) + } + + var configForClient *Config + originalConfig := c.config + if c.config.GetConfigForClient != nil { + chi := clientHelloInfo(c, clientHello) + if configForClient, err = c.config.GetConfigForClient(chi); err != nil { + c.sendAlert(alertInternalError) + return nil, err + } else if configForClient != nil { + c.config = configForClient + } + } + c.ticketKeys = originalConfig.ticketKeys(configForClient) + + clientVersions := clientHello.supportedVersions + if len(clientHello.supportedVersions) == 0 { + clientVersions = supportedVersionsFromMax(clientHello.vers) + } + c.vers, ok = c.config.mutualVersion(clientVersions) + if !ok { + c.sendAlert(alertProtocolVersion) + return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions) + } + c.haveVers = true + c.in.version = c.vers + c.out.version = c.vers + + return clientHello, nil +} + +func (hs *serverHandshakeState) processClientHello() error { + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshakeProcessClientHello { + return nil + } + c.handshakeStatusAsync = stateServerHandshakeProcessClientHello + + hs.hello = new(serverHelloMsg) + hs.hello.vers = c.vers + + foundCompression := false + // We only support null compression, so check that the client offered it. + for _, compression := range hs.clientHello.compressionMethods { + if compression == compressionNone { + foundCompression = true + break + } + } + + if !foundCompression { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client does not support uncompressed connections") + } + + hs.hello.random = make([]byte, 32) + serverRandom := hs.hello.random + // Downgrade protection canaries. See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion() + if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary { + if c.vers == VersionTLS12 { + copy(serverRandom[24:], downgradeCanaryTLS12) + } else { + copy(serverRandom[24:], downgradeCanaryTLS11) + } + serverRandom = serverRandom[:24] + } + _, err := io.ReadFull(c.config.rand(), serverRandom) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported + hs.hello.compressionMethod = compressionNone + if len(hs.clientHello.serverName) > 0 { + c.serverName = hs.clientHello.serverName + } + + if len(hs.clientHello.alpnProtocols) > 0 { + if selectedProto := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); selectedProto != "" { + hs.hello.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + } + } + + hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) + if err != nil { + if err == errNoCertificates { + c.sendAlert(alertUnrecognizedName) + } else { + c.sendAlert(alertInternalError) + } + return err + } + if hs.clientHello.scts { + hs.hello.scts = hs.cert.SignedCertificateTimestamps + } + + hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) + + if hs.ecdheOk { + // Although omitting the ec_point_formats extension is permitted, some + // old OpenSSL version will refuse to handshake if not present. + // + // Per RFC 4492, section 5.1.2, implementations MUST support the + // uncompressed point format. See golang.org/issue/31943. + hs.hello.supportedPoints = []uint8{pointFormatUncompressed} + } + + if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok { + switch priv.Public().(type) { + case *ecdsa.PublicKey: + hs.ecSignOk = true + case ed25519.PublicKey: + hs.ecSignOk = true + case *rsa.PublicKey: + hs.rsaSignOk = true + default: + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public()) + } + } + if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok { + switch priv.Public().(type) { + case *rsa.PublicKey: + hs.rsaDecryptOk = true + default: + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public()) + } + } + + return nil +} + +// supportsECDHE returns whether ECDHE key exchanges can be used with this +// pre-TLS 1.3 client. +func supportsECDHE(c *Config, supportedCurves []CurveID, supportedPoints []uint8) bool { + supportsCurve := false + for _, curve := range supportedCurves { + if c.supportsCurve(curve) { + supportsCurve = true + break + } + } + + supportsPointFormat := false + for _, pointFormat := range supportedPoints { + if pointFormat == pointFormatUncompressed { + supportsPointFormat = true + break + } + } + + return supportsCurve && supportsPointFormat +} + +func (hs *serverHandshakeState) pickCipherSuite() error { + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshakePickCipherSuite2 { + return nil + } + c.handshakeStatusAsync = stateServerHandshakePickCipherSuite2 + + var preferenceList, supportedList []uint16 + if c.config.PreferServerCipherSuites { + preferenceList = c.config.cipherSuites() + supportedList = hs.clientHello.cipherSuites + + // If the client does not seem to have hardware support for AES-GCM, + // and the application did not specify a cipher suite preference order, + // prefer other AEAD ciphers even if we prioritized AES-GCM ciphers + // by default. + if c.config.CipherSuites == nil && !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceList = deprioritizeAES(preferenceList) + } + } else { + preferenceList = hs.clientHello.cipherSuites + supportedList = c.config.cipherSuites() + + // If we don't have hardware support for AES-GCM, prefer other AEAD + // ciphers even if the client prioritized AES-GCM. + if !hasAESGCMHardwareSupport { + preferenceList = deprioritizeAES(preferenceList) + } + } + + hs.suite = selectCipherSuite(preferenceList, supportedList, hs.cipherSuiteOk) + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no cipher suite supported by both client and server") + } + c.cipherSuite = hs.suite.id + + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // The client is doing a fallback connection. See RFC 7507. + if hs.clientHello.vers < c.config.maxSupportedVersion() { + c.sendAlert(alertInappropriateFallback) + return errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + return nil +} + +func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool { + if c.flags&suiteECDHE != 0 { + if !hs.ecdheOk { + return false + } + if c.flags&suiteECSign != 0 { + if !hs.ecSignOk { + return false + } + } else if !hs.rsaSignOk { + return false + } + } else if !hs.rsaDecryptOk { + return false + } + if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true +} + +// checkForResumption reports whether we should perform resumption on this connection. +func (hs *serverHandshakeState) checkForResumption() bool { + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshakeCheckForResumption { + return hs.ok + } + c.handshakeStatusAsync = stateServerHandshakeCheckForResumption + + if c.config.SessionTicketsDisabled { + hs.ok = false + return false + } + + plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket) + if plaintext == nil { + hs.ok = false + return false + } + hs.sessionState = &sessionState{usedOldKey: usedOldKey} + ok := hs.sessionState.unmarshal(plaintext) + if !ok { + hs.ok = false + return false + } + + createdAt := time.Unix(int64(hs.sessionState.createdAt), 0) + if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { + hs.ok = false + return false + } + + // Never resume a session for a different TLS version. + if c.vers != hs.sessionState.vers { + hs.ok = false + return false + } + + cipherSuiteOk := false + // Check that the client is still offering the ciphersuite in the session. + for _, id := range hs.clientHello.cipherSuites { + if id == hs.sessionState.cipherSuite { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + hs.ok = false + return false + } + + // Check that we also support the ciphersuite from the session. + hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite}, + c.config.cipherSuites(), hs.cipherSuiteOk) + if hs.suite == nil { + hs.ok = false + return false + } + + sessionHasClientCerts := len(hs.sessionState.certificates) != 0 + needClientCerts := requiresClientCert(c.config.ClientAuth) + if needClientCerts && !sessionHasClientCerts { + hs.ok = false + return false + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + hs.ok = false + return false + } + hs.ok = true + return true +} + +func (hs *serverHandshakeState) doResumeHandshake() error { + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshakeDoResumeHandshake { + return nil + } + c.handshakeStatusAsync = stateServerHandshakeDoResumeHandshake + + hs.hello.cipherSuite = hs.suite.id + c.cipherSuite = hs.suite.id + // We echo the client's session ID in the ServerHello to let it know + // that we're doing a resumption. + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.ticketSupported = hs.sessionState.usedOldKey + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + hs.finishedHash.discardHandshakeBuffer() + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if err := c.processCertsFromClient(Certificate{ + Certificate: hs.sessionState.certificates, + }); err != nil { + return err + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + hs.masterSecret = hs.sessionState.masterSecret + + return nil +} + +func (hs *serverHandshakeState) doFullHandshake() error { + c := hs.c + + var err error + var msg interface{} + var pub crypto.PublicKey // public key for client auth, if any + var certReq *certificateRequestMsg + if c.handshakeStatusAsync < stateServerHandshakeDoFullHandshake2 { + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2 + + if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { + hs.hello.ocspStapling = true + } + + hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled + hs.hello.cipherSuite = hs.suite.id + + hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) + if c.config.ClientAuth == NoClientCert { + // No need to keep a full record of the handshake if client + // certificates won't be used. + hs.finishedHash.discardHandshakeBuffer() + } + + c.buffering = true + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + certMsg := new(certificateMsg) + certMsg.certificates = hs.cert.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + if hs.hello.ocspStapling { + certStatus := new(certificateStatusMsg) + certStatus.response = hs.cert.OCSPStaple + hs.finishedHash.Write(certStatus.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { + return err + } + } + + hs.ka = hs.suite.ka(c.vers) + skx, err := hs.ka.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + if skx != nil { + hs.finishedHash.Write(skx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { + return err + } + } + + if c.config.ClientAuth >= RequestClientCert { + // Request a client certificate + certReq = new(certificateRequestMsg) + certReq.certificateTypes = []byte{ + byte(certTypeRSASign), + byte(certTypeECDSASign), + } + if c.vers >= VersionTLS12 { + certReq.hasSignatureAlgorithm = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + + // An empty list of certificateAuthorities signals to + // the client that it may send any certificate in response + // to our request. When we know the CAs we trust, then + // we can send them down, so that the client can choose + // an appropriate certificate to give to us. + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + hs.finishedHash.Write(certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + helloDone := new(serverHelloDoneMsg) + hs.finishedHash.Write(helloDone.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { + return err + } + + if _, err := c.flush(); err != nil { + return err + } + } + + if c.handshakeStatusAsync < stateServerHandshakeDoFullHandshake2ReadHandshake1 { + msg, err = c.readHandshake() + if err != nil { + if err != errDataNotEnough { + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake1 + } + return err + } + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake1 + } + + // If we requested a client certificate, then the client must send a + // certificate message, even if it's empty. + + if c.config.ClientAuth >= RequestClientCert { + if c.handshakeStatusAsync < stateServerHandshakeDoFullHandshake2HandleCertificateMsg { + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2HandleCertificateMsg + + certMsg, ok := msg.(*certificateMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + if err := c.processCertsFromClient(Certificate{ + Certificate: certMsg.certificates, + }); err != nil { + return err + } + if len(certMsg.certificates) != 0 { + pub = c.peerCertificates[0].PublicKey + } + } + if c.handshakeStatusAsync < stateServerHandshakeDoFullHandshake2ReadHandshake2 { + msg, err = c.readHandshake() + if err != nil { + if err != errDataNotEnough { + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake2 + } + return err + } + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake2 + + } + } + + if c.handshakeStatusAsync < stateServerHandshakeDoFullHandshake2HandleVerifyConnection { + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2HandleVerifyConnection + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + // Get client key exchange + ckx, ok := msg.(*clientKeyExchangeMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(ckx, msg) + } + hs.finishedHash.Write(ckx.marshal()) + + preMasterSecret, err := hs.ka.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) + if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + + } + + if c.handshakeStatusAsync >= stateServerHandshakeDoFullHandshake2ReadHandshake3 { + return nil + } + // If we received a client cert in response to our certificate request message, + // the client will send us a certificateVerifyMsg immediately after the + // clientKeyExchangeMsg. This message is a digest of all preceding + // handshake-layer messages that is signed using the private key corresponding + // to the client's certificate. This allows us to verify that the client is in + // possession of the private key of the certificate. + if len(c.peerCertificates) > 0 { + msg, err := c.readHandshake() + if err != nil { + if err != errDataNotEnough { + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3 + } + return err + } + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3 + return unexpectedMessageError(certVerify, msg) + } + + var sigType uint8 + var sigHash crypto.Hash + if c.vers >= VersionTLS12 { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3 + return errors.New("tls: client certificate used with invalid signature algorithm") + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3 + return c.sendAlert(alertInternalError) + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub) + if err != nil { + c.sendAlert(alertIllegalParameter) + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3 + return err + } + } + + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3 + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } + + hs.finishedHash.Write(certVerify.marshal()) + } + + hs.finishedHash.discardHandshakeBuffer() + + c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3 + + return nil +} + +func (hs *serverHandshakeState) establishKeys() error { + c := hs.c + if c.handshakeStatusAsync >= stateServerHandshakeEstablishKeys2 { + return nil + } + c.handshakeStatusAsync = stateServerHandshakeEstablishKeys2 + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + + var clientCipher, serverCipher interface{} + var clientHash, serverHash hash.Hash + + if hs.suite.aead == nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) + clientHash = hs.suite.mac(clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) + serverHash = hs.suite.mac(serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) + c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) + return nil +} + +func (hs *serverHandshakeState) readFinished(out []byte) error { + c := hs.c + + if c.handshakeStatusAsync < stateServerHandshakeReadFinishedReadChangeCipherSpec { + if err := c.readChangeCipherSpec(); err != nil { + if err != errDataNotEnough { + c.handshakeStatusAsync = stateServerHandshakeReadFinishedReadChangeCipherSpec + } + return err + } + c.handshakeStatusAsync = stateServerHandshakeReadFinishedReadChangeCipherSpec + } + + if c.handshakeStatusAsync >= stateServerHandshakeReadFinishedDone { + return nil + } + msg, err := c.readHandshake() + if err != nil { + if err != errDataNotEnough { + c.handshakeStatusAsync = stateServerHandshakeReadFinishedDone + } + return err + } + clientFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + c.handshakeStatusAsync = stateServerHandshakeReadFinishedDone + return unexpectedMessageError(clientFinished, msg) + } + + verify := hs.finishedHash.clientSum(hs.masterSecret) + if len(verify) != len(clientFinished.verifyData) || + subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + c.handshakeStatusAsync = stateServerHandshakeReadFinishedDone + return errors.New("tls: client's Finished message is incorrect") + } + + hs.finishedHash.Write(clientFinished.marshal()) + copy(out, verify) + + c.handshakeStatusAsync = stateServerHandshakeReadFinishedDone + return nil +} + +func (hs *serverHandshakeState) sendSessionTicket() error { + // ticketSupported is set in a resumption handshake if the + // ticket from the client was encrypted with an old session + // ticket key and thus a refreshed ticket should be sent. + if !hs.hello.ticketSupported { + return nil + } + + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshakeSendSessionTicket { + return nil + } + c.handshakeStatusAsync = stateServerHandshakeSendSessionTicket + + m := new(newSessionTicketMsg) + + createdAt := uint64(c.config.time().Unix()) + if hs.sessionState != nil { + // If this is re-wrapping an old key, then keep + // the original time it was created. + createdAt = hs.sessionState.createdAt + } + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionState{ + vers: c.vers, + cipherSuite: hs.suite.id, + createdAt: createdAt, + masterSecret: hs.masterSecret, + certificates: certsFromClient, + } + var err error + m.ticket, err = c.encryptTicket(state.marshal()) + if err != nil { + return err + } + + hs.finishedHash.Write(m.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeState) sendSessionTicket2() error { + // ticketSupported is set in a resumption handshake if the + // ticket from the client was encrypted with an old session + // ticket key and thus a refreshed ticket should be sent. + if !hs.hello.ticketSupported { + return nil + } + + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshakeSendSessionTicket2 { + return nil + } + c.handshakeStatusAsync = stateServerHandshakeSendSessionTicket2 + + m := new(newSessionTicketMsg) + + createdAt := uint64(c.config.time().Unix()) + if hs.sessionState != nil { + // If this is re-wrapping an old key, then keep + // the original time it was created. + createdAt = hs.sessionState.createdAt + } + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionState{ + vers: c.vers, + cipherSuite: hs.suite.id, + createdAt: createdAt, + masterSecret: hs.masterSecret, + certificates: certsFromClient, + } + var err error + m.ticket, err = c.encryptTicket(state.marshal()) + if err != nil { + return err + } + + hs.finishedHash.Write(m.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshakeSendFinished { + return nil + } + c.handshakeStatusAsync = stateServerHandshakeSendFinished + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + copy(out, finished.verifyData) + + return nil +} + +func (hs *serverHandshakeState) sendFinished2(out []byte) error { + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshakeSendFinished2 { + return nil + } + c.handshakeStatusAsync = stateServerHandshakeSendFinished2 + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + copy(out, finished.verifyData) + + return nil +} + +// processCertsFromClient takes a chain of client certificates either from a +// Certificates message or from a sessionState and verifies them. It returns +// the public key of the leaf certificate. +func (c *Conn) processCertsFromClient(certificate Certificate) error { + certificates := certificate.Certificate + certs := make([]*x509.Certificate, len(certificates)) + var err error + for i, asn1Data := range certificates { + if certs[i], err = x509.ParseCertificate(asn1Data); err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse client certificate: " + err.Error()) + } + } + + if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: client didn't provide a certificate") + } + + if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 { + opts := x509.VerifyOptions{ + Roots: c.config.ClientCAs, + CurrentTime: c.config.time(), + Intermediates: x509.NewCertPool(), + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + + chains, err := certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to verify client certificate: " + err.Error()) + } + + c.verifiedChains = chains + } + + c.peerCertificates = certs + c.ocspResponse = certificate.OCSPStaple + c.scts = certificate.SignedCertificateTimestamps + + if len(certs) > 0 { + switch certs[0].PublicKey.(type) { + case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) + } + } + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + return nil +} + +func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { + supportedVersions := clientHello.supportedVersions + if len(clientHello.supportedVersions) == 0 { + supportedVersions = supportedVersionsFromMax(clientHello.vers) + } + + return &ClientHelloInfo{ + CipherSuites: clientHello.cipherSuites, + ServerName: clientHello.serverName, + SupportedCurves: clientHello.supportedCurves, + SupportedPoints: clientHello.supportedPoints, + SignatureSchemes: clientHello.supportedSignatureAlgorithms, + SupportedProtos: clientHello.alpnProtocols, + SupportedVersions: supportedVersions, + Conn: c.conn, + config: c.config, + } +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_server_test.go b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_server_test.go new file mode 100644 index 0000000..d6bf9e4 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_server_test.go @@ -0,0 +1,1941 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto" + "crypto/elliptic" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io" + "net" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "golang.org/x/crypto/curve25519" +) + +func testClientHello(t *testing.T, serverConfig *Config, m handshakeMessage) { + testClientHelloFailure(t, serverConfig, m, "") +} + +func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessage, expectedSubStr string) { + c, s := localPipe(t) + go func() { + cli := Client(c, testConfig) + if ch, ok := m.(*clientHelloMsg); ok { + cli.vers = ch.vers + } + cli.writeRecord(recordTypeHandshake, m.marshal()) + c.Close() + }() + conn := Server(s, serverConfig) + ch, err := conn.readClientHello() + hs := serverHandshakeState{ + c: conn, + clientHello: ch, + } + if err == nil { + err = hs.processClientHello() + } + if err == nil { + err = hs.pickCipherSuite() + } + s.Close() + if len(expectedSubStr) == 0 { + if err != nil && err != io.EOF { + t.Errorf("Got error: %s; expected to succeed", err) + } + } else if err == nil || !strings.Contains(err.Error(), expectedSubStr) { + t.Errorf("Got error: %v; expected to match substring '%s'", err, expectedSubStr) + } +} + +func TestSimpleError(t *testing.T) { + testClientHelloFailure(t, testConfig, &serverHelloDoneMsg{}, "unexpected handshake message") +} + +var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x0205, VersionSSL30} + +func TestRejectBadProtocolVersion(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionSSL30 + for _, v := range badProtocolVersions { + testClientHelloFailure(t, config, &clientHelloMsg{ + vers: v, + random: make([]byte, 32), + }, "unsupported versions") + } + testClientHelloFailure(t, config, &clientHelloMsg{ + vers: VersionTLS12, + supportedVersions: badProtocolVersions, + random: make([]byte, 32), + }, "unsupported versions") +} + +func TestNoSuiteOverlap(t *testing.T) { + clientHello := &clientHelloMsg{ + vers: VersionTLS10, + random: make([]byte, 32), + cipherSuites: []uint16{0xff00}, + compressionMethods: []uint8{compressionNone}, + } + testClientHelloFailure(t, testConfig, clientHello, "no cipher suite supported by both client and server") +} + +func TestNoCompressionOverlap(t *testing.T) { + clientHello := &clientHelloMsg{ + vers: VersionTLS10, + random: make([]byte, 32), + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + compressionMethods: []uint8{0xff}, + } + testClientHelloFailure(t, testConfig, clientHello, "client does not support uncompressed connections") +} + +func TestNoRC4ByDefault(t *testing.T) { + clientHello := &clientHelloMsg{ + vers: VersionTLS10, + random: make([]byte, 32), + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + compressionMethods: []uint8{compressionNone}, + } + serverConfig := testConfig.Clone() + // Reset the enabled cipher suites to nil in order to test the + // defaults. + serverConfig.CipherSuites = nil + testClientHelloFailure(t, serverConfig, clientHello, "no cipher suite supported by both client and server") +} + +func TestRejectSNIWithTrailingDot(t *testing.T) { + testClientHelloFailure(t, testConfig, &clientHelloMsg{ + vers: VersionTLS12, + random: make([]byte, 32), + serverName: "foo.com.", + }, "unexpected message") +} + +func TestDontSelectECDSAWithRSAKey(t *testing.T) { + // Test that, even when both sides support an ECDSA cipher suite, it + // won't be selected if the server's private key doesn't support it. + clientHello := &clientHelloMsg{ + vers: VersionTLS10, + random: make([]byte, 32), + cipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, + compressionMethods: []uint8{compressionNone}, + supportedCurves: []CurveID{CurveP256}, + supportedPoints: []uint8{pointFormatUncompressed}, + } + serverConfig := testConfig.Clone() + serverConfig.CipherSuites = clientHello.cipherSuites + serverConfig.Certificates = make([]Certificate, 1) + serverConfig.Certificates[0].Certificate = [][]byte{testECDSACertificate} + serverConfig.Certificates[0].PrivateKey = testECDSAPrivateKey + serverConfig.BuildNameToCertificate() + // First test that it *does* work when the server's key is ECDSA. + testClientHello(t, serverConfig, clientHello) + + // Now test that switching to an RSA key causes the expected error (and + // not an internal error about a signing failure). + serverConfig.Certificates = testConfig.Certificates + testClientHelloFailure(t, serverConfig, clientHello, "no cipher suite supported by both client and server") +} + +func TestDontSelectRSAWithECDSAKey(t *testing.T) { + // Test that, even when both sides support an RSA cipher suite, it + // won't be selected if the server's private key doesn't support it. + clientHello := &clientHelloMsg{ + vers: VersionTLS10, + random: make([]byte, 32), + cipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA}, + compressionMethods: []uint8{compressionNone}, + supportedCurves: []CurveID{CurveP256}, + supportedPoints: []uint8{pointFormatUncompressed}, + } + serverConfig := testConfig.Clone() + serverConfig.CipherSuites = clientHello.cipherSuites + // First test that it *does* work when the server's key is RSA. + testClientHello(t, serverConfig, clientHello) + + // Now test that switching to an ECDSA key causes the expected error + // (and not an internal error about a signing failure). + serverConfig.Certificates = make([]Certificate, 1) + serverConfig.Certificates[0].Certificate = [][]byte{testECDSACertificate} + serverConfig.Certificates[0].PrivateKey = testECDSAPrivateKey + serverConfig.BuildNameToCertificate() + testClientHelloFailure(t, serverConfig, clientHello, "no cipher suite supported by both client and server") +} + +func TestRenegotiationExtension(t *testing.T) { + clientHello := &clientHelloMsg{ + vers: VersionTLS12, + compressionMethods: []uint8{compressionNone}, + random: make([]byte, 32), + secureRenegotiationSupported: true, + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + } + + bufChan := make(chan []byte, 1) + c, s := localPipe(t) + + go func() { + cli := Client(c, testConfig) + cli.vers = clientHello.vers + cli.writeRecord(recordTypeHandshake, clientHello.marshal()) + + buf := make([]byte, 1024) + n, err := c.Read(buf) + if err != nil { + t.Errorf("Server read returned error: %s", err) + return + } + c.Close() + bufChan <- buf[:n] + }() + + Server(s, testConfig).Handshake() + buf := <-bufChan + + if len(buf) < 5+4 { + t.Fatalf("Server returned short message of length %d", len(buf)) + } + // buf contains a TLS record, with a 5 byte record header and a 4 byte + // handshake header. The length of the ServerHello is taken from the + // handshake header. + serverHelloLen := int(buf[6])<<16 | int(buf[7])<<8 | int(buf[8]) + + var serverHello serverHelloMsg + // unmarshal expects to be given the handshake header, but + // serverHelloLen doesn't include it. + if !serverHello.unmarshal(buf[5 : 9+serverHelloLen]) { + t.Fatalf("Failed to parse ServerHello") + } + + if !serverHello.secureRenegotiationSupported { + t.Errorf("Secure renegotiation extension was not echoed.") + } +} + +func TestTLS12OnlyCipherSuites(t *testing.T) { + // Test that a Server doesn't select a TLS 1.2-only cipher suite when + // the client negotiates TLS 1.1. + clientHello := &clientHelloMsg{ + vers: VersionTLS11, + random: make([]byte, 32), + cipherSuites: []uint16{ + // The Server, by default, will use the client's + // preference order. So the GCM cipher suite + // will be selected unless it's excluded because + // of the version in this ClientHello. + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_RC4_128_SHA, + }, + compressionMethods: []uint8{compressionNone}, + supportedCurves: []CurveID{CurveP256, CurveP384, CurveP521}, + supportedPoints: []uint8{pointFormatUncompressed}, + } + + c, s := localPipe(t) + replyChan := make(chan interface{}) + go func() { + cli := Client(c, testConfig) + cli.vers = clientHello.vers + cli.writeRecord(recordTypeHandshake, clientHello.marshal()) + reply, err := cli.readHandshake() + c.Close() + if err != nil { + replyChan <- err + } else { + replyChan <- reply + } + }() + config := testConfig.Clone() + config.CipherSuites = clientHello.cipherSuites + Server(s, config).Handshake() + s.Close() + reply := <-replyChan + if err, ok := reply.(error); ok { + t.Fatal(err) + } + serverHello, ok := reply.(*serverHelloMsg) + if !ok { + t.Fatalf("didn't get ServerHello message in reply. Got %v\n", reply) + } + if s := serverHello.cipherSuite; s != TLS_RSA_WITH_RC4_128_SHA { + t.Fatalf("bad cipher suite from server: %x", s) + } +} + +func TestTLSPointFormats(t *testing.T) { + // Test that a Server returns the ec_point_format extension when ECC is + // negotiated, and not returned on RSA handshake. + tests := []struct { + name string + cipherSuites []uint16 + supportedCurves []CurveID + supportedPoints []uint8 + wantSupportedPoints bool + }{ + {"ECC", []uint16{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA}, []CurveID{CurveP256}, []uint8{compressionNone}, true}, + {"RSA", []uint16{TLS_RSA_WITH_AES_256_GCM_SHA384}, nil, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientHello := &clientHelloMsg{ + vers: VersionTLS12, + random: make([]byte, 32), + cipherSuites: tt.cipherSuites, + compressionMethods: []uint8{compressionNone}, + supportedCurves: tt.supportedCurves, + supportedPoints: tt.supportedPoints, + } + + c, s := localPipe(t) + replyChan := make(chan interface{}) + go func() { + cli := Client(c, testConfig) + cli.vers = clientHello.vers + cli.writeRecord(recordTypeHandshake, clientHello.marshal()) + reply, err := cli.readHandshake() + c.Close() + if err != nil { + replyChan <- err + } else { + replyChan <- reply + } + }() + config := testConfig.Clone() + config.CipherSuites = clientHello.cipherSuites + Server(s, config).Handshake() + s.Close() + reply := <-replyChan + if err, ok := reply.(error); ok { + t.Fatal(err) + } + serverHello, ok := reply.(*serverHelloMsg) + if !ok { + t.Fatalf("didn't get ServerHello message in reply. Got %v\n", reply) + } + if tt.wantSupportedPoints { + if len(serverHello.supportedPoints) < 1 { + t.Fatal("missing ec_point_format extension from server") + } + found := false + for _, p := range serverHello.supportedPoints { + if p == pointFormatUncompressed { + found = true + break + } + } + if !found { + t.Fatal("missing uncompressed format in ec_point_format extension from server") + } + } else { + if len(serverHello.supportedPoints) != 0 { + t.Fatalf("unexcpected ec_point_format extension from server: %v", serverHello.supportedPoints) + } + } + }) + } +} + +func TestAlertForwarding(t *testing.T) { + c, s := localPipe(t) + go func() { + Client(c, testConfig).sendAlert(alertUnknownCA) + c.Close() + }() + + err := Server(s, testConfig).Handshake() + s.Close() + var opErr *net.OpError + if !errors.As(err, &opErr) || opErr.Err != error(alertUnknownCA) { + t.Errorf("Got error: %s; expected: %s", err, error(alertUnknownCA)) + } +} + +func TestClose(t *testing.T) { + c, s := localPipe(t) + go c.Close() + + err := Server(s, testConfig).Handshake() + s.Close() + if err != io.EOF { + t.Errorf("Got error: %s; expected: %s", err, io.EOF) + } +} + +func TestVersion(t *testing.T) { + serverConfig := &Config{ + Certificates: testConfig.Certificates, + MaxVersion: VersionTLS11, + } + clientConfig := &Config{ + InsecureSkipVerify: true, + } + state, _, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + if state.Version != VersionTLS11 { + t.Fatalf("Incorrect version %x, should be %x", state.Version, VersionTLS11) + } +} + +func TestCipherSuitePreference(t *testing.T) { + serverConfig := &Config{ + CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_RSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, + Certificates: testConfig.Certificates, + MaxVersion: VersionTLS11, + } + clientConfig := &Config{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_RC4_128_SHA}, + InsecureSkipVerify: true, + } + state, _, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + if state.CipherSuite != TLS_RSA_WITH_AES_128_CBC_SHA { + // By default the server should use the client's preference. + t.Fatalf("Client's preference was not used, got %x", state.CipherSuite) + } + + serverConfig.PreferServerCipherSuites = true + state, _, err = testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + if state.CipherSuite != TLS_RSA_WITH_RC4_128_SHA { + t.Fatalf("Server's preference was not used, got %x", state.CipherSuite) + } +} + +func TestSCTHandshake(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testSCTHandshake(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testSCTHandshake(t, VersionTLS13) }) +} + +func testSCTHandshake(t *testing.T, version uint16) { + expected := [][]byte{[]byte("certificate"), []byte("transparency")} + serverConfig := &Config{ + Certificates: []Certificate{{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + SignedCertificateTimestamps: expected, + }}, + MaxVersion: version, + } + clientConfig := &Config{ + InsecureSkipVerify: true, + } + _, state, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + actual := state.SignedCertificateTimestamps + if len(actual) != len(expected) { + t.Fatalf("got %d scts, want %d", len(actual), len(expected)) + } + for i, sct := range expected { + if !bytes.Equal(sct, actual[i]) { + t.Fatalf("SCT #%d was %x, but expected %x", i, actual[i], sct) + } + } +} + +func TestCrossVersionResume(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testCrossVersionResume(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testCrossVersionResume(t, VersionTLS13) }) +} + +func testCrossVersionResume(t *testing.T, version uint16) { + serverConfig := &Config{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA}, + Certificates: testConfig.Certificates, + } + clientConfig := &Config{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA}, + InsecureSkipVerify: true, + ClientSessionCache: NewLRUClientSessionCache(1), + ServerName: "servername", + } + + // Establish a session at TLS 1.1. + clientConfig.MaxVersion = VersionTLS11 + _, _, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + + // The client session cache now contains a TLS 1.1 session. + state, _, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + if !state.DidResume { + t.Fatalf("handshake did not resume at the same version") + } + + // Test that the server will decline to resume at a lower version. + clientConfig.MaxVersion = VersionTLS10 + state, _, err = testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + if state.DidResume { + t.Fatalf("handshake resumed at a lower version") + } + + // The client session cache now contains a TLS 1.0 session. + state, _, err = testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + if !state.DidResume { + t.Fatalf("handshake did not resume at the same version") + } + + // Test that the server will decline to resume at a higher version. + clientConfig.MaxVersion = VersionTLS11 + state, _, err = testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + if state.DidResume { + t.Fatalf("handshake resumed at a higher version") + } +} + +// Note: see comment in handshake_test.go for details of how the reference +// tests work. + +// serverTest represents a test of the TLS server handshake against a reference +// implementation. +type serverTest struct { + // name is a freeform string identifying the test and the file in which + // the expected results will be stored. + name string + // command, if not empty, contains a series of arguments for the + // command to run for the reference server. + command []string + // expectedPeerCerts contains a list of PEM blocks of expected + // certificates from the client. + expectedPeerCerts []string + // config, if not nil, contains a custom Config to use for this test. + config *Config + // expectHandshakeErrorIncluding, when not empty, contains a string + // that must be a substring of the error resulting from the handshake. + expectHandshakeErrorIncluding string + // validate, if not nil, is a function that will be called with the + // ConnectionState of the resulting connection. It returns false if the + // ConnectionState is unacceptable. + validate func(ConnectionState) error + // wait, if true, prevents this subtest from calling t.Parallel. + // If false, runServerTest* returns immediately. + wait bool +} + +var defaultClientCommand = []string{"openssl", "s_client", "-no_ticket"} + +// connFromCommand starts opens a listening socket and starts the reference +// client to connect to it. It returns a recordingConn that wraps the resulting +// connection. +func (test *serverTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, err error) { + l, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, + }) + if err != nil { + return nil, nil, err + } + defer l.Close() + + port := l.Addr().(*net.TCPAddr).Port + + var command []string + command = append(command, test.command...) + if len(command) == 0 { + command = defaultClientCommand + } + command = append(command, "-connect") + command = append(command, fmt.Sprintf("127.0.0.1:%d", port)) + cmd := exec.Command(command[0], command[1:]...) + cmd.Stdin = nil + var output bytes.Buffer + cmd.Stdout = &output + cmd.Stderr = &output + if err := cmd.Start(); err != nil { + return nil, nil, err + } + + connChan := make(chan interface{}, 1) + go func() { + tcpConn, err := l.Accept() + if err != nil { + connChan <- err + return + } + connChan <- tcpConn + }() + + var tcpConn net.Conn + select { + case connOrError := <-connChan: + if err, ok := connOrError.(error); ok { + return nil, nil, err + } + tcpConn = connOrError.(net.Conn) + case <-time.After(2 * time.Second): + return nil, nil, errors.New("timed out waiting for connection from child process") + } + + record := &recordingConn{ + Conn: tcpConn, + } + + return record, cmd, nil +} + +func (test *serverTest) dataPath() string { + return filepath.Join("testdata", "Server-"+test.name) +} + +func (test *serverTest) loadData() (flows [][]byte, err error) { + in, err := os.Open(test.dataPath()) + if err != nil { + return nil, err + } + defer in.Close() + return parseTestData(in) +} + +func (test *serverTest) run(t *testing.T, write bool) { + var clientConn, serverConn net.Conn + var recordingConn *recordingConn + var childProcess *exec.Cmd + + if write { + var err error + recordingConn, childProcess, err = test.connFromCommand() + if err != nil { + t.Fatalf("Failed to start subcommand: %s", err) + } + serverConn = recordingConn + defer func() { + if t.Failed() { + t.Logf("OpenSSL output:\n\n%s", childProcess.Stdout) + } + }() + } else { + clientConn, serverConn = localPipe(t) + } + config := test.config + if config == nil { + config = testConfig + } + server := Server(serverConn, config) + connStateChan := make(chan ConnectionState, 1) + go func() { + _, err := server.Write([]byte("hello, world\n")) + if len(test.expectHandshakeErrorIncluding) > 0 { + if err == nil { + t.Errorf("Error expected, but no error returned") + } else if s := err.Error(); !strings.Contains(s, test.expectHandshakeErrorIncluding) { + t.Errorf("Error expected containing '%s' but got '%s'", test.expectHandshakeErrorIncluding, s) + } + } else { + if err != nil { + t.Logf("Error from Server.Write: '%s'", err) + } + } + server.Close() + serverConn.Close() + connStateChan <- server.ConnectionState() + }() + + if !write { + flows, err := test.loadData() + if err != nil { + t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath()) + } + for i, b := range flows { + if i%2 == 0 { + if *fast { + clientConn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + } else { + clientConn.SetWriteDeadline(time.Now().Add(1 * time.Minute)) + } + clientConn.Write(b) + continue + } + bb := make([]byte, len(b)) + if *fast { + clientConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + } else { + clientConn.SetReadDeadline(time.Now().Add(1 * time.Minute)) + } + n, err := io.ReadFull(clientConn, bb) + if err != nil { + t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b) + } + if !bytes.Equal(b, bb) { + t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b) + } + } + clientConn.Close() + } + + connState := <-connStateChan + peerCerts := connState.PeerCertificates + if len(peerCerts) == len(test.expectedPeerCerts) { + for i, peerCert := range peerCerts { + block, _ := pem.Decode([]byte(test.expectedPeerCerts[i])) + if !bytes.Equal(block.Bytes, peerCert.Raw) { + t.Fatalf("%s: mismatch on peer cert %d", test.name, i+1) + } + } + } else { + t.Fatalf("%s: mismatch on peer list length: %d (wanted) != %d (got)", test.name, len(test.expectedPeerCerts), len(peerCerts)) + } + + if test.validate != nil { + if err := test.validate(connState); err != nil { + t.Fatalf("validate callback returned error: %s", err) + } + } + + if write { + path := test.dataPath() + out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + t.Fatalf("Failed to create output file: %s", err) + } + defer out.Close() + recordingConn.Close() + if len(recordingConn.flows) < 3 { + if len(test.expectHandshakeErrorIncluding) == 0 { + t.Fatalf("Handshake failed") + } + } + recordingConn.WriteTo(out) + t.Logf("Wrote %s\n", path) + childProcess.Wait() + } +} + +func runServerTestForVersion(t *testing.T, template *serverTest, version, option string) { + // Make a deep copy of the template before going parallel. + test := *template + if template.config != nil { + test.config = template.config.Clone() + } + test.name = version + "-" + test.name + if len(test.command) == 0 { + test.command = defaultClientCommand + } + test.command = append([]string(nil), test.command...) + test.command = append(test.command, option) + + runTestAndUpdateIfNeeded(t, version, test.run, test.wait) +} + +func runServerTestTLS10(t *testing.T, template *serverTest) { + runServerTestForVersion(t, template, "TLSv10", "-tls1") +} + +func runServerTestTLS11(t *testing.T, template *serverTest) { + runServerTestForVersion(t, template, "TLSv11", "-tls1_1") +} + +func runServerTestTLS12(t *testing.T, template *serverTest) { + runServerTestForVersion(t, template, "TLSv12", "-tls1_2") +} + +func runServerTestTLS13(t *testing.T, template *serverTest) { + runServerTestForVersion(t, template, "TLSv13", "-tls1_3") +} + +func TestHandshakeServerRSARC4(t *testing.T) { + test := &serverTest{ + name: "RSA-RC4", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "RC4-SHA"}, + } + runServerTestTLS10(t, test) + runServerTestTLS11(t, test) + runServerTestTLS12(t, test) +} + +func TestHandshakeServerRSA3DES(t *testing.T) { + test := &serverTest{ + name: "RSA-3DES", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "DES-CBC3-SHA"}, + } + runServerTestTLS10(t, test) + runServerTestTLS12(t, test) +} + +func TestHandshakeServerRSAAES(t *testing.T) { + test := &serverTest{ + name: "RSA-AES", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA"}, + } + runServerTestTLS10(t, test) + runServerTestTLS12(t, test) +} + +func TestHandshakeServerAESGCM(t *testing.T) { + test := &serverTest{ + name: "RSA-AES-GCM", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "ECDHE-RSA-AES128-GCM-SHA256"}, + } + runServerTestTLS12(t, test) +} + +func TestHandshakeServerAES256GCMSHA384(t *testing.T) { + test := &serverTest{ + name: "RSA-AES256-GCM-SHA384", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "ECDHE-RSA-AES256-GCM-SHA384"}, + } + runServerTestTLS12(t, test) +} + +func TestHandshakeServerAES128SHA256(t *testing.T) { + test := &serverTest{ + name: "AES128-SHA256", + command: []string{"openssl", "s_client", "-no_ticket", "-ciphersuites", "TLS_AES_128_GCM_SHA256"}, + } + runServerTestTLS13(t, test) +} +func TestHandshakeServerAES256SHA384(t *testing.T) { + test := &serverTest{ + name: "AES256-SHA384", + command: []string{"openssl", "s_client", "-no_ticket", "-ciphersuites", "TLS_AES_256_GCM_SHA384"}, + } + runServerTestTLS13(t, test) +} +func TestHandshakeServerCHACHA20SHA256(t *testing.T) { + test := &serverTest{ + name: "CHACHA20-SHA256", + command: []string{"openssl", "s_client", "-no_ticket", "-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"}, + } + runServerTestTLS13(t, test) +} + +func TestHandshakeServerECDHEECDSAAES(t *testing.T) { + config := testConfig.Clone() + config.Certificates = make([]Certificate, 1) + config.Certificates[0].Certificate = [][]byte{testECDSACertificate} + config.Certificates[0].PrivateKey = testECDSAPrivateKey + config.BuildNameToCertificate() + + test := &serverTest{ + name: "ECDHE-ECDSA-AES", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "ECDHE-ECDSA-AES256-SHA", "-ciphersuites", "TLS_AES_128_GCM_SHA256"}, + config: config, + } + runServerTestTLS10(t, test) + runServerTestTLS12(t, test) + runServerTestTLS13(t, test) +} + +func TestHandshakeServerX25519(t *testing.T) { + config := testConfig.Clone() + config.CurvePreferences = []CurveID{X25519} + + test := &serverTest{ + name: "X25519", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "ECDHE-RSA-CHACHA20-POLY1305", "-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256", "-curves", "X25519"}, + config: config, + } + runServerTestTLS12(t, test) + runServerTestTLS13(t, test) +} + +func TestHandshakeServerP256(t *testing.T) { + config := testConfig.Clone() + config.CurvePreferences = []CurveID{CurveP256} + + test := &serverTest{ + name: "P256", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "ECDHE-RSA-CHACHA20-POLY1305", "-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256", "-curves", "P-256"}, + config: config, + } + runServerTestTLS12(t, test) + runServerTestTLS13(t, test) +} + +func TestHandshakeServerHelloRetryRequest(t *testing.T) { + config := testConfig.Clone() + config.CurvePreferences = []CurveID{CurveP256} + + test := &serverTest{ + name: "HelloRetryRequest", + command: []string{"openssl", "s_client", "-no_ticket", "-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256", "-curves", "X25519:P-256"}, + config: config, + } + runServerTestTLS13(t, test) +} + +func TestHandshakeServerALPN(t *testing.T) { + config := testConfig.Clone() + config.NextProtos = []string{"proto1", "proto2"} + + test := &serverTest{ + name: "ALPN", + // Note that this needs OpenSSL 1.0.2 because that is the first + // version that supports the -alpn flag. + command: []string{"openssl", "s_client", "-alpn", "proto2,proto1", "-cipher", "ECDHE-RSA-CHACHA20-POLY1305", "-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"}, + config: config, + validate: func(state ConnectionState) error { + // The server's preferences should override the client. + if state.NegotiatedProtocol != "proto1" { + return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol) + } + return nil + }, + } + runServerTestTLS12(t, test) + runServerTestTLS13(t, test) +} + +func TestHandshakeServerALPNNoMatch(t *testing.T) { + config := testConfig.Clone() + config.NextProtos = []string{"proto3"} + + test := &serverTest{ + name: "ALPN-NoMatch", + // Note that this needs OpenSSL 1.0.2 because that is the first + // version that supports the -alpn flag. + command: []string{"openssl", "s_client", "-alpn", "proto2,proto1", "-cipher", "ECDHE-RSA-CHACHA20-POLY1305", "-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"}, + config: config, + validate: func(state ConnectionState) error { + // Rather than reject the connection, Go doesn't select + // a protocol when there is no overlap. + if state.NegotiatedProtocol != "" { + return fmt.Errorf("Got protocol %q, wanted ''", state.NegotiatedProtocol) + } + return nil + }, + } + runServerTestTLS12(t, test) + runServerTestTLS13(t, test) +} + +// TestHandshakeServerSNI involves a client sending an SNI extension of +// "snitest.com", which happens to match the CN of testSNICertificate. The test +// verifies that the server correctly selects that certificate. +func TestHandshakeServerSNI(t *testing.T) { + test := &serverTest{ + name: "SNI", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"}, + } + runServerTestTLS12(t, test) +} + +// TestHandshakeServerSNICertForName is similar to TestHandshakeServerSNI, but +// tests the dynamic GetCertificate method +func TestHandshakeServerSNIGetCertificate(t *testing.T) { + config := testConfig.Clone() + + // Replace the NameToCertificate map with a GetCertificate function + nameToCert := config.NameToCertificate + config.NameToCertificate = nil + config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { + cert := nameToCert[clientHello.ServerName] + return cert, nil + } + test := &serverTest{ + name: "SNI-GetCertificate", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"}, + config: config, + } + runServerTestTLS12(t, test) +} + +// TestHandshakeServerSNICertForNameNotFound is similar to +// TestHandshakeServerSNICertForName, but tests to make sure that when the +// GetCertificate method doesn't return a cert, we fall back to what's in +// the NameToCertificate map. +func TestHandshakeServerSNIGetCertificateNotFound(t *testing.T) { + config := testConfig.Clone() + + config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { + return nil, nil + } + test := &serverTest{ + name: "SNI-GetCertificateNotFound", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"}, + config: config, + } + runServerTestTLS12(t, test) +} + +// TestHandshakeServerSNICertForNameError tests to make sure that errors in +// GetCertificate result in a tls alert. +func TestHandshakeServerSNIGetCertificateError(t *testing.T) { + const errMsg = "TestHandshakeServerSNIGetCertificateError error" + + serverConfig := testConfig.Clone() + serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { + return nil, errors.New(errMsg) + } + + clientHello := &clientHelloMsg{ + vers: VersionTLS10, + random: make([]byte, 32), + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + compressionMethods: []uint8{compressionNone}, + serverName: "test", + } + testClientHelloFailure(t, serverConfig, clientHello, errMsg) +} + +// TestHandshakeServerEmptyCertificates tests that GetCertificates is called in +// the case that Certificates is empty, even without SNI. +func TestHandshakeServerEmptyCertificates(t *testing.T) { + const errMsg = "TestHandshakeServerEmptyCertificates error" + + serverConfig := testConfig.Clone() + serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { + return nil, errors.New(errMsg) + } + serverConfig.Certificates = nil + + clientHello := &clientHelloMsg{ + vers: VersionTLS10, + random: make([]byte, 32), + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + compressionMethods: []uint8{compressionNone}, + } + testClientHelloFailure(t, serverConfig, clientHello, errMsg) + + // With an empty Certificates and a nil GetCertificate, the server + // should always return a “no certificates” error. + serverConfig.GetCertificate = nil + + clientHello = &clientHelloMsg{ + vers: VersionTLS10, + random: make([]byte, 32), + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + compressionMethods: []uint8{compressionNone}, + } + testClientHelloFailure(t, serverConfig, clientHello, "no certificates") +} + +// TestCipherSuiteCertPreferance ensures that we select an RSA ciphersuite with +// an RSA certificate and an ECDSA ciphersuite with an ECDSA certificate. +func TestCipherSuiteCertPreferenceECDSA(t *testing.T) { + config := testConfig.Clone() + config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA} + config.PreferServerCipherSuites = true + + test := &serverTest{ + name: "CipherSuiteCertPreferenceRSA", + config: config, + } + runServerTestTLS12(t, test) + + config = testConfig.Clone() + config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA} + config.Certificates = []Certificate{ + { + Certificate: [][]byte{testECDSACertificate}, + PrivateKey: testECDSAPrivateKey, + }, + } + config.BuildNameToCertificate() + config.PreferServerCipherSuites = true + + test = &serverTest{ + name: "CipherSuiteCertPreferenceECDSA", + config: config, + } + runServerTestTLS12(t, test) +} + +func TestServerResumption(t *testing.T) { + sessionFilePath := tempFile("") + defer os.Remove(sessionFilePath) + + testIssue := &serverTest{ + name: "IssueTicket", + command: []string{"openssl", "s_client", "-cipher", "AES128-SHA", "-ciphersuites", "TLS_AES_128_GCM_SHA256", "-sess_out", sessionFilePath}, + wait: true, + } + testResume := &serverTest{ + name: "Resume", + command: []string{"openssl", "s_client", "-cipher", "AES128-SHA", "-ciphersuites", "TLS_AES_128_GCM_SHA256", "-sess_in", sessionFilePath}, + validate: func(state ConnectionState) error { + if !state.DidResume { + return errors.New("did not resume") + } + return nil + }, + } + + runServerTestTLS12(t, testIssue) + runServerTestTLS12(t, testResume) + + runServerTestTLS13(t, testIssue) + runServerTestTLS13(t, testResume) + + config := testConfig.Clone() + config.CurvePreferences = []CurveID{CurveP256} + + testResumeHRR := &serverTest{ + name: "Resume-HelloRetryRequest", + command: []string{"openssl", "s_client", "-curves", "X25519:P-256", "-cipher", "AES128-SHA", "-ciphersuites", + "TLS_AES_128_GCM_SHA256", "-sess_in", sessionFilePath}, + config: config, + validate: func(state ConnectionState) error { + if !state.DidResume { + return errors.New("did not resume") + } + return nil + }, + } + + runServerTestTLS13(t, testResumeHRR) +} + +func TestServerResumptionDisabled(t *testing.T) { + sessionFilePath := tempFile("") + defer os.Remove(sessionFilePath) + + config := testConfig.Clone() + + testIssue := &serverTest{ + name: "IssueTicketPreDisable", + command: []string{"openssl", "s_client", "-cipher", "AES128-SHA", "-ciphersuites", "TLS_AES_128_GCM_SHA256", "-sess_out", sessionFilePath}, + config: config, + wait: true, + } + testResume := &serverTest{ + name: "ResumeDisabled", + command: []string{"openssl", "s_client", "-cipher", "AES128-SHA", "-ciphersuites", "TLS_AES_128_GCM_SHA256", "-sess_in", sessionFilePath}, + config: config, + validate: func(state ConnectionState) error { + if state.DidResume { + return errors.New("resumed with SessionTicketsDisabled") + } + return nil + }, + } + + config.SessionTicketsDisabled = false + runServerTestTLS12(t, testIssue) + config.SessionTicketsDisabled = true + runServerTestTLS12(t, testResume) + + config.SessionTicketsDisabled = false + runServerTestTLS13(t, testIssue) + config.SessionTicketsDisabled = true + runServerTestTLS13(t, testResume) +} + +func TestFallbackSCSV(t *testing.T) { + serverConfig := Config{ + Certificates: testConfig.Certificates, + } + test := &serverTest{ + name: "FallbackSCSV", + config: &serverConfig, + // OpenSSL 1.0.1j is needed for the -fallback_scsv option. + command: []string{"openssl", "s_client", "-fallback_scsv"}, + expectHandshakeErrorIncluding: "inappropriate protocol fallback", + } + runServerTestTLS11(t, test) +} + +func TestHandshakeServerExportKeyingMaterial(t *testing.T) { + test := &serverTest{ + name: "ExportKeyingMaterial", + command: []string{"openssl", "s_client", "-cipher", "ECDHE-RSA-CHACHA20-POLY1305", "-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"}, + config: testConfig.Clone(), + validate: func(state ConnectionState) error { + if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil { + return fmt.Errorf("ExportKeyingMaterial failed: %v", err) + } else if len(km) != 42 { + return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42) + } + return nil + }, + } + runServerTestTLS10(t, test) + runServerTestTLS12(t, test) + runServerTestTLS13(t, test) +} + +func TestHandshakeServerRSAPKCS1v15(t *testing.T) { + test := &serverTest{ + name: "RSA-RSAPKCS1v15", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "ECDHE-RSA-CHACHA20-POLY1305", "-sigalgs", "rsa_pkcs1_sha256"}, + } + runServerTestTLS12(t, test) +} + +func TestHandshakeServerRSAPSS(t *testing.T) { + // We send rsa_pss_rsae_sha512 first, as the test key won't fit, and we + // verify the server implementation will disregard the client preference in + // that case. See Issue 29793. + test := &serverTest{ + name: "RSA-RSAPSS", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "ECDHE-RSA-CHACHA20-POLY1305", "-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256", "-sigalgs", "rsa_pss_rsae_sha512:rsa_pss_rsae_sha256"}, + } + runServerTestTLS12(t, test) + runServerTestTLS13(t, test) + + test = &serverTest{ + name: "RSA-RSAPSS-TooSmall", + command: []string{"openssl", "s_client", "-no_ticket", "-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256", "-sigalgs", "rsa_pss_rsae_sha512"}, + expectHandshakeErrorIncluding: "peer doesn't support any of the certificate's signature algorithms", + } + runServerTestTLS13(t, test) +} + +func TestHandshakeServerEd25519(t *testing.T) { + config := testConfig.Clone() + config.Certificates = make([]Certificate, 1) + config.Certificates[0].Certificate = [][]byte{testEd25519Certificate} + config.Certificates[0].PrivateKey = testEd25519PrivateKey + config.BuildNameToCertificate() + + test := &serverTest{ + name: "Ed25519", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305", "-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"}, + config: config, + } + runServerTestTLS12(t, test) + runServerTestTLS13(t, test) +} + +func benchmarkHandshakeServer(b *testing.B, version uint16, cipherSuite uint16, curve CurveID, cert []byte, key crypto.PrivateKey) { + config := testConfig.Clone() + config.CipherSuites = []uint16{cipherSuite} + config.CurvePreferences = []CurveID{curve} + config.Certificates = make([]Certificate, 1) + config.Certificates[0].Certificate = [][]byte{cert} + config.Certificates[0].PrivateKey = key + config.BuildNameToCertificate() + + clientConn, serverConn := localPipe(b) + serverConn = &recordingConn{Conn: serverConn} + go func() { + config := testConfig.Clone() + config.MaxVersion = version + config.CurvePreferences = []CurveID{curve} + client := Client(clientConn, config) + client.Handshake() + }() + server := Server(serverConn, config) + if err := server.Handshake(); err != nil { + b.Fatalf("handshake failed: %v", err) + } + serverConn.Close() + flows := serverConn.(*recordingConn).flows + + feeder := make(chan struct{}) + clientConn, serverConn = localPipe(b) + + go func() { + for range feeder { + for i, f := range flows { + if i%2 == 0 { + clientConn.Write(f) + continue + } + ff := make([]byte, len(f)) + n, err := io.ReadFull(clientConn, ff) + if err != nil { + b.Errorf("#%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", i+1, err, n, len(ff), ff[:n], f) + } + if !bytes.Equal(f, ff) { + b.Errorf("#%d: mismatch on read: got:%x want:%x", i+1, ff, f) + } + } + } + }() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + feeder <- struct{}{} + server := Server(serverConn, config) + if err := server.Handshake(); err != nil { + b.Fatalf("handshake failed: %v", err) + } + } + close(feeder) +} + +func BenchmarkHandshakeServer(b *testing.B) { + b.Run("RSA", func(b *testing.B) { + benchmarkHandshakeServer(b, VersionTLS12, TLS_RSA_WITH_AES_128_GCM_SHA256, + 0, testRSACertificate, testRSAPrivateKey) + }) + b.Run("ECDHE-P256-RSA", func(b *testing.B) { + b.Run("TLSv13", func(b *testing.B) { + benchmarkHandshakeServer(b, VersionTLS13, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + CurveP256, testRSACertificate, testRSAPrivateKey) + }) + b.Run("TLSv12", func(b *testing.B) { + benchmarkHandshakeServer(b, VersionTLS12, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + CurveP256, testRSACertificate, testRSAPrivateKey) + }) + }) + b.Run("ECDHE-P256-ECDSA-P256", func(b *testing.B) { + b.Run("TLSv13", func(b *testing.B) { + benchmarkHandshakeServer(b, VersionTLS13, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + CurveP256, testP256Certificate, testP256PrivateKey) + }) + b.Run("TLSv12", func(b *testing.B) { + benchmarkHandshakeServer(b, VersionTLS12, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + CurveP256, testP256Certificate, testP256PrivateKey) + }) + }) + b.Run("ECDHE-X25519-ECDSA-P256", func(b *testing.B) { + b.Run("TLSv13", func(b *testing.B) { + benchmarkHandshakeServer(b, VersionTLS13, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + X25519, testP256Certificate, testP256PrivateKey) + }) + b.Run("TLSv12", func(b *testing.B) { + benchmarkHandshakeServer(b, VersionTLS12, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + X25519, testP256Certificate, testP256PrivateKey) + }) + }) + b.Run("ECDHE-P521-ECDSA-P521", func(b *testing.B) { + if testECDSAPrivateKey.PublicKey.Curve != elliptic.P521() { + b.Fatal("test ECDSA key doesn't use curve P-521") + } + b.Run("TLSv13", func(b *testing.B) { + benchmarkHandshakeServer(b, VersionTLS13, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + CurveP521, testECDSACertificate, testECDSAPrivateKey) + }) + b.Run("TLSv12", func(b *testing.B) { + benchmarkHandshakeServer(b, VersionTLS12, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + CurveP521, testECDSACertificate, testECDSAPrivateKey) + }) + }) +} + +func TestClientAuth(t *testing.T) { + var certPath, keyPath, ecdsaCertPath, ecdsaKeyPath, ed25519CertPath, ed25519KeyPath string + + if *update { + certPath = tempFile(clientCertificatePEM) + defer os.Remove(certPath) + keyPath = tempFile(clientKeyPEM) + defer os.Remove(keyPath) + ecdsaCertPath = tempFile(clientECDSACertificatePEM) + defer os.Remove(ecdsaCertPath) + ecdsaKeyPath = tempFile(clientECDSAKeyPEM) + defer os.Remove(ecdsaKeyPath) + ed25519CertPath = tempFile(clientEd25519CertificatePEM) + defer os.Remove(ed25519CertPath) + ed25519KeyPath = tempFile(clientEd25519KeyPEM) + defer os.Remove(ed25519KeyPath) + } else { + t.Parallel() + } + + config := testConfig.Clone() + config.ClientAuth = RequestClientCert + + test := &serverTest{ + name: "ClientAuthRequestedNotGiven", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-ciphersuites", "TLS_AES_128_GCM_SHA256"}, + config: config, + } + runServerTestTLS12(t, test) + runServerTestTLS13(t, test) + + test = &serverTest{ + name: "ClientAuthRequestedAndGiven", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-ciphersuites", "TLS_AES_128_GCM_SHA256", + "-cert", certPath, "-key", keyPath, "-client_sigalgs", "rsa_pss_rsae_sha256"}, + config: config, + expectedPeerCerts: []string{clientCertificatePEM}, + } + runServerTestTLS12(t, test) + runServerTestTLS13(t, test) + + test = &serverTest{ + name: "ClientAuthRequestedAndECDSAGiven", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-ciphersuites", "TLS_AES_128_GCM_SHA256", + "-cert", ecdsaCertPath, "-key", ecdsaKeyPath}, + config: config, + expectedPeerCerts: []string{clientECDSACertificatePEM}, + } + runServerTestTLS12(t, test) + runServerTestTLS13(t, test) + + test = &serverTest{ + name: "ClientAuthRequestedAndEd25519Given", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-ciphersuites", "TLS_AES_128_GCM_SHA256", + "-cert", ed25519CertPath, "-key", ed25519KeyPath}, + config: config, + expectedPeerCerts: []string{clientEd25519CertificatePEM}, + } + runServerTestTLS12(t, test) + runServerTestTLS13(t, test) + + test = &serverTest{ + name: "ClientAuthRequestedAndPKCS1v15Given", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", + "-cert", certPath, "-key", keyPath, "-client_sigalgs", "rsa_pkcs1_sha256"}, + config: config, + expectedPeerCerts: []string{clientCertificatePEM}, + } + runServerTestTLS12(t, test) +} + +func TestSNIGivenOnFailure(t *testing.T) { + const expectedServerName = "test.testing" + + clientHello := &clientHelloMsg{ + vers: VersionTLS10, + random: make([]byte, 32), + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + compressionMethods: []uint8{compressionNone}, + serverName: expectedServerName, + } + + serverConfig := testConfig.Clone() + // Erase the server's cipher suites to ensure the handshake fails. + serverConfig.CipherSuites = nil + + c, s := localPipe(t) + go func() { + cli := Client(c, testConfig) + cli.vers = clientHello.vers + cli.writeRecord(recordTypeHandshake, clientHello.marshal()) + c.Close() + }() + conn := Server(s, serverConfig) + ch, err := conn.readClientHello() + hs := serverHandshakeState{ + c: conn, + clientHello: ch, + } + if err == nil { + err = hs.processClientHello() + } + if err == nil { + err = hs.pickCipherSuite() + } + defer s.Close() + + if err == nil { + t.Error("No error reported from server") + } + + cs := hs.c.ConnectionState() + if cs.HandshakeComplete { + t.Error("Handshake registered as complete") + } + + if cs.ServerName != expectedServerName { + t.Errorf("Expected ServerName of %q, but got %q", expectedServerName, cs.ServerName) + } +} + +var getConfigForClientTests = []struct { + setup func(config *Config) + callback func(clientHello *ClientHelloInfo) (*Config, error) + errorSubstring string + verify func(config *Config) error +}{ + { + nil, + func(clientHello *ClientHelloInfo) (*Config, error) { + return nil, nil + }, + "", + nil, + }, + { + nil, + func(clientHello *ClientHelloInfo) (*Config, error) { + return nil, errors.New("should bubble up") + }, + "should bubble up", + nil, + }, + { + nil, + func(clientHello *ClientHelloInfo) (*Config, error) { + config := testConfig.Clone() + // Setting a maximum version of TLS 1.1 should cause + // the handshake to fail, as the client MinVersion is TLS 1.2. + config.MaxVersion = VersionTLS11 + return config, nil + }, + "client offered only unsupported versions", + nil, + }, + { + func(config *Config) { + for i := range config.SessionTicketKey { + config.SessionTicketKey[i] = byte(i) + } + config.sessionTicketKeys = nil + }, + func(clientHello *ClientHelloInfo) (*Config, error) { + config := testConfig.Clone() + for i := range config.SessionTicketKey { + config.SessionTicketKey[i] = 0 + } + config.sessionTicketKeys = nil + return config, nil + }, + "", + func(config *Config) error { + if config.SessionTicketKey == [32]byte{} { + return fmt.Errorf("expected SessionTicketKey to be set") + } + return nil + }, + }, + { + func(config *Config) { + var dummyKey [32]byte + for i := range dummyKey { + dummyKey[i] = byte(i) + } + + config.SetSessionTicketKeys([][32]byte{dummyKey}) + }, + func(clientHello *ClientHelloInfo) (*Config, error) { + config := testConfig.Clone() + config.sessionTicketKeys = nil + return config, nil + }, + "", + func(config *Config) error { + if config.SessionTicketKey == [32]byte{} { + return fmt.Errorf("expected SessionTicketKey to be set") + } + return nil + }, + }, +} + +func TestGetConfigForClient(t *testing.T) { + serverConfig := testConfig.Clone() + clientConfig := testConfig.Clone() + clientConfig.MinVersion = VersionTLS12 + + for i, test := range getConfigForClientTests { + if test.setup != nil { + test.setup(serverConfig) + } + + var configReturned *Config + serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) { + config, err := test.callback(clientHello) + configReturned = config + return config, err + } + c, s := localPipe(t) + done := make(chan error) + + go func() { + defer s.Close() + done <- Server(s, serverConfig).Handshake() + }() + + clientErr := Client(c, clientConfig).Handshake() + c.Close() + + serverErr := <-done + + if len(test.errorSubstring) == 0 { + if serverErr != nil || clientErr != nil { + t.Errorf("test[%d]: expected no error but got serverErr: %q, clientErr: %q", i, serverErr, clientErr) + } + if test.verify != nil { + if err := test.verify(configReturned); err != nil { + t.Errorf("test[%d]: verify returned error: %v", i, err) + } + } + } else { + if serverErr == nil { + t.Errorf("test[%d]: expected error containing %q but got no error", i, test.errorSubstring) + } else if !strings.Contains(serverErr.Error(), test.errorSubstring) { + t.Errorf("test[%d]: expected error to contain %q but it was %q", i, test.errorSubstring, serverErr) + } + } + } +} + +func TestCloseServerConnectionOnIdleClient(t *testing.T) { + clientConn, serverConn := localPipe(t) + server := Server(serverConn, testConfig.Clone()) + go func() { + clientConn.Write([]byte{'0'}) + server.Close() + }() + server.SetReadDeadline(time.Now().Add(time.Minute)) + err := server.Handshake() + if err != nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + t.Errorf("Expected a closed network connection error but got '%s'", err.Error()) + } + } else { + t.Errorf("Error expected, but no error returned") + } +} + +func TestCloneHash(t *testing.T) { + h1 := crypto.SHA256.New() + h1.Write([]byte("test")) + s1 := h1.Sum(nil) + h2 := cloneHash(h1, crypto.SHA256) + s2 := h2.Sum(nil) + if !bytes.Equal(s1, s2) { + t.Error("cloned hash generated a different sum") + } +} + +func expectError(t *testing.T, err error, sub string) { + if err == nil { + t.Errorf(`expected error %q, got nil`, sub) + } else if !strings.Contains(err.Error(), sub) { + t.Errorf(`expected error %q, got %q`, sub, err) + } +} + +func TestKeyTooSmallForRSAPSS(t *testing.T) { + cert, err := X509KeyPair([]byte(`-----BEGIN CERTIFICATE----- +MIIBcTCCARugAwIBAgIQGjQnkCFlUqaFlt6ixyz/tDANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMB4XDTE5MDExODIzMjMyOFoXDTIwMDExODIzMjMy +OFowEjEQMA4GA1UEChMHQWNtZSBDbzBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQDd +ez1rFUDwax2HTxbcnFUP9AhcgEGMHVV2nn4VVEWFJB6I8C/Nkx0XyyQlrmFYBzEQ +nIPhKls4T0hFoLvjJnXpAgMBAAGjTTBLMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE +DDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBYGA1UdEQQPMA2CC2V4YW1wbGUu +Y29tMA0GCSqGSIb3DQEBCwUAA0EAxDuUS+BrrS3c+h+k+fQPOmOScy6yTX9mHw0Q +KbucGamXYEy0URIwOdO0tQ3LHPc1YGvYSPwkDjkjqECs2Vm/AA== +-----END CERTIFICATE-----`), []byte(testingKey(`-----BEGIN RSA TESTING KEY----- +MIIBOgIBAAJBAN17PWsVQPBrHYdPFtycVQ/0CFyAQYwdVXaefhVURYUkHojwL82T +HRfLJCWuYVgHMRCcg+EqWzhPSEWgu+MmdekCAwEAAQJBALjQYNTdXF4CFBbXwUz/ +yt9QFDYT9B5WT/12jeGAe653gtYS6OOi/+eAkGmzg1GlRnw6fOfn+HYNFDORST7z +4j0CIQDn2xz9hVWQEu9ee3vecNT3f60huDGTNoRhtqgweQGX0wIhAPSLj1VcRZEz +nKpbtU22+PbIMSJ+e80fmY9LIPx5N4HTAiAthGSimMR9bloz0EY3GyuUEyqoDgMd +hXxjuno2WesoJQIgemilbcALXpxsLmZLgcQ2KSmaVr7jb5ECx9R+hYKTw1sCIG4s +T+E0J8wlH24pgwQHzy7Ko2qLwn1b5PW8ecrlvP1g +-----END RSA TESTING KEY-----`))) + if err != nil { + t.Fatal(err) + } + + clientConn, serverConn := localPipe(t) + client := Client(clientConn, testConfig) + done := make(chan struct{}) + go func() { + config := testConfig.Clone() + config.Certificates = []Certificate{cert} + config.MinVersion = VersionTLS13 + server := Server(serverConn, config) + err := server.Handshake() + expectError(t, err, "key size too small") + close(done) + }() + err = client.Handshake() + expectError(t, err, "handshake failure") + <-done +} + +func TestMultipleCertificates(t *testing.T) { + clientConfig := testConfig.Clone() + clientConfig.CipherSuites = []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256} + clientConfig.MaxVersion = VersionTLS12 + + serverConfig := testConfig.Clone() + serverConfig.Certificates = []Certificate{{ + Certificate: [][]byte{testECDSACertificate}, + PrivateKey: testECDSAPrivateKey, + }, { + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + }} + + _, clientState, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatal(err) + } + if got := clientState.PeerCertificates[0].PublicKeyAlgorithm; got != x509.RSA { + t.Errorf("expected RSA certificate, got %v", got) + } +} + +func TestAESCipherReordering(t *testing.T) { + currentAESSupport := hasAESGCMHardwareSupport + defer func() { hasAESGCMHardwareSupport = currentAESSupport; initDefaultCipherSuites() }() + + tests := []struct { + name string + clientCiphers []uint16 + serverHasAESGCM bool + preferServerCipherSuites bool + serverCiphers []uint16 + expectedCipher uint16 + }{ + { + name: "server has hardware AES, client doesn't (pick ChaCha)", + clientCiphers: []uint16{ + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA, + }, + serverHasAESGCM: true, + preferServerCipherSuites: true, + expectedCipher: TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + }, + { + name: "server strongly prefers AES-GCM, client doesn't (pick AES-GCM)", + clientCiphers: []uint16{ + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA, + }, + serverHasAESGCM: true, + preferServerCipherSuites: true, + serverCiphers: []uint16{ + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_RSA_WITH_AES_128_CBC_SHA, + }, + expectedCipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, + { + name: "client prefers AES-GCM, server doesn't have hardware AES (pick ChaCha)", + clientCiphers: []uint16{ + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_RSA_WITH_AES_128_CBC_SHA, + }, + serverHasAESGCM: false, + expectedCipher: TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + }, + { + name: "client prefers AES-GCM, server has hardware AES (pick AES-GCM)", + clientCiphers: []uint16{ + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_RSA_WITH_AES_128_CBC_SHA, + }, + serverHasAESGCM: true, + expectedCipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, + { + name: "client prefers AES-GCM and sends GREASE, server has hardware AES (pick AES-GCM)", + clientCiphers: []uint16{ + 0x0A0A, // GREASE value + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_RSA_WITH_AES_128_CBC_SHA, + }, + serverHasAESGCM: true, + expectedCipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, + { + name: "client prefers AES-GCM and doesn't support ChaCha, server doesn't have hardware AES (pick AES-GCM)", + clientCiphers: []uint16{ + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA, + }, + serverHasAESGCM: false, + expectedCipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, + { + name: "client prefers AES-GCM and AES-CBC over ChaCha, server doesn't have hardware AES (pick AES-GCM)", + clientCiphers: []uint16{ + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + }, + serverHasAESGCM: false, + expectedCipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, + { + name: "client prefers AES-GCM over ChaCha and sends GREASE, server doesn't have hardware AES (pick ChaCha)", + clientCiphers: []uint16{ + 0x0A0A, // GREASE value + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_RSA_WITH_AES_128_CBC_SHA, + }, + serverHasAESGCM: false, + expectedCipher: TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + }, + { + name: "client supports multiple AES-GCM, server doesn't have hardware AES and doesn't support ChaCha (pick corrent AES-GCM)", + clientCiphers: []uint16{ + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, + serverHasAESGCM: false, + serverCiphers: []uint16{ + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, + expectedCipher: TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + hasAESGCMHardwareSupport = tc.serverHasAESGCM + initDefaultCipherSuites() + hs := &serverHandshakeState{ + c: &Conn{ + config: &Config{ + PreferServerCipherSuites: tc.preferServerCipherSuites, + CipherSuites: tc.serverCiphers, + }, + vers: VersionTLS12, + }, + clientHello: &clientHelloMsg{ + cipherSuites: tc.clientCiphers, + vers: VersionTLS12, + }, + ecdheOk: true, + rsaSignOk: true, + rsaDecryptOk: true, + } + + err := hs.pickCipherSuite() + if err != nil { + t.Errorf("pickCipherSuite failed: %s", err) + } + + if tc.expectedCipher != hs.suite.id { + t.Errorf("unexpected cipher chosen: want %d, got %d", tc.expectedCipher, hs.suite.id) + } + }) + } +} + +func TestAESCipherReordering13(t *testing.T) { + currentAESSupport := hasAESGCMHardwareSupport + defer func() { hasAESGCMHardwareSupport = currentAESSupport; initDefaultCipherSuites() }() + + tests := []struct { + name string + clientCiphers []uint16 + serverHasAESGCM bool + preferServerCipherSuites bool + expectedCipher uint16 + }{ + { + name: "server has hardware AES, client doesn't (pick ChaCha)", + clientCiphers: []uint16{ + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + }, + serverHasAESGCM: true, + preferServerCipherSuites: true, + expectedCipher: TLS_CHACHA20_POLY1305_SHA256, + }, + { + name: "neither server nor client have hardware AES (pick ChaCha)", + clientCiphers: []uint16{ + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + }, + serverHasAESGCM: false, + preferServerCipherSuites: true, + expectedCipher: TLS_CHACHA20_POLY1305_SHA256, + }, + { + name: "client prefers AES, server doesn't have hardware, prefer server ciphers (pick ChaCha)", + clientCiphers: []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_CHACHA20_POLY1305_SHA256, + }, + serverHasAESGCM: false, + preferServerCipherSuites: true, + expectedCipher: TLS_CHACHA20_POLY1305_SHA256, + }, + { + name: "client prefers AES and sends GREASE, server doesn't have hardware, prefer server ciphers (pick ChaCha)", + clientCiphers: []uint16{ + 0x0A0A, // GREASE value + TLS_AES_128_GCM_SHA256, + TLS_CHACHA20_POLY1305_SHA256, + }, + serverHasAESGCM: false, + preferServerCipherSuites: true, + expectedCipher: TLS_CHACHA20_POLY1305_SHA256, + }, + { + name: "client prefers AES, server doesn't (pick ChaCha)", + clientCiphers: []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_CHACHA20_POLY1305_SHA256, + }, + serverHasAESGCM: false, + expectedCipher: TLS_CHACHA20_POLY1305_SHA256, + }, + { + name: "client prefers AES, server has hardware AES (pick AES)", + clientCiphers: []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_CHACHA20_POLY1305_SHA256, + }, + serverHasAESGCM: true, + expectedCipher: TLS_AES_128_GCM_SHA256, + }, + { + name: "client prefers AES and sends GREASE, server has hardware AES (pick AES)", + clientCiphers: []uint16{ + 0x0A0A, // GREASE value + TLS_AES_128_GCM_SHA256, + TLS_CHACHA20_POLY1305_SHA256, + }, + serverHasAESGCM: true, + expectedCipher: TLS_AES_128_GCM_SHA256, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + hasAESGCMHardwareSupport = tc.serverHasAESGCM + initDefaultCipherSuites() + hs := &serverHandshakeStateTLS13{ + c: &Conn{ + config: &Config{ + PreferServerCipherSuites: tc.preferServerCipherSuites, + }, + vers: VersionTLS13, + }, + clientHello: &clientHelloMsg{ + cipherSuites: tc.clientCiphers, + supportedVersions: []uint16{VersionTLS13}, + compressionMethods: []uint8{compressionNone}, + keyShares: []keyShare{{group: X25519, data: curve25519.Basepoint}}, + }, + } + + err := hs.processClientHello() + if err != nil { + t.Errorf("pickCipherSuite failed: %s", err) + } + + if tc.expectedCipher != hs.suite.id { + t.Errorf("unexpected cipher chosen: want %d, got %d", tc.expectedCipher, hs.suite.id) + } + }) + } +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_server_tls13.go b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_server_tls13.go new file mode 100644 index 0000000..bf93c8a --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_server_tls13.go @@ -0,0 +1,971 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto" + "crypto/hmac" + "crypto/rsa" + "errors" + "hash" + "io" + "sync/atomic" + "time" +) + +// maxClientPSKIdentities is the number of client PSK identities the server will +// attempt to validate. It will ignore the rest not to let cheap ClientHello +// messages cause too much work in session ticket decryption attempts. +const maxClientPSKIdentities = 5 + +type serverHandshakeStateTLS13 struct { + c *Conn + clientHello *clientHelloMsg + hello *serverHelloMsg + sentDummyCCS bool + usingPSK bool + suite *cipherSuiteTLS13 + cert *Certificate + sigAlg SignatureScheme + earlySecret []byte + sharedKey []byte + handshakeSecret []byte + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 + transcript hash.Hash + clientFinished []byte + + err error +} + +func (hs *serverHandshakeStateTLS13) handshake() error { + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshake13HandshakeDone { + return nil + } + if hs.err != nil && hs.err != errDataNotEnough { + return hs.err + } + + // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. + if err := hs.processClientHello(); err != nil { + hs.err = err + return err + } + + if err := hs.checkForResumption(); err != nil { + hs.err = err + return err + } + + if err := hs.pickCertificate(); err != nil { + hs.err = err + return err + } + c.buffering = true + if err := hs.sendServerParameters(); err != nil { + hs.err = err + return err + } + if err := hs.sendServerCertificate(); err != nil { + hs.err = err + return err + } + if err := hs.sendServerFinished(); err != nil { + hs.err = err + return err + } + // Note that at this point we could start sending application data without + // waiting for the client's second flight, but the application might not + // expect the lack of replay protection of the ClientHello parameters. + if _, err := c.flush(); err != nil { + hs.err = err + return err + } + if err := hs.readClientCertificate(); err != nil { + hs.err = err + return err + } + if err := hs.readClientFinished(); err != nil { + hs.err = err + return err + } + + c.handshakeStatusAsync = stateServerHandshake13HandshakeDone + atomic.StoreUint32(&c.handshakeStatus, 1) + return nil +} + +func (hs *serverHandshakeStateTLS13) processClientHello() error { + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshake13ProcessClientHello { + return nil + } + c.handshakeStatusAsync = stateServerHandshake13ProcessClientHello + + hs.hello = new(serverHelloMsg) + + // TLS 1.3 froze the ServerHello.legacy_version field, and uses + // supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1. + hs.hello.vers = VersionTLS12 + hs.hello.supportedVersion = c.vers + + if len(hs.clientHello.supportedVersions) == 0 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client used the legacy version field to negotiate TLS 1.3") + } + + // Abort if the client is doing a fallback and landing lower than what we + // support. See RFC 7507, which however does not specify the interaction + // with supported_versions. The only difference is that with + // supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4] + // handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case, + // it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to + // TLS 1.2, because a TLS 1.3 server would abort here. The situation before + // supported_versions was not better because there was just no way to do a + // TLS 1.4 handshake without risking the server selecting TLS 1.3. + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // Use c.vers instead of max(supported_versions) because an attacker + // could defeat this by adding an arbitrary high version otherwise. + if c.vers < c.config.maxSupportedVersion() { + c.sendAlert(alertInappropriateFallback) + return errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + if len(hs.clientHello.compressionMethods) != 1 || + hs.clientHello.compressionMethods[0] != compressionNone { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: TLS 1.3 client supports illegal compression methods") + } + + hs.hello.random = make([]byte, 32) + if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil { + c.sendAlert(alertInternalError) + return err + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + if hs.clientHello.earlyData { + // See RFC 8446, Section 4.2.10 for the complicated behavior required + // here. The scenario is that a different server at our address offered + // to accept early data in the past, which we can't handle. For now, all + // 0-RTT enabled session tickets need to expire before a Go server can + // replace a server or join a pool. That's the same requirement that + // applies to mixing or replacing with any TLS 1.2 server. + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: client sent unexpected early data") + } + + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.compressionMethod = compressionNone + + var preferenceList, supportedList []uint16 + if c.config.PreferServerCipherSuites { + preferenceList = defaultCipherSuitesTLS13() + supportedList = hs.clientHello.cipherSuites + + // If the client does not seem to have hardware support for AES-GCM, + // prefer other AEAD ciphers even if we prioritized AES-GCM ciphers + // by default. + if !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceList = deprioritizeAES(preferenceList) + } + } else { + preferenceList = hs.clientHello.cipherSuites + supportedList = defaultCipherSuitesTLS13() + + // If we don't have hardware support for AES-GCM, prefer other AEAD + // ciphers even if the client prioritized AES-GCM. + if !hasAESGCMHardwareSupport { + preferenceList = deprioritizeAES(preferenceList) + } + } + for _, suiteID := range preferenceList { + hs.suite = mutualCipherSuiteTLS13(supportedList, suiteID) + if hs.suite != nil { + break + } + } + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no cipher suite supported by both client and server") + } + c.cipherSuite = hs.suite.id + hs.hello.cipherSuite = hs.suite.id + hs.transcript = hs.suite.hash.New() + + // Pick the ECDHE group in server preference order, but give priority to + // groups with a key share, to avoid a HelloRetryRequest round-trip. + var selectedGroup CurveID + var clientKeyShare *keyShare +GroupSelection: + for _, preferredGroup := range c.config.curvePreferences() { + for _, ks := range hs.clientHello.keyShares { + if ks.group == preferredGroup { + selectedGroup = ks.group + clientKeyShare = &ks + break GroupSelection + } + } + if selectedGroup != 0 { + continue + } + for _, group := range hs.clientHello.supportedCurves { + if group == preferredGroup { + selectedGroup = group + break + } + } + } + if selectedGroup == 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no ECDHE curve supported by both client and server") + } + if clientKeyShare == nil { + if err := hs.doHelloRetryRequest(selectedGroup); err != nil { + return err + } + clientKeyShare = &hs.clientHello.keyShares[0] + } + + if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), selectedGroup) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} + hs.sharedKey = params.SharedKey(clientKeyShare.data) + if hs.sharedKey == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid client key share") + } + + c.serverName = hs.clientHello.serverName + + return nil +} + +func (hs *serverHandshakeStateTLS13) checkForResumption() error { + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshake13CheckForResumption { + return nil + } + c.handshakeStatusAsync = stateServerHandshake13CheckForResumption + + if c.config.SessionTicketsDisabled { + return nil + } + + modeOK := false + for _, mode := range hs.clientHello.pskModes { + if mode == pskModeDHE { + modeOK = true + break + } + } + if !modeOK { + return nil + } + + if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid or missing PSK binders") + } + if len(hs.clientHello.pskIdentities) == 0 { + return nil + } + + for i, identity := range hs.clientHello.pskIdentities { + if i >= maxClientPSKIdentities { + break + } + + plaintext, _ := c.decryptTicket(identity.label) + if plaintext == nil { + continue + } + sessionState := new(sessionStateTLS13) + if ok := sessionState.unmarshal(plaintext); !ok { + continue + } + + createdAt := time.Unix(int64(sessionState.createdAt), 0) + if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { + continue + } + + // We don't check the obfuscated ticket age because it's affected by + // clock skew and it's only a freshness signal useful for shrinking the + // window for replay attacks, which don't affect us as we don't do 0-RTT. + + pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite) + if pskSuite == nil || pskSuite.hash != hs.suite.hash { + continue + } + + // PSK connections don't re-establish client certificates, but carry + // them over in the session ticket. Ensure the presence of client certs + // in the ticket is consistent with the configured requirements. + sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0 + needClientCerts := requiresClientCert(c.config.ClientAuth) + if needClientCerts && !sessionHasClientCerts { + continue + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + continue + } + + psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption", + nil, hs.suite.hash.Size()) + hs.earlySecret = hs.suite.extract(psk, nil) + binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) + // Clone the transcript in case a HelloRetryRequest was recorded. + transcript := cloneHash(hs.transcript, hs.suite.hash) + if transcript == nil { + c.sendAlert(alertInternalError) + return errors.New("tls: internal error: failed to clone hash") + } + transcript.Write(hs.clientHello.marshalWithoutBinders()) + pskBinder := hs.suite.finishedHash(binderKey, transcript) + if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid PSK binder") + } + + c.didResume = true + if err := c.processCertsFromClient(sessionState.certificate); err != nil { + return err + } + + hs.hello.selectedIdentityPresent = true + hs.hello.selectedIdentity = uint16(i) + hs.usingPSK = true + + return nil + } + + return nil +} + +// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler +// interfaces implemented by standard library hashes to clone the state of in +// to a new instance of h. It returns nil if the operation fails. +func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash { + // Recreate the interface to avoid importing encoding. + type binaryMarshaler interface { + MarshalBinary() (data []byte, err error) + UnmarshalBinary(data []byte) error + } + marshaler, ok := in.(binaryMarshaler) + if !ok { + return nil + } + state, err := marshaler.MarshalBinary() + if err != nil { + return nil + } + out := h.New() + unmarshaler, ok := out.(binaryMarshaler) + if !ok { + return nil + } + if err := unmarshaler.UnmarshalBinary(state); err != nil { + return nil + } + return out +} + +func (hs *serverHandshakeStateTLS13) pickCertificate() error { + c := hs.c + if c.handshakeStatusAsync >= stateServerHandshake13PickCertificate { + return nil + } + c.handshakeStatusAsync = stateServerHandshake13PickCertificate + + // Only one of PSK and certificates are used at a time. + if hs.usingPSK { + return nil + } + + // signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3. + if len(hs.clientHello.supportedSignatureAlgorithms) == 0 { + return c.sendAlert(alertMissingExtension) + } + + certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) + if err != nil { + if err == errNoCertificates { + c.sendAlert(alertUnrecognizedName) + } else { + c.sendAlert(alertInternalError) + } + return err + } + hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms) + if err != nil { + // getCertificate returned a certificate that is unsupported or + // incompatible with the client's signature algorithms. + c.sendAlert(alertHandshakeFailure) + return err + } + hs.cert = certificate + + return nil +} + +// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility +// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. +func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.sentDummyCCS { + return nil + } + hs.sentDummyCCS = true + + _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return err +} + +func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { + c := hs.c + + // The first ClientHello gets double-hashed into the transcript upon a + // HelloRetryRequest. See RFC 8446, Section 4.4.1. + hs.transcript.Write(hs.clientHello.marshal()) + chHash := hs.transcript.Sum(nil) + hs.transcript.Reset() + hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + hs.transcript.Write(chHash) + + helloRetryRequest := &serverHelloMsg{ + vers: hs.hello.vers, + random: helloRetryRequestRandom, + sessionId: hs.hello.sessionId, + cipherSuite: hs.hello.cipherSuite, + compressionMethod: hs.hello.compressionMethod, + supportedVersion: hs.hello.supportedVersion, + selectedGroup: selectedGroup, + } + + hs.transcript.Write(helloRetryRequest.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { + return err + } + + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + clientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientHello, msg) + } + + if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client sent invalid key share in second ClientHello") + } + + if clientHello.earlyData { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client indicated early data in second ClientHello") + } + + if illegalClientHelloChange(clientHello, hs.clientHello) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client illegally modified second ClientHello") + } + + hs.clientHello = clientHello + return nil +} + +// illegalClientHelloChange reports whether the two ClientHello messages are +// different, with the exception of the changes allowed before and after a +// HelloRetryRequest. See RFC 8446, Section 4.1.2. +func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { + if len(ch.supportedVersions) != len(ch1.supportedVersions) || + len(ch.cipherSuites) != len(ch1.cipherSuites) || + len(ch.supportedCurves) != len(ch1.supportedCurves) || + len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) || + len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) || + len(ch.alpnProtocols) != len(ch1.alpnProtocols) { + return true + } + for i := range ch.supportedVersions { + if ch.supportedVersions[i] != ch1.supportedVersions[i] { + return true + } + } + for i := range ch.cipherSuites { + if ch.cipherSuites[i] != ch1.cipherSuites[i] { + return true + } + } + for i := range ch.supportedCurves { + if ch.supportedCurves[i] != ch1.supportedCurves[i] { + return true + } + } + for i := range ch.supportedSignatureAlgorithms { + if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] { + return true + } + } + for i := range ch.supportedSignatureAlgorithmsCert { + if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] { + return true + } + } + for i := range ch.alpnProtocols { + if ch.alpnProtocols[i] != ch1.alpnProtocols[i] { + return true + } + } + return ch.vers != ch1.vers || + !bytes.Equal(ch.random, ch1.random) || + !bytes.Equal(ch.sessionId, ch1.sessionId) || + !bytes.Equal(ch.compressionMethods, ch1.compressionMethods) || + ch.serverName != ch1.serverName || + ch.ocspStapling != ch1.ocspStapling || + !bytes.Equal(ch.supportedPoints, ch1.supportedPoints) || + ch.ticketSupported != ch1.ticketSupported || + !bytes.Equal(ch.sessionTicket, ch1.sessionTicket) || + ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported || + !bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) || + ch.scts != ch1.scts || + !bytes.Equal(ch.cookie, ch1.cookie) || + !bytes.Equal(ch.pskModes, ch1.pskModes) +} + +func (hs *serverHandshakeStateTLS13) sendServerParameters() error { + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshake13SendServerParameters { + return nil + } + c.handshakeStatusAsync = stateServerHandshake13SendServerParameters + + hs.transcript.Write(hs.clientHello.marshal()) + hs.transcript.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + + earlySecret := hs.earlySecret + if earlySecret == nil { + earlySecret = hs.suite.extract(nil, nil) + } + hs.handshakeSecret = hs.suite.extract(hs.sharedKey, + hs.suite.deriveSecret(earlySecret, "derived", nil)) + + clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, + clientHandshakeTrafficLabel, hs.transcript) + c.in.setTrafficSecret(hs.suite, clientSecret) + serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, + serverHandshakeTrafficLabel, hs.transcript) + c.out.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + encryptedExtensions := new(encryptedExtensionsMsg) + + if len(hs.clientHello.alpnProtocols) > 0 { + if selectedProto := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); selectedProto != "" { + encryptedExtensions.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + } + } + + hs.transcript.Write(encryptedExtensions.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, encryptedExtensions.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) requestClientCert() bool { + return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK +} + +func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { + c := hs.c + + if c.handshakeStatusAsync >= stateServerHandshake13SendServerCertificate { + return nil + } + c.handshakeStatusAsync = stateServerHandshake13SendServerCertificate + + // Only one of PSK and certificates are used at a time. + if hs.usingPSK { + return nil + } + + if hs.requestClientCert() { + // Request a client certificate + certReq := new(certificateRequestMsgTLS13) + certReq.ocspStapling = true + certReq.scts = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + + hs.transcript.Write(certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + certMsg := new(certificateMsgTLS13) + + certMsg.certificate = *hs.cert + certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 + certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 + + hs.transcript.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + certVerifyMsg := new(certificateVerifyMsg) + certVerifyMsg.hasSignatureAlgorithm = true + certVerifyMsg.signatureAlgorithm = hs.sigAlg + + sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg) + if err != nil { + return c.sendAlert(alertInternalError) + } + + signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + if err != nil { + public := hs.cert.PrivateKey.(crypto.Signer).Public() + if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS && + rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS + c.sendAlert(alertHandshakeFailure) + } else { + c.sendAlert(alertInternalError) + } + return errors.New("tls: failed to sign handshake: " + err.Error()) + } + certVerifyMsg.signature = sig + + hs.transcript.Write(certVerifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) sendServerFinished() error { + c := hs.c + if c.handshakeStatusAsync >= stateServerHandshake13SendServerFinished { + return nil + } + c.handshakeStatusAsync = stateServerHandshake13SendServerFinished + + finished := &finishedMsg{ + verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), + } + + hs.transcript.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + // Derive secrets that take context through the server Finished. + + hs.masterSecret = hs.suite.extract(nil, + hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil)) + + hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, + clientApplicationTrafficLabel, hs.transcript) + serverSecret := hs.suite.deriveSecret(hs.masterSecret, + serverApplicationTrafficLabel, hs.transcript) + c.out.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) + + // If we did not request client certificates, at this point we can + // precompute the client finished and roll the transcript forward to send + // session tickets in our first flight. + if !hs.requestClientCert() { + if err := hs.sendSessionTickets(); err != nil { + return err + } + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool { + if hs.c.config.SessionTicketsDisabled { + return false + } + + // Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9. + for _, pskMode := range hs.clientHello.pskModes { + if pskMode == pskModeDHE { + return true + } + } + return false +} + +func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { + c := hs.c + + hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) + finishedMsg := &finishedMsg{ + verifyData: hs.clientFinished, + } + hs.transcript.Write(finishedMsg.marshal()) + + if !hs.shouldSendSessionTickets() { + return nil + } + + resumptionSecret := hs.suite.deriveSecret(hs.masterSecret, + resumptionLabel, hs.transcript) + + m := new(newSessionTicketMsgTLS13) + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionStateTLS13{ + cipherSuite: hs.suite.id, + createdAt: uint64(c.config.time().Unix()), + resumptionSecret: resumptionSecret, + certificate: Certificate{ + Certificate: certsFromClient, + OCSPStaple: c.ocspResponse, + SignedCertificateTimestamps: c.scts, + }, + } + var err error + m.label, err = c.encryptTicket(state.marshal()) + if err != nil { + return err + } + m.lifetime = uint32(maxSessionTicketLifetime / time.Second) + + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) readClientCertificate() error { + c := hs.c + if c.handshakeStatusAsync >= stateServerHandshake13ReadClientCertificate { + return nil + } + + if !hs.requestClientCert() { + // Make sure the connection is still being verified whether or not + // the server requested a client certificate. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + return nil + } + + // If we requested a client certificate, then the client must send a + // certificate message. If it's empty, no CertificateVerify is sent. + + if c.certMsg == nil { + msg, err := c.readHandshake() + if err != nil { + if err != errDataNotEnough { + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + } + return err + } + + certMsg, ok := msg.(*certificateMsgTLS13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + return unexpectedMessageError(certMsg, msg) + } + c.certMsg = certMsg + + hs.transcript.Write(certMsg.marshal()) + + if err := c.processCertsFromClient(certMsg.certificate); err != nil { + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + return err + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + return err + } + } + } + + i := 0 + if len(c.certMsg.certificate.Certificate) != 0 { + if len(c.certMsgVerified) == 0 { + c.certMsgVerified = make([]bool, len(c.certMsg.certificate.Certificate)) + } + for ; i < len(c.certMsg.certificate.Certificate); i++ { + if c.certMsgVerified[i] { + } else { + msg, err := c.readHandshake() + if err != nil { + if err != errDataNotEnough { + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + } + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + return unexpectedMessageError(certVerify, msg) + } + + // See RFC 8446, Section 4.4.3. + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + return errors.New("tls: client certificate used with invalid signature algorithm") + } + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + return c.sendAlert(alertInternalError) + } + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + c.sendAlert(alertIllegalParameter) + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + return errors.New("tls: client certificate used with invalid signature algorithm") + } + signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, + sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } + + hs.transcript.Write(certVerify.marshal()) + + c.certMsgVerified[i] = true + } + } + } + + // If we waited until the client certificates to send session tickets, we + // are ready to do it now. + + if i == len(c.certMsg.certificate.Certificate) { + if err := hs.sendSessionTickets(); err != nil { + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + return err + } + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + return nil + } + + c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate + + return nil +} + +func (hs *serverHandshakeStateTLS13) readClientFinished() error { + c := hs.c + if c.handshakeStatusAsync >= stateServerHandshake13ReadClientFinished { + return nil + } + + msg, err := c.readHandshake() + if err != nil { + if err != errDataNotEnough { + c.handshakeStatusAsync = stateServerHandshake13ReadClientFinished + } + return err + } + + finished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + c.handshakeStatusAsync = stateServerHandshake13ReadClientFinished + return unexpectedMessageError(finished, msg) + } + + if !hmac.Equal(hs.clientFinished, finished.verifyData) { + c.sendAlert(alertDecryptError) + c.handshakeStatusAsync = stateServerHandshake13ReadClientFinished + return errors.New("tls: invalid client finished hash") + } + + c.in.setTrafficSecret(hs.suite, hs.trafficSecret) + c.handshakeStatusAsync = stateServerHandshake13ReadClientFinished + return nil +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_test.go b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_test.go new file mode 100644 index 0000000..d9ff9fe --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_test.go @@ -0,0 +1,535 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bufio" + "crypto/ed25519" + "crypto/x509" + "encoding/hex" + "errors" + "flag" + "fmt" + "io" + "net" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +// TLS reference tests run a connection against a reference implementation +// (OpenSSL) of TLS and record the bytes of the resulting connection. The Go +// code, during a test, is configured with deterministic randomness and so the +// reference test can be reproduced exactly in the future. +// +// In order to save everyone who wishes to run the tests from needing the +// reference implementation installed, the reference connections are saved in +// files in the testdata directory. Thus running the tests involves nothing +// external, but creating and updating them requires the reference +// implementation. +// +// Tests can be updated by running them with the -update flag. This will cause +// the test files for failing tests to be regenerated. Since the reference +// implementation will always generate fresh random numbers, large parts of the +// reference connection will always change. + +var ( + update = flag.Bool("update", false, "update golden files on failure") + fast = flag.Bool("fast", false, "impose a quick, possibly flaky timeout on recorded tests") + keyFile = flag.String("keylog", "", "destination file for KeyLogWriter") +) + +func runTestAndUpdateIfNeeded(t *testing.T, name string, run func(t *testing.T, update bool), wait bool) { + success := t.Run(name, func(t *testing.T) { + if !*update && !wait { + t.Parallel() + } + run(t, false) + }) + + if !success && *update { + t.Run(name+"#update", func(t *testing.T) { + run(t, true) + }) + } +} + +// checkOpenSSLVersion ensures that the version of OpenSSL looks reasonable +// before updating the test data. +func checkOpenSSLVersion() error { + if !*update { + return nil + } + + openssl := exec.Command("openssl", "version") + output, err := openssl.CombinedOutput() + if err != nil { + return err + } + + version := string(output) + if strings.HasPrefix(version, "OpenSSL 1.1.1") { + return nil + } + + println("***********************************************") + println("") + println("You need to build OpenSSL 1.1.1 from source in order") + println("to update the test data.") + println("") + println("Configure it with:") + println("./Configure enable-weak-ssl-ciphers no-shared") + println("and then add the apps/ directory at the front of your PATH.") + println("***********************************************") + + return errors.New("version of OpenSSL does not appear to be suitable for updating test data") +} + +// recordingConn is a net.Conn that records the traffic that passes through it. +// WriteTo can be used to produce output that can be later be loaded with +// ParseTestData. +type recordingConn struct { + net.Conn + sync.Mutex + flows [][]byte + reading bool +} + +func (r *recordingConn) Read(b []byte) (n int, err error) { + if n, err = r.Conn.Read(b); n == 0 { + return + } + b = b[:n] + + r.Lock() + defer r.Unlock() + + if l := len(r.flows); l == 0 || !r.reading { + buf := make([]byte, len(b)) + copy(buf, b) + r.flows = append(r.flows, buf) + } else { + r.flows[l-1] = append(r.flows[l-1], b[:n]...) + } + r.reading = true + return +} + +func (r *recordingConn) Write(b []byte) (n int, err error) { + if n, err = r.Conn.Write(b); n == 0 { + return + } + b = b[:n] + + r.Lock() + defer r.Unlock() + + if l := len(r.flows); l == 0 || r.reading { + buf := make([]byte, len(b)) + copy(buf, b) + r.flows = append(r.flows, buf) + } else { + r.flows[l-1] = append(r.flows[l-1], b[:n]...) + } + r.reading = false + return +} + +// WriteTo writes Go source code to w that contains the recorded traffic. +func (r *recordingConn) WriteTo(w io.Writer) (int64, error) { + // TLS always starts with a client to server flow. + clientToServer := true + var written int64 + for i, flow := range r.flows { + source, dest := "client", "server" + if !clientToServer { + source, dest = dest, source + } + n, err := fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest) + written += int64(n) + if err != nil { + return written, err + } + dumper := hex.Dumper(w) + n, err = dumper.Write(flow) + written += int64(n) + if err != nil { + return written, err + } + err = dumper.Close() + if err != nil { + return written, err + } + clientToServer = !clientToServer + } + return written, nil +} + +func parseTestData(r io.Reader) (flows [][]byte, err error) { + var currentFlow []byte + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + // If the line starts with ">>> " then it marks the beginning + // of a new flow. + if strings.HasPrefix(line, ">>> ") { + if len(currentFlow) > 0 || len(flows) > 0 { + flows = append(flows, currentFlow) + currentFlow = nil + } + continue + } + + // Otherwise the line is a line of hex dump that looks like: + // 00000170 fc f5 06 bf (...) |.....X{&?......!| + // (Some bytes have been omitted from the middle section.) + + if i := strings.IndexByte(line, ' '); i >= 0 { + line = line[i:] + } else { + return nil, errors.New("invalid test data") + } + + if i := strings.IndexByte(line, '|'); i >= 0 { + line = line[:i] + } else { + return nil, errors.New("invalid test data") + } + + hexBytes := strings.Fields(line) + for _, hexByte := range hexBytes { + val, err := strconv.ParseUint(hexByte, 16, 8) + if err != nil { + return nil, errors.New("invalid hex byte in test data: " + err.Error()) + } + currentFlow = append(currentFlow, byte(val)) + } + } + + if len(currentFlow) > 0 { + flows = append(flows, currentFlow) + } + + return flows, nil +} + +// tempFile creates a temp file containing contents and returns its path. +func tempFile(contents string) string { + file, err := os.CreateTemp("", "go-tls-test") + if err != nil { + panic("failed to create temp file: " + err.Error()) + } + path := file.Name() + file.WriteString(contents) + file.Close() + return path +} + +// localListener is set up by TestMain and used by localPipe to create Conn +// pairs like net.Pipe, but connected by an actual buffered TCP connection. +var localListener struct { + mu sync.Mutex + addr net.Addr + ch chan net.Conn +} + +const localFlakes = 0 // change to 1 or 2 to exercise localServer/localPipe handling of mismatches + +func localServer(l net.Listener) { + for n := 0; ; n++ { + c, err := l.Accept() + if err != nil { + return + } + if localFlakes == 1 && n%2 == 0 { + c.Close() + continue + } + localListener.ch <- c + } +} + +var isConnRefused = func(err error) bool { return false } + +func localPipe(t testing.TB) (net.Conn, net.Conn) { + localListener.mu.Lock() + defer localListener.mu.Unlock() + + addr := localListener.addr + + var err error +Dialing: + // We expect a rare mismatch, but probably not 5 in a row. + for i := 0; i < 5; i++ { + tooSlow := time.NewTimer(1 * time.Second) + defer tooSlow.Stop() + var c1 net.Conn + c1, err = net.Dial(addr.Network(), addr.String()) + if err != nil { + if runtime.GOOS == "dragonfly" && (isConnRefused(err) || os.IsTimeout(err)) { + // golang.org/issue/29583: Dragonfly sometimes returns a spurious + // ECONNREFUSED or ETIMEDOUT. + <-tooSlow.C + continue + } + t.Fatalf("localPipe: %v", err) + } + if localFlakes == 2 && i == 0 { + c1.Close() + continue + } + for { + select { + case <-tooSlow.C: + t.Logf("localPipe: timeout waiting for %v", c1.LocalAddr()) + c1.Close() + continue Dialing + + case c2 := <-localListener.ch: + if c2.RemoteAddr().String() == c1.LocalAddr().String() { + return c1, c2 + } + t.Logf("localPipe: unexpected connection: %v != %v", c2.RemoteAddr(), c1.LocalAddr()) + c2.Close() + } + } + } + + t.Fatalf("localPipe: failed to connect: %v", err) + panic("unreachable") +} + +// zeroSource is an io.Reader that returns an unlimited number of zero bytes. +type zeroSource struct{} + +func (zeroSource) Read(b []byte) (n int, err error) { + for i := range b { + b[i] = 0 + } + + return len(b), nil +} + +func allCipherSuites() []uint16 { + ids := make([]uint16, len(cipherSuites)) + for i, suite := range cipherSuites { + ids[i] = suite.id + } + + return ids +} + +var testConfig *Config + +func TestMain(m *testing.M) { + flag.Parse() + os.Exit(runMain(m)) +} + +func runMain(m *testing.M) int { + // TLS 1.3 cipher suites preferences are not configurable and change based + // on the architecture. Force them to the version with AES acceleration for + // test consistency. + once.Do(initDefaultCipherSuites) + varDefaultCipherSuitesTLS13 = []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_256_GCM_SHA384, + } + + // Set up localPipe. + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + l, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to open local listener: %v", err) + os.Exit(1) + } + localListener.ch = make(chan net.Conn) + localListener.addr = l.Addr() + defer l.Close() + go localServer(l) + + if err := checkOpenSSLVersion(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v", err) + os.Exit(1) + } + + testConfig = &Config{ + Time: func() time.Time { return time.Unix(0, 0) }, + Rand: zeroSource{}, + Certificates: make([]Certificate, 2), + InsecureSkipVerify: true, + CipherSuites: allCipherSuites(), + } + testConfig.Certificates[0].Certificate = [][]byte{testRSACertificate} + testConfig.Certificates[0].PrivateKey = testRSAPrivateKey + testConfig.Certificates[1].Certificate = [][]byte{testSNICertificate} + testConfig.Certificates[1].PrivateKey = testRSAPrivateKey + testConfig.BuildNameToCertificate() + if *keyFile != "" { + f, err := os.OpenFile(*keyFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + panic("failed to open -keylog file: " + err.Error()) + } + testConfig.KeyLogWriter = f + defer f.Close() + } + + return m.Run() +} + +func testHandshake(t *testing.T, clientConfig, serverConfig *Config) (serverState, clientState ConnectionState, err error) { + const sentinel = "SENTINEL\n" + c, s := localPipe(t) + errChan := make(chan error) + go func() { + cli := Client(c, clientConfig) + err := cli.Handshake() + if err != nil { + errChan <- fmt.Errorf("client: %v", err) + c.Close() + return + } + defer cli.Close() + clientState = cli.ConnectionState() + buf, err := io.ReadAll(cli) + if err != nil { + t.Errorf("failed to call cli.Read: %v", err) + } + if got := string(buf); got != sentinel { + t.Errorf("read %q from TLS connection, but expected %q", got, sentinel) + } + errChan <- nil + }() + server := Server(s, serverConfig) + err = server.Handshake() + if err == nil { + serverState = server.ConnectionState() + if _, err := io.WriteString(server, sentinel); err != nil { + t.Errorf("failed to call server.Write: %v", err) + } + if err := server.Close(); err != nil { + t.Errorf("failed to call server.Close: %v", err) + } + err = <-errChan + } else { + s.Close() + <-errChan + } + return +} + +func fromHex(s string) []byte { + b, _ := hex.DecodeString(s) + return b +} + +var testRSACertificate = fromHex("3082024b308201b4a003020102020900e8f09d3fe25beaa6300d06092a864886f70d01010b0500301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f74301e170d3136303130313030303030305a170d3235303130313030303030305a301a310b3009060355040a1302476f310b300906035504031302476f30819f300d06092a864886f70d010101050003818d0030818902818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d70203010001a38193308190300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030106082b06010505070302300c0603551d130101ff0402300030190603551d0e041204109f91161f43433e49a6de6db680d79f60301b0603551d230414301280104813494d137e1631bba301d5acab6e7b30190603551d1104123010820e6578616d706c652e676f6c616e67300d06092a864886f70d01010b0500038181009d30cc402b5b50a061cbbae55358e1ed8328a9581aa938a495a1ac315a1a84663d43d32dd90bf297dfd320643892243a00bccf9c7db74020015faad3166109a276fd13c3cce10c5ceeb18782f16c04ed73bbb343778d0c1cf10fa1d8408361c94c722b9daedb4606064df4c1b33ec0d1bd42d4dbfe3d1360845c21d33be9fae7") + +var testRSACertificateIssuer = fromHex("3082021930820182a003020102020900ca5e4e811a965964300d06092a864886f70d01010b0500301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f74301e170d3136303130313030303030305a170d3235303130313030303030305a301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f7430819f300d06092a864886f70d010101050003818d0030818902818100d667b378bb22f34143b6cd2008236abefaf2852adf3ab05e01329e2c14834f5105df3f3073f99dab5442d45ee5f8f57b0111c8cb682fbb719a86944eebfffef3406206d898b8c1b1887797c9c5006547bb8f00e694b7a063f10839f269f2c34fff7a1f4b21fbcd6bfdfb13ac792d1d11f277b5c5b48600992203059f2a8f8cc50203010001a35d305b300e0603551d0f0101ff040403020204301d0603551d250416301406082b0601050507030106082b06010505070302300f0603551d130101ff040530030101ff30190603551d0e041204104813494d137e1631bba301d5acab6e7b300d06092a864886f70d01010b050003818100c1154b4bab5266221f293766ae4138899bd4c5e36b13cee670ceeaa4cbdf4f6679017e2fe649765af545749fe4249418a56bd38a04b81e261f5ce86b8d5c65413156a50d12449554748c59a30c515bc36a59d38bddf51173e899820b282e40aa78c806526fd184fb6b4cf186ec728edffa585440d2b3225325f7ab580e87dd76") + +// testRSAPSSCertificate has signatureAlgorithm rsassaPss, but subjectPublicKeyInfo +// algorithm rsaEncryption, for use with the rsa_pss_rsae_* SignatureSchemes. +// See also TestRSAPSSKeyError. testRSAPSSCertificate is self-signed. +var testRSAPSSCertificate = fromHex("308202583082018da003020102021100f29926eb87ea8a0db9fcc247347c11b0304106092a864886f70d01010a3034a00f300d06096086480165030402010500a11c301a06092a864886f70d010108300d06096086480165030402010500a20302012030123110300e060355040a130741636d6520436f301e170d3137313132333136313631305a170d3138313132333136313631305a30123110300e060355040a130741636d6520436f30819f300d06092a864886f70d010101050003818d0030818902818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d70203010001a3463044300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000300f0603551d110408300687047f000001304106092a864886f70d01010a3034a00f300d06096086480165030402010500a11c301a06092a864886f70d010108300d06096086480165030402010500a20302012003818100cdac4ef2ce5f8d79881042707f7cbf1b5a8a00ef19154b40151771006cd41626e5496d56da0c1a139fd84695593cb67f87765e18aa03ea067522dd78d2a589b8c92364e12838ce346c6e067b51f1a7e6f4b37ffab13f1411896679d18e880e0ba09e302ac067efca460288e9538122692297ad8093d4f7dd701424d7700a46a1") + +var testECDSACertificate = fromHex("3082020030820162020900b8bf2d47a0d2ebf4300906072a8648ce3d04013045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464301e170d3132313132323135303633325a170d3232313132303135303633325a3045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746430819b301006072a8648ce3d020106052b81040023038186000400c4a1edbe98f90b4873367ec316561122f23d53c33b4d213dcd6b75e6f6b0dc9adf26c1bcb287f072327cb3642f1c90bcea6823107efee325c0483a69e0286dd33700ef0462dd0da09c706283d881d36431aa9e9731bd96b068c09b23de76643f1a5c7fe9120e5858b65f70dd9bd8ead5d7f5d5ccb9b69f30665b669a20e227e5bffe3b300906072a8648ce3d040103818c0030818802420188a24febe245c5487d1bacf5ed989dae4770c05e1bb62fbdf1b64db76140d311a2ceee0b7e927eff769dc33b7ea53fcefa10e259ec472d7cacda4e970e15a06fd00242014dfcbe67139c2d050ebd3fa38c25c13313830d9406bbd4377af6ec7ac9862eddd711697f857c56defb31782be4c7780daecbbe9e4e3624317b6a0f399512078f2a") + +var testEd25519Certificate = fromHex("3082012e3081e1a00302010202100f431c425793941de987e4f1ad15005d300506032b657030123110300e060355040a130741636d6520436f301e170d3139303531363231333830315a170d3230303531353231333830315a30123110300e060355040a130741636d6520436f302a300506032b65700321003fe2152ee6e3ef3f4e854a7577a3649eede0bf842ccc92268ffa6f3483aaec8fa34d304b300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030160603551d11040f300d820b6578616d706c652e636f6d300506032b65700341006344ed9cc4be5324539fd2108d9fe82108909539e50dc155ff2c16b71dfcab7d4dd4e09313d0a942e0b66bfe5d6748d79f50bc6ccd4b03837cf20858cdaccf0c") + +var testSNICertificate = fromHex("0441883421114c81480804c430820237308201a0a003020102020900e8f09d3fe25beaa6300d06092a864886f70d01010b0500301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f74301e170d3136303130313030303030305a170d3235303130313030303030305a3023310b3009060355040a1302476f311430120603550403130b736e69746573742e636f6d30819f300d06092a864886f70d010101050003818d0030818902818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d70203010001a3773075300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030106082b06010505070302300c0603551d130101ff0402300030190603551d0e041204109f91161f43433e49a6de6db680d79f60301b0603551d230414301280104813494d137e1631bba301d5acab6e7b300d06092a864886f70d01010b0500038181007beeecff0230dbb2e7a334af65430b7116e09f327c3bbf918107fc9c66cb497493207ae9b4dbb045cb63d605ec1b5dd485bb69124d68fa298dc776699b47632fd6d73cab57042acb26f083c4087459bc5a3bb3ca4d878d7fe31016b7bc9a627438666566e3389bfaeebe6becc9a0093ceed18d0f9ac79d56f3a73f18188988ed") + +var testP256Certificate = fromHex("308201693082010ea00302010202105012dc24e1124ade4f3e153326ff27bf300a06082a8648ce3d04030230123110300e060355040a130741636d6520436f301e170d3137303533313232343934375a170d3138303533313232343934375a30123110300e060355040a130741636d6520436f3059301306072a8648ce3d020106082a8648ce3d03010703420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75a3463044300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000300f0603551d1104083006820474657374300a06082a8648ce3d0403020349003046022100963712d6226c7b2bef41512d47e1434131aaca3ba585d666c924df71ac0448b3022100f4d05c725064741aef125f243cdbccaa2a5d485927831f221c43023bd5ae471a") + +var testRSAPrivateKey, _ = x509.ParsePKCS1PrivateKey(fromHex("3082025b02010002818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d702030100010281800b07fbcf48b50f1388db34b016298b8217f2092a7c9a04f77db6775a3d1279b62ee9951f7e371e9de33f015aea80660760b3951dc589a9f925ed7de13e8f520e1ccbc7498ce78e7fab6d59582c2386cc07ed688212a576ff37833bd5943483b5554d15a0b9b4010ed9bf09f207e7e9805f649240ed6c1256ed75ab7cd56d9671024100fded810da442775f5923debae4ac758390a032a16598d62f059bb2e781a9c2f41bfa015c209f966513fe3bf5a58717cbdb385100de914f88d649b7d15309fa49024100dd10978c623463a1802c52f012cfa72ff5d901f25a2292446552c2568b1840e49a312e127217c2186615aae4fb6602a4f6ebf3f3d160f3b3ad04c592f65ae41f02400c69062ca781841a09de41ed7a6d9f54adc5d693a2c6847949d9e1358555c9ac6a8d9e71653ac77beb2d3abaf7bb1183aa14278956575dbebf525d0482fd72d90240560fe1900ba36dae3022115fd952f2399fb28e2975a1c3e3d0b679660bdcb356cc189d611cfdd6d87cd5aea45aa30a2082e8b51e94c2f3dd5d5c6036a8a615ed0240143993d80ece56f877cb80048335701eb0e608cc0c1ca8c2227b52edf8f1ac99c562f2541b5ce81f0515af1c5b4770dba53383964b4b725ff46fdec3d08907df")) + +var testECDSAPrivateKey, _ = x509.ParseECPrivateKey(fromHex("3081dc0201010442019883e909ad0ac9ea3d33f9eae661f1785206970f8ca9a91672f1eedca7a8ef12bd6561bb246dda5df4b4d5e7e3a92649bc5d83a0bf92972e00e62067d0c7bd99d7a00706052b81040023a18189038186000400c4a1edbe98f90b4873367ec316561122f23d53c33b4d213dcd6b75e6f6b0dc9adf26c1bcb287f072327cb3642f1c90bcea6823107efee325c0483a69e0286dd33700ef0462dd0da09c706283d881d36431aa9e9731bd96b068c09b23de76643f1a5c7fe9120e5858b65f70dd9bd8ead5d7f5d5ccb9b69f30665b669a20e227e5bffe3b")) + +var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75")) + +var testEd25519PrivateKey = ed25519.PrivateKey(fromHex("3a884965e76b3f55e5faf9615458a92354894234de3ec9f684d46d55cebf3dc63fe2152ee6e3ef3f4e854a7577a3649eede0bf842ccc92268ffa6f3483aaec8f")) + +const clientCertificatePEM = ` +-----BEGIN CERTIFICATE----- +MIIB7zCCAVigAwIBAgIQXBnBiWWDVW/cC8m5k5/pvDANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMB4XDTE2MDgxNzIxNTIzMVoXDTE3MDgxNzIxNTIz +MVowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkC +gYEAum+qhr3Pv5/y71yUYHhv6BPy0ZZvzdkybiI3zkH5yl0prOEn2mGi7oHLEMff +NFiVhuk9GeZcJ3NgyI14AvQdpJgJoxlwaTwlYmYqqyIjxXuFOE8uCXMyp70+m63K +hAfmDzr/d8WdQYUAirab7rCkPy1MTOZCPrtRyN1IVPQMjkcCAwEAAaNGMEQwDgYD +VR0PAQH/BAQDAgWgMBMGA1UdJQQMMAoGCCsGAQUFBwMBMAwGA1UdEwEB/wQCMAAw +DwYDVR0RBAgwBocEfwAAATANBgkqhkiG9w0BAQsFAAOBgQBGq0Si+yhU+Fpn+GKU +8ZqyGJ7ysd4dfm92lam6512oFmyc9wnTN+RLKzZ8Aa1B0jLYw9KT+RBrjpW5LBeK +o0RIvFkTgxYEiKSBXCUNmAysEbEoVr4dzWFihAm/1oDGRY2CLLTYg5vbySK3KhIR +e/oCO8HJ/+rJnahJ05XX1Q7lNQ== +-----END CERTIFICATE-----` + +var clientKeyPEM = testingKey(` +-----BEGIN RSA TESTING KEY----- +MIICXQIBAAKBgQC6b6qGvc+/n/LvXJRgeG/oE/LRlm/N2TJuIjfOQfnKXSms4Sfa +YaLugcsQx980WJWG6T0Z5lwnc2DIjXgC9B2kmAmjGXBpPCViZiqrIiPFe4U4Ty4J +czKnvT6brcqEB+YPOv93xZ1BhQCKtpvusKQ/LUxM5kI+u1HI3UhU9AyORwIDAQAB +AoGAEJZ03q4uuMb7b26WSQsOMeDsftdatT747LGgs3pNRkMJvTb/O7/qJjxoG+Mc +qeSj0TAZXp+PXXc3ikCECAc+R8rVMfWdmp903XgO/qYtmZGCorxAHEmR80SrfMXv +PJnznLQWc8U9nphQErR+tTESg7xWEzmFcPKwnZd1xg8ERYkCQQDTGtrFczlB2b/Z +9TjNMqUlMnTLIk/a/rPE2fLLmAYhK5sHnJdvDURaH2mF4nso0EGtENnTsh6LATnY +dkrxXGm9AkEA4hXHG2q3MnhgK1Z5hjv+Fnqd+8bcbII9WW4flFs15EKoMgS1w/PJ +zbsySaSy5IVS8XeShmT9+3lrleed4sy+UwJBAJOOAbxhfXP5r4+5R6ql66jES75w +jUCVJzJA5ORJrn8g64u2eGK28z/LFQbv9wXgCwfc72R468BdawFSLa/m2EECQGbZ +rWiFla26IVXV0xcD98VWJsTBZMlgPnSOqoMdM1kSEd4fUmlAYI/dFzV1XYSkOmVr +FhdZnklmpVDeu27P4c0CQQCuCOup0FlJSBpWY1TTfun/KMBkBatMz0VMA3d7FKIU +csPezl677Yjo8u1r/KzeI6zLg87Z8E6r6ZWNc9wBSZK6 +-----END RSA TESTING KEY-----`) + +const clientECDSACertificatePEM = ` +-----BEGIN CERTIFICATE----- +MIIB/DCCAV4CCQCaMIRsJjXZFzAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw +EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0 +eSBMdGQwHhcNMTIxMTE0MTMyNTUzWhcNMjIxMTEyMTMyNTUzWjBBMQswCQYDVQQG +EwJBVTEMMAoGA1UECBMDTlNXMRAwDgYDVQQHEwdQeXJtb250MRIwEAYDVQQDEwlK +b2VsIFNpbmcwgZswEAYHKoZIzj0CAQYFK4EEACMDgYYABACVjJF1FMBexFe01MNv +ja5oHt1vzobhfm6ySD6B5U7ixohLZNz1MLvT/2XMW/TdtWo+PtAd3kfDdq0Z9kUs +jLzYHQFMH3CQRnZIi4+DzEpcj0B22uCJ7B0rxE4wdihBsmKo+1vx+U56jb0JuK7q +ixgnTy5w/hOWusPTQBbNZU6sER7m8TAJBgcqhkjOPQQBA4GMADCBiAJCAOAUxGBg +C3JosDJdYUoCdFzCgbkWqD8pyDbHgf9stlvZcPE4O1BIKJTLCRpS8V3ujfK58PDa +2RU6+b0DeoeiIzXsAkIBo9SKeDUcSpoj0gq+KxAxnZxfvuiRs9oa9V2jI/Umi0Vw +jWVim34BmT0Y9hCaOGGbLlfk+syxis7iI6CH8OFnUes= +-----END CERTIFICATE-----` + +var clientECDSAKeyPEM = testingKey(` +-----BEGIN EC PARAMETERS----- +BgUrgQQAIw== +-----END EC PARAMETERS----- +-----BEGIN EC TESTING KEY----- +MIHcAgEBBEIBkJN9X4IqZIguiEVKMqeBUP5xtRsEv4HJEtOpOGLELwO53SD78Ew8 +k+wLWoqizS3NpQyMtrU8JFdWfj+C57UNkOugBwYFK4EEACOhgYkDgYYABACVjJF1 +FMBexFe01MNvja5oHt1vzobhfm6ySD6B5U7ixohLZNz1MLvT/2XMW/TdtWo+PtAd +3kfDdq0Z9kUsjLzYHQFMH3CQRnZIi4+DzEpcj0B22uCJ7B0rxE4wdihBsmKo+1vx ++U56jb0JuK7qixgnTy5w/hOWusPTQBbNZU6sER7m8Q== +-----END EC TESTING KEY-----`) + +const clientEd25519CertificatePEM = ` +-----BEGIN CERTIFICATE----- +MIIBLjCB4aADAgECAhAX0YGTviqMISAQJRXoNCNPMAUGAytlcDASMRAwDgYDVQQK +EwdBY21lIENvMB4XDTE5MDUxNjIxNTQyNloXDTIwMDUxNTIxNTQyNlowEjEQMA4G +A1UEChMHQWNtZSBDbzAqMAUGAytlcAMhAAvgtWC14nkwPb7jHuBQsQTIbcd4bGkv +xRStmmNveRKRo00wSzAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUH +AwIwDAYDVR0TAQH/BAIwADAWBgNVHREEDzANggtleGFtcGxlLmNvbTAFBgMrZXAD +QQD8GRcqlKUx+inILn9boF2KTjRAOdazENwZ/qAicbP1j6FYDc308YUkv+Y9FN/f +7Q7hF9gRomDQijcjKsJGqjoI +-----END CERTIFICATE-----` + +var clientEd25519KeyPEM = testingKey(` +-----BEGIN TESTING KEY----- +MC4CAQAwBQYDK2VwBCIEINifzf07d9qx3d44e0FSbV4mC/xQxT644RRbpgNpin7I +-----END TESTING KEY-----`) diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_unix_test.go b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_unix_test.go new file mode 100644 index 0000000..7271854 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/handshake_unix_test.go @@ -0,0 +1,18 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris + +package tls + +import ( + "errors" + "syscall" +) + +func init() { + isConnRefused = func(err error) bool { + return errors.Is(err, syscall.ECONNREFUSED) + } +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/key_agreement.go b/vendor/github.com/lesismal/llib/std/crypto/tls/key_agreement.go new file mode 100644 index 0000000..7e6534b --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/key_agreement.go @@ -0,0 +1,334 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto" + "crypto/md5" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "errors" + "fmt" + "io" +) + +var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") +var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") + +// rsaKeyAgreement implements the standard TLS key agreement where the client +// encrypts the pre-master secret to the server's public key. +type rsaKeyAgreement struct{} + +func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + return nil, nil +} + +func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) < 2 { + return nil, errClientKeyExchange + } + ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) + if ciphertextLen != len(ckx.ciphertext)-2 { + return nil, errClientKeyExchange + } + ciphertext := ckx.ciphertext[2:] + + priv, ok := cert.PrivateKey.(crypto.Decrypter) + if !ok { + return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") + } + // Perform constant time RSA PKCS #1 v1.5 decryption + preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48}) + if err != nil { + return nil, err + } + // We don't check the version number in the premaster secret. For one, + // by checking it, we would leak information about the validity of the + // encrypted pre-master secret. Secondly, it provides only a small + // benefit against a downgrade attack and some implementations send the + // wrong version anyway. See the discussion at the end of section + // 7.4.7.1 of RFC 4346. + return preMasterSecret, nil +} + +func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + return errors.New("tls: unexpected ServerKeyExchange") +} + +func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + preMasterSecret := make([]byte, 48) + preMasterSecret[0] = byte(clientHello.vers >> 8) + preMasterSecret[1] = byte(clientHello.vers) + _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) + if err != nil { + return nil, nil, err + } + + encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret) + if err != nil { + return nil, nil, err + } + ckx := new(clientKeyExchangeMsg) + ckx.ciphertext = make([]byte, len(encrypted)+2) + ckx.ciphertext[0] = byte(len(encrypted) >> 8) + ckx.ciphertext[1] = byte(len(encrypted)) + copy(ckx.ciphertext[2:], encrypted) + return preMasterSecret, ckx, nil +} + +// sha1Hash calculates a SHA1 hash over the given byte slices. +func sha1Hash(slices [][]byte) []byte { + hsha1 := sha1.New() + for _, slice := range slices { + hsha1.Write(slice) + } + return hsha1.Sum(nil) +} + +// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the +// concatenation of an MD5 and SHA1 hash. +func md5SHA1Hash(slices [][]byte) []byte { + md5sha1 := make([]byte, md5.Size+sha1.Size) + hmd5 := md5.New() + for _, slice := range slices { + hmd5.Write(slice) + } + copy(md5sha1, hmd5.Sum(nil)) + copy(md5sha1[md5.Size:], sha1Hash(slices)) + return md5sha1 +} + +// hashForServerKeyExchange hashes the given slices and returns their digest +// using the given hash function (for >= TLS 1.2) or using a default based on +// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't +// do pre-hashing, it returns the concatenation of the slices. +func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { + if sigType == signatureEd25519 { + var signed []byte + for _, slice := range slices { + signed = append(signed, slice...) + } + return signed + } + if version >= VersionTLS12 { + h := hashFunc.New() + for _, slice := range slices { + h.Write(slice) + } + digest := h.Sum(nil) + return digest + } + if sigType == signatureECDSA { + return sha1Hash(slices) + } + return md5SHA1Hash(slices) +} + +// ecdheKeyAgreement implements a TLS key agreement where the server +// generates an ephemeral EC public/private key pair and signs it. The +// pre-master secret is then calculated using ECDH. The signature may +// be ECDSA, Ed25519 or RSA. +type ecdheKeyAgreement struct { + version uint16 + isRSA bool + params ecdheParameters + + // ckx and preMasterSecret are generated in processServerKeyExchange + // and returned in generateClientKeyExchange. + ckx *clientKeyExchangeMsg + preMasterSecret []byte +} + +func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + var curveID CurveID + for _, c := range clientHello.supportedCurves { + if config.supportsCurve(c) { + curveID = c + break + } + } + + if curveID == 0 { + return nil, errors.New("tls: no supported elliptic curves offered") + } + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, err + } + ka.params = params + + // See RFC 4492, Section 5.4. + ecdhePublic := params.PublicKey() + serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) + serverECDHEParams[0] = 3 // named curve + serverECDHEParams[1] = byte(curveID >> 8) + serverECDHEParams[2] = byte(curveID) + serverECDHEParams[3] = byte(len(ecdhePublic)) + copy(serverECDHEParams[4:], ecdhePublic) + + priv, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey) + } + + var signatureAlgorithm SignatureScheme + var sigType uint8 + var sigHash crypto.Hash + if ka.version >= VersionTLS12 { + signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) + if err != nil { + return nil, err + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return nil, err + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public()) + if err != nil { + return nil, err + } + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") + } + + signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams) + + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := priv.Sign(config.rand(), signed, signOpts) + if err != nil { + return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) + } + + skx := new(serverKeyExchangeMsg) + sigAndHashLen := 0 + if ka.version >= VersionTLS12 { + sigAndHashLen = 2 + } + skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig)) + copy(skx.key, serverECDHEParams) + k := skx.key[len(serverECDHEParams):] + if ka.version >= VersionTLS12 { + k[0] = byte(signatureAlgorithm >> 8) + k[1] = byte(signatureAlgorithm) + k = k[2:] + } + k[0] = byte(len(sig) >> 8) + k[1] = byte(len(sig)) + copy(k[2:], sig) + + return skx, nil +} + +func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { + return nil, errClientKeyExchange + } + + preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:]) + if preMasterSecret == nil { + return nil, errClientKeyExchange + } + + return preMasterSecret, nil +} + +func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + if len(skx.key) < 4 { + return errServerKeyExchange + } + if skx.key[0] != 3 { // named curve + return errors.New("tls: server selected unsupported curve") + } + curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) + + publicLen := int(skx.key[3]) + if publicLen+4 > len(skx.key) { + return errServerKeyExchange + } + serverECDHEParams := skx.key[:4+publicLen] + publicKey := serverECDHEParams[4:] + + sig := skx.key[4+publicLen:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return errors.New("tls: server selected unsupported curve") + } + + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return err + } + ka.params = params + + ka.preMasterSecret = params.SharedKey(publicKey) + if ka.preMasterSecret == nil { + return errServerKeyExchange + } + + ourPublicKey := params.PublicKey() + ka.ckx = new(clientKeyExchangeMsg) + ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey)) + ka.ckx.ciphertext[0] = byte(len(ourPublicKey)) + copy(ka.ckx.ciphertext[1:], ourPublicKey) + + var sigType uint8 + var sigHash crypto.Hash + if ka.version >= VersionTLS12 { + signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) + sig = sig[2:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { + return errors.New("tls: certificate used with invalid signature algorithm") + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return err + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey) + if err != nil { + return err + } + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return errServerKeyExchange + } + + sigLen := int(sig[0])<<8 | int(sig[1]) + if sigLen+2 != len(sig) { + return errServerKeyExchange + } + sig = sig[2:] + + signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams) + if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } + return nil +} + +func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + if ka.ckx == nil { + return nil, nil, errors.New("tls: missing ServerKeyExchange message") + } + + return ka.preMasterSecret, ka.ckx, nil +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/key_schedule.go b/vendor/github.com/lesismal/llib/std/crypto/tls/key_schedule.go new file mode 100644 index 0000000..3140169 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/key_schedule.go @@ -0,0 +1,199 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto/elliptic" + "crypto/hmac" + "errors" + "hash" + "io" + "math/big" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" +) + +// This file contains the functions necessary to compute the TLS 1.3 key +// schedule. See RFC 8446, Section 7. + +const ( + resumptionBinderLabel = "res binder" + clientHandshakeTrafficLabel = "c hs traffic" + serverHandshakeTrafficLabel = "s hs traffic" + clientApplicationTrafficLabel = "c ap traffic" + serverApplicationTrafficLabel = "s ap traffic" + exporterLabel = "exp master" + resumptionLabel = "res master" + trafficUpdateLabel = "traffic upd" +) + +// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { + var hkdfLabel cryptobyte.Builder + hkdfLabel.AddUint16(uint16(length)) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte("tls13 ")) + b.AddBytes([]byte(label)) + }) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(context) + }) + out := make([]byte, length) + n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) + if err != nil || n != length { + panic("tls: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} + +// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { + if transcript == nil { + transcript = c.hash.New() + } + return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size()) +} + +// extract implements HKDF-Extract with the cipher suite hash. +func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte { + if newSecret == nil { + newSecret = make([]byte, c.hash.Size()) + } + return hkdf.Extract(c.hash.New, newSecret, currentSecret) +} + +// nextTrafficSecret generates the next traffic secret, given the current one, +// according to RFC 8446, Section 7.2. +func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte { + return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size()) +} + +// trafficKey generates traffic keys according to RFC 8446, Section 7.3. +func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { + key = c.expandLabel(trafficSecret, "key", nil, c.keyLen) + iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength) + return +} + +// finishedHash generates the Finished verify_data or PskBinderEntry according +// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey +// selection. +func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte { + finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size()) + verifyData := hmac.New(c.hash.New, finishedKey) + verifyData.Write(transcript.Sum(nil)) + return verifyData.Sum(nil) +} + +// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to +// RFC 8446, Section 7.5. +func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { + expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript) + return func(label string, context []byte, length int) ([]byte, error) { + secret := c.deriveSecret(expMasterSecret, label, nil) + h := c.hash.New() + h.Write(context) + return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil + } +} + +// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519, +// according to RFC 8446, Section 4.2.8.2. +type ecdheParameters interface { + CurveID() CurveID + PublicKey() []byte + SharedKey(peerPublicKey []byte) []byte +} + +func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) { + if curveID == X25519 { + privateKey := make([]byte, curve25519.ScalarSize) + if _, err := io.ReadFull(rand, privateKey); err != nil { + return nil, err + } + publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint) + if err != nil { + return nil, err + } + return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil + } + + curve, ok := curveForCurveID(curveID) + if !ok { + return nil, errors.New("tls: internal error: unsupported curve") + } + + p := &nistParameters{curveID: curveID} + var err error + p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand) + if err != nil { + return nil, err + } + return p, nil +} + +func curveForCurveID(id CurveID) (elliptic.Curve, bool) { + switch id { + case CurveP256: + return elliptic.P256(), true + case CurveP384: + return elliptic.P384(), true + case CurveP521: + return elliptic.P521(), true + default: + return nil, false + } +} + +type nistParameters struct { + privateKey []byte + x, y *big.Int // public key + curveID CurveID +} + +func (p *nistParameters) CurveID() CurveID { + return p.curveID +} + +func (p *nistParameters) PublicKey() []byte { + curve, _ := curveForCurveID(p.curveID) + return elliptic.Marshal(curve, p.x, p.y) +} + +func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte { + curve, _ := curveForCurveID(p.curveID) + // Unmarshal also checks whether the given point is on the curve. + x, y := elliptic.Unmarshal(curve, peerPublicKey) + if x == nil { + return nil + } + + xShared, _ := curve.ScalarMult(x, y, p.privateKey) + sharedKey := make([]byte, (curve.Params().BitSize+7)/8) + return xShared.FillBytes(sharedKey) +} + +type x25519Parameters struct { + privateKey []byte + publicKey []byte +} + +func (p *x25519Parameters) CurveID() CurveID { + return X25519 +} + +func (p *x25519Parameters) PublicKey() []byte { + return p.publicKey[:] +} + +func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte { + sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey) + if err != nil { + return nil + } + return sharedKey +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/key_schedule_test.go b/vendor/github.com/lesismal/llib/std/crypto/tls/key_schedule_test.go new file mode 100644 index 0000000..79ff6a6 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/key_schedule_test.go @@ -0,0 +1,175 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "encoding/hex" + "hash" + "strings" + "testing" + "unicode" +) + +// This file contains tests derived from draft-ietf-tls-tls13-vectors-07. + +func parseVector(v string) []byte { + v = strings.Map(func(c rune) rune { + if unicode.IsSpace(c) { + return -1 + } + return c + }, v) + parts := strings.Split(v, ":") + v = parts[len(parts)-1] + res, err := hex.DecodeString(v) + if err != nil { + panic(err) + } + return res +} + +func TestDeriveSecret(t *testing.T) { + chTranscript := cipherSuitesTLS13[0].hash.New() + chTranscript.Write(parseVector(` + payload (512 octets): 01 00 01 fc 03 03 1b c3 ce b6 bb e3 9c ff + 93 83 55 b5 a5 0a db 6d b2 1b 7a 6a f6 49 d7 b4 bc 41 9d 78 76 + 48 7d 95 00 00 06 13 01 13 03 13 02 01 00 01 cd 00 00 00 0b 00 + 09 00 00 06 73 65 72 76 65 72 ff 01 00 01 00 00 0a 00 14 00 12 + 00 1d 00 17 00 18 00 19 01 00 01 01 01 02 01 03 01 04 00 33 00 + 26 00 24 00 1d 00 20 e4 ff b6 8a c0 5f 8d 96 c9 9d a2 66 98 34 + 6c 6b e1 64 82 ba dd da fe 05 1a 66 b4 f1 8d 66 8f 0b 00 2a 00 + 00 00 2b 00 03 02 03 04 00 0d 00 20 00 1e 04 03 05 03 06 03 02 + 03 08 04 08 05 08 06 04 01 05 01 06 01 02 01 04 02 05 02 06 02 + 02 02 00 2d 00 02 01 01 00 1c 00 02 40 01 00 15 00 57 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 29 00 dd 00 b8 00 b2 2c 03 5d 82 93 59 ee 5f f7 af 4e c9 00 + 00 00 00 26 2a 64 94 dc 48 6d 2c 8a 34 cb 33 fa 90 bf 1b 00 70 + ad 3c 49 88 83 c9 36 7c 09 a2 be 78 5a bc 55 cd 22 60 97 a3 a9 + 82 11 72 83 f8 2a 03 a1 43 ef d3 ff 5d d3 6d 64 e8 61 be 7f d6 + 1d 28 27 db 27 9c ce 14 50 77 d4 54 a3 66 4d 4e 6d a4 d2 9e e0 + 37 25 a6 a4 da fc d0 fc 67 d2 ae a7 05 29 51 3e 3d a2 67 7f a5 + 90 6c 5b 3f 7d 8f 92 f2 28 bd a4 0d da 72 14 70 f9 fb f2 97 b5 + ae a6 17 64 6f ac 5c 03 27 2e 97 07 27 c6 21 a7 91 41 ef 5f 7d + e6 50 5e 5b fb c3 88 e9 33 43 69 40 93 93 4a e4 d3 57 fa d6 aa + cb 00 21 20 3a dd 4f b2 d8 fd f8 22 a0 ca 3c f7 67 8e f5 e8 8d + ae 99 01 41 c5 92 4d 57 bb 6f a3 1b 9e 5f 9d`)) + + type args struct { + secret []byte + label string + transcript hash.Hash + } + tests := []struct { + name string + args args + want []byte + }{ + { + `derive secret for handshake "tls13 derived"`, + args{ + parseVector(`PRK (32 octets): 33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c e2 + 10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a`), + "derived", + nil, + }, + parseVector(`expanded (32 octets): 6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba + b6 97 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba`), + }, + { + `derive secret "tls13 c e traffic"`, + args{ + parseVector(`PRK (32 octets): 9b 21 88 e9 b2 fc 6d 64 d7 1d c3 29 90 0e 20 bb + 41 91 50 00 f6 78 aa 83 9c bb 79 7c b7 d8 33 2c`), + "c e traffic", + chTranscript, + }, + parseVector(`expanded (32 octets): 3f bb e6 a6 0d eb 66 c3 0a 32 79 5a ba 0e + ff 7e aa 10 10 55 86 e7 be 5c 09 67 8d 63 b6 ca ab 62`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := cipherSuitesTLS13[0] + if got := c.deriveSecret(tt.args.secret, tt.args.label, tt.args.transcript); !bytes.Equal(got, tt.want) { + t.Errorf("cipherSuiteTLS13.deriveSecret() = % x, want % x", got, tt.want) + } + }) + } +} + +func TestTrafficKey(t *testing.T) { + trafficSecret := parseVector( + `PRK (32 octets): b6 7b 7d 69 0c c1 6c 4e 75 e5 42 13 cb 2d 37 b4 + e9 c9 12 bc de d9 10 5d 42 be fd 59 d3 91 ad 38`) + wantKey := parseVector( + `key expanded (16 octets): 3f ce 51 60 09 c2 17 27 d0 f2 e4 e8 6e + e4 03 bc`) + wantIV := parseVector( + `iv expanded (12 octets): 5d 31 3e b2 67 12 76 ee 13 00 0b 30`) + + c := cipherSuitesTLS13[0] + gotKey, gotIV := c.trafficKey(trafficSecret) + if !bytes.Equal(gotKey, wantKey) { + t.Errorf("cipherSuiteTLS13.trafficKey() gotKey = % x, want % x", gotKey, wantKey) + } + if !bytes.Equal(gotIV, wantIV) { + t.Errorf("cipherSuiteTLS13.trafficKey() gotIV = % x, want % x", gotIV, wantIV) + } +} + +func TestExtract(t *testing.T) { + type args struct { + newSecret []byte + currentSecret []byte + } + tests := []struct { + name string + args args + want []byte + }{ + { + `extract secret "early"`, + args{ + nil, + nil, + }, + parseVector(`secret (32 octets): 33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c + e2 10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a`), + }, + { + `extract secret "master"`, + args{ + nil, + parseVector(`salt (32 octets): 43 de 77 e0 c7 77 13 85 9a 94 4d b9 db 25 90 b5 + 31 90 a6 5b 3e e2 e4 f1 2d d7 a0 bb 7c e2 54 b4`), + }, + parseVector(`secret (32 octets): 18 df 06 84 3d 13 a0 8b f2 a4 49 84 4c 5f 8a + 47 80 01 bc 4d 4c 62 79 84 d5 a4 1d a8 d0 40 29 19`), + }, + { + `extract secret "handshake"`, + args{ + parseVector(`IKM (32 octets): 8b d4 05 4f b5 5b 9d 63 fd fb ac f9 f0 4b 9f 0d + 35 e6 d6 3f 53 75 63 ef d4 62 72 90 0f 89 49 2d`), + parseVector(`salt (32 octets): 6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba b6 97 + 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba`), + }, + parseVector(`secret (32 octets): 1d c8 26 e9 36 06 aa 6f dc 0a ad c1 2f 74 1b + 01 04 6a a6 b9 9f 69 1e d2 21 a9 f0 ca 04 3f be ac`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := cipherSuitesTLS13[0] + if got := c.extract(tt.args.newSecret, tt.args.currentSecret); !bytes.Equal(got, tt.want) { + t.Errorf("cipherSuiteTLS13.extract() = % x, want % x", got, tt.want) + } + }) + } +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/link_test.go b/vendor/github.com/lesismal/llib/std/crypto/tls/link_test.go new file mode 100644 index 0000000..dff5eb5 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/link_test.go @@ -0,0 +1,108 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "os" + "os/exec" + "path/filepath" + "testing" + + "github.com/lesismal/llib/std/internal/testenv" +) + +// Tests that the linker is able to remove references to the Client or Server if unused. +func TestLinkerGC(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + t.Parallel() + goBin := testenv.GoToolPath(t) + testenv.MustHaveGoBuild(t) + + tests := []struct { + name string + program string + want []string + bad []string + }{ + { + name: "empty_import", + program: `package main +import _ "crypto/tls" +func main() {} +`, + bad: []string{ + "tls.(*Conn)", + "type.crypto/tls.clientHandshakeState", + "type.crypto/tls.serverHandshakeState", + }, + }, + { + name: "client_and_server", + program: `package main +import "crypto/tls" +func main() { + tls.Dial("", "", nil) + tls.Server(nil, nil) +} +`, + want: []string{ + "crypto/tls.(*Conn).clientHandshake", + "crypto/tls.(*Conn).serverHandshake", + }, + }, + { + name: "only_client", + program: `package main +import "crypto/tls" +func main() { tls.Dial("", "", nil) } +`, + want: []string{ + "crypto/tls.(*Conn).clientHandshake", + }, + bad: []string{ + "crypto/tls.(*Conn).serverHandshake", + }, + }, + // TODO: add only_server like func main() { tls.Server(nil, nil) } + // That currently brings in the client via Conn.handleRenegotiation. + + } + tmpDir := t.TempDir() + goFile := filepath.Join(tmpDir, "x.go") + exeFile := filepath.Join(tmpDir, "x.exe") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := os.WriteFile(goFile, []byte(tt.program), 0644); err != nil { + t.Fatal(err) + } + os.Remove(exeFile) + cmd := exec.Command(goBin, "build", "-o", "x.exe", "x.go") + cmd.Dir = tmpDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("compile: %v, %s", err, out) + } + + cmd = exec.Command(goBin, "tool", "nm", "x.exe") + cmd.Dir = tmpDir + nm, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("nm: %v, %s", err, nm) + } + for _, sym := range tt.want { + if !bytes.Contains(nm, []byte(sym)) { + t.Errorf("expected symbol %q not found", sym) + } + } + for _, sym := range tt.bad { + if bytes.Contains(nm, []byte(sym)) { + t.Errorf("unexpected symbol %q found", sym) + } + } + }) + } +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/prf.go b/vendor/github.com/lesismal/llib/std/crypto/tls/prf.go new file mode 100644 index 0000000..13bfa00 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/prf.go @@ -0,0 +1,283 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto" + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "hash" +) + +// Split a premaster secret in two as specified in RFC 4346, Section 5. +func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { + s1 = secret[0 : (len(secret)+1)/2] + s2 = secret[len(secret)/2:] + return +} + +// pHash implements the P_hash function, as defined in RFC 4346, Section 5. +func pHash(result, secret, seed []byte, hash func() hash.Hash) { + h := hmac.New(hash, secret) + h.Write(seed) + a := h.Sum(nil) + + j := 0 + for j < len(result) { + h.Reset() + h.Write(a) + h.Write(seed) + b := h.Sum(nil) + copy(result[j:], b) + j += len(b) + + h.Reset() + h.Write(a) + a = h.Sum(nil) + } +} + +// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. +func prf10(result, secret, label, seed []byte) { + hashSHA1 := sha1.New + hashMD5 := md5.New + + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + s1, s2 := splitPreMasterSecret(secret) + pHash(result, s1, labelAndSeed, hashMD5) + result2 := make([]byte, len(result)) + pHash(result2, s2, labelAndSeed, hashSHA1) + + for i, b := range result2 { + result[i] ^= b + } +} + +// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5. +func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { + return func(result, secret, label, seed []byte) { + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + pHash(result, secret, labelAndSeed, hashFunc) + } +} + +const ( + masterSecretLength = 48 // Length of a master secret in TLS 1.1. + finishedVerifyLength = 12 // Length of verify_data in a Finished message. +) + +var masterSecretLabel = []byte("master secret") +var keyExpansionLabel = []byte("key expansion") +var clientFinishedLabel = []byte("client finished") +var serverFinishedLabel = []byte("server finished") + +func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { + switch version { + case VersionTLS10, VersionTLS11: + return prf10, crypto.Hash(0) + case VersionTLS12: + if suite.flags&suiteSHA384 != 0 { + return prf12(sha512.New384), crypto.SHA384 + } + return prf12(sha256.New), crypto.SHA256 + default: + panic("unknown version") + } +} + +func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { + prf, _ := prfAndHashForVersion(version, suite) + return prf +} + +// masterFromPreMasterSecret generates the master secret from the pre-master +// secret. See RFC 5246, Section 8.1. +func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { + seed := make([]byte, 0, len(clientRandom)+len(serverRandom)) + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + masterSecret := make([]byte, masterSecretLength) + prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed) + return masterSecret +} + +// keysFromMasterSecret generates the connection keys from the master +// secret, given the lengths of the MAC key, cipher key and IV, as defined in +// RFC 2246, Section 6.3. +func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { + seed := make([]byte, 0, len(serverRandom)+len(clientRandom)) + seed = append(seed, serverRandom...) + seed = append(seed, clientRandom...) + + n := 2*macLen + 2*keyLen + 2*ivLen + keyMaterial := make([]byte, n) + prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed) + clientMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + serverMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + clientKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + serverKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + clientIV = keyMaterial[:ivLen] + keyMaterial = keyMaterial[ivLen:] + serverIV = keyMaterial[:ivLen] + return +} + +func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { + var buffer []byte + if version >= VersionTLS12 { + buffer = []byte{} + } + + prf, hash := prfAndHashForVersion(version, cipherSuite) + if hash != 0 { + return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf} + } + + return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf} +} + +// A finishedHash calculates the hash of a set of handshake messages suitable +// for including in a Finished message. +type finishedHash struct { + client hash.Hash + server hash.Hash + + // Prior to TLS 1.2, an additional MD5 hash is required. + clientMD5 hash.Hash + serverMD5 hash.Hash + + // In TLS 1.2, a full buffer is sadly required. + buffer []byte + + version uint16 + prf func(result, secret, label, seed []byte) +} + +func (h *finishedHash) Write(msg []byte) (n int, err error) { + h.client.Write(msg) + h.server.Write(msg) + + if h.version < VersionTLS12 { + h.clientMD5.Write(msg) + h.serverMD5.Write(msg) + } + + if h.buffer != nil { + h.buffer = append(h.buffer, msg...) + } + + return len(msg), nil +} + +func (h finishedHash) Sum() []byte { + if h.version >= VersionTLS12 { + return h.client.Sum(nil) + } + + out := make([]byte, 0, md5.Size+sha1.Size) + out = h.clientMD5.Sum(out) + return h.client.Sum(out) +} + +// clientSum returns the contents of the verify_data member of a client's +// Finished message. +func (h finishedHash) clientSum(masterSecret []byte) []byte { + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) + return out +} + +// serverSum returns the contents of the verify_data member of a server's +// Finished message. +func (h finishedHash) serverSum(masterSecret []byte) []byte { + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) + return out +} + +// hashForClientCertificate returns the handshake messages so far, pre-hashed if +// necessary, suitable for signing by a TLS client certificate. +func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte { + if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { + panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") + } + + if sigType == signatureEd25519 { + return h.buffer + } + + if h.version >= VersionTLS12 { + hash := hashAlg.New() + hash.Write(h.buffer) + return hash.Sum(nil) + } + + if sigType == signatureECDSA { + return h.server.Sum(nil) + } + + return h.Sum() +} + +// discardHandshakeBuffer is called when there is no more need to +// buffer the entirety of the handshake messages. +func (h *finishedHash) discardHandshakeBuffer() { + h.buffer = nil +} + +// noExportedKeyingMaterial is used as a value of +// ConnectionState.ekm when renegotiation is enabled and thus +// we wish to fail all key-material export requests. +func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) { + return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled") +} + +// ekmFromMasterSecret generates exported keying material as defined in RFC 5705. +func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) { + return func(label string, context []byte, length int) ([]byte, error) { + switch label { + case "client finished", "server finished", "master secret", "key expansion": + // These values are reserved and may not be used. + return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label) + } + + seedLen := len(serverRandom) + len(clientRandom) + if context != nil { + seedLen += 2 + len(context) + } + seed := make([]byte, 0, seedLen) + + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + if context != nil { + if len(context) >= 1<<16 { + return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long") + } + seed = append(seed, byte(len(context)>>8), byte(len(context))) + seed = append(seed, context...) + } + + keyMaterial := make([]byte, length) + prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed) + return keyMaterial, nil + } +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/prf_test.go b/vendor/github.com/lesismal/llib/std/crypto/tls/prf_test.go new file mode 100644 index 0000000..8233985 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/prf_test.go @@ -0,0 +1,140 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "encoding/hex" + "testing" +) + +type testSplitPreMasterSecretTest struct { + in, out1, out2 string +} + +var testSplitPreMasterSecretTests = []testSplitPreMasterSecretTest{ + {"", "", ""}, + {"00", "00", "00"}, + {"0011", "00", "11"}, + {"001122", "0011", "1122"}, + {"00112233", "0011", "2233"}, +} + +func TestSplitPreMasterSecret(t *testing.T) { + for i, test := range testSplitPreMasterSecretTests { + in, _ := hex.DecodeString(test.in) + out1, out2 := splitPreMasterSecret(in) + s1 := hex.EncodeToString(out1) + s2 := hex.EncodeToString(out2) + if s1 != test.out1 || s2 != test.out2 { + t.Errorf("#%d: got: (%s, %s) want: (%s, %s)", i, s1, s2, test.out1, test.out2) + } + } +} + +type testKeysFromTest struct { + version uint16 + suite *cipherSuite + preMasterSecret string + clientRandom, serverRandom string + masterSecret string + clientMAC, serverMAC string + clientKey, serverKey string + macLen, keyLen int + contextKeyingMaterial, noContextKeyingMaterial string +} + +func TestKeysFromPreMasterSecret(t *testing.T) { + for i, test := range testKeysFromTests { + in, _ := hex.DecodeString(test.preMasterSecret) + clientRandom, _ := hex.DecodeString(test.clientRandom) + serverRandom, _ := hex.DecodeString(test.serverRandom) + + masterSecret := masterFromPreMasterSecret(test.version, test.suite, in, clientRandom, serverRandom) + if s := hex.EncodeToString(masterSecret); s != test.masterSecret { + t.Errorf("#%d: bad master secret %s, want %s", i, s, test.masterSecret) + continue + } + + clientMAC, serverMAC, clientKey, serverKey, _, _ := keysFromMasterSecret(test.version, test.suite, masterSecret, clientRandom, serverRandom, test.macLen, test.keyLen, 0) + clientMACString := hex.EncodeToString(clientMAC) + serverMACString := hex.EncodeToString(serverMAC) + clientKeyString := hex.EncodeToString(clientKey) + serverKeyString := hex.EncodeToString(serverKey) + if clientMACString != test.clientMAC || + serverMACString != test.serverMAC || + clientKeyString != test.clientKey || + serverKeyString != test.serverKey { + t.Errorf("#%d: got: (%s, %s, %s, %s) want: (%s, %s, %s, %s)", i, clientMACString, serverMACString, clientKeyString, serverKeyString, test.clientMAC, test.serverMAC, test.clientKey, test.serverKey) + } + + ekm := ekmFromMasterSecret(test.version, test.suite, masterSecret, clientRandom, serverRandom) + contextKeyingMaterial, err := ekm("label", []byte("context"), 32) + if err != nil { + t.Fatalf("ekmFromMasterSecret failed: %v", err) + } + + noContextKeyingMaterial, err := ekm("label", nil, 32) + if err != nil { + t.Fatalf("ekmFromMasterSecret failed: %v", err) + } + + if hex.EncodeToString(contextKeyingMaterial) != test.contextKeyingMaterial || + hex.EncodeToString(noContextKeyingMaterial) != test.noContextKeyingMaterial { + t.Errorf("#%d: got keying material: (%s, %s) want: (%s, %s)", i, contextKeyingMaterial, noContextKeyingMaterial, test.contextKeyingMaterial, test.noContextKeyingMaterial) + } + } +} + +// These test vectors were generated from GnuTLS using `gnutls-cli --insecure -d 9 ` +var testKeysFromTests = []testKeysFromTest{ + { + VersionTLS10, + cipherSuiteByID(TLS_RSA_WITH_RC4_128_SHA), + "0302cac83ad4b1db3b9ab49ad05957de2a504a634a386fc600889321e1a971f57479466830ac3e6f468e87f5385fa0c5", + "4ae66303755184a3917fcb44880605fcc53baa01912b22ed94473fc69cebd558", + "4ae663020ec16e6bb5130be918cfcafd4d765979a3136a5d50c593446e4e44db", + "3d851bab6e5556e959a16bc36d66cfae32f672bfa9ecdef6096cbb1b23472df1da63dbbd9827606413221d149ed08ceb", + "805aaa19b3d2c0a0759a4b6c9959890e08480119", + "2d22f9fe519c075c16448305ceee209fc24ad109", + "d50b5771244f850cd8117a9ccafe2cf1", + "e076e33206b30507a85c32855acd0919", + 20, + 16, + "4d1bb6fc278c37d27aa6e2a13c2e079095d143272c2aa939da33d88c1c0cec22", + "93fba89599b6321ae538e27c6548ceb8b46821864318f5190d64a375e5d69d41", + }, + { + VersionTLS10, + cipherSuiteByID(TLS_RSA_WITH_RC4_128_SHA), + "03023f7527316bc12cbcd69e4b9e8275d62c028f27e65c745cfcddc7ce01bd3570a111378b63848127f1c36e5f9e4890", + "4ae66364b5ea56b20ce4e25555aed2d7e67f42788dd03f3fee4adae0459ab106", + "4ae66363ab815cbf6a248b87d6b556184e945e9b97fbdf247858b0bdafacfa1c", + "7d64be7c80c59b740200b4b9c26d0baaa1c5ae56705acbcf2307fe62beb4728c19392c83f20483801cce022c77645460", + "97742ed60a0554ca13f04f97ee193177b971e3b0", + "37068751700400e03a8477a5c7eec0813ab9e0dc", + "207cddbc600d2a200abac6502053ee5c", + "df3f94f6e1eacc753b815fe16055cd43", + 20, + 16, + "2c9f8961a72b97cbe76553b5f954caf8294fc6360ef995ac1256fe9516d0ce7f", + "274f19c10291d188857ad8878e2119f5aa437d4da556601cf1337aff23154016", + }, + { + VersionTLS10, + cipherSuiteByID(TLS_RSA_WITH_RC4_128_SHA), + "832d515f1d61eebb2be56ba0ef79879efb9b527504abb386fb4310ed5d0e3b1f220d3bb6b455033a2773e6d8bdf951d278a187482b400d45deb88a5d5a6bb7d6a7a1decc04eb9ef0642876cd4a82d374d3b6ff35f0351dc5d411104de431375355addc39bfb1f6329fb163b0bc298d658338930d07d313cd980a7e3d9196cac1", + "4ae663b2ee389c0de147c509d8f18f5052afc4aaf9699efe8cb05ece883d3a5e", + "4ae664d503fd4cff50cfc1fb8fc606580f87b0fcdac9554ba0e01d785bdf278e", + "1aff2e7a2c4279d0126f57a65a77a8d9d0087cf2733366699bec27eb53d5740705a8574bb1acc2abbe90e44f0dd28d6c", + "3c7647c93c1379a31a609542aa44e7f117a70085", + "0d73102994be74a575a3ead8532590ca32a526d4", + "ac7581b0b6c10d85bbd905ffbf36c65e", + "ff07edde49682b45466bd2e39464b306", + 20, + 16, + "678b0d43f607de35241dc7e9d1a7388a52c35033a1a0336d4d740060a6638fe2", + "f3b4ac743f015ef21d79978297a53da3e579ee047133f38c234d829c0f907dab", + }, +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/ticket.go b/vendor/github.com/lesismal/llib/std/crypto/tls/ticket.go new file mode 100644 index 0000000..6c1d20d --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/ticket.go @@ -0,0 +1,185 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "errors" + "io" + + "golang.org/x/crypto/cryptobyte" +) + +// sessionState contains the information that is serialized into a session +// ticket in order to later resume a connection. +type sessionState struct { + vers uint16 + cipherSuite uint16 + createdAt uint64 + masterSecret []byte // opaque master_secret<1..2^16-1>; + // struct { opaque certificate<1..2^24-1> } Certificate; + certificates [][]byte // Certificate certificate_list<0..2^24-1>; + + // usedOldKey is true if the ticket from which this session came from + // was encrypted with an older key and thus should be refreshed. + usedOldKey bool +} + +func (m *sessionState) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(m.vers) + b.AddUint16(m.cipherSuite) + addUint64(&b, m.createdAt) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.masterSecret) + }) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for _, cert := range m.certificates { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + } + }) + return b.BytesOrPanic() +} + +func (m *sessionState) unmarshal(data []byte) bool { + *m = sessionState{usedOldKey: m.usedOldKey} + s := cryptobyte.String(data) + if ok := s.ReadUint16(&m.vers) && + s.ReadUint16(&m.cipherSuite) && + readUint64(&s, &m.createdAt) && + readUint16LengthPrefixed(&s, &m.masterSecret) && + len(m.masterSecret) != 0; !ok { + return false + } + var certList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&certList) { + return false + } + for !certList.Empty() { + var cert []byte + if !readUint24LengthPrefixed(&certList, &cert) { + return false + } + m.certificates = append(m.certificates, cert) + } + return s.Empty() +} + +// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first +// version (revision = 0) doesn't carry any of the information needed for 0-RTT +// validation and the nonce is always empty. +type sessionStateTLS13 struct { + // uint8 version = 0x0304; + // uint8 revision = 0; + cipherSuite uint16 + createdAt uint64 + resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>; + certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; +} + +func (m *sessionStateTLS13) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(VersionTLS13) + b.AddUint8(0) // revision + b.AddUint16(m.cipherSuite) + addUint64(&b, m.createdAt) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.resumptionSecret) + }) + marshalCertificate(&b, m.certificate) + return b.BytesOrPanic() +} + +func (m *sessionStateTLS13) unmarshal(data []byte) bool { + *m = sessionStateTLS13{} + s := cryptobyte.String(data) + var version uint16 + var revision uint8 + return s.ReadUint16(&version) && + version == VersionTLS13 && + s.ReadUint8(&revision) && + revision == 0 && + s.ReadUint16(&m.cipherSuite) && + readUint64(&s, &m.createdAt) && + readUint8LengthPrefixed(&s, &m.resumptionSecret) && + len(m.resumptionSecret) != 0 && + unmarshalCertificate(&s, &m.certificate) && + s.Empty() +} + +func (c *Conn) encryptTicket(state []byte) ([]byte, error) { + if len(c.ticketKeys) == 0 { + return nil, errors.New("tls: internal error: session ticket keys unavailable") + } + + encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size) + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + + if _, err := io.ReadFull(c.config.rand(), iv); err != nil { + return nil, err + } + key := c.ticketKeys[0] + copy(keyName, key.keyName[:]) + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) + } + cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state) + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + mac.Sum(macBytes[:0]) + + return encrypted, nil +} + +func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) { + if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { + return nil, false + } + + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] + + keyIndex := -1 + for i, candidateKey := range c.ticketKeys { + if bytes.Equal(keyName, candidateKey.keyName[:]) { + keyIndex = i + break + } + } + if keyIndex == -1 { + return nil, false + } + key := &c.ticketKeys[keyIndex] + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + expected := mac.Sum(nil) + + if subtle.ConstantTimeCompare(macBytes, expected) != 1 { + return nil, false + } + + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, false + } + plaintext = make([]byte, len(ciphertext)) + cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) + + return plaintext, keyIndex > 0 +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/tls.go b/vendor/github.com/lesismal/llib/std/crypto/tls/tls.go new file mode 100644 index 0000000..f15ce61 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/tls.go @@ -0,0 +1,430 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tls partially implements TLS 1.2, as specified in RFC 5246, +// and TLS 1.3, as specified in RFC 8446. +package tls + +// BUG(agl): The crypto/tls package only implements some countermeasures +// against Lucky13 attacks on CBC-mode encryption, and only on SHA1 +// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "net" + "os" + "strings" + "time" +) + +// NewConn returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func NewConn(conn net.Conn, config *Config, isClient bool, isNonBlock bool, v ...interface{}) *Conn { + c := &Conn{ + conn: conn, + config: config, + isClient: isClient, + isNonBlock: isNonBlock, + } + c.handshakeFn = c.serverHandshake + if isClient { + c.handshakeFn = c.clientHandshake + } + if len(v) > 0 { + if allocator, ok := v[0].(Allocator); ok { + c.allocator = allocator + } + } + if c.allocator == nil { + c.allocator = &NativeAllocator{} + } + return c +} + +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Server(conn net.Conn, config *Config) *Conn { + c := &Conn{ + conn: conn, + config: config, + } + c.handshakeFn = c.serverHandshake + return c +} + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. +func Client(conn net.Conn, config *Config) *Conn { + c := &Conn{ + conn: conn, + config: config, + isClient: true, + } + c.handshakeFn = c.clientHandshake + return c +} + +// A listener implements a network listener (net.Listener) for TLS connections. +type listener struct { + net.Listener + config *Config + allocator Allocator +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection is of type *Conn. +func (l *listener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + tlsConn := Server(c, l.config) + tlsConn.allocator = l.allocator + return tlsConn, nil +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func NewListener(inner net.Listener, config *Config, v ...interface{}) net.Listener { + l := new(listener) + l.Listener = inner + l.config = config + + if len(v) > 0 { + if allocator, ok := v[0].(Allocator); ok { + l.allocator = allocator + } + } + + return l +} + +// Listen creates a TLS listener accepting connections on the +// given network address using net.Listen. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Listen(network, laddr string, config *Config) (net.Listener, error) { + if config == nil || len(config.Certificates) == 0 && + config.GetCertificate == nil && config.GetConfigForClient == nil { + return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config") + } + l, err := net.Listen(network, laddr) + if err != nil { + return nil, err + } + return NewListener(l, config), nil +} + +type timeoutError struct{} + +func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config, v ...interface{}) (*Conn, error) { + return dial(context.Background(), dialer, network, addr, config, v...) +} + +func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config, v ...interface{}) (*Conn, error) { + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := netDialer.Timeout + + if !netDialer.Deadline.IsZero() { + deadlineTimeout := time.Until(netDialer.Deadline) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + // hsErrCh is non-nil if we might not wait for Handshake to complete. + var hsErrCh chan error + if timeout != 0 || ctx.Done() != nil { + hsErrCh = make(chan error, 2) + } + if timeout != 0 { + timer := time.AfterFunc(timeout, func() { + hsErrCh <- timeoutError{} + }) + defer timer.Stop() + } + + rawConn, err := netDialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + if config == nil { + config = defaultConfig() + } + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + // Make a copy to avoid polluting argument or default. + c := config.Clone() + c.ServerName = hostname + config = c + } + + conn := Client(rawConn, config) + if len(v) > 0 { + if allocator, ok := v[0].(Allocator); ok { + conn.allocator = allocator + } + } + if conn.allocator == nil { + conn.allocator = &NativeAllocator{} + } + if hsErrCh == nil { + err = conn.Handshake() + } else { + go func() { + hsErrCh <- conn.Handshake() + }() + + select { + case <-ctx.Done(): + err = ctx.Err() + case err = <-hsErrCh: + if err != nil { + // If the error was due to the context + // closing, prefer the context's error, rather + // than some random network teardown error. + if e := ctx.Err(); e != nil { + err = e + } + } + } + } + + if err != nil { + rawConn.Close() + return nil, err + } + + return conn, nil +} + +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config, v ...interface{}) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config, v...) +} + +// Dialer dials TLS connections given a configuration and a Dialer for the +// underlying connection. +type Dialer struct { + // NetDialer is the optional dialer to use for the TLS connections' + // underlying TCP connections. + // A nil NetDialer is equivalent to the net.Dialer zero value. + NetDialer *net.Dialer + + // Config is the TLS configuration to use for new connections. + // A nil configuration is equivalent to the zero + // configuration; see the documentation of Config for the + // defaults. + Config *Config +} + +// Dial connects to the given network address and initiates a TLS +// handshake, returning the resulting TLS connection. +// +// The returned Conn, if any, will always be of type *Conn. +func (d *Dialer) Dial(network, addr string) (net.Conn, error) { + return d.DialContext(context.Background(), network, addr) +} + +func (d *Dialer) netDialer() *net.Dialer { + if d.NetDialer != nil { + return d.NetDialer + } + return new(net.Dialer) +} + +// DialContext connects to the given network address and initiates a TLS +// handshake, returning the resulting TLS connection. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// The returned Conn, if any, will always be of type *Conn. +func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := dial(ctx, d.netDialer(), network, addr, d.Config) + if err != nil { + // Don't return c (a typed nil) in an interface. + return nil, err + } + return c, nil +} + +// LoadX509KeyPair reads and parses a public/private key pair from a pair +// of files. The files must contain PEM encoded data. The certificate file +// may contain intermediate certificates following the leaf certificate to +// form a certificate chain. On successful return, Certificate.Leaf will +// be nil because the parsed form of the certificate is not retained. +func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { + certPEMBlock, err := os.ReadFile(certFile) + if err != nil { + return Certificate{}, err + } + keyPEMBlock, err := os.ReadFile(keyFile) + if err != nil { + return Certificate{}, err + } + return X509KeyPair(certPEMBlock, keyPEMBlock) +} + +// X509KeyPair parses a public/private key pair from a pair of +// PEM encoded data. On successful return, Certificate.Leaf will be nil because +// the parsed form of the certificate is not retained. +func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { + fail := func(err error) (Certificate, error) { return Certificate{}, err } + + var cert Certificate + var skippedBlockTypes []string + for { + var certDERBlock *pem.Block + certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) + if certDERBlock == nil { + break + } + if certDERBlock.Type == "CERTIFICATE" { + cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) + } else { + skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) + } + } + + if len(cert.Certificate) == 0 { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in certificate input")) + } + if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { + return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) + } + return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + + skippedBlockTypes = skippedBlockTypes[:0] + var keyDERBlock *pem.Block + for { + keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) + if keyDERBlock == nil { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in key input")) + } + if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { + return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key")) + } + return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { + break + } + skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) + } + + // We don't need to parse the public key for TLS, but we so do anyway + // to check that it looks sane and matches the private key. + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fail(err) + } + + cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) + if err != nil { + return fail(err) + } + + switch pub := x509Cert.PublicKey.(type) { + case *rsa.PublicKey: + priv, ok := cert.PrivateKey.(*rsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.N.Cmp(priv.N) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case *ecdsa.PublicKey: + priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case ed25519.PublicKey: + priv, ok := cert.PrivateKey.(ed25519.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { + return fail(errors.New("tls: private key does not match public key")) + } + default: + return fail(errors.New("tls: unknown public key algorithm")) + } + + return cert, nil +} + +// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates +// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. +// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. +func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { + if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { + switch key := key.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: + return key, nil + default: + return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") + } + } + if key, err := x509.ParseECPrivateKey(der); err == nil { + return key, nil + } + + return nil, errors.New("tls: failed to parse private key") +} diff --git a/vendor/github.com/lesismal/llib/std/crypto/tls/tls_test.go b/vendor/github.com/lesismal/llib/std/crypto/tls/tls_test.go new file mode 100644 index 0000000..169e38e --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/crypto/tls/tls_test.go @@ -0,0 +1,1477 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "context" + "crypto" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io" + "math" + "net" + "os" + "reflect" + "strings" + "testing" + "time" + + "github.com/lesismal/llib/std/internal/testenv" +) + +var rsaCertPEM = `-----BEGIN CERTIFICATE----- +MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ +hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa +rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv +zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF +MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW +r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V +-----END CERTIFICATE----- +` + +var rsaKeyPEM = testingKey(`-----BEGIN RSA TESTING KEY----- +MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo +k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G +6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N +MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW +SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T +xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi +D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g== +-----END RSA TESTING KEY----- +`) + +// keyPEM is the same as rsaKeyPEM, but declares itself as just +// "PRIVATE KEY", not "RSA PRIVATE KEY". https://golang.org/issue/4477 +var keyPEM = testingKey(`-----BEGIN TESTING KEY----- +MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo +k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G +6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N +MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW +SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T +xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi +D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g== +-----END TESTING KEY----- +`) + +var ecdsaCertPEM = `-----BEGIN CERTIFICATE----- +MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw +EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0 +eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG +EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk +Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR +lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl +01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8 +XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo +A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb +H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1 ++jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA== +-----END CERTIFICATE----- +` + +var ecdsaKeyPEM = testingKey(`-----BEGIN EC PARAMETERS----- +BgUrgQQAIw== +-----END EC PARAMETERS----- +-----BEGIN EC TESTING KEY----- +MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0 +NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL +06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz +VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q +kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ== +-----END EC TESTING KEY----- +`) + +var keyPairTests = []struct { + algo string + cert string + key string +}{ + {"ECDSA", ecdsaCertPEM, ecdsaKeyPEM}, + {"RSA", rsaCertPEM, rsaKeyPEM}, + {"RSA-untyped", rsaCertPEM, keyPEM}, // golang.org/issue/4477 +} + +func TestX509KeyPair(t *testing.T) { + t.Parallel() + var pem []byte + for _, test := range keyPairTests { + pem = []byte(test.cert + test.key) + if _, err := X509KeyPair(pem, pem); err != nil { + t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err) + } + pem = []byte(test.key + test.cert) + if _, err := X509KeyPair(pem, pem); err != nil { + t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err) + } + } +} + +func TestX509KeyPairErrors(t *testing.T) { + _, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM)) + if err == nil { + t.Fatalf("X509KeyPair didn't return an error when arguments were switched") + } + if subStr := "been switched"; !strings.Contains(err.Error(), subStr) { + t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err) + } + + _, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM)) + if err == nil { + t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates") + } + if subStr := "certificate"; !strings.Contains(err.Error(), subStr) { + t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err) + } + + const nonsensePEM = ` +-----BEGIN NONSENSE----- +Zm9vZm9vZm9v +-----END NONSENSE----- +` + + _, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM)) + if err == nil { + t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense") + } + if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) { + t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err) + } +} + +func TestX509MixedKeyPair(t *testing.T) { + if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil { + t.Error("Load of RSA certificate succeeded with ECDSA private key") + } + if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil { + t.Error("Load of ECDSA certificate succeeded with RSA private key") + } +} + +func newLocalListener(t testing.TB) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +} + +func TestDialTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + listener := newLocalListener(t) + + addr := listener.Addr().String() + defer listener.Close() + + complete := make(chan bool) + defer close(complete) + + go func() { + conn, err := listener.Accept() + if err != nil { + t.Error(err) + return + } + <-complete + conn.Close() + }() + + dialer := &net.Dialer{ + Timeout: 10 * time.Millisecond, + } + + var err error + if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil { + t.Fatal("DialWithTimeout completed successfully") + } + + if !isTimeoutError(err) { + t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err) + } +} + +func TestDeadlineOnWrite(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + ln := newLocalListener(t) + defer ln.Close() + + srvCh := make(chan *Conn, 1) + + go func() { + sconn, err := ln.Accept() + if err != nil { + srvCh <- nil + return + } + srv := Server(sconn, testConfig.Clone()) + if err := srv.Handshake(); err != nil { + srvCh <- nil + return + } + srvCh <- srv + }() + + clientConfig := testConfig.Clone() + clientConfig.MaxVersion = VersionTLS12 + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + srv := <-srvCh + if srv == nil { + t.Error(err) + } + + // Make sure the client/server is setup correctly and is able to do a typical Write/Read + buf := make([]byte, 6) + if _, err := srv.Write([]byte("foobar")); err != nil { + t.Errorf("Write err: %v", err) + } + if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" { + t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf) + } + + // Set a deadline which should cause Write to timeout + if err = srv.SetDeadline(time.Now()); err != nil { + t.Fatalf("SetDeadline(time.Now()) err: %v", err) + } + if _, err = srv.Write([]byte("should fail")); err == nil { + t.Fatal("Write should have timed out") + } + + // Clear deadline and make sure it still times out + if err = srv.SetDeadline(time.Time{}); err != nil { + t.Fatalf("SetDeadline(time.Time{}) err: %v", err) + } + if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil { + t.Fatal("Write which previously failed should still time out") + } + + // Verify the error + if ne := err.(net.Error); ne.Temporary() != false { + t.Error("Write timed out but incorrectly classified the error as Temporary") + } + if !isTimeoutError(err) { + t.Error("Write timed out but did not classify the error as a Timeout") + } +} + +type readerFunc func([]byte) (int, error) + +func (f readerFunc) Read(b []byte) (int, error) { return f(b) } + +// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake. +// (The other cases are all handled by the existing dial tests in this package, which +// all also flow through the same code shared code paths) +func TestDialer(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + unblockServer := make(chan struct{}) // close-only + defer close(unblockServer) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + <-unblockServer + }() + + ctx, cancel := context.WithCancel(context.Background()) + d := Dialer{Config: &Config{ + Rand: readerFunc(func(b []byte) (n int, err error) { + // By the time crypto/tls wants randomness, that means it has a TCP + // connection, so we're past the Dialer's dial and now blocked + // in a handshake. Cancel our context and see if we get unstuck. + // (Our TCP listener above never reads or writes, so the Handshake + // would otherwise be stuck forever) + cancel() + return len(b), nil + }), + ServerName: "foo", + }} + _, err := d.DialContext(ctx, "tcp", ln.Addr().String()) + if err != context.Canceled { + t.Errorf("err = %v; want context.Canceled", err) + } +} + +func isTimeoutError(err error) bool { + if ne, ok := err.(net.Error); ok { + return ne.Timeout() + } + return false +} + +// tests that Conn.Read returns (non-zero, io.EOF) instead of +// (non-zero, nil) when a Close (alertCloseNotify) is sitting right +// behind the application data in the buffer. +func TestConnReadNonzeroAndEOF(t *testing.T) { + // This test is racy: it assumes that after a write to a + // localhost TCP connection, the peer TCP connection can + // immediately read it. Because it's racy, we skip this test + // in short mode, and then retry it several times with an + // increasing sleep in between our final write (via srv.Close + // below) and the following read. + if testing.Short() { + t.Skip("skipping in short mode") + } + var err error + for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 { + if err = testConnReadNonzeroAndEOF(t, delay); err == nil { + return + } + } + t.Error(err) +} + +func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error { + ln := newLocalListener(t) + defer ln.Close() + + srvCh := make(chan *Conn, 1) + var serr error + go func() { + sconn, err := ln.Accept() + if err != nil { + serr = err + srvCh <- nil + return + } + serverConfig := testConfig.Clone() + srv := Server(sconn, serverConfig) + if err := srv.Handshake(); err != nil { + serr = fmt.Errorf("handshake: %v", err) + srvCh <- nil + return + } + srvCh <- srv + }() + + clientConfig := testConfig.Clone() + // In TLS 1.3, alerts are encrypted and disguised as application data, so + // the opportunistic peek won't work. + clientConfig.MaxVersion = VersionTLS12 + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + srv := <-srvCh + if srv == nil { + return serr + } + + buf := make([]byte, 6) + + srv.Write([]byte("foobar")) + n, err := conn.Read(buf) + if n != 6 || err != nil || string(buf) != "foobar" { + return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf) + } + + srv.Write([]byte("abcdef")) + srv.Close() + time.Sleep(delay) + n, err = conn.Read(buf) + if n != 6 || string(buf) != "abcdef" { + return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf) + } + if err != io.EOF { + return fmt.Errorf("Second Read error = %v; want io.EOF", err) + } + return nil +} + +func TestTLSUniqueMatches(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + serverTLSUniques := make(chan []byte) + parentDone := make(chan struct{}) + childDone := make(chan struct{}) + defer close(parentDone) + go func() { + defer close(childDone) + for i := 0; i < 2; i++ { + sconn, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + serverConfig := testConfig.Clone() + serverConfig.MaxVersion = VersionTLS12 // TLSUnique is not defined in TLS 1.3 + srv := Server(sconn, serverConfig) + if err := srv.Handshake(); err != nil { + t.Error(err) + return + } + select { + case <-parentDone: + return + case serverTLSUniques <- srv.ConnectionState().TLSUnique: + } + } + }() + + clientConfig := testConfig.Clone() + clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + t.Fatal(err) + } + + var serverTLSUniquesValue []byte + select { + case <-childDone: + return + case serverTLSUniquesValue = <-serverTLSUniques: + } + + if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) { + t.Error("client and server channel bindings differ") + } + conn.Close() + + conn, err = Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + if !conn.ConnectionState().DidResume { + t.Error("second session did not use resumption") + } + + select { + case <-childDone: + return + case serverTLSUniquesValue = <-serverTLSUniques: + } + + if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) { + t.Error("client and server channel bindings differ when session resumption is used") + } +} + +func TestVerifyHostname(t *testing.T) { + testenv.MustHaveExternalNetwork(t) + + c, err := Dial("tcp", "www.google.com:https", nil) + if err != nil { + t.Fatal(err) + } + if err := c.VerifyHostname("www.google.com"); err != nil { + t.Fatalf("verify www.google.com: %v", err) + } + if err := c.VerifyHostname("www.yahoo.com"); err == nil { + t.Fatalf("verify www.yahoo.com succeeded") + } + + c, err = Dial("tcp", "www.google.com:https", &Config{InsecureSkipVerify: true}) + if err != nil { + t.Fatal(err) + } + if err := c.VerifyHostname("www.google.com"); err == nil { + t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true") + } +} + +func TestConnCloseBreakingWrite(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + srvCh := make(chan *Conn, 1) + var serr error + var sconn net.Conn + go func() { + var err error + sconn, err = ln.Accept() + if err != nil { + serr = err + srvCh <- nil + return + } + serverConfig := testConfig.Clone() + srv := Server(sconn, serverConfig) + if err := srv.Handshake(); err != nil { + serr = fmt.Errorf("handshake: %v", err) + srvCh <- nil + return + } + srvCh <- srv + }() + + cconn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer cconn.Close() + + conn := &changeImplConn{ + Conn: cconn, + } + + clientConfig := testConfig.Clone() + tconn := Client(conn, clientConfig) + if err := tconn.Handshake(); err != nil { + t.Fatal(err) + } + + srv := <-srvCh + if srv == nil { + t.Fatal(serr) + } + defer sconn.Close() + + connClosed := make(chan struct{}) + conn.closeFunc = func() error { + close(connClosed) + return nil + } + + inWrite := make(chan bool, 1) + var errConnClosed = errors.New("conn closed for test") + conn.writeFunc = func(p []byte) (n int, err error) { + inWrite <- true + <-connClosed + return 0, errConnClosed + } + + closeReturned := make(chan bool, 1) + go func() { + <-inWrite + tconn.Close() // test that this doesn't block forever. + closeReturned <- true + }() + + _, err = tconn.Write([]byte("foo")) + if err != errConnClosed { + t.Errorf("Write error = %v; want errConnClosed", err) + } + + <-closeReturned + if err := tconn.Close(); err != net.ErrClosed { + t.Errorf("Close error = %v; want net.ErrClosed", err) + } +} + +func TestConnCloseWrite(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + clientDoneChan := make(chan struct{}) + + serverCloseWrite := func() error { + sconn, err := ln.Accept() + if err != nil { + return fmt.Errorf("accept: %v", err) + } + defer sconn.Close() + + serverConfig := testConfig.Clone() + srv := Server(sconn, serverConfig) + if err := srv.Handshake(); err != nil { + return fmt.Errorf("handshake: %v", err) + } + defer srv.Close() + + data, err := io.ReadAll(srv) + if err != nil { + return err + } + if len(data) > 0 { + return fmt.Errorf("Read data = %q; want nothing", data) + } + + if err := srv.CloseWrite(); err != nil { + return fmt.Errorf("server CloseWrite: %v", err) + } + + // Wait for clientCloseWrite to finish, so we know we + // tested the CloseWrite before we defer the + // sconn.Close above, which would also cause the + // client to unblock like CloseWrite. + <-clientDoneChan + return nil + } + + clientCloseWrite := func() error { + defer close(clientDoneChan) + + clientConfig := testConfig.Clone() + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + return err + } + if err := conn.Handshake(); err != nil { + return err + } + defer conn.Close() + + if err := conn.CloseWrite(); err != nil { + return fmt.Errorf("client CloseWrite: %v", err) + } + + if _, err := conn.Write([]byte{0}); err != errShutdown { + return fmt.Errorf("CloseWrite error = %v; want errShutdown", err) + } + + data, err := io.ReadAll(conn) + if err != nil { + return err + } + if len(data) > 0 { + return fmt.Errorf("Read data = %q; want nothing", data) + } + return nil + } + + errChan := make(chan error, 2) + + go func() { errChan <- serverCloseWrite() }() + go func() { errChan <- clientCloseWrite() }() + + for i := 0; i < 2; i++ { + select { + case err := <-errChan: + if err != nil { + t.Fatal(err) + } + case <-time.After(10 * time.Second): + t.Fatal("deadlock") + } + } + + // Also test CloseWrite being called before the handshake is + // finished: + { + ln2 := newLocalListener(t) + defer ln2.Close() + + netConn, err := net.Dial("tcp", ln2.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer netConn.Close() + conn := Client(netConn, testConfig.Clone()) + + if err := conn.CloseWrite(); err != errEarlyCloseWrite { + t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err) + } + } +} + +func TestWarningAlertFlood(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + server := func() error { + sconn, err := ln.Accept() + if err != nil { + return fmt.Errorf("accept: %v", err) + } + defer sconn.Close() + + serverConfig := testConfig.Clone() + srv := Server(sconn, serverConfig) + if err := srv.Handshake(); err != nil { + return fmt.Errorf("handshake: %v", err) + } + defer srv.Close() + + _, err = io.ReadAll(srv) + if err == nil { + return errors.New("unexpected lack of error from server") + } + const expected = "too many ignored" + if str := err.Error(); !strings.Contains(str, expected) { + return fmt.Errorf("expected error containing %q, but saw: %s", expected, str) + } + + return nil + } + + errChan := make(chan error, 1) + go func() { errChan <- server() }() + + clientConfig := testConfig.Clone() + clientConfig.MaxVersion = VersionTLS12 // there are no warning alerts in TLS 1.3 + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + if err := conn.Handshake(); err != nil { + t.Fatal(err) + } + + for i := 0; i < maxUselessRecords+1; i++ { + conn.sendAlert(alertNoRenegotiation) + } + + if err := <-errChan; err != nil { + t.Fatal(err) + } +} + +func TestCloneFuncFields(t *testing.T) { + const expectedCount = 6 + called := 0 + + c1 := Config{ + Time: func() time.Time { + called |= 1 << 0 + return time.Time{} + }, + GetCertificate: func(*ClientHelloInfo) (*Certificate, error) { + called |= 1 << 1 + return nil, nil + }, + GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) { + called |= 1 << 2 + return nil, nil + }, + GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { + called |= 1 << 3 + return nil, nil + }, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + called |= 1 << 4 + return nil + }, + VerifyConnection: func(ConnectionState) error { + called |= 1 << 5 + return nil + }, + } + + c2 := c1.Clone() + + c2.Time() + c2.GetCertificate(nil) + c2.GetClientCertificate(nil) + c2.GetConfigForClient(nil) + c2.VerifyPeerCertificate(nil, nil) + c2.VerifyConnection(ConnectionState{}) + + if called != (1< len(p) { + allowed = len(p) + } + if wrote < allowed { + n, err := c.Conn.Write(p[wrote:allowed]) + wrote += n + if err != nil { + return wrote, err + } + } + } + return len(p), nil +} + +func latency(b *testing.B, version uint16, bps int, dynamicRecordSizingDisabled bool) { + ln := newLocalListener(b) + defer ln.Close() + + N := b.N + + go func() { + for i := 0; i < N; i++ { + sconn, err := ln.Accept() + if err != nil { + // panic rather than synchronize to avoid benchmark overhead + // (cannot call b.Fatal in goroutine) + panic(fmt.Errorf("accept: %v", err)) + } + serverConfig := testConfig.Clone() + serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled + srv := Server(&slowConn{sconn, bps}, serverConfig) + if err := srv.Handshake(); err != nil { + panic(fmt.Errorf("handshake: %v", err)) + } + io.Copy(srv, srv) + } + }() + + clientConfig := testConfig.Clone() + clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled + clientConfig.MaxVersion = version + + buf := make([]byte, 16384) + peek := make([]byte, 1) + + for i := 0; i < N; i++ { + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + b.Fatal(err) + } + // make sure we're connected and previous connection has stopped + if _, err := conn.Write(buf[:1]); err != nil { + b.Fatal(err) + } + if _, err := io.ReadFull(conn, peek); err != nil { + b.Fatal(err) + } + if _, err := conn.Write(buf); err != nil { + b.Fatal(err) + } + if _, err = io.ReadFull(conn, peek); err != nil { + b.Fatal(err) + } + conn.Close() + } +} + +func BenchmarkLatency(b *testing.B) { + for _, mode := range []string{"Max", "Dynamic"} { + for _, kbps := range []int{200, 500, 1000, 2000, 5000} { + name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps) + b.Run(name, func(b *testing.B) { + b.Run("TLSv12", func(b *testing.B) { + latency(b, VersionTLS12, kbps*1000, mode == "Max") + }) + b.Run("TLSv13", func(b *testing.B) { + latency(b, VersionTLS13, kbps*1000, mode == "Max") + }) + }) + } + } +} + +func TestConnectionStateMarshal(t *testing.T) { + cs := &ConnectionState{} + _, err := json.Marshal(cs) + if err != nil { + t.Errorf("json.Marshal failed on ConnectionState: %v", err) + } +} + +func TestConnectionState(t *testing.T) { + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + rootCAs := x509.NewCertPool() + rootCAs.AddCert(issuer) + + now := func() time.Time { return time.Unix(1476984729, 0) } + + const alpnProtocol = "golang" + const serverName = "example.golang" + var scts = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} + var ocsp = []byte("dummy ocsp") + + for _, v := range []uint16{VersionTLS12, VersionTLS13} { + var name string + switch v { + case VersionTLS12: + name = "TLSv12" + case VersionTLS13: + name = "TLSv13" + } + t.Run(name, func(t *testing.T) { + config := &Config{ + Time: now, + Rand: zeroSource{}, + Certificates: make([]Certificate, 1), + MaxVersion: v, + RootCAs: rootCAs, + ClientCAs: rootCAs, + ClientAuth: RequireAndVerifyClientCert, + NextProtos: []string{alpnProtocol}, + ServerName: serverName, + } + config.Certificates[0].Certificate = [][]byte{testRSACertificate} + config.Certificates[0].PrivateKey = testRSAPrivateKey + config.Certificates[0].SignedCertificateTimestamps = scts + config.Certificates[0].OCSPStaple = ocsp + + ss, cs, err := testHandshake(t, config, config) + if err != nil { + t.Fatalf("Handshake failed: %v", err) + } + + if ss.Version != v || cs.Version != v { + t.Errorf("Got versions %x (server) and %x (client), expected %x", ss.Version, cs.Version, v) + } + + if !ss.HandshakeComplete || !cs.HandshakeComplete { + t.Errorf("Got HandshakeComplete %v (server) and %v (client), expected true", ss.HandshakeComplete, cs.HandshakeComplete) + } + + if ss.DidResume || cs.DidResume { + t.Errorf("Got DidResume %v (server) and %v (client), expected false", ss.DidResume, cs.DidResume) + } + + if ss.CipherSuite == 0 || cs.CipherSuite == 0 { + t.Errorf("Got invalid cipher suite: %v (server) and %v (client)", ss.CipherSuite, cs.CipherSuite) + } + + if ss.NegotiatedProtocol != alpnProtocol || cs.NegotiatedProtocol != alpnProtocol { + t.Errorf("Got negotiated protocol %q (server) and %q (client), expected %q", ss.NegotiatedProtocol, cs.NegotiatedProtocol, alpnProtocol) + } + + if !cs.NegotiatedProtocolIsMutual { + t.Errorf("Got false NegotiatedProtocolIsMutual on the client side") + } + // NegotiatedProtocolIsMutual on the server side is unspecified. + + if ss.ServerName != serverName { + t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName) + } + if cs.ServerName != serverName { + t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName) + } + + if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 { + t.Errorf("Got %d (server) and %d (client) peer certificates, expected %d", len(ss.PeerCertificates), len(cs.PeerCertificates), 1) + } + + if len(ss.VerifiedChains) != 1 || len(cs.VerifiedChains) != 1 { + t.Errorf("Got %d (server) and %d (client) verified chains, expected %d", len(ss.VerifiedChains), len(cs.VerifiedChains), 1) + } else if len(ss.VerifiedChains[0]) != 2 || len(cs.VerifiedChains[0]) != 2 { + t.Errorf("Got %d (server) and %d (client) long verified chain, expected %d", len(ss.VerifiedChains[0]), len(cs.VerifiedChains[0]), 2) + } + + if len(cs.SignedCertificateTimestamps) != 2 { + t.Errorf("Got %d SCTs, expected %d", len(cs.SignedCertificateTimestamps), 2) + } + if !bytes.Equal(cs.OCSPResponse, ocsp) { + t.Errorf("Got OCSPs %x, expected %x", cs.OCSPResponse, ocsp) + } + // Only TLS 1.3 supports OCSP and SCTs on client certs. + if v == VersionTLS13 { + if len(ss.SignedCertificateTimestamps) != 2 { + t.Errorf("Got %d client SCTs, expected %d", len(ss.SignedCertificateTimestamps), 2) + } + if !bytes.Equal(ss.OCSPResponse, ocsp) { + t.Errorf("Got client OCSPs %x, expected %x", ss.OCSPResponse, ocsp) + } + } + + if v == VersionTLS13 { + if ss.TLSUnique != nil || cs.TLSUnique != nil { + t.Errorf("Got TLSUnique %x (server) and %x (client), expected nil in TLS 1.3", ss.TLSUnique, cs.TLSUnique) + } + } else { + if ss.TLSUnique == nil || cs.TLSUnique == nil { + t.Errorf("Got TLSUnique %x (server) and %x (client), expected non-nil", ss.TLSUnique, cs.TLSUnique) + } + } + }) + } +} + +// Issue 28744: Ensure that we don't modify memory +// that Config doesn't own such as Certificates. +func TestBuildNameToCertificate_doesntModifyCertificates(t *testing.T) { + c0 := Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + } + c1 := Certificate{ + Certificate: [][]byte{testSNICertificate}, + PrivateKey: testRSAPrivateKey, + } + config := testConfig.Clone() + config.Certificates = []Certificate{c0, c1} + + config.BuildNameToCertificate() + got := config.Certificates + want := []Certificate{c0, c1} + if !reflect.DeepEqual(got, want) { + t.Fatalf("Certificates were mutated by BuildNameToCertificate\nGot: %#v\nWant: %#v\n", got, want) + } +} + +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } + +func TestClientHelloInfo_SupportsCertificate(t *testing.T) { + rsaCert := &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + } + pkcs1Cert := &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}, + } + ecdsaCert := &Certificate{ + // ECDSA P-256 certificate + Certificate: [][]byte{testP256Certificate}, + PrivateKey: testP256PrivateKey, + } + ed25519Cert := &Certificate{ + Certificate: [][]byte{testEd25519Certificate}, + PrivateKey: testEd25519PrivateKey, + } + + tests := []struct { + c *Certificate + chi *ClientHelloInfo + wantErr string + }{ + {rsaCert, &ClientHelloInfo{ + ServerName: "example.golang", + SignatureSchemes: []SignatureScheme{PSSWithSHA256}, + SupportedVersions: []uint16{VersionTLS13}, + }, ""}, + {ecdsaCert, &ClientHelloInfo{ + SignatureSchemes: []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256}, + SupportedVersions: []uint16{VersionTLS13, VersionTLS12}, + }, ""}, + {rsaCert, &ClientHelloInfo{ + ServerName: "example.com", + SignatureSchemes: []SignatureScheme{PSSWithSHA256}, + SupportedVersions: []uint16{VersionTLS13}, + }, "not valid for requested server name"}, + {ecdsaCert, &ClientHelloInfo{ + SignatureSchemes: []SignatureScheme{ECDSAWithP384AndSHA384}, + SupportedVersions: []uint16{VersionTLS13}, + }, "signature algorithms"}, + {pkcs1Cert, &ClientHelloInfo{ + SignatureSchemes: []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256}, + SupportedVersions: []uint16{VersionTLS13}, + }, "signature algorithms"}, + + {rsaCert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, + SignatureSchemes: []SignatureScheme{PKCS1WithSHA1}, + SupportedVersions: []uint16{VersionTLS13, VersionTLS12}, + }, "signature algorithms"}, + {rsaCert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, + SignatureSchemes: []SignatureScheme{PKCS1WithSHA1}, + SupportedVersions: []uint16{VersionTLS13, VersionTLS12}, + config: &Config{ + MaxVersion: VersionTLS12, + }, + }, ""}, // Check that mutual version selection works. + + {ecdsaCert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SupportedCurves: []CurveID{CurveP256}, + SupportedPoints: []uint8{pointFormatUncompressed}, + SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, + SupportedVersions: []uint16{VersionTLS12}, + }, ""}, + {ecdsaCert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SupportedCurves: []CurveID{CurveP256}, + SupportedPoints: []uint8{pointFormatUncompressed}, + SignatureSchemes: []SignatureScheme{ECDSAWithP384AndSHA384}, + SupportedVersions: []uint16{VersionTLS12}, + }, ""}, // TLS 1.2 does not restrict curves based on the SignatureScheme. + {ecdsaCert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SupportedCurves: []CurveID{CurveP256}, + SupportedPoints: []uint8{pointFormatUncompressed}, + SignatureSchemes: nil, + SupportedVersions: []uint16{VersionTLS12}, + }, ""}, // TLS 1.2 comes with default signature schemes. + {ecdsaCert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, + SupportedCurves: []CurveID{CurveP256}, + SupportedPoints: []uint8{pointFormatUncompressed}, + SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, + SupportedVersions: []uint16{VersionTLS12}, + }, "cipher suite"}, + {ecdsaCert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SupportedCurves: []CurveID{CurveP256}, + SupportedPoints: []uint8{pointFormatUncompressed}, + SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, + SupportedVersions: []uint16{VersionTLS12}, + config: &Config{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, + }, + }, "cipher suite"}, + {ecdsaCert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SupportedCurves: []CurveID{CurveP384}, + SupportedPoints: []uint8{pointFormatUncompressed}, + SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, + SupportedVersions: []uint16{VersionTLS12}, + }, "certificate curve"}, + {ecdsaCert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SupportedCurves: []CurveID{CurveP256}, + SupportedPoints: []uint8{1}, + SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, + SupportedVersions: []uint16{VersionTLS12}, + }, "doesn't support ECDHE"}, + {ecdsaCert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SupportedCurves: []CurveID{CurveP256}, + SupportedPoints: []uint8{pointFormatUncompressed}, + SignatureSchemes: []SignatureScheme{PSSWithSHA256}, + SupportedVersions: []uint16{VersionTLS12}, + }, "signature algorithms"}, + + {ed25519Cert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support + SupportedPoints: []uint8{pointFormatUncompressed}, + SignatureSchemes: []SignatureScheme{Ed25519}, + SupportedVersions: []uint16{VersionTLS12}, + }, ""}, + {ed25519Cert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support + SupportedPoints: []uint8{pointFormatUncompressed}, + SignatureSchemes: []SignatureScheme{Ed25519}, + SupportedVersions: []uint16{VersionTLS10}, + }, "doesn't support Ed25519"}, + {ed25519Cert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SupportedCurves: []CurveID{}, + SupportedPoints: []uint8{pointFormatUncompressed}, + SignatureSchemes: []SignatureScheme{Ed25519}, + SupportedVersions: []uint16{VersionTLS12}, + }, "doesn't support ECDHE"}, + + {rsaCert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}, + SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support + SupportedPoints: []uint8{pointFormatUncompressed}, + SupportedVersions: []uint16{VersionTLS10}, + }, ""}, + {rsaCert, &ClientHelloInfo{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, + SupportedVersions: []uint16{VersionTLS12}, + }, ""}, // static RSA fallback + } + for i, tt := range tests { + err := tt.chi.SupportsCertificate(tt.c) + switch { + case tt.wantErr == "" && err != nil: + t.Errorf("%d: unexpected error: %v", i, err) + case tt.wantErr != "" && err == nil: + t.Errorf("%d: unexpected success", i) + case tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr): + t.Errorf("%d: got error %q, expected %q", i, err, tt.wantErr) + } + } +} + +func TestCipherSuites(t *testing.T) { + var lastID uint16 + for _, c := range CipherSuites() { + if lastID > c.ID { + t.Errorf("CipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID) + } else { + lastID = c.ID + } + + if c.Insecure { + t.Errorf("%#04x: Insecure CipherSuite returned by CipherSuites()", c.ID) + } + } + lastID = 0 + for _, c := range InsecureCipherSuites() { + if lastID > c.ID { + t.Errorf("InsecureCipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID) + } else { + lastID = c.ID + } + + if !c.Insecure { + t.Errorf("%#04x: not Insecure CipherSuite returned by InsecureCipherSuites()", c.ID) + } + } + + cipherSuiteByID := func(id uint16) *CipherSuite { + for _, c := range CipherSuites() { + if c.ID == id { + return c + } + } + for _, c := range InsecureCipherSuites() { + if c.ID == id { + return c + } + } + return nil + } + + for _, c := range cipherSuites { + cc := cipherSuiteByID(c.id) + if cc == nil { + t.Errorf("%#04x: no CipherSuite entry", c.id) + continue + } + + if defaultOff := c.flags&suiteDefaultOff != 0; defaultOff != cc.Insecure { + t.Errorf("%#04x: Insecure %v, expected %v", c.id, cc.Insecure, defaultOff) + } + if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 { + t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions) + } else if !tls12Only && len(cc.SupportedVersions) != 3 { + t.Errorf("%#04x: suite TLS 1.0-1.2, but SupportedVersions is %v", c.id, cc.SupportedVersions) + } + + if got := CipherSuiteName(c.id); got != cc.Name { + t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name) + } + } + for _, c := range cipherSuitesTLS13 { + cc := cipherSuiteByID(c.id) + if cc == nil { + t.Errorf("%#04x: no CipherSuite entry", c.id) + continue + } + + if cc.Insecure { + t.Errorf("%#04x: Insecure %v, expected false", c.id, cc.Insecure) + } + if len(cc.SupportedVersions) != 1 || cc.SupportedVersions[0] != VersionTLS13 { + t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions) + } + + if got := CipherSuiteName(c.id); got != cc.Name { + t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name) + } + } + + if got := CipherSuiteName(0xabc); got != "0x0ABC" { + t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got) + } +} + +type brokenSigner struct{ crypto.Signer } + +func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { + // Replace opts with opts.HashFunc(), so rsa.PSSOptions are discarded. + return s.Signer.Sign(rand, digest, opts.HashFunc()) +} + +// TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that +// always makes PKCS #1 v1.5 signatures, so can't be used with RSA-PSS. +func TestPKCS1OnlyCert(t *testing.T) { + clientConfig := testConfig.Clone() + clientConfig.Certificates = []Certificate{{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: brokenSigner{testRSAPrivateKey}, + }} + serverConfig := testConfig.Clone() + serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS #1 v1.5 + serverConfig.ClientAuth = RequireAnyClientCert + + // If RSA-PSS is selected, the handshake should fail. + if _, _, err := testHandshake(t, clientConfig, serverConfig); err == nil { + t.Fatal("expected broken certificate to cause connection to fail") + } + + clientConfig.Certificates[0].SupportedSignatureAlgorithms = + []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256} + + // But if the certificate restricts supported algorithms, RSA-PSS should not + // be selected, and the handshake should succeed. + if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil { + t.Error(err) + } +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu.go new file mode 100644 index 0000000..dab5d06 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu.go @@ -0,0 +1,226 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cpu implements processor feature detection +// used by the Go standard library. +package cpu + +// DebugOptions is set to true by the runtime if the OS supports reading +// GODEBUG early in runtime startup. +// This should not be changed after it is initialized. +var DebugOptions bool + +// CacheLinePad is used to pad structs to avoid false sharing. +type CacheLinePad struct{ _ [CacheLinePadSize]byte } + +// CacheLineSize is the CPU's assumed cache line size. +// There is currently no runtime detection of the real cache line size +// so we use the constant per GOARCH CacheLinePadSize as an approximation. +var CacheLineSize uintptr = CacheLinePadSize + +// The booleans in X86 contain the correspondingly named cpuid feature bit. +// HasAVX and HasAVX2 are only set if the OS does support XMM and YMM registers +// in addition to the cpuid feature bit being set. +// The struct is padded to avoid false sharing. +var X86 struct { + _ CacheLinePad + HasAES bool + HasADX bool + HasAVX bool + HasAVX2 bool + HasBMI1 bool + HasBMI2 bool + HasERMS bool + HasFMA bool + HasOSXSAVE bool + HasPCLMULQDQ bool + HasPOPCNT bool + HasSSE2 bool + HasSSE3 bool + HasSSSE3 bool + HasSSE41 bool + HasSSE42 bool + _ CacheLinePad +} + +// The booleans in ARM contain the correspondingly named cpu feature bit. +// The struct is padded to avoid false sharing. +var ARM struct { + _ CacheLinePad + HasVFPv4 bool + HasIDIVA bool + _ CacheLinePad +} + +// The booleans in ARM64 contain the correspondingly named cpu feature bit. +// The struct is padded to avoid false sharing. +var ARM64 struct { + _ CacheLinePad + HasAES bool + HasPMULL bool + HasSHA1 bool + HasSHA2 bool + HasCRC32 bool + HasATOMICS bool + HasCPUID bool + IsNeoverseN1 bool + IsZeus bool + _ CacheLinePad +} + +var MIPS64X struct { + _ CacheLinePad + HasMSA bool // MIPS SIMD architecture + _ CacheLinePad +} + +// For ppc64(le), it is safe to check only for ISA level starting on ISA v3.00, +// since there are no optional categories. There are some exceptions that also +// require kernel support to work (darn, scv), so there are feature bits for +// those as well. The minimum processor requirement is POWER8 (ISA 2.07). +// The struct is padded to avoid false sharing. +var PPC64 struct { + _ CacheLinePad + HasDARN bool // Hardware random number generator (requires kernel enablement) + HasSCV bool // Syscall vectored (requires kernel enablement) + IsPOWER8 bool // ISA v2.07 (POWER8) + IsPOWER9 bool // ISA v3.00 (POWER9) + _ CacheLinePad +} + +var S390X struct { + _ CacheLinePad + HasZARCH bool // z architecture mode is active [mandatory] + HasSTFLE bool // store facility list extended [mandatory] + HasLDISP bool // long (20-bit) displacements [mandatory] + HasEIMM bool // 32-bit immediates [mandatory] + HasDFP bool // decimal floating point + HasETF3EH bool // ETF-3 enhanced + HasMSA bool // message security assist (CPACF) + HasAES bool // KM-AES{128,192,256} functions + HasAESCBC bool // KMC-AES{128,192,256} functions + HasAESCTR bool // KMCTR-AES{128,192,256} functions + HasAESGCM bool // KMA-GCM-AES{128,192,256} functions + HasGHASH bool // KIMD-GHASH function + HasSHA1 bool // K{I,L}MD-SHA-1 functions + HasSHA256 bool // K{I,L}MD-SHA-256 functions + HasSHA512 bool // K{I,L}MD-SHA-512 functions + HasSHA3 bool // K{I,L}MD-SHA3-{224,256,384,512} and K{I,L}MD-SHAKE-{128,256} functions + HasVX bool // vector facility. Note: the runtime sets this when it processes auxv records. + HasVXE bool // vector-enhancements facility 1 + HasKDSA bool // elliptic curve functions + HasECDSA bool // NIST curves + HasEDDSA bool // Edwards curves + _ CacheLinePad +} + +// Initialize examines the processor and sets the relevant variables above. +// This is called by the runtime package early in program initialization, +// before normal init functions are run. env is set by runtime if the OS supports +// cpu feature options in GODEBUG. +func Initialize(env string) { + doinit() + processOptions(env) +} + +// options contains the cpu debug options that can be used in GODEBUG. +// Options are arch dependent and are added by the arch specific doinit functions. +// Features that are mandatory for the specific GOARCH should not be added to options +// (e.g. SSE2 on amd64). +var options []option + +// Option names should be lower case. e.g. avx instead of AVX. +type option struct { + Name string + Feature *bool + Specified bool // whether feature value was specified in GODEBUG + Enable bool // whether feature should be enabled + Required bool // whether feature is mandatory and can not be disabled +} + +// processOptions enables or disables CPU feature values based on the parsed env string. +// The env string is expected to be of the form cpu.feature1=value1,cpu.feature2=value2... +// where feature names is one of the architecture specific list stored in the +// cpu packages options variable and values are either 'on' or 'off'. +// If env contains cpu.all=off then all cpu features referenced through the options +// variable are disabled. Other feature names and values result in warning messages. +func processOptions(env string) { +field: + for env != "" { + field := "" + i := indexByte(env, ',') + if i < 0 { + field, env = env, "" + } else { + field, env = env[:i], env[i+1:] + } + if len(field) < 4 || field[:4] != "cpu." { + continue + } + i = indexByte(field, '=') + if i < 0 { + print("GODEBUG: no value specified for \"", field, "\"\n") + continue + } + key, value := field[4:i], field[i+1:] // e.g. "SSE2", "on" + + var enable bool + switch value { + case "on": + enable = true + case "off": + enable = false + default: + print("GODEBUG: value \"", value, "\" not supported for cpu option \"", key, "\"\n") + continue field + } + + if key == "all" { + for i := range options { + options[i].Specified = true + options[i].Enable = enable || options[i].Required + } + continue field + } + + for i := range options { + if options[i].Name == key { + options[i].Specified = true + options[i].Enable = enable + continue field + } + } + + print("GODEBUG: unknown cpu feature \"", key, "\"\n") + } + + for _, o := range options { + if !o.Specified { + continue + } + + if o.Enable && !*o.Feature { + print("GODEBUG: can not enable \"", o.Name, "\", missing CPU support\n") + continue + } + + if !o.Enable && o.Required { + print("GODEBUG: can not disable \"", o.Name, "\", required CPU feature\n") + continue + } + + *o.Feature = o.Enable + } +} + +// indexByte returns the index of the first instance of c in s, +// or -1 if c is not present in s. +func indexByte(s string, c byte) int { + for i := 0; i < len(s); i++ { + if s[i] == c { + return i + } + } + return -1 +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu.s b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu.s new file mode 100644 index 0000000..3c770c1 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu.s @@ -0,0 +1,6 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This assembly file exists to allow internal/cpu to call +// non-exported runtime functions that use "go:linkname". \ No newline at end of file diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_386.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_386.go new file mode 100644 index 0000000..561c81f --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_386.go @@ -0,0 +1,7 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const GOARCH = "386" diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_amd64.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_amd64.go new file mode 100644 index 0000000..9b00153 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_amd64.go @@ -0,0 +1,7 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const GOARCH = "amd64" diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm.go new file mode 100644 index 0000000..b624526 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm.go @@ -0,0 +1,34 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLinePadSize = 32 + +// arm doesn't have a 'cpuid' equivalent, so we rely on HWCAP/HWCAP2. +// These are initialized by archauxv() and should not be changed after they are +// initialized. +var HWCap uint +var HWCap2 uint + +// HWCAP/HWCAP2 bits. These are exposed by Linux and FreeBSD. +const ( + hwcap_VFPv4 = 1 << 16 + hwcap_IDIVA = 1 << 17 +) + +func doinit() { + options = []option{ + {Name: "vfpv4", Feature: &ARM.HasVFPv4}, + {Name: "idiva", Feature: &ARM.HasIDIVA}, + } + + // HWCAP feature bits + ARM.HasVFPv4 = isSet(HWCap, hwcap_VFPv4) + ARM.HasIDIVA = isSet(HWCap, hwcap_IDIVA) +} + +func isSet(hwc uint, value uint) bool { + return hwc&value != 0 +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64.go new file mode 100644 index 0000000..f64d9e4 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64.go @@ -0,0 +1,28 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLinePadSize = 64 + +func doinit() { + options = []option{ + {Name: "aes", Feature: &ARM64.HasAES}, + {Name: "pmull", Feature: &ARM64.HasPMULL}, + {Name: "sha1", Feature: &ARM64.HasSHA1}, + {Name: "sha2", Feature: &ARM64.HasSHA2}, + {Name: "crc32", Feature: &ARM64.HasCRC32}, + {Name: "atomics", Feature: &ARM64.HasATOMICS}, + {Name: "cpuid", Feature: &ARM64.HasCPUID}, + {Name: "isNeoverseN1", Feature: &ARM64.IsNeoverseN1}, + {Name: "isZeus", Feature: &ARM64.IsZeus}, + } + + // arm64 uses different ways to detect CPU features at runtime depending on the operating system. + osInit() +} + +func getisar0() uint64 + +func getMIDR() uint64 diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64.s b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64.s new file mode 100644 index 0000000..d6e7f44 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64.s @@ -0,0 +1,18 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "textflag.h" + +// func getisar0() uint64 +TEXT ·getisar0(SB),NOSPLIT,$0 + // get Instruction Set Attributes 0 into R0 + MRS ID_AA64ISAR0_EL1, R0 + MOVD R0, ret+0(FP) + RET + +// func getMIDR() uint64 +TEXT ·getMIDR(SB), NOSPLIT, $0-8 + MRS MIDR_EL1, R0 + MOVD R0, ret+0(FP) + RET diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_android.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_android.go new file mode 100644 index 0000000..3c9e57c --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_android.go @@ -0,0 +1,11 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build arm64 + +package cpu + +func osInit() { + hwcapInit("android") +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_darwin.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_darwin.go new file mode 100644 index 0000000..e094b97 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_darwin.go @@ -0,0 +1,34 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build arm64 +// +build darwin +// +build !ios + +package cpu + +func osInit() { + ARM64.HasATOMICS = sysctlEnabled([]byte("hw.optional.armv8_1_atomics\x00")) + ARM64.HasCRC32 = sysctlEnabled([]byte("hw.optional.armv8_crc32\x00")) + + // There are no hw.optional sysctl values for the below features on Mac OS 11.0 + // to detect their supported state dynamically. Assume the CPU features that + // Apple Silicon M1 supports to be available as a minimal set of features + // to all Go programs running on darwin/arm64. + ARM64.HasAES = true + ARM64.HasPMULL = true + ARM64.HasSHA1 = true + ARM64.HasSHA2 = true +} + +//go:noescape +func getsysctlbyname(name []byte) (int32, int32) + +func sysctlEnabled(name []byte) bool { + ret, value := getsysctlbyname(name) + if ret < 0 { + return false + } + return value > 0 +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_freebsd.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_freebsd.go new file mode 100644 index 0000000..9de2005 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_freebsd.go @@ -0,0 +1,45 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build arm64 + +package cpu + +func osInit() { + // Retrieve info from system register ID_AA64ISAR0_EL1. + isar0 := getisar0() + + // ID_AA64ISAR0_EL1 + switch extractBits(isar0, 4, 7) { + case 1: + ARM64.HasAES = true + case 2: + ARM64.HasAES = true + ARM64.HasPMULL = true + } + + switch extractBits(isar0, 8, 11) { + case 1: + ARM64.HasSHA1 = true + } + + switch extractBits(isar0, 12, 15) { + case 1, 2: + ARM64.HasSHA2 = true + } + + switch extractBits(isar0, 16, 19) { + case 1: + ARM64.HasCRC32 = true + } + + switch extractBits(isar0, 20, 23) { + case 2: + ARM64.HasATOMICS = true + } +} + +func extractBits(data uint64, start, end uint) uint { + return (uint)(data>>start) & ((1 << (end - start + 1)) - 1) +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_hwcap.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_hwcap.go new file mode 100644 index 0000000..fdaf43e --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_hwcap.go @@ -0,0 +1,63 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build arm64 +// +build linux + +package cpu + +// HWCap may be initialized by archauxv and +// should not be changed after it was initialized. +var HWCap uint + +// HWCAP bits. These are exposed by Linux. +const ( + hwcap_AES = 1 << 3 + hwcap_PMULL = 1 << 4 + hwcap_SHA1 = 1 << 5 + hwcap_SHA2 = 1 << 6 + hwcap_CRC32 = 1 << 7 + hwcap_ATOMICS = 1 << 8 + hwcap_CPUID = 1 << 11 +) + +func hwcapInit(os string) { + // HWCap was populated by the runtime from the auxiliary vector. + // Use HWCap information since reading aarch64 system registers + // is not supported in user space on older linux kernels. + ARM64.HasAES = isSet(HWCap, hwcap_AES) + ARM64.HasPMULL = isSet(HWCap, hwcap_PMULL) + ARM64.HasSHA1 = isSet(HWCap, hwcap_SHA1) + ARM64.HasSHA2 = isSet(HWCap, hwcap_SHA2) + ARM64.HasCRC32 = isSet(HWCap, hwcap_CRC32) + ARM64.HasCPUID = isSet(HWCap, hwcap_CPUID) + + // The Samsung S9+ kernel reports support for atomics, but not all cores + // actually support them, resulting in SIGILL. See issue #28431. + // TODO(elias.naur): Only disable the optimization on bad chipsets on android. + ARM64.HasATOMICS = isSet(HWCap, hwcap_ATOMICS) && os != "android" + + // Check to see if executing on a NeoverseN1 and in order to do that, + // check the AUXV for the CPUID bit. The getMIDR function executes an + // instruction which would normally be an illegal instruction, but it's + // trapped by the kernel, the value sanitized and then returned. Without + // the CPUID bit the kernel will not trap the instruction and the process + // will be terminated with SIGILL. + if ARM64.HasCPUID { + midr := getMIDR() + part_num := uint16((midr >> 4) & 0xfff) + implementor := byte((midr >> 24) & 0xff) + + if implementor == 'A' && part_num == 0xd0c { + ARM64.IsNeoverseN1 = true + } + if implementor == 'A' && part_num == 0xd40 { + ARM64.IsZeus = true + } + } +} + +func isSet(hwc uint, value uint) bool { + return hwc&value != 0 +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_linux.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_linux.go new file mode 100644 index 0000000..2f7411f --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_linux.go @@ -0,0 +1,13 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build arm64 +// +build linux +// +build !android + +package cpu + +func osInit() { + hwcapInit("linux") +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_other.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_other.go new file mode 100644 index 0000000..f191db2 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_arm64_other.go @@ -0,0 +1,17 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build arm64 +// +build !linux +// +build !freebsd +// +build !android +// +build !darwin ios + +package cpu + +func osInit() { + // Other operating systems do not support reading HWCap from auxiliary vector, + // reading privileged aarch64 system registers or sysctl in user space to detect + // CPU features at runtime. +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_mips.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_mips.go new file mode 100644 index 0000000..14a9c97 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_mips.go @@ -0,0 +1,10 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLinePadSize = 32 + +func doinit() { +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_mips64x.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_mips64x.go new file mode 100644 index 0000000..0c4794a --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_mips64x.go @@ -0,0 +1,32 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build mips64 mips64le + +package cpu + +const CacheLinePadSize = 32 + +// This is initialized by archauxv and should not be changed after it is +// initialized. +var HWCap uint + +// HWCAP bits. These are exposed by the Linux kernel 5.4. +const ( + // CPU features + hwcap_MIPS_MSA = 1 << 1 +) + +func doinit() { + options = []option{ + {Name: "msa", Feature: &MIPS64X.HasMSA}, + } + + // HWCAP feature bits + MIPS64X.HasMSA = isSet(HWCap, hwcap_MIPS_MSA) +} + +func isSet(hwc uint, value uint) bool { + return hwc&value != 0 +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_mipsle.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_mipsle.go new file mode 100644 index 0000000..14a9c97 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_mipsle.go @@ -0,0 +1,10 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLinePadSize = 32 + +func doinit() { +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_no_name.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_no_name.go new file mode 100644 index 0000000..ce1c37a --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_no_name.go @@ -0,0 +1,19 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !386 +// +build !amd64 + +package cpu + +// Name returns the CPU name given by the vendor +// if it can be read directly from memory or by CPU instructions. +// If the CPU name can not be determined an empty string is returned. +// +// Implementations that use the Operating System (e.g. sysctl or /sys/) +// to gather CPU information for display should be placed in internal/sysinfo. +func Name() string { + // "A CPU has no name". + return "" +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_ppc64x.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_ppc64x.go new file mode 100644 index 0000000..beb1765 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_ppc64x.go @@ -0,0 +1,23 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ppc64 ppc64le + +package cpu + +const CacheLinePadSize = 128 + +func doinit() { + options = []option{ + {Name: "darn", Feature: &PPC64.HasDARN}, + {Name: "scv", Feature: &PPC64.HasSCV}, + {Name: "power9", Feature: &PPC64.IsPOWER9}, + } + + osinit() +} + +func isSet(hwc uint, value uint) bool { + return hwc&value != 0 +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_ppc64x_aix.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_ppc64x_aix.go new file mode 100644 index 0000000..b840b82 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_ppc64x_aix.go @@ -0,0 +1,21 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ppc64 ppc64le + +package cpu + +const ( + // getsystemcfg constants + _SC_IMPL = 2 + _IMPL_POWER9 = 0x20000 +) + +func osinit() { + impl := getsystemcfg(_SC_IMPL) + PPC64.IsPOWER9 = isSet(impl, _IMPL_POWER9) +} + +// getsystemcfg is defined in runtime/os2_aix.go +func getsystemcfg(label uint) uint diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_ppc64x_linux.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_ppc64x_linux.go new file mode 100644 index 0000000..73b1914 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_ppc64x_linux.go @@ -0,0 +1,29 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ppc64 ppc64le + +package cpu + +// ppc64 doesn't have a 'cpuid' equivalent, so we rely on HWCAP/HWCAP2. +// These are initialized by archauxv and should not be changed after they are +// initialized. +var HWCap uint +var HWCap2 uint + +// HWCAP bits. These are exposed by Linux. +const ( + // ISA Level + hwcap2_ARCH_3_00 = 0x00800000 + + // CPU features + hwcap2_DARN = 0x00200000 + hwcap2_SCV = 0x00100000 +) + +func osinit() { + PPC64.IsPOWER9 = isSet(HWCap2, hwcap2_ARCH_3_00) + PPC64.HasDARN = isSet(HWCap2, hwcap2_DARN) + PPC64.HasSCV = isSet(HWCap2, hwcap2_SCV) +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_riscv64.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_riscv64.go new file mode 100644 index 0000000..54b8c33 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_riscv64.go @@ -0,0 +1,10 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLinePadSize = 32 + +func doinit() { +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_s390x.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_s390x.go new file mode 100644 index 0000000..45d8ed2 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_s390x.go @@ -0,0 +1,205 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLinePadSize = 256 + +var HWCap uint + +// bitIsSet reports whether the bit at index is set. The bit index +// is in big endian order, so bit index 0 is the leftmost bit. +func bitIsSet(bits []uint64, index uint) bool { + return bits[index/64]&((1<<63)>>(index%64)) != 0 +} + +// function is the function code for the named function. +type function uint8 + +const ( + // KM{,A,C,CTR} function codes + aes128 function = 18 // AES-128 + aes192 function = 19 // AES-192 + aes256 function = 20 // AES-256 + + // K{I,L}MD function codes + sha1 function = 1 // SHA-1 + sha256 function = 2 // SHA-256 + sha512 function = 3 // SHA-512 + sha3_224 function = 32 // SHA3-224 + sha3_256 function = 33 // SHA3-256 + sha3_384 function = 34 // SHA3-384 + sha3_512 function = 35 // SHA3-512 + shake128 function = 36 // SHAKE-128 + shake256 function = 37 // SHAKE-256 + + // KLMD function codes + ghash function = 65 // GHASH +) + +const ( + // KDSA function codes + ecdsaVerifyP256 function = 1 // NIST P256 + ecdsaVerifyP384 function = 2 // NIST P384 + ecdsaVerifyP521 function = 3 // NIST P521 + ecdsaSignP256 function = 9 // NIST P256 + ecdsaSignP384 function = 10 // NIST P384 + ecdsaSignP521 function = 11 // NIST P521 + eddsaVerifyEd25519 function = 32 // Curve25519 + eddsaVerifyEd448 function = 36 // Curve448 + eddsaSignEd25519 function = 40 // Curve25519 + eddsaSignEd448 function = 44 // Curve448 +) + +// queryResult contains the result of a Query function +// call. Bits are numbered in big endian order so the +// leftmost bit (the MSB) is at index 0. +type queryResult struct { + bits [2]uint64 +} + +// Has reports whether the given functions are present. +func (q *queryResult) Has(fns ...function) bool { + if len(fns) == 0 { + panic("no function codes provided") + } + for _, f := range fns { + if !bitIsSet(q.bits[:], uint(f)) { + return false + } + } + return true +} + +// facility is a bit index for the named facility. +type facility uint8 + +const ( + // mandatory facilities + zarch facility = 1 // z architecture mode is active + stflef facility = 7 // store-facility-list-extended + ldisp facility = 18 // long-displacement + eimm facility = 21 // extended-immediate + + // miscellaneous facilities + dfp facility = 42 // decimal-floating-point + etf3eh facility = 30 // extended-translation 3 enhancement + + // cryptography facilities + msa facility = 17 // message-security-assist + msa3 facility = 76 // message-security-assist extension 3 + msa4 facility = 77 // message-security-assist extension 4 + msa5 facility = 57 // message-security-assist extension 5 + msa8 facility = 146 // message-security-assist extension 8 + msa9 facility = 155 // message-security-assist extension 9 + + // vector facilities + vxe facility = 135 // vector-enhancements 1 + + // Note: vx requires kernel support + // and so must be fetched from HWCAP. + + hwcap_VX = 1 << 11 // vector facility +) + +// facilityList contains the result of an STFLE call. +// Bits are numbered in big endian order so the +// leftmost bit (the MSB) is at index 0. +type facilityList struct { + bits [4]uint64 +} + +// Has reports whether the given facilities are present. +func (s *facilityList) Has(fs ...facility) bool { + if len(fs) == 0 { + panic("no facility bits provided") + } + for _, f := range fs { + if !bitIsSet(s.bits[:], uint(f)) { + return false + } + } + return true +} + +// The following feature detection functions are defined in cpu_s390x.s. +// They are likely to be expensive to call so the results should be cached. +func stfle() facilityList +func kmQuery() queryResult +func kmcQuery() queryResult +func kmctrQuery() queryResult +func kmaQuery() queryResult +func kimdQuery() queryResult +func klmdQuery() queryResult +func kdsaQuery() queryResult + +func doinit() { + options = []option{ + {Name: "zarch", Feature: &S390X.HasZARCH}, + {Name: "stfle", Feature: &S390X.HasSTFLE}, + {Name: "ldisp", Feature: &S390X.HasLDISP}, + {Name: "msa", Feature: &S390X.HasMSA}, + {Name: "eimm", Feature: &S390X.HasEIMM}, + {Name: "dfp", Feature: &S390X.HasDFP}, + {Name: "etf3eh", Feature: &S390X.HasETF3EH}, + {Name: "vx", Feature: &S390X.HasVX}, + {Name: "vxe", Feature: &S390X.HasVXE}, + {Name: "kdsa", Feature: &S390X.HasKDSA}, + } + + aes := []function{aes128, aes192, aes256} + facilities := stfle() + + S390X.HasZARCH = facilities.Has(zarch) + S390X.HasSTFLE = facilities.Has(stflef) + S390X.HasLDISP = facilities.Has(ldisp) + S390X.HasEIMM = facilities.Has(eimm) + S390X.HasDFP = facilities.Has(dfp) + S390X.HasETF3EH = facilities.Has(etf3eh) + S390X.HasMSA = facilities.Has(msa) + + if S390X.HasMSA { + // cipher message + km, kmc := kmQuery(), kmcQuery() + S390X.HasAES = km.Has(aes...) + S390X.HasAESCBC = kmc.Has(aes...) + if facilities.Has(msa4) { + kmctr := kmctrQuery() + S390X.HasAESCTR = kmctr.Has(aes...) + } + if facilities.Has(msa8) { + kma := kmaQuery() + S390X.HasAESGCM = kma.Has(aes...) + } + + // compute message digest + kimd := kimdQuery() // intermediate (no padding) + klmd := klmdQuery() // last (padding) + S390X.HasSHA1 = kimd.Has(sha1) && klmd.Has(sha1) + S390X.HasSHA256 = kimd.Has(sha256) && klmd.Has(sha256) + S390X.HasSHA512 = kimd.Has(sha512) && klmd.Has(sha512) + S390X.HasGHASH = kimd.Has(ghash) // KLMD-GHASH does not exist + sha3 := []function{ + sha3_224, sha3_256, sha3_384, sha3_512, + shake128, shake256, + } + S390X.HasSHA3 = kimd.Has(sha3...) && klmd.Has(sha3...) + S390X.HasKDSA = facilities.Has(msa9) // elliptic curves + if S390X.HasKDSA { + kdsa := kdsaQuery() + S390X.HasECDSA = kdsa.Has(ecdsaVerifyP256, ecdsaSignP256, ecdsaVerifyP384, ecdsaSignP384, ecdsaVerifyP521, ecdsaSignP521) + S390X.HasEDDSA = kdsa.Has(eddsaVerifyEd25519, eddsaSignEd25519, eddsaVerifyEd448, eddsaSignEd448) + } + } + + S390X.HasVX = isSet(HWCap, hwcap_VX) + + if S390X.HasVX { + S390X.HasVXE = facilities.Has(vxe) + } +} + +func isSet(hwc uint, value uint) bool { + return hwc&value != 0 +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_s390x.s b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_s390x.s new file mode 100644 index 0000000..a1243aa --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_s390x.s @@ -0,0 +1,63 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "textflag.h" + +// func stfle() facilityList +TEXT ·stfle(SB), NOSPLIT|NOFRAME, $0-32 + MOVD $ret+0(FP), R1 + MOVD $3, R0 // last doubleword index to store + XC $32, (R1), (R1) // clear 4 doublewords (32 bytes) + WORD $0xb2b01000 // store facility list extended (STFLE) + RET + +// func kmQuery() queryResult +TEXT ·kmQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KM-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xB92E0024 // cipher message (KM) + RET + +// func kmcQuery() queryResult +TEXT ·kmcQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KMC-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xB92F0024 // cipher message with chaining (KMC) + RET + +// func kmctrQuery() queryResult +TEXT ·kmctrQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KMCTR-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xB92D4024 // cipher message with counter (KMCTR) + RET + +// func kmaQuery() queryResult +TEXT ·kmaQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KMA-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xb9296024 // cipher message with authentication (KMA) + RET + +// func kimdQuery() queryResult +TEXT ·kimdQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KIMD-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xB93E0024 // compute intermediate message digest (KIMD) + RET + +// func klmdQuery() queryResult +TEXT ·klmdQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KLMD-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xB93F0024 // compute last message digest (KLMD) + RET + +// func kdsaQuery() queryResult +TEXT ·kdsaQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KLMD-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xB93A0008 // compute digital signature authentication + RET + diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_s390x_test.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_s390x_test.go new file mode 100644 index 0000000..ad86858 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_s390x_test.go @@ -0,0 +1,63 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu_test + +import ( + "errors" + . "internal/cpu" + "os" + "regexp" + "testing" +) + +func getFeatureList() ([]string, error) { + cpuinfo, err := os.ReadFile("/proc/cpuinfo") + if err != nil { + return nil, err + } + r := regexp.MustCompile("features\\s*:\\s*(.*)") + b := r.FindSubmatch(cpuinfo) + if len(b) < 2 { + return nil, errors.New("no feature list in /proc/cpuinfo") + } + return regexp.MustCompile("\\s+").Split(string(b[1]), -1), nil +} + +func TestS390XAgainstCPUInfo(t *testing.T) { + // mapping of linux feature strings to S390X fields + mapping := make(map[string]*bool) + for _, option := range Options { + mapping[option.Name] = option.Feature + } + + // these must be true on the machines Go supports + mandatory := make(map[string]bool) + mandatory["zarch"] = false + mandatory["eimm"] = false + mandatory["ldisp"] = false + mandatory["stfle"] = false + + features, err := getFeatureList() + if err != nil { + t.Error(err) + } + for _, feature := range features { + if _, ok := mandatory[feature]; ok { + mandatory[feature] = true + } + if flag, ok := mapping[feature]; ok { + if !*flag { + t.Errorf("feature '%v' not detected", feature) + } + } else { + t.Logf("no entry for '%v'", feature) + } + } + for k, v := range mandatory { + if !v { + t.Errorf("mandatory feature '%v' not detected", k) + } + } +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_test.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_test.go new file mode 100644 index 0000000..2de7365 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_test.go @@ -0,0 +1,83 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu_test + +import ( + . "internal/cpu" + "internal/testenv" + "os" + "os/exec" + "runtime" + "strings" + "testing" +) + +func TestMinimalFeatures(t *testing.T) { + // TODO: maybe do MustSupportFeatureDectection(t) ? + if runtime.GOARCH == "arm64" { + switch runtime.GOOS { + case "linux", "android", "darwin": + default: + t.Skipf("%s/%s is not supported", runtime.GOOS, runtime.GOARCH) + } + } + + for _, o := range Options { + if o.Required && !*o.Feature { + t.Errorf("%v expected true, got false", o.Name) + } + } +} + +func MustHaveDebugOptionsSupport(t *testing.T) { + if !DebugOptions { + t.Skipf("skipping test: cpu feature options not supported by OS") + } +} + +func MustSupportFeatureDectection(t *testing.T) { + // TODO: add platforms that do not have CPU feature detection support. +} + +func runDebugOptionsTest(t *testing.T, test string, options string) { + MustHaveDebugOptionsSupport(t) + + testenv.MustHaveExec(t) + + env := "GODEBUG=" + options + + cmd := exec.Command(os.Args[0], "-test.run="+test) + cmd.Env = append(cmd.Env, env) + + output, err := cmd.CombinedOutput() + lines := strings.Fields(string(output)) + lastline := lines[len(lines)-1] + + got := strings.TrimSpace(lastline) + want := "PASS" + if err != nil || got != want { + t.Fatalf("%s with %s: want %s, got %v", test, env, want, got) + } +} + +func TestDisableAllCapabilities(t *testing.T) { + MustSupportFeatureDectection(t) + runDebugOptionsTest(t, "TestAllCapabilitiesDisabled", "cpu.all=off") +} + +func TestAllCapabilitiesDisabled(t *testing.T) { + MustHaveDebugOptionsSupport(t) + + if os.Getenv("GODEBUG") != "cpu.all=off" { + t.Skipf("skipping test: GODEBUG=cpu.all=off not set") + } + + for _, o := range Options { + want := o.Required + if got := *o.Feature; got != want { + t.Errorf("%v: expected %v, got %v", o.Name, want, got) + } + } +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_wasm.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_wasm.go new file mode 100644 index 0000000..2310ad6 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_wasm.go @@ -0,0 +1,10 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLinePadSize = 64 + +func doinit() { +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_x86.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_x86.go new file mode 100644 index 0000000..ba6bf69 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_x86.go @@ -0,0 +1,163 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 + +package cpu + +const CacheLinePadSize = 64 + +// cpuid is implemented in cpu_x86.s. +func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32) + +// xgetbv with ecx = 0 is implemented in cpu_x86.s. +func xgetbv() (eax, edx uint32) + +const ( + // edx bits + cpuid_SSE2 = 1 << 26 + + // ecx bits + cpuid_SSE3 = 1 << 0 + cpuid_PCLMULQDQ = 1 << 1 + cpuid_SSSE3 = 1 << 9 + cpuid_FMA = 1 << 12 + cpuid_SSE41 = 1 << 19 + cpuid_SSE42 = 1 << 20 + cpuid_POPCNT = 1 << 23 + cpuid_AES = 1 << 25 + cpuid_OSXSAVE = 1 << 27 + cpuid_AVX = 1 << 28 + + // ebx bits + cpuid_BMI1 = 1 << 3 + cpuid_AVX2 = 1 << 5 + cpuid_BMI2 = 1 << 8 + cpuid_ERMS = 1 << 9 + cpuid_ADX = 1 << 19 +) + +var maxExtendedFunctionInformation uint32 + +func doinit() { + options = []option{ + {Name: "adx", Feature: &X86.HasADX}, + {Name: "aes", Feature: &X86.HasAES}, + {Name: "avx", Feature: &X86.HasAVX}, + {Name: "avx2", Feature: &X86.HasAVX2}, + {Name: "bmi1", Feature: &X86.HasBMI1}, + {Name: "bmi2", Feature: &X86.HasBMI2}, + {Name: "erms", Feature: &X86.HasERMS}, + {Name: "fma", Feature: &X86.HasFMA}, + {Name: "pclmulqdq", Feature: &X86.HasPCLMULQDQ}, + {Name: "popcnt", Feature: &X86.HasPOPCNT}, + {Name: "sse3", Feature: &X86.HasSSE3}, + {Name: "sse41", Feature: &X86.HasSSE41}, + {Name: "sse42", Feature: &X86.HasSSE42}, + {Name: "ssse3", Feature: &X86.HasSSSE3}, + + // These capabilities should always be enabled on amd64: + {Name: "sse2", Feature: &X86.HasSSE2, Required: GOARCH == "amd64"}, + } + + maxID, _, _, _ := cpuid(0, 0) + + if maxID < 1 { + return + } + + maxExtendedFunctionInformation, _, _, _ = cpuid(0x80000000, 0) + + _, _, ecx1, edx1 := cpuid(1, 0) + X86.HasSSE2 = isSet(edx1, cpuid_SSE2) + + X86.HasSSE3 = isSet(ecx1, cpuid_SSE3) + X86.HasPCLMULQDQ = isSet(ecx1, cpuid_PCLMULQDQ) + X86.HasSSSE3 = isSet(ecx1, cpuid_SSSE3) + X86.HasSSE41 = isSet(ecx1, cpuid_SSE41) + X86.HasSSE42 = isSet(ecx1, cpuid_SSE42) + X86.HasPOPCNT = isSet(ecx1, cpuid_POPCNT) + X86.HasAES = isSet(ecx1, cpuid_AES) + + // OSXSAVE can be false when using older Operating Systems + // or when explicitly disabled on newer Operating Systems by + // e.g. setting the xsavedisable boot option on Windows 10. + X86.HasOSXSAVE = isSet(ecx1, cpuid_OSXSAVE) + + // The FMA instruction set extension only has VEX prefixed instructions. + // VEX prefixed instructions require OSXSAVE to be enabled. + // See Intel 64 and IA-32 Architecture Software Developer’s Manual Volume 2 + // Section 2.4 "AVX and SSE Instruction Exception Specification" + X86.HasFMA = isSet(ecx1, cpuid_FMA) && X86.HasOSXSAVE + + osSupportsAVX := false + // For XGETBV, OSXSAVE bit is required and sufficient. + if X86.HasOSXSAVE { + eax, _ := xgetbv() + // Check if XMM and YMM registers have OS support. + osSupportsAVX = isSet(eax, 1<<1) && isSet(eax, 1<<2) + } + + X86.HasAVX = isSet(ecx1, cpuid_AVX) && osSupportsAVX + + if maxID < 7 { + return + } + + _, ebx7, _, _ := cpuid(7, 0) + X86.HasBMI1 = isSet(ebx7, cpuid_BMI1) + X86.HasAVX2 = isSet(ebx7, cpuid_AVX2) && osSupportsAVX + X86.HasBMI2 = isSet(ebx7, cpuid_BMI2) + X86.HasERMS = isSet(ebx7, cpuid_ERMS) + X86.HasADX = isSet(ebx7, cpuid_ADX) +} + +func isSet(hwc uint32, value uint32) bool { + return hwc&value != 0 +} + +// Name returns the CPU name given by the vendor. +// If the CPU name can not be determined an +// empty string is returned. +func Name() string { + if maxExtendedFunctionInformation < 0x80000004 { + return "" + } + + data := make([]byte, 0, 3*4*4) + + var eax, ebx, ecx, edx uint32 + eax, ebx, ecx, edx = cpuid(0x80000002, 0) + data = appendBytes(data, eax, ebx, ecx, edx) + eax, ebx, ecx, edx = cpuid(0x80000003, 0) + data = appendBytes(data, eax, ebx, ecx, edx) + eax, ebx, ecx, edx = cpuid(0x80000004, 0) + data = appendBytes(data, eax, ebx, ecx, edx) + + // Trim leading spaces. + for len(data) > 0 && data[0] == ' ' { + data = data[1:] + } + + // Trim tail after and including the first null byte. + for i, c := range data { + if c == '\x00' { + data = data[:i] + break + } + } + + return string(data) +} + +func appendBytes(b []byte, args ...uint32) []byte { + for _, arg := range args { + b = append(b, + byte((arg >> 0)), + byte((arg >> 8)), + byte((arg >> 16)), + byte((arg >> 24))) + } + return b +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_x86.s b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_x86.s new file mode 100644 index 0000000..93c712d --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_x86.s @@ -0,0 +1,26 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 + +#include "textflag.h" + +// func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32) +TEXT ·cpuid(SB), NOSPLIT, $0-24 + MOVL eaxArg+0(FP), AX + MOVL ecxArg+4(FP), CX + CPUID + MOVL AX, eax+8(FP) + MOVL BX, ebx+12(FP) + MOVL CX, ecx+16(FP) + MOVL DX, edx+20(FP) + RET + +// func xgetbv() (eax, edx uint32) +TEXT ·xgetbv(SB),NOSPLIT,$0-8 + MOVL $0, CX + XGETBV + MOVL AX, eax+0(FP) + MOVL DX, edx+4(FP) + RET diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_x86_test.go b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_x86_test.go new file mode 100644 index 0000000..61db93b --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/cpu_x86_test.go @@ -0,0 +1,54 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 + +package cpu_test + +import ( + . "internal/cpu" + "os" + "runtime" + "testing" +) + +func TestX86ifAVX2hasAVX(t *testing.T) { + if X86.HasAVX2 && !X86.HasAVX { + t.Fatalf("HasAVX expected true when HasAVX2 is true, got false") + } +} + +func TestDisableSSE2(t *testing.T) { + runDebugOptionsTest(t, "TestSSE2DebugOption", "cpu.sse2=off") +} + +func TestSSE2DebugOption(t *testing.T) { + MustHaveDebugOptionsSupport(t) + + if os.Getenv("GODEBUG") != "cpu.sse2=off" { + t.Skipf("skipping test: GODEBUG=cpu.sse2=off not set") + } + + want := runtime.GOARCH != "386" // SSE2 can only be disabled on 386. + if got := X86.HasSSE2; got != want { + t.Errorf("X86.HasSSE2 on %s expected %v, got %v", runtime.GOARCH, want, got) + } +} + +func TestDisableSSE3(t *testing.T) { + runDebugOptionsTest(t, "TestSSE3DebugOption", "cpu.sse3=off") +} + +func TestSSE3DebugOption(t *testing.T) { + MustHaveDebugOptionsSupport(t) + + if os.Getenv("GODEBUG") != "cpu.sse3=off" { + t.Skipf("skipping test: GODEBUG=cpu.sse3=off not set") + } + + want := false + if got := X86.HasSSE3; got != want { + t.Errorf("X86.HasSSE3 expected %v, got %v", want, got) + } +} diff --git a/vendor/github.com/lesismal/llib/std/internal/cpu/export_test.go b/vendor/github.com/lesismal/llib/std/internal/cpu/export_test.go new file mode 100644 index 0000000..91bfc1b --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/cpu/export_test.go @@ -0,0 +1,9 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +var ( + Options = options +) diff --git a/vendor/github.com/lesismal/llib/std/internal/nettrace/nettrace.go b/vendor/github.com/lesismal/llib/std/internal/nettrace/nettrace.go new file mode 100755 index 0000000..de3254d --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/nettrace/nettrace.go @@ -0,0 +1,45 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package nettrace contains internal hooks for tracing activity in +// the net package. This package is purely internal for use by the +// net/http/httptrace package and has no stable API exposed to end +// users. +package nettrace + +// TraceKey is a context.Context Value key. Its associated value should +// be a *Trace struct. +type TraceKey struct{} + +// LookupIPAltResolverKey is a context.Context Value key used by tests to +// specify an alternate resolver func. +// It is not exposed to outsider users. (But see issue 12503) +// The value should be the same type as lookupIP: +// func lookupIP(ctx context.Context, host string) ([]IPAddr, error) +type LookupIPAltResolverKey struct{} + +// Trace contains a set of hooks for tracing events within +// the net package. Any specific hook may be nil. +type Trace struct { + // DNSStart is called with the hostname of a DNS lookup + // before it begins. + DNSStart func(name string) + + // DNSDone is called after a DNS lookup completes (or fails). + // The coalesced parameter is whether singleflight de-dupped + // the call. The addrs are of type net.IPAddr but can't + // actually be for circular dependency reasons. + DNSDone func(netIPs []interface{}, coalesced bool, err error) + + // ConnectStart is called before a Dial, excluding Dials made + // during DNS lookups. In the case of DualStack (Happy Eyeballs) + // dialing, this may be called multiple times, from multiple + // goroutines. + ConnectStart func(network, addr string) + + // ConnectStart is called after a Dial with the results, excluding + // Dials made during DNS lookups. It may also be called multiple + // times, like ConnectStart. + ConnectDone func(network, addr string, err error) +} diff --git a/vendor/github.com/lesismal/llib/std/internal/testenv/testenv.go b/vendor/github.com/lesismal/llib/std/internal/testenv/testenv.go new file mode 100644 index 0000000..c902b14 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/testenv/testenv.go @@ -0,0 +1,308 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package testenv provides information about what functionality +// is available in different testing environments run by the Go team. +// +// It is an internal package because these details are specific +// to the Go team's test setup (on build.golang.org) and not +// fundamental to tests in general. +package testenv + +import ( + "errors" + "flag" + "internal/cfg" + "os" + "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "testing" +) + +// Builder reports the name of the builder running this test +// (for example, "linux-amd64" or "windows-386-gce"). +// If the test is not running on the build infrastructure, +// Builder returns the empty string. +func Builder() string { + return os.Getenv("GO_BUILDER_NAME") +} + +// HasGoBuild reports whether the current system can build programs with ``go build'' +// and then run them with os.StartProcess or exec.Command. +func HasGoBuild() bool { + if os.Getenv("GO_GCFLAGS") != "" { + // It's too much work to require every caller of the go command + // to pass along "-gcflags="+os.Getenv("GO_GCFLAGS"). + // For now, if $GO_GCFLAGS is set, report that we simply can't + // run go build. + return false + } + switch runtime.GOOS { + case "android", "js", "ios": + return false + } + return true +} + +// MustHaveGoBuild checks that the current system can build programs with ``go build'' +// and then run them with os.StartProcess or exec.Command. +// If not, MustHaveGoBuild calls t.Skip with an explanation. +func MustHaveGoBuild(t testing.TB) { + if os.Getenv("GO_GCFLAGS") != "" { + t.Skipf("skipping test: 'go build' not compatible with setting $GO_GCFLAGS") + } + if !HasGoBuild() { + t.Skipf("skipping test: 'go build' not available on %s/%s", runtime.GOOS, runtime.GOARCH) + } +} + +// HasGoRun reports whether the current system can run programs with ``go run.'' +func HasGoRun() bool { + // For now, having go run and having go build are the same. + return HasGoBuild() +} + +// MustHaveGoRun checks that the current system can run programs with ``go run.'' +// If not, MustHaveGoRun calls t.Skip with an explanation. +func MustHaveGoRun(t testing.TB) { + if !HasGoRun() { + t.Skipf("skipping test: 'go run' not available on %s/%s", runtime.GOOS, runtime.GOARCH) + } +} + +// GoToolPath reports the path to the Go tool. +// It is a convenience wrapper around GoTool. +// If the tool is unavailable GoToolPath calls t.Skip. +// If the tool should be available and isn't, GoToolPath calls t.Fatal. +func GoToolPath(t testing.TB) string { + MustHaveGoBuild(t) + path, err := GoTool() + if err != nil { + t.Fatal(err) + } + // Add all environment variables that affect the Go command to test metadata. + // Cached test results will be invalidate when these variables change. + // See golang.org/issue/32285. + for _, envVar := range strings.Fields(cfg.KnownEnv) { + os.Getenv(envVar) + } + return path +} + +// GoTool reports the path to the Go tool. +func GoTool() (string, error) { + if !HasGoBuild() { + return "", errors.New("platform cannot run go tool") + } + var exeSuffix string + if runtime.GOOS == "windows" { + exeSuffix = ".exe" + } + path := filepath.Join(runtime.GOROOT(), "bin", "go"+exeSuffix) + if _, err := os.Stat(path); err == nil { + return path, nil + } + goBin, err := exec.LookPath("go" + exeSuffix) + if err != nil { + return "", errors.New("cannot find go tool: " + err.Error()) + } + return goBin, nil +} + +// HasExec reports whether the current system can start new processes +// using os.StartProcess or (more commonly) exec.Command. +func HasExec() bool { + switch runtime.GOOS { + case "js", "ios": + return false + } + return true +} + +// HasSrc reports whether the entire source tree is available under GOROOT. +func HasSrc() bool { + switch runtime.GOOS { + case "ios": + return false + } + return true +} + +// MustHaveExec checks that the current system can start new processes +// using os.StartProcess or (more commonly) exec.Command. +// If not, MustHaveExec calls t.Skip with an explanation. +func MustHaveExec(t testing.TB) { + if !HasExec() { + t.Skipf("skipping test: cannot exec subprocess on %s/%s", runtime.GOOS, runtime.GOARCH) + } +} + +var execPaths sync.Map // path -> error + +// MustHaveExecPath checks that the current system can start the named executable +// using os.StartProcess or (more commonly) exec.Command. +// If not, MustHaveExecPath calls t.Skip with an explanation. +func MustHaveExecPath(t testing.TB, path string) { + MustHaveExec(t) + + err, found := execPaths.Load(path) + if !found { + _, err = exec.LookPath(path) + err, _ = execPaths.LoadOrStore(path, err) + } + if err != nil { + t.Skipf("skipping test: %s: %s", path, err) + } +} + +// HasExternalNetwork reports whether the current system can use +// external (non-localhost) networks. +func HasExternalNetwork() bool { + return !testing.Short() && runtime.GOOS != "js" +} + +// MustHaveExternalNetwork checks that the current system can use +// external (non-localhost) networks. +// If not, MustHaveExternalNetwork calls t.Skip with an explanation. +func MustHaveExternalNetwork(t testing.TB) { + if runtime.GOOS == "js" { + t.Skipf("skipping test: no external network on %s", runtime.GOOS) + } + if testing.Short() { + t.Skipf("skipping test: no external network in -short mode") + } +} + +var haveCGO bool + +// HasCGO reports whether the current system can use cgo. +func HasCGO() bool { + return haveCGO +} + +// MustHaveCGO calls t.Skip if cgo is not available. +func MustHaveCGO(t testing.TB) { + if !haveCGO { + t.Skipf("skipping test: no cgo") + } +} + +// CanInternalLink reports whether the current system can link programs with +// internal linking. +// (This is the opposite of cmd/internal/sys.MustLinkExternal. Keep them in sync.) +func CanInternalLink() bool { + switch runtime.GOOS { + case "android": + if runtime.GOARCH != "arm64" { + return false + } + case "ios": + if runtime.GOARCH == "arm64" { + return false + } + } + return true +} + +// MustInternalLink checks that the current system can link programs with internal +// linking. +// If not, MustInternalLink calls t.Skip with an explanation. +func MustInternalLink(t testing.TB) { + if !CanInternalLink() { + t.Skipf("skipping test: internal linking on %s/%s is not supported", runtime.GOOS, runtime.GOARCH) + } +} + +// HasSymlink reports whether the current system can use os.Symlink. +func HasSymlink() bool { + ok, _ := hasSymlink() + return ok +} + +// MustHaveSymlink reports whether the current system can use os.Symlink. +// If not, MustHaveSymlink calls t.Skip with an explanation. +func MustHaveSymlink(t testing.TB) { + ok, reason := hasSymlink() + if !ok { + t.Skipf("skipping test: cannot make symlinks on %s/%s%s", runtime.GOOS, runtime.GOARCH, reason) + } +} + +// HasLink reports whether the current system can use os.Link. +func HasLink() bool { + // From Android release M (Marshmallow), hard linking files is blocked + // and an attempt to call link() on a file will return EACCES. + // - https://code.google.com/p/android-developer-preview/issues/detail?id=3150 + return runtime.GOOS != "plan9" && runtime.GOOS != "android" +} + +// MustHaveLink reports whether the current system can use os.Link. +// If not, MustHaveLink calls t.Skip with an explanation. +func MustHaveLink(t testing.TB) { + if !HasLink() { + t.Skipf("skipping test: hardlinks are not supported on %s/%s", runtime.GOOS, runtime.GOARCH) + } +} + +var flaky = flag.Bool("flaky", false, "run known-flaky tests too") + +func SkipFlaky(t testing.TB, issue int) { + t.Helper() + if !*flaky { + t.Skipf("skipping known flaky test without the -flaky flag; see golang.org/issue/%d", issue) + } +} + +func SkipFlakyNet(t testing.TB) { + t.Helper() + if v, _ := strconv.ParseBool(os.Getenv("GO_BUILDER_FLAKY_NET")); v { + t.Skip("skipping test on builder known to have frequent network failures") + } +} + +// CleanCmdEnv will fill cmd.Env with the environment, excluding certain +// variables that could modify the behavior of the Go tools such as +// GODEBUG and GOTRACEBACK. +func CleanCmdEnv(cmd *exec.Cmd) *exec.Cmd { + if cmd.Env != nil { + panic("environment already set") + } + for _, env := range os.Environ() { + // Exclude GODEBUG from the environment to prevent its output + // from breaking tests that are trying to parse other command output. + if strings.HasPrefix(env, "GODEBUG=") { + continue + } + // Exclude GOTRACEBACK for the same reason. + if strings.HasPrefix(env, "GOTRACEBACK=") { + continue + } + cmd.Env = append(cmd.Env, env) + } + return cmd +} + +// CPUIsSlow reports whether the CPU running the test is suspected to be slow. +func CPUIsSlow() bool { + switch runtime.GOARCH { + case "arm", "mips", "mipsle", "mips64", "mips64le": + return true + } + return false +} + +// SkipIfShortAndSlow skips t if -short is set and the CPU running the test is +// suspected to be slow. +// +// (This is useful for CPU-intensive tests that otherwise complete quickly.) +func SkipIfShortAndSlow(t testing.TB) { + if testing.Short() && CPUIsSlow() { + t.Helper() + t.Skipf("skipping test in -short mode on %s", runtime.GOARCH) + } +} diff --git a/vendor/github.com/lesismal/llib/std/internal/testenv/testenv_cgo.go b/vendor/github.com/lesismal/llib/std/internal/testenv/testenv_cgo.go new file mode 100644 index 0000000..e3d4d16 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/testenv/testenv_cgo.go @@ -0,0 +1,11 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build cgo + +package testenv + +func init() { + haveCGO = true +} diff --git a/vendor/github.com/lesismal/llib/std/internal/testenv/testenv_notwin.go b/vendor/github.com/lesismal/llib/std/internal/testenv/testenv_notwin.go new file mode 100644 index 0000000..ccb5d55 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/testenv/testenv_notwin.go @@ -0,0 +1,20 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows + +package testenv + +import ( + "runtime" +) + +func hasSymlink() (ok bool, reason string) { + switch runtime.GOOS { + case "android", "plan9": + return false, "" + } + + return true, "" +} diff --git a/vendor/github.com/lesismal/llib/std/internal/testenv/testenv_windows.go b/vendor/github.com/lesismal/llib/std/internal/testenv/testenv_windows.go new file mode 100644 index 0000000..4802b13 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/internal/testenv/testenv_windows.go @@ -0,0 +1,47 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testenv + +import ( + "os" + "path/filepath" + "sync" + "syscall" +) + +var symlinkOnce sync.Once +var winSymlinkErr error + +func initWinHasSymlink() { + tmpdir, err := os.MkdirTemp("", "symtest") + if err != nil { + panic("failed to create temp directory: " + err.Error()) + } + defer os.RemoveAll(tmpdir) + + err = os.Symlink("target", filepath.Join(tmpdir, "symlink")) + if err != nil { + err = err.(*os.LinkError).Err + switch err { + case syscall.EWINDOWS, syscall.ERROR_PRIVILEGE_NOT_HELD: + winSymlinkErr = err + } + } +} + +func hasSymlink() (ok bool, reason string) { + symlinkOnce.Do(initWinHasSymlink) + + switch winSymlinkErr { + case nil: + return true, "" + case syscall.EWINDOWS: + return false, ": symlinks are not supported on your version of Windows" + case syscall.ERROR_PRIVILEGE_NOT_HELD: + return false, ": you don't have enough privileges to create symlinks" + } + + return false, "" +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/alpn_test.go b/vendor/github.com/lesismal/llib/std/net/http/alpn_test.go new file mode 100644 index 0000000..ae345dc --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/alpn_test.go @@ -0,0 +1,132 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "bufio" + "bytes" + "crypto/tls" + "crypto/x509" + "fmt" + "github.com/lesismal/llib/std/net/http/httptest" + "io" + . "net/http" + "strings" + "testing" +) + +func TestNextProtoUpgrade(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "path=%s,proto=", r.URL.Path) + if r.TLS != nil { + w.Write([]byte(r.TLS.NegotiatedProtocol)) + } + if r.RemoteAddr == "" { + t.Error("request with no RemoteAddr") + } + if r.Body == nil { + t.Errorf("request with nil Body") + } + })) + ts.TLS = &tls.Config{ + NextProtos: []string{"unhandled-proto", "tls-0.9"}, + } + ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){ + "tls-0.9": handleTLSProtocol09, + } + ts.StartTLS() + defer ts.Close() + + // Normal request, without NPN. + { + c := ts.Client() + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if want := "path=/,proto="; string(body) != want { + t.Errorf("plain request = %q; want %q", body, want) + } + } + + // Request to an advertised but unhandled NPN protocol. + // Server will hang up. + { + certPool := x509.NewCertPool() + certPool.AddCert(ts.Certificate()) + tr := &Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certPool, + NextProtos: []string{"unhandled-proto"}, + }, + } + defer tr.CloseIdleConnections() + c := &Client{ + Transport: tr, + } + res, err := c.Get(ts.URL) + if err == nil { + defer res.Body.Close() + var buf bytes.Buffer + res.Write(&buf) + t.Errorf("expected error on unhandled-proto request; got: %s", buf.Bytes()) + } + } + + // Request using the "tls-0.9" protocol, which we register here. + // It is HTTP/0.9 over TLS. + { + c := ts.Client() + tlsConfig := c.Transport.(*Transport).TLSClientConfig + tlsConfig.NextProtos = []string{"tls-0.9"} + conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) + if err != nil { + t.Fatal(err) + } + conn.Write([]byte("GET /foo\n")) + body, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if want := "path=/foo,proto=tls-0.9"; string(body) != want { + t.Errorf("plain request = %q; want %q", body, want) + } + } +} + +// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the +// TestNextProtoUpgrade test. +func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) { + br := bufio.NewReader(conn) + line, err := br.ReadString('\n') + if err != nil { + return + } + line = strings.TrimSpace(line) + path := strings.TrimPrefix(line, "GET ") + if path == line { + return + } + req, _ := NewRequest("GET", path, nil) + req.Proto = "HTTP/0.9" + req.ProtoMajor = 0 + req.ProtoMinor = 9 + rw := &http09Writer{conn, make(Header)} + h.ServeHTTP(rw, req) +} + +type http09Writer struct { + io.Writer + h Header +} + +func (w http09Writer) Header() Header { return w.h } +func (w http09Writer) WriteHeader(int) {} // no headers diff --git a/vendor/github.com/lesismal/llib/std/net/http/cgi/child.go b/vendor/github.com/lesismal/llib/std/net/http/cgi/child.go new file mode 100644 index 0000000..0114da3 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cgi/child.go @@ -0,0 +1,220 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements CGI from the perspective of a child +// process. + +package cgi + +import ( + "bufio" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strconv" + "strings" +) + +// Request returns the HTTP request as represented in the current +// environment. This assumes the current program is being run +// by a web server in a CGI environment. +// The returned Request's Body is populated, if applicable. +func Request() (*http.Request, error) { + r, err := RequestFromMap(envMap(os.Environ())) + if err != nil { + return nil, err + } + if r.ContentLength > 0 { + r.Body = io.NopCloser(io.LimitReader(os.Stdin, r.ContentLength)) + } + return r, nil +} + +func envMap(env []string) map[string]string { + m := make(map[string]string) + for _, kv := range env { + if idx := strings.Index(kv, "="); idx != -1 { + m[kv[:idx]] = kv[idx+1:] + } + } + return m +} + +// RequestFromMap creates an http.Request from CGI variables. +// The returned Request's Body field is not populated. +func RequestFromMap(params map[string]string) (*http.Request, error) { + r := new(http.Request) + r.Method = params["REQUEST_METHOD"] + if r.Method == "" { + return nil, errors.New("cgi: no REQUEST_METHOD in environment") + } + + r.Proto = params["SERVER_PROTOCOL"] + var ok bool + r.ProtoMajor, r.ProtoMinor, ok = http.ParseHTTPVersion(r.Proto) + if !ok { + return nil, errors.New("cgi: invalid SERVER_PROTOCOL version") + } + + r.Close = true + r.Trailer = http.Header{} + r.Header = http.Header{} + + r.Host = params["HTTP_HOST"] + + if lenstr := params["CONTENT_LENGTH"]; lenstr != "" { + clen, err := strconv.ParseInt(lenstr, 10, 64) + if err != nil { + return nil, errors.New("cgi: bad CONTENT_LENGTH in environment: " + lenstr) + } + r.ContentLength = clen + } + + if ct := params["CONTENT_TYPE"]; ct != "" { + r.Header.Set("Content-Type", ct) + } + + // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers + for k, v := range params { + if !strings.HasPrefix(k, "HTTP_") || k == "HTTP_HOST" { + continue + } + r.Header.Add(strings.ReplaceAll(k[5:], "_", "-"), v) + } + + uriStr := params["REQUEST_URI"] + if uriStr == "" { + // Fallback to SCRIPT_NAME, PATH_INFO and QUERY_STRING. + uriStr = params["SCRIPT_NAME"] + params["PATH_INFO"] + s := params["QUERY_STRING"] + if s != "" { + uriStr += "?" + s + } + } + + // There's apparently a de-facto standard for this. + // https://web.archive.org/web/20170105004655/http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 + if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" { + r.TLS = &tls.ConnectionState{HandshakeComplete: true} + } + + if r.Host != "" { + // Hostname is provided, so we can reasonably construct a URL. + rawurl := r.Host + uriStr + if r.TLS == nil { + rawurl = "http://" + rawurl + } else { + rawurl = "https://" + rawurl + } + url, err := url.Parse(rawurl) + if err != nil { + return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl) + } + r.URL = url + } + // Fallback logic if we don't have a Host header or the URL + // failed to parse + if r.URL == nil { + url, err := url.Parse(uriStr) + if err != nil { + return nil, errors.New("cgi: failed to parse REQUEST_URI into a URL: " + uriStr) + } + r.URL = url + } + + // Request.RemoteAddr has its port set by Go's standard http + // server, so we do here too. + remotePort, _ := strconv.Atoi(params["REMOTE_PORT"]) // zero if unset or invalid + r.RemoteAddr = net.JoinHostPort(params["REMOTE_ADDR"], strconv.Itoa(remotePort)) + + return r, nil +} + +// Serve executes the provided Handler on the currently active CGI +// request, if any. If there's no current CGI environment +// an error is returned. The provided handler may be nil to use +// http.DefaultServeMux. +func Serve(handler http.Handler) error { + req, err := Request() + if err != nil { + return err + } + if req.Body == nil { + req.Body = http.NoBody + } + if handler == nil { + handler = http.DefaultServeMux + } + rw := &response{ + req: req, + header: make(http.Header), + bufw: bufio.NewWriter(os.Stdout), + } + handler.ServeHTTP(rw, req) + rw.Write(nil) // make sure a response is sent + if err = rw.bufw.Flush(); err != nil { + return err + } + return nil +} + +type response struct { + req *http.Request + header http.Header + code int + wroteHeader bool + wroteCGIHeader bool + bufw *bufio.Writer +} + +func (r *response) Flush() { + r.bufw.Flush() +} + +func (r *response) Header() http.Header { + return r.header +} + +func (r *response) Write(p []byte) (n int, err error) { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + if !r.wroteCGIHeader { + r.writeCGIHeader(p) + } + return r.bufw.Write(p) +} + +func (r *response) WriteHeader(code int) { + if r.wroteHeader { + // Note: explicitly using Stderr, as Stdout is our HTTP output. + fmt.Fprintf(os.Stderr, "CGI attempted to write header twice on request for %s", r.req.URL) + return + } + r.wroteHeader = true + r.code = code +} + +// writeCGIHeader finalizes the header sent to the client and writes it to the output. +// p is not written by writeHeader, but is the first chunk of the body +// that will be written. It is sniffed for a Content-Type if none is +// set explicitly. +func (r *response) writeCGIHeader(p []byte) { + if r.wroteCGIHeader { + return + } + r.wroteCGIHeader = true + fmt.Fprintf(r.bufw, "Status: %d %s\r\n", r.code, http.StatusText(r.code)) + if _, hasType := r.header["Content-Type"]; !hasType { + r.header.Set("Content-Type", http.DetectContentType(p)) + } + r.header.Write(r.bufw) + r.bufw.WriteString("\r\n") + r.bufw.Flush() +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cgi/child_test.go b/vendor/github.com/lesismal/llib/std/net/http/cgi/child_test.go new file mode 100644 index 0000000..f476973 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cgi/child_test.go @@ -0,0 +1,208 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for CGI (the child process perspective) + +package cgi + +import ( + "bufio" + "bytes" + "github.com/lesismal/llib/std/net/http/httptest" + "net/http" + "strings" + "testing" +) + +func TestRequest(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "REQUEST_METHOD": "GET", + "HTTP_HOST": "example.com", + "HTTP_REFERER": "elsewhere", + "HTTP_USER_AGENT": "goclient", + "HTTP_FOO_BAR": "baz", + "REQUEST_URI": "/path?a=b", + "CONTENT_LENGTH": "123", + "CONTENT_TYPE": "text/xml", + "REMOTE_ADDR": "5.6.7.8", + "REMOTE_PORT": "54321", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if g, e := req.UserAgent(), "goclient"; e != g { + t.Errorf("expected UserAgent %q; got %q", e, g) + } + if g, e := req.Method, "GET"; e != g { + t.Errorf("expected Method %q; got %q", e, g) + } + if g, e := req.Header.Get("Content-Type"), "text/xml"; e != g { + t.Errorf("expected Content-Type %q; got %q", e, g) + } + if g, e := req.ContentLength, int64(123); e != g { + t.Errorf("expected ContentLength %d; got %d", e, g) + } + if g, e := req.Referer(), "elsewhere"; e != g { + t.Errorf("expected Referer %q; got %q", e, g) + } + if req.Header == nil { + t.Fatalf("unexpected nil Header") + } + if g, e := req.Header.Get("Foo-Bar"), "baz"; e != g { + t.Errorf("expected Foo-Bar %q; got %q", e, g) + } + if g, e := req.URL.String(), "http://example.com/path?a=b"; e != g { + t.Errorf("expected URL %q; got %q", e, g) + } + if g, e := req.FormValue("a"), "b"; e != g { + t.Errorf("expected FormValue(a) %q; got %q", e, g) + } + if req.Trailer == nil { + t.Errorf("unexpected nil Trailer") + } + if req.TLS != nil { + t.Errorf("expected nil TLS") + } + if e, g := "5.6.7.8:54321", req.RemoteAddr; e != g { + t.Errorf("RemoteAddr: got %q; want %q", g, e) + } +} + +func TestRequestWithTLS(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "REQUEST_METHOD": "GET", + "HTTP_HOST": "example.com", + "HTTP_REFERER": "elsewhere", + "REQUEST_URI": "/path?a=b", + "CONTENT_TYPE": "text/xml", + "HTTPS": "1", + "REMOTE_ADDR": "5.6.7.8", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if g, e := req.URL.String(), "https://example.com/path?a=b"; e != g { + t.Errorf("expected URL %q; got %q", e, g) + } + if req.TLS == nil { + t.Errorf("expected non-nil TLS") + } +} + +func TestRequestWithoutHost(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_HOST": "", + "REQUEST_METHOD": "GET", + "REQUEST_URI": "/path?a=b", + "CONTENT_LENGTH": "123", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if req.URL == nil { + t.Fatalf("unexpected nil URL") + } + if g, e := req.URL.String(), "/path?a=b"; e != g { + t.Errorf("URL = %q; want %q", g, e) + } +} + +func TestRequestWithoutRequestURI(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_HOST": "example.com", + "REQUEST_METHOD": "GET", + "SCRIPT_NAME": "/dir/scriptname", + "PATH_INFO": "/p1/p2", + "QUERY_STRING": "a=1&b=2", + "CONTENT_LENGTH": "123", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if req.URL == nil { + t.Fatalf("unexpected nil URL") + } + if g, e := req.URL.String(), "http://example.com/dir/scriptname/p1/p2?a=1&b=2"; e != g { + t.Errorf("URL = %q; want %q", g, e) + } +} + +func TestRequestWithoutRemotePort(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_HOST": "example.com", + "REQUEST_METHOD": "GET", + "REQUEST_URI": "/path?a=b", + "CONTENT_LENGTH": "123", + "REMOTE_ADDR": "5.6.7.8", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if e, g := "5.6.7.8:0", req.RemoteAddr; e != g { + t.Errorf("RemoteAddr: got %q; want %q", g, e) + } +} + +func TestResponse(t *testing.T) { + var tests = []struct { + name string + body string + wantCT string + }{ + { + name: "no body", + wantCT: "text/plain; charset=utf-8", + }, + { + name: "html", + body: "test pageThis is a body", + wantCT: "text/html; charset=utf-8", + }, + { + name: "text", + body: strings.Repeat("gopher", 86), + wantCT: "text/plain; charset=utf-8", + }, + { + name: "jpg", + body: "\xFF\xD8\xFF" + strings.Repeat("B", 1024), + wantCT: "image/jpeg", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + resp := response{ + req: httptest.NewRequest("GET", "/", nil), + header: http.Header{}, + bufw: bufio.NewWriter(&buf), + } + n, err := resp.Write([]byte(tt.body)) + if err != nil { + t.Errorf("Write: unexpected %v", err) + } + if want := len(tt.body); n != want { + t.Errorf("reported short Write: got %v want %v", n, want) + } + resp.writeCGIHeader(nil) + resp.Flush() + if got := resp.Header().Get("Content-Type"); got != tt.wantCT { + t.Errorf("wrong content-type: got %q, want %q", got, tt.wantCT) + } + if !bytes.HasSuffix(buf.Bytes(), []byte(tt.body)) { + t.Errorf("body was not correctly written") + } + }) + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cgi/host.go b/vendor/github.com/lesismal/llib/std/net/http/cgi/host.go new file mode 100644 index 0000000..eff67ca --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cgi/host.go @@ -0,0 +1,408 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements the host side of CGI (being the webserver +// parent process). + +// Package cgi implements CGI (Common Gateway Interface) as specified +// in RFC 3875. +// +// Note that using CGI means starting a new process to handle each +// request, which is typically less efficient than using a +// long-running server. This package is intended primarily for +// compatibility with existing systems. +package cgi + +import ( + "bufio" + "fmt" + "io" + "log" + "net" + "net/http" + "net/textproto" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" + + "golang.org/x/net/http/httpguts" +) + +var trailingPort = regexp.MustCompile(`:([0-9]+)$`) + +var osDefaultInheritEnv = func() []string { + switch runtime.GOOS { + case "darwin", "ios": + return []string{"DYLD_LIBRARY_PATH"} + case "linux", "freebsd", "netbsd", "openbsd": + return []string{"LD_LIBRARY_PATH"} + case "hpux": + return []string{"LD_LIBRARY_PATH", "SHLIB_PATH"} + case "irix": + return []string{"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"} + case "illumos", "solaris": + return []string{"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"} + case "windows": + return []string{"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"} + } + return nil +}() + +// Handler runs an executable in a subprocess with a CGI environment. +type Handler struct { + Path string // path to the CGI executable + Root string // root URI prefix of handler or empty for "/" + + // Dir specifies the CGI executable's working directory. + // If Dir is empty, the base directory of Path is used. + // If Path has no base directory, the current working + // directory is used. + Dir string + + Env []string // extra environment variables to set, if any, as "key=value" + InheritEnv []string // environment variables to inherit from host, as "key" + Logger *log.Logger // optional log for errors or nil to use log.Print + Args []string // optional arguments to pass to child process + Stderr io.Writer // optional stderr for the child process; nil means os.Stderr + + // PathLocationHandler specifies the root http Handler that + // should handle internal redirects when the CGI process + // returns a Location header value starting with a "/", as + // specified in RFC 3875 § 6.3.2. This will likely be + // http.DefaultServeMux. + // + // If nil, a CGI response with a local URI path is instead sent + // back to the client and not redirected internally. + PathLocationHandler http.Handler +} + +func (h *Handler) stderr() io.Writer { + if h.Stderr != nil { + return h.Stderr + } + return os.Stderr +} + +// removeLeadingDuplicates remove leading duplicate in environments. +// It's possible to override environment like following. +// cgi.Handler{ +// ... +// Env: []string{"SCRIPT_FILENAME=foo.php"}, +// } +func removeLeadingDuplicates(env []string) (ret []string) { + for i, e := range env { + found := false + if eq := strings.IndexByte(e, '='); eq != -1 { + keq := e[:eq+1] // "key=" + for _, e2 := range env[i+1:] { + if strings.HasPrefix(e2, keq) { + found = true + break + } + } + } + if !found { + ret = append(ret, e) + } + } + return +} + +func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + root := h.Root + if root == "" { + root = "/" + } + + if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" { + rw.WriteHeader(http.StatusBadRequest) + rw.Write([]byte("Chunked request bodies are not supported by CGI.")) + return + } + + pathInfo := req.URL.Path + if root != "/" && strings.HasPrefix(pathInfo, root) { + pathInfo = pathInfo[len(root):] + } + + port := "80" + if matches := trailingPort.FindStringSubmatch(req.Host); len(matches) != 0 { + port = matches[1] + } + + env := []string{ + "SERVER_SOFTWARE=go", + "SERVER_NAME=" + req.Host, + "SERVER_PROTOCOL=HTTP/1.1", + "HTTP_HOST=" + req.Host, + "GATEWAY_INTERFACE=CGI/1.1", + "REQUEST_METHOD=" + req.Method, + "QUERY_STRING=" + req.URL.RawQuery, + "REQUEST_URI=" + req.URL.RequestURI(), + "PATH_INFO=" + pathInfo, + "SCRIPT_NAME=" + root, + "SCRIPT_FILENAME=" + h.Path, + "SERVER_PORT=" + port, + } + + if remoteIP, remotePort, err := net.SplitHostPort(req.RemoteAddr); err == nil { + env = append(env, "REMOTE_ADDR="+remoteIP, "REMOTE_HOST="+remoteIP, "REMOTE_PORT="+remotePort) + } else { + // could not parse ip:port, let's use whole RemoteAddr and leave REMOTE_PORT undefined + env = append(env, "REMOTE_ADDR="+req.RemoteAddr, "REMOTE_HOST="+req.RemoteAddr) + } + + if req.TLS != nil { + env = append(env, "HTTPS=on") + } + + for k, v := range req.Header { + k = strings.Map(upperCaseAndUnderscore, k) + if k == "PROXY" { + // See Issue 16405 + continue + } + joinStr := ", " + if k == "COOKIE" { + joinStr = "; " + } + env = append(env, "HTTP_"+k+"="+strings.Join(v, joinStr)) + } + + if req.ContentLength > 0 { + env = append(env, fmt.Sprintf("CONTENT_LENGTH=%d", req.ContentLength)) + } + if ctype := req.Header.Get("Content-Type"); ctype != "" { + env = append(env, "CONTENT_TYPE="+ctype) + } + + envPath := os.Getenv("PATH") + if envPath == "" { + envPath = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin" + } + env = append(env, "PATH="+envPath) + + for _, e := range h.InheritEnv { + if v := os.Getenv(e); v != "" { + env = append(env, e+"="+v) + } + } + + for _, e := range osDefaultInheritEnv { + if v := os.Getenv(e); v != "" { + env = append(env, e+"="+v) + } + } + + if h.Env != nil { + env = append(env, h.Env...) + } + + env = removeLeadingDuplicates(env) + + var cwd, path string + if h.Dir != "" { + path = h.Path + cwd = h.Dir + } else { + cwd, path = filepath.Split(h.Path) + } + if cwd == "" { + cwd = "." + } + + internalError := func(err error) { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("CGI error: %v", err) + } + + cmd := &exec.Cmd{ + Path: path, + Args: append([]string{h.Path}, h.Args...), + Dir: cwd, + Env: env, + Stderr: h.stderr(), + } + if req.ContentLength != 0 { + cmd.Stdin = req.Body + } + stdoutRead, err := cmd.StdoutPipe() + if err != nil { + internalError(err) + return + } + + err = cmd.Start() + if err != nil { + internalError(err) + return + } + if hook := testHookStartProcess; hook != nil { + hook(cmd.Process) + } + defer cmd.Wait() + defer stdoutRead.Close() + + linebody := bufio.NewReaderSize(stdoutRead, 1024) + headers := make(http.Header) + statusCode := 0 + headerLines := 0 + sawBlankLine := false + for { + line, isPrefix, err := linebody.ReadLine() + if isPrefix { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: long header line from subprocess.") + return + } + if err == io.EOF { + break + } + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: error reading headers: %v", err) + return + } + if len(line) == 0 { + sawBlankLine = true + break + } + headerLines++ + parts := strings.SplitN(string(line), ":", 2) + if len(parts) < 2 { + h.printf("cgi: bogus header line: %s", string(line)) + continue + } + header, val := parts[0], parts[1] + if !httpguts.ValidHeaderFieldName(header) { + h.printf("cgi: invalid header name: %q", header) + continue + } + val = textproto.TrimString(val) + switch { + case header == "Status": + if len(val) < 3 { + h.printf("cgi: bogus status (short): %q", val) + return + } + code, err := strconv.Atoi(val[0:3]) + if err != nil { + h.printf("cgi: bogus status: %q", val) + h.printf("cgi: line was %q", line) + return + } + statusCode = code + default: + headers.Add(header, val) + } + } + if headerLines == 0 || !sawBlankLine { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: no headers") + return + } + + if loc := headers.Get("Location"); loc != "" { + if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil { + h.handleInternalRedirect(rw, req, loc) + return + } + if statusCode == 0 { + statusCode = http.StatusFound + } + } + + if statusCode == 0 && headers.Get("Content-Type") == "" { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: missing required Content-Type in headers") + return + } + + if statusCode == 0 { + statusCode = http.StatusOK + } + + // Copy headers to rw's headers, after we've decided not to + // go into handleInternalRedirect, which won't want its rw + // headers to have been touched. + for k, vv := range headers { + for _, v := range vv { + rw.Header().Add(k, v) + } + } + + rw.WriteHeader(statusCode) + + _, err = io.Copy(rw, linebody) + if err != nil { + h.printf("cgi: copy error: %v", err) + // And kill the child CGI process so we don't hang on + // the deferred cmd.Wait above if the error was just + // the client (rw) going away. If it was a read error + // (because the child died itself), then the extra + // kill of an already-dead process is harmless (the PID + // won't be reused until the Wait above). + cmd.Process.Kill() + } +} + +func (h *Handler) printf(format string, v ...interface{}) { + if h.Logger != nil { + h.Logger.Printf(format, v...) + } else { + log.Printf(format, v...) + } +} + +func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) { + url, err := req.URL.Parse(path) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: error resolving local URI path %q: %v", path, err) + return + } + // TODO: RFC 3875 isn't clear if only GET is supported, but it + // suggests so: "Note that any message-body attached to the + // request (such as for a POST request) may not be available + // to the resource that is the target of the redirect." We + // should do some tests against Apache to see how it handles + // POST, HEAD, etc. Does the internal redirect get the same + // method or just GET? What about incoming headers? + // (e.g. Cookies) Which headers, if any, are copied into the + // second request? + newReq := &http.Request{ + Method: "GET", + URL: url, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: url.Host, + RemoteAddr: req.RemoteAddr, + TLS: req.TLS, + } + h.PathLocationHandler.ServeHTTP(rw, newReq) +} + +func upperCaseAndUnderscore(r rune) rune { + switch { + case r >= 'a' && r <= 'z': + return r - ('a' - 'A') + case r == '-': + return '_' + case r == '=': + // Maybe not part of the CGI 'spec' but would mess up + // the environment in any case, as Go represents the + // environment as a slice of "key=value" strings. + return '_' + } + // TODO: other transformations in spec or practice? + return r +} + +var testHookStartProcess func(*os.Process) // nil except for some tests diff --git a/vendor/github.com/lesismal/llib/std/net/http/cgi/host_test.go b/vendor/github.com/lesismal/llib/std/net/http/cgi/host_test.go new file mode 100644 index 0000000..4fa00be --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cgi/host_test.go @@ -0,0 +1,578 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for package cgi + +package cgi + +import ( + "bufio" + "bytes" + "fmt" + "github.com/lesismal/llib/std/net/http/httptest" + "io" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "reflect" + "runtime" + "strconv" + "strings" + "testing" + "time" +) + +func newRequest(httpreq string) *http.Request { + buf := bufio.NewReader(strings.NewReader(httpreq)) + req, err := http.ReadRequest(buf) + if err != nil { + panic("cgi: bogus http request in test: " + httpreq) + } + req.RemoteAddr = "1.2.3.4:1234" + return req +} + +func runCgiTest(t *testing.T, h *Handler, + httpreq string, + expectedMap map[string]string, checks ...func(reqInfo map[string]string)) *httptest.ResponseRecorder { + rw := httptest.NewRecorder() + req := newRequest(httpreq) + h.ServeHTTP(rw, req) + runResponseChecks(t, rw, expectedMap, checks...) + return rw +} + +func runResponseChecks(t *testing.T, rw *httptest.ResponseRecorder, + expectedMap map[string]string, checks ...func(reqInfo map[string]string)) { + // Make a map to hold the test map that the CGI returns. + m := make(map[string]string) + m["_body"] = rw.Body.String() + linesRead := 0 +readlines: + for { + line, err := rw.Body.ReadString('\n') + switch { + case err == io.EOF: + break readlines + case err != nil: + t.Fatalf("unexpected error reading from CGI: %v", err) + } + linesRead++ + trimmedLine := strings.TrimRight(line, "\r\n") + split := strings.SplitN(trimmedLine, "=", 2) + if len(split) != 2 { + t.Fatalf("Unexpected %d parts from invalid line number %v: %q; existing map=%v", + len(split), linesRead, line, m) + } + m[split[0]] = split[1] + } + + for key, expected := range expectedMap { + got := m[key] + if key == "cwd" { + // For Windows. golang.org/issue/4645. + fi1, _ := os.Stat(got) + fi2, _ := os.Stat(expected) + if os.SameFile(fi1, fi2) { + got = expected + } + } + if got != expected { + t.Errorf("for key %q got %q; expected %q", key, got, expected) + } + } + for _, check := range checks { + check(m) + } +} + +var cgiTested, cgiWorks bool + +func check(t *testing.T) { + if !cgiTested { + cgiTested = true + cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil + } + if !cgiWorks { + // No Perl on Windows, needed by test.cgi + // TODO: make the child process be Go, not Perl. + t.Skip("Skipping test: test.cgi failed.") + } +} + +func TestCGIBasicGet(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "test": "Hello CGI", + "param-a": "b", + "param-foo": "bar", + "env-GATEWAY_INTERFACE": "CGI/1.1", + "env-HTTP_HOST": "example.com", + "env-PATH_INFO": "", + "env-QUERY_STRING": "foo=bar&a=b", + "env-REMOTE_ADDR": "1.2.3.4", + "env-REMOTE_HOST": "1.2.3.4", + "env-REMOTE_PORT": "1234", + "env-REQUEST_METHOD": "GET", + "env-REQUEST_URI": "/test.cgi?foo=bar&a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + "env-SERVER_NAME": "example.com", + "env-SERVER_PORT": "80", + "env-SERVER_SOFTWARE": "go", + } + replay := runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) + + if expected, got := "text/html", replay.Header().Get("Content-Type"); got != expected { + t.Errorf("got a Content-Type of %q; expected %q", got, expected) + } + if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected { + t.Errorf("got a X-Test-Header of %q; expected %q", got, expected) + } +} + +func TestCGIEnvIPv6(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "test": "Hello CGI", + "param-a": "b", + "param-foo": "bar", + "env-GATEWAY_INTERFACE": "CGI/1.1", + "env-HTTP_HOST": "example.com", + "env-PATH_INFO": "", + "env-QUERY_STRING": "foo=bar&a=b", + "env-REMOTE_ADDR": "2000::3000", + "env-REMOTE_HOST": "2000::3000", + "env-REMOTE_PORT": "12345", + "env-REQUEST_METHOD": "GET", + "env-REQUEST_URI": "/test.cgi?foo=bar&a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + "env-SERVER_NAME": "example.com", + "env-SERVER_PORT": "80", + "env-SERVER_SOFTWARE": "go", + } + + rw := httptest.NewRecorder() + req := newRequest("GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n") + req.RemoteAddr = "[2000::3000]:12345" + h.ServeHTTP(rw, req) + runResponseChecks(t, rw, expectedMap) +} + +func TestCGIBasicGetAbsPath(t *testing.T) { + check(t) + pwd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd error: %v", err) + } + h := &Handler{ + Path: pwd + "/testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "env-REQUEST_URI": "/test.cgi?foo=bar&a=b", + "env-SCRIPT_FILENAME": pwd + "/testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + } + runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestPathInfo(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "param-a": "b", + "env-PATH_INFO": "/extrapath", + "env-QUERY_STRING": "a=b", + "env-REQUEST_URI": "/test.cgi/extrapath?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + } + runCgiTest(t, h, "GET /test.cgi/extrapath?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestPathInfoDirRoot(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/myscript/", + } + expectedMap := map[string]string{ + "env-PATH_INFO": "bar", + "env-QUERY_STRING": "a=b", + "env-REQUEST_URI": "/myscript/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/myscript/", + } + runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestDupHeaders(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + } + expectedMap := map[string]string{ + "env-REQUEST_URI": "/myscript/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-HTTP_COOKIE": "nom=NOM; yum=YUM", + "env-HTTP_X_FOO": "val1, val2", + } + runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+ + "Cookie: nom=NOM\n"+ + "Cookie: yum=YUM\n"+ + "X-Foo: val1\n"+ + "X-Foo: val2\n"+ + "Host: example.com\n\n", + expectedMap) +} + +// Issue 16405: CGI+http.Transport differing uses of HTTP_PROXY. +// Verify we don't set the HTTP_PROXY environment variable. +// Hope nobody was depending on it. It's not a known header, though. +func TestDropProxyHeader(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + } + expectedMap := map[string]string{ + "env-REQUEST_URI": "/myscript/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-HTTP_X_FOO": "a", + } + runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+ + "X-Foo: a\n"+ + "Proxy: should_be_stripped\n"+ + "Host: example.com\n\n", + expectedMap, + func(reqInfo map[string]string) { + if v, ok := reqInfo["env-HTTP_PROXY"]; ok { + t.Errorf("HTTP_PROXY = %q; should be absent", v) + } + }) +} + +func TestPathInfoNoRoot(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "", + } + expectedMap := map[string]string{ + "env-PATH_INFO": "/bar", + "env-QUERY_STRING": "a=b", + "env-REQUEST_URI": "/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/", + } + runCgiTest(t, h, "GET /bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestCGIBasicPost(t *testing.T) { + check(t) + postReq := `POST /test.cgi?a=b HTTP/1.0 +Host: example.com +Content-Type: application/x-www-form-urlencoded +Content-Length: 15 + +postfoo=postbar` + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "test": "Hello CGI", + "param-postfoo": "postbar", + "env-REQUEST_METHOD": "POST", + "env-CONTENT_LENGTH": "15", + "env-REQUEST_URI": "/test.cgi?a=b", + } + runCgiTest(t, h, postReq, expectedMap) +} + +func chunk(s string) string { + return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) +} + +// The CGI spec doesn't allow chunked requests. +func TestCGIPostChunked(t *testing.T) { + check(t) + postReq := `POST /test.cgi?a=b HTTP/1.1 +Host: example.com +Content-Type: application/x-www-form-urlencoded +Transfer-Encoding: chunked + +` + chunk("postfoo") + chunk("=") + chunk("postbar") + chunk("") + + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{} + resp := runCgiTest(t, h, postReq, expectedMap) + if got, expected := resp.Code, http.StatusBadRequest; got != expected { + t.Fatalf("Expected %v response code from chunked request body; got %d", + expected, got) + } +} + +func TestRedirect(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil) + if e, g := 302, rec.Code; e != g { + t.Errorf("expected status code %d; got %d", e, g) + } + if e, g := "http://foo.com/", rec.Header().Get("Location"); e != g { + t.Errorf("expected Location header of %q; got %q", e, g) + } +} + +func TestInternalRedirect(t *testing.T) { + check(t) + baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path) + fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr) + }) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + PathLocationHandler: baseHandler, + } + expectedMap := map[string]string{ + "basepath": "/foo", + "remoteaddr": "1.2.3.4:1234", + } + runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +// TestCopyError tests that we kill the process if there's an error copying +// its output. (for example, from the client having gone away) +func TestCopyError(t *testing.T) { + check(t) + if runtime.GOOS == "windows" { + t.Skipf("skipping test on %q", runtime.GOOS) + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + ts := httptest.NewServer(h) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + req, _ := http.NewRequest("GET", "http://example.com/test.cgi?bigresponse=1", nil) + err = req.Write(conn) + if err != nil { + t.Fatalf("Write: %v", err) + } + + res, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + t.Fatalf("ReadResponse: %v", err) + } + + pidstr := res.Header.Get("X-CGI-Pid") + if pidstr == "" { + t.Fatalf("expected an X-CGI-Pid header in response") + } + pid, err := strconv.Atoi(pidstr) + if err != nil { + t.Fatalf("invalid X-CGI-Pid value") + } + + var buf [5000]byte + n, err := io.ReadFull(res.Body, buf[:]) + if err != nil { + t.Fatalf("ReadFull: %d bytes, %v", n, err) + } + + childRunning := func() bool { + return isProcessRunning(pid) + } + + if !childRunning() { + t.Fatalf("pre-conn.Close, expected child to be running") + } + conn.Close() + + tries := 0 + for tries < 25 && childRunning() { + time.Sleep(50 * time.Millisecond * time.Duration(tries)) + tries++ + } + if childRunning() { + t.Fatalf("post-conn.Close, expected child to be gone") + } +} + +func TestDirUnix(t *testing.T) { + check(t) + if runtime.GOOS == "windows" { + t.Skipf("skipping test on %q", runtime.GOOS) + } + cwd, _ := os.Getwd() + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + Dir: cwd, + } + expectedMap := map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) + + cwd, _ = os.Getwd() + cwd = filepath.Join(cwd, "testdata") + h = &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap = map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func findPerl(t *testing.T) string { + t.Helper() + perl, err := exec.LookPath("perl") + if err != nil { + t.Skip("Skipping test: perl not found.") + } + perl, _ = filepath.Abs(perl) + + cmd := exec.Command(perl, "-e", "print 123") + cmd.Env = []string{"PATH=/garbage"} + out, err := cmd.Output() + if err != nil || string(out) != "123" { + t.Skipf("Skipping test: %s is not functional", perl) + } + return perl +} + +func TestDirWindows(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Skipping windows specific test.") + } + + cgifile, _ := filepath.Abs("testdata/test.cgi") + + perl := findPerl(t) + + cwd, _ := os.Getwd() + h := &Handler{ + Path: perl, + Root: "/test.cgi", + Dir: cwd, + Args: []string{cgifile}, + Env: []string{"SCRIPT_FILENAME=" + cgifile}, + } + expectedMap := map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) + + // If not specify Dir on windows, working directory should be + // base directory of perl. + cwd, _ = filepath.Split(perl) + if cwd != "" && cwd[len(cwd)-1] == filepath.Separator { + cwd = cwd[:len(cwd)-1] + } + h = &Handler{ + Path: perl, + Root: "/test.cgi", + Args: []string{cgifile}, + Env: []string{"SCRIPT_FILENAME=" + cgifile}, + } + expectedMap = map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestEnvOverride(t *testing.T) { + check(t) + cgifile, _ := filepath.Abs("testdata/test.cgi") + + perl := findPerl(t) + + cwd, _ := os.Getwd() + h := &Handler{ + Path: perl, + Root: "/test.cgi", + Dir: cwd, + Args: []string{cgifile}, + Env: []string{ + "SCRIPT_FILENAME=" + cgifile, + "REQUEST_URI=/foo/bar", + "PATH=/wibble"}, + } + expectedMap := map[string]string{ + "cwd": cwd, + "env-SCRIPT_FILENAME": cgifile, + "env-REQUEST_URI": "/foo/bar", + "env-PATH": "/wibble", + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestHandlerStderr(t *testing.T) { + check(t) + var stderr bytes.Buffer + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + Stderr: &stderr, + } + + rw := httptest.NewRecorder() + req := newRequest("GET /test.cgi?writestderr=1 HTTP/1.0\nHost: example.com\n\n") + h.ServeHTTP(rw, req) + if got, want := stderr.String(), "Hello, stderr!\n"; got != want { + t.Errorf("Stderr = %q; want %q", got, want) + } +} + +func TestRemoveLeadingDuplicates(t *testing.T) { + tests := []struct { + env []string + want []string + }{ + { + env: []string{"a=b", "b=c", "a=b2"}, + want: []string{"b=c", "a=b2"}, + }, + { + env: []string{"a=b", "b=c", "d", "e=f"}, + want: []string{"a=b", "b=c", "d", "e=f"}, + }, + } + for _, tt := range tests { + got := removeLeadingDuplicates(tt.env) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("removeLeadingDuplicates(%q) = %q; want %q", tt.env, got, tt.want) + } + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cgi/integration_test.go b/vendor/github.com/lesismal/llib/std/net/http/cgi/integration_test.go new file mode 100644 index 0000000..aa8e21c --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cgi/integration_test.go @@ -0,0 +1,295 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests a Go CGI program running under a Go CGI host process. +// Further, the two programs are the same binary, just checking +// their environment to figure out what mode to run in. + +package cgi + +import ( + "bytes" + "errors" + "fmt" + "github.com/lesismal/llib/std/net/http/httptest" + "internal/testenv" + "io" + "net/http" + "net/url" + "os" + "strings" + "testing" + "time" +) + +// This test is a CGI host (testing host.go) that runs its own binary +// as a child process testing the other half of CGI (child.go). +func TestHostingOurselves(t *testing.T) { + testenv.MustHaveExec(t) + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "test": "Hello CGI-in-CGI", + "param-a": "b", + "param-foo": "bar", + "env-GATEWAY_INTERFACE": "CGI/1.1", + "env-HTTP_HOST": "example.com", + "env-PATH_INFO": "", + "env-QUERY_STRING": "foo=bar&a=b", + "env-REMOTE_ADDR": "1.2.3.4", + "env-REMOTE_HOST": "1.2.3.4", + "env-REMOTE_PORT": "1234", + "env-REQUEST_METHOD": "GET", + "env-REQUEST_URI": "/test.go?foo=bar&a=b", + "env-SCRIPT_FILENAME": os.Args[0], + "env-SCRIPT_NAME": "/test.go", + "env-SERVER_NAME": "example.com", + "env-SERVER_PORT": "80", + "env-SERVER_SOFTWARE": "go", + } + replay := runCgiTest(t, h, "GET /test.go?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) + + if expected, got := "text/plain; charset=utf-8", replay.Header().Get("Content-Type"); got != expected { + t.Errorf("got a Content-Type of %q; expected %q", got, expected) + } + if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected { + t.Errorf("got a X-Test-Header of %q; expected %q", got, expected) + } +} + +type customWriterRecorder struct { + w io.Writer + *httptest.ResponseRecorder +} + +func (r *customWriterRecorder) Write(p []byte) (n int, err error) { + return r.w.Write(p) +} + +type limitWriter struct { + w io.Writer + n int +} + +func (w *limitWriter) Write(p []byte) (n int, err error) { + if len(p) > w.n { + p = p[:w.n] + } + if len(p) > 0 { + n, err = w.w.Write(p) + w.n -= n + } + if w.n == 0 { + err = errors.New("past write limit") + } + return +} + +// If there's an error copying the child's output to the parent, test +// that we kill the child. +func TestKillChildAfterCopyError(t *testing.T) { + testenv.MustHaveExec(t) + + defer func() { testHookStartProcess = nil }() + proc := make(chan *os.Process, 1) + testHookStartProcess = func(p *os.Process) { + proc <- p + } + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + req, _ := http.NewRequest("GET", "http://example.com/test.cgi?write-forever=1", nil) + rec := httptest.NewRecorder() + var out bytes.Buffer + const writeLen = 50 << 10 + rw := &customWriterRecorder{&limitWriter{&out, writeLen}, rec} + + donec := make(chan bool, 1) + go func() { + h.ServeHTTP(rw, req) + donec <- true + }() + + select { + case <-donec: + if out.Len() != writeLen || out.Bytes()[0] != 'a' { + t.Errorf("unexpected output: %q", out.Bytes()) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout. ServeHTTP hung and didn't kill the child process?") + select { + case p := <-proc: + p.Kill() + t.Logf("killed process") + default: + t.Logf("didn't kill process") + } + } +} + +// Test that a child handler writing only headers works. +// golang.org/issue/7196 +func TestChildOnlyHeaders(t *testing.T) { + testenv.MustHaveExec(t) + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "_body": "", + } + replay := runCgiTest(t, h, "GET /test.go?no-body=1 HTTP/1.0\nHost: example.com\n\n", expectedMap) + if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected { + t.Errorf("got a X-Test-Header of %q; expected %q", got, expected) + } +} + +// Test that a child handler does not receive a nil Request Body. +// golang.org/issue/39190 +func TestNilRequestBody(t *testing.T) { + testenv.MustHaveExec(t) + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "nil-request-body": "false", + } + _ = runCgiTest(t, h, "POST /test.go?nil-request-body=1 HTTP/1.0\nHost: example.com\n\n", expectedMap) + _ = runCgiTest(t, h, "POST /test.go?nil-request-body=1 HTTP/1.0\nHost: example.com\nContent-Length: 0\n\n", expectedMap) +} + +func TestChildContentType(t *testing.T) { + testenv.MustHaveExec(t) + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + var tests = []struct { + name string + body string + wantCT string + }{ + { + name: "no body", + wantCT: "text/plain; charset=utf-8", + }, + { + name: "html", + body: "test pageThis is a body", + wantCT: "text/html; charset=utf-8", + }, + { + name: "text", + body: strings.Repeat("gopher", 86), + wantCT: "text/plain; charset=utf-8", + }, + { + name: "jpg", + body: "\xFF\xD8\xFF" + strings.Repeat("B", 1024), + wantCT: "image/jpeg", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expectedMap := map[string]string{"_body": tt.body} + req := fmt.Sprintf("GET /test.go?exact-body=%s HTTP/1.0\nHost: example.com\n\n", url.QueryEscape(tt.body)) + replay := runCgiTest(t, h, req, expectedMap) + if got := replay.Header().Get("Content-Type"); got != tt.wantCT { + t.Errorf("got a Content-Type of %q; expected it to start with %q", got, tt.wantCT) + } + }) + } +} + +// golang.org/issue/7198 +func Test500WithNoHeaders(t *testing.T) { want500Test(t, "/immediate-disconnect") } +func Test500WithNoContentType(t *testing.T) { want500Test(t, "/no-content-type") } +func Test500WithEmptyHeaders(t *testing.T) { want500Test(t, "/empty-headers") } + +func want500Test(t *testing.T, path string) { + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "_body": "", + } + replay := runCgiTest(t, h, "GET "+path+" HTTP/1.0\nHost: example.com\n\n", expectedMap) + if replay.Code != 500 { + t.Errorf("Got code %d; want 500", replay.Code) + } +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +// Note: not actually a test. +func TestBeChildCGIProcess(t *testing.T) { + if os.Getenv("REQUEST_METHOD") == "" { + // Not in a CGI environment; skipping test. + return + } + switch os.Getenv("REQUEST_URI") { + case "/immediate-disconnect": + os.Exit(0) + case "/no-content-type": + fmt.Printf("Content-Length: 6\n\nHello\n") + os.Exit(0) + case "/empty-headers": + fmt.Printf("\nHello") + os.Exit(0) + } + Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.FormValue("nil-request-body") == "1" { + fmt.Fprintf(rw, "nil-request-body=%v\n", req.Body == nil) + return + } + rw.Header().Set("X-Test-Header", "X-Test-Value") + req.ParseForm() + if req.FormValue("no-body") == "1" { + return + } + if eb, ok := req.Form["exact-body"]; ok { + io.WriteString(rw, eb[0]) + return + } + if req.FormValue("write-forever") == "1" { + io.Copy(rw, neverEnding('a')) + for { + time.Sleep(5 * time.Second) // hang forever, until killed + } + } + fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n") + for k, vv := range req.Form { + for _, v := range vv { + fmt.Fprintf(rw, "param-%s=%s\n", k, v) + } + } + for _, kv := range os.Environ() { + fmt.Fprintf(rw, "env-%s\n", kv) + } + })) + os.Exit(0) +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cgi/plan9_test.go b/vendor/github.com/lesismal/llib/std/net/http/cgi/plan9_test.go new file mode 100644 index 0000000..cc20fe0 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cgi/plan9_test.go @@ -0,0 +1,17 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build plan9 + +package cgi + +import ( + "os" + "strconv" +) + +func isProcessRunning(pid int) bool { + _, err := os.Stat("/proc/" + strconv.Itoa(pid)) + return err == nil +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cgi/posix_test.go b/vendor/github.com/lesismal/llib/std/net/http/cgi/posix_test.go new file mode 100644 index 0000000..9396ce0 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cgi/posix_test.go @@ -0,0 +1,20 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9 + +package cgi + +import ( + "os" + "syscall" +) + +func isProcessRunning(pid int) bool { + p, err := os.FindProcess(pid) + if err != nil { + return false + } + return p.Signal(syscall.Signal(0)) == nil +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cgi/testdata/test.cgi b/vendor/github.com/lesismal/llib/std/net/http/cgi/testdata/test.cgi new file mode 100644 index 0000000..667fce2 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cgi/testdata/test.cgi @@ -0,0 +1,95 @@ +#!/usr/bin/perl +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. +# +# Test script run as a child process under cgi_test.go + +use strict; +use Cwd; + +binmode STDOUT; + +my $q = MiniCGI->new; +my $params = $q->Vars; + +if ($params->{"loc"}) { + print "Location: $params->{loc}\r\n\r\n"; + exit(0); +} + +print "Content-Type: text/html\r\n"; +print "X-CGI-Pid: $$\r\n"; +print "X-Test-Header: X-Test-Value\r\n"; +print "\r\n"; + +if ($params->{"writestderr"}) { + print STDERR "Hello, stderr!\n"; +} + +if ($params->{"bigresponse"}) { + # 17 MB, for OS X: golang.org/issue/4958 + for (1..(17 * 1024)) { + print "A" x 1024, "\r\n"; + } + exit 0; +} + +print "test=Hello CGI\r\n"; + +foreach my $k (sort keys %$params) { + print "param-$k=$params->{$k}\r\n"; +} + +foreach my $k (sort keys %ENV) { + my $clean_env = $ENV{$k}; + $clean_env =~ s/[\n\r]//g; + print "env-$k=$clean_env\r\n"; +} + +# NOTE: msys perl returns /c/go/src/... not C:\go\.... +my $dir = getcwd(); +if ($^O eq 'MSWin32' || $^O eq 'msys' || $^O eq 'cygwin') { + if ($dir =~ /^.:/) { + $dir =~ s!/!\\!g; + } else { + my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe'; + $cmd =~ s!\\!/!g; + $dir = `$cmd /c cd`; + chomp $dir; + } +} +print "cwd=$dir\r\n"; + +# A minimal version of CGI.pm, for people without the perl-modules +# package installed. (CGI.pm used to be part of the Perl core, but +# some distros now bundle perl-base and perl-modules separately...) +package MiniCGI; + +sub new { + my $class = shift; + return bless {}, $class; +} + +sub Vars { + my $self = shift; + my $pairs; + if ($ENV{CONTENT_LENGTH}) { + $pairs = do { local $/; }; + } else { + $pairs = $ENV{QUERY_STRING}; + } + my $vars = {}; + foreach my $kv (split(/&/, $pairs)) { + my ($k, $v) = split(/=/, $kv, 2); + $vars->{_urldecode($k)} = _urldecode($v); + } + return $vars; +} + +sub _urldecode { + my $v = shift; + $v =~ tr/+/ /; + $v =~ s/%([a-fA-F0-9][a-fA-F0-9])/pack("C", hex($1))/eg; + return $v; +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/client.go b/vendor/github.com/lesismal/llib/std/net/http/client.go new file mode 100644 index 0000000..88e2028 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/client.go @@ -0,0 +1,1009 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP client. See RFC 7230 through 7235. +// +// This is the high-level Client interface. +// The low-level implementation is in transport.go. + +package http + +import ( + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "log" + "net/url" + "reflect" + "sort" + "strings" + "sync" + "time" +) + +// A Client is an HTTP client. Its zero value (DefaultClient) is a +// usable client that uses DefaultTransport. +// +// The Client's Transport typically has internal state (cached TCP +// connections), so Clients should be reused instead of created as +// needed. Clients are safe for concurrent use by multiple goroutines. +// +// A Client is higher-level than a RoundTripper (such as Transport) +// and additionally handles HTTP details such as cookies and +// redirects. +// +// When following redirects, the Client will forward all headers set on the +// initial Request except: +// +// • when forwarding sensitive headers like "Authorization", +// "WWW-Authenticate", and "Cookie" to untrusted targets. +// These headers will be ignored when following a redirect to a domain +// that is not a subdomain match or exact match of the initial domain. +// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com" +// will forward the sensitive headers, but a redirect to "bar.com" will not. +// +// • when forwarding the "Cookie" header with a non-nil cookie Jar. +// Since each redirect may mutate the state of the cookie jar, +// a redirect may possibly alter a cookie set in the initial request. +// When forwarding the "Cookie" header, any mutated cookies will be omitted, +// with the expectation that the Jar will insert those mutated cookies +// with the updated values (assuming the origin matches). +// If Jar is nil, the initial cookies are forwarded without change. +// +type Client struct { + // Transport specifies the mechanism by which individual + // HTTP requests are made. + // If nil, DefaultTransport is used. + Transport RoundTripper + + // CheckRedirect specifies the policy for handling redirects. + // If CheckRedirect is not nil, the client calls it before + // following an HTTP redirect. The arguments req and via are + // the upcoming request and the requests made already, oldest + // first. If CheckRedirect returns an error, the Client's Get + // method returns both the previous Response (with its Body + // closed) and CheckRedirect's error (wrapped in a url.Error) + // instead of issuing the Request req. + // As a special case, if CheckRedirect returns ErrUseLastResponse, + // then the most recent response is returned with its body + // unclosed, along with a nil error. + // + // If CheckRedirect is nil, the Client uses its default policy, + // which is to stop after 10 consecutive requests. + CheckRedirect func(req *Request, via []*Request) error + + // Jar specifies the cookie jar. + // + // The Jar is used to insert relevant cookies into every + // outbound Request and is updated with the cookie values + // of every inbound Response. The Jar is consulted for every + // redirect that the Client follows. + // + // If Jar is nil, cookies are only sent if they are explicitly + // set on the Request. + Jar CookieJar + + // Timeout specifies a time limit for requests made by this + // Client. The timeout includes connection time, any + // redirects, and reading the response body. The timer remains + // running after Get, Head, Post, or Do return and will + // interrupt reading of the Response.Body. + // + // A Timeout of zero means no timeout. + // + // The Client cancels requests to the underlying Transport + // as if the Request's Context ended. + // + // For compatibility, the Client will also use the deprecated + // CancelRequest method on Transport if found. New + // RoundTripper implementations should use the Request's Context + // for cancellation instead of implementing CancelRequest. + Timeout time.Duration +} + +// DefaultClient is the default Client and is used by Get, Head, and Post. +var DefaultClient = &Client{} + +// RoundTripper is an interface representing the ability to execute a +// single HTTP transaction, obtaining the Response for a given Request. +// +// A RoundTripper must be safe for concurrent use by multiple +// goroutines. +type RoundTripper interface { + // RoundTrip executes a single HTTP transaction, returning + // a Response for the provided Request. + // + // RoundTrip should not attempt to interpret the response. In + // particular, RoundTrip must return err == nil if it obtained + // a response, regardless of the response's HTTP status code. + // A non-nil err should be reserved for failure to obtain a + // response. Similarly, RoundTrip should not attempt to + // handle higher-level protocol details such as redirects, + // authentication, or cookies. + // + // RoundTrip should not modify the request, except for + // consuming and closing the Request's Body. RoundTrip may + // read fields of the request in a separate goroutine. Callers + // should not mutate or reuse the request until the Response's + // Body has been closed. + // + // RoundTrip must always close the body, including on errors, + // but depending on the implementation may do so in a separate + // goroutine even after RoundTrip returns. This means that + // callers wanting to reuse the body for subsequent requests + // must arrange to wait for the Close call before doing so. + // + // The Request's URL and Header fields must be initialized. + RoundTrip(*Request) (*Response, error) +} + +// refererForURL returns a referer without any authentication info or +// an empty string if lastReq scheme is https and newReq scheme is http. +func refererForURL(lastReq, newReq *url.URL) string { + // https://tools.ietf.org/html/rfc7231#section-5.5.2 + // "Clients SHOULD NOT include a Referer header field in a + // (non-secure) HTTP request if the referring page was + // transferred with a secure protocol." + if lastReq.Scheme == "https" && newReq.Scheme == "http" { + return "" + } + referer := lastReq.String() + if lastReq.User != nil { + // This is not very efficient, but is the best we can + // do without: + // - introducing a new method on URL + // - creating a race condition + // - copying the URL struct manually, which would cause + // maintenance problems down the line + auth := lastReq.User.String() + "@" + referer = strings.Replace(referer, auth, "", 1) + } + return referer +} + +// didTimeout is non-nil only if err != nil. +func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { + if c.Jar != nil { + for _, cookie := range c.Jar.Cookies(req.URL) { + req.AddCookie(cookie) + } + } + resp, didTimeout, err = send(req, c.transport(), deadline) + if err != nil { + return nil, didTimeout, err + } + if c.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + c.Jar.SetCookies(req.URL, rc) + } + } + return resp, nil, nil +} + +func (c *Client) deadline() time.Time { + if c.Timeout > 0 { + return time.Now().Add(c.Timeout) + } + return time.Time{} +} + +func (c *Client) transport() RoundTripper { + if c.Transport != nil { + return c.Transport + } + return DefaultTransport +} + +// send issues an HTTP request. +// Caller should close resp.Body when done reading from it. +func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { + req := ireq // req is either the original request, or a modified fork + + if rt == nil { + req.closeBody() + return nil, alwaysFalse, errors.New("http: no Client.Transport or DefaultTransport") + } + + if req.URL == nil { + req.closeBody() + return nil, alwaysFalse, errors.New("http: nil Request.URL") + } + + if req.RequestURI != "" { + req.closeBody() + return nil, alwaysFalse, errors.New("http: Request.RequestURI can't be set in client requests") + } + + // forkReq forks req into a shallow clone of ireq the first + // time it's called. + forkReq := func() { + if ireq == req { + req = new(Request) + *req = *ireq // shallow clone + } + } + + // Most the callers of send (Get, Post, et al) don't need + // Headers, leaving it uninitialized. We guarantee to the + // Transport that this has been initialized, though. + if req.Header == nil { + forkReq() + req.Header = make(Header) + } + + if u := req.URL.User; u != nil && req.Header.Get("Authorization") == "" { + username := u.Username() + password, _ := u.Password() + forkReq() + req.Header = cloneOrMakeHeader(ireq.Header) + req.Header.Set("Authorization", "Basic "+basicAuth(username, password)) + } + + if !deadline.IsZero() { + forkReq() + } + stopTimer, didTimeout := setRequestCancel(req, rt, deadline) + + resp, err = rt.RoundTrip(req) + if err != nil { + stopTimer() + if resp != nil { + log.Printf("RoundTripper returned a response & error; ignoring response") + } + if tlsErr, ok := err.(tls.RecordHeaderError); ok { + // If we get a bad TLS record header, check to see if the + // response looks like HTTP and give a more helpful error. + // See golang.org/issue/11111. + if string(tlsErr.RecordHeader[:]) == "HTTP/" { + err = errors.New("http: server gave HTTP response to HTTPS client") + } + } + return nil, didTimeout, err + } + if resp == nil { + return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a nil *Response with a nil error", rt) + } + if resp.Body == nil { + // The documentation on the Body field says “The http Client and Transport + // guarantee that Body is always non-nil, even on responses without a body + // or responses with a zero-length body.” Unfortunately, we didn't document + // that same constraint for arbitrary RoundTripper implementations, and + // RoundTripper implementations in the wild (mostly in tests) assume that + // they can use a nil Body to mean an empty one (similar to Request.Body). + // (See https://golang.org/issue/38095.) + // + // If the ContentLength allows the Body to be empty, fill in an empty one + // here to ensure that it is non-nil. + if resp.ContentLength > 0 && req.Method != "HEAD" { + return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a *Response with content length %d but a nil Body", rt, resp.ContentLength) + } + resp.Body = io.NopCloser(strings.NewReader("")) + } + if !deadline.IsZero() { + resp.Body = &cancelTimerBody{ + stop: stopTimer, + rc: resp.Body, + reqDidTimeout: didTimeout, + } + } + return resp, nil, nil +} + +// timeBeforeContextDeadline reports whether the non-zero Time t is +// before ctx's deadline, if any. If ctx does not have a deadline, it +// always reports true (the deadline is considered infinite). +func timeBeforeContextDeadline(t time.Time, ctx context.Context) bool { + d, ok := ctx.Deadline() + if !ok { + return true + } + return t.Before(d) +} + +// knownRoundTripperImpl reports whether rt is a RoundTripper that's +// maintained by the Go team and known to implement the latest +// optional semantics (notably contexts). The Request is used +// to check whether this particular request is using an alternate protocol, +// in which case we need to check the RoundTripper for that protocol. +func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { + switch t := rt.(type) { + case *Transport: + if altRT := t.alternateRoundTripper(req); altRT != nil { + return knownRoundTripperImpl(altRT, req) + } + return true + case *http2Transport, http2noDialH2RoundTripper: + return true + } + // There's a very minor chance of a false positive with this. + // Instead of detecting our golang.org/x/net/http2.Transport, + // it might detect a Transport type in a different http2 + // package. But I know of none, and the only problem would be + // some temporarily leaked goroutines if the transport didn't + // support contexts. So this is a good enough heuristic: + if reflect.TypeOf(rt).String() == "*http2.Transport" { + return true + } + return false +} + +// setRequestCancel sets req.Cancel and adds a deadline context to req +// if deadline is non-zero. The RoundTripper's type is used to +// determine whether the legacy CancelRequest behavior should be used. +// +// As background, there are three ways to cancel a request: +// First was Transport.CancelRequest. (deprecated) +// Second was Request.Cancel. +// Third was Request.Context. +// This function populates the second and third, and uses the first if it really needs to. +func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) { + if deadline.IsZero() { + return nop, alwaysFalse + } + knownTransport := knownRoundTripperImpl(rt, req) + oldCtx := req.Context() + + if req.Cancel == nil && knownTransport { + // If they already had a Request.Context that's + // expiring sooner, do nothing: + if !timeBeforeContextDeadline(deadline, oldCtx) { + return nop, alwaysFalse + } + + var cancelCtx func() + req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) + return cancelCtx, func() bool { return time.Now().After(deadline) } + } + initialReqCancel := req.Cancel // the user's original Request.Cancel, if any + + var cancelCtx func() + if oldCtx := req.Context(); timeBeforeContextDeadline(deadline, oldCtx) { + req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) + } + + cancel := make(chan struct{}) + req.Cancel = cancel + + doCancel := func() { + // The second way in the func comment above: + close(cancel) + // The first way, used only for RoundTripper + // implementations written before Go 1.5 or Go 1.6. + type canceler interface{ CancelRequest(*Request) } + if v, ok := rt.(canceler); ok { + v.CancelRequest(req) + } + } + + stopTimerCh := make(chan struct{}) + var once sync.Once + stopTimer = func() { + once.Do(func() { + close(stopTimerCh) + if cancelCtx != nil { + cancelCtx() + } + }) + } + + timer := time.NewTimer(time.Until(deadline)) + var timedOut atomicBool + + go func() { + select { + case <-initialReqCancel: + doCancel() + timer.Stop() + case <-timer.C: + timedOut.setTrue() + doCancel() + case <-stopTimerCh: + timer.Stop() + } + }() + + return stopTimer, timedOut.isSet +} + +// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt +// "To receive authorization, the client sends the userid and password, +// separated by a single colon (":") character, within a base64 +// encoded string in the credentials." +// It is not meant to be urlencoded. +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +// Get issues a GET to the specified URL. If the response is one of +// the following redirect codes, Get follows the redirect, up to a +// maximum of 10 redirects: +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// 308 (Permanent Redirect) +// +// An error is returned if there were too many redirects or if there +// was an HTTP protocol error. A non-2xx response doesn't cause an +// error. Any returned error will be of type *url.Error. The url.Error +// value's Timeout method will report true if request timed out or was +// canceled. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +// +// Get is a wrapper around DefaultClient.Get. +// +// To make a request with custom headers, use NewRequest and +// DefaultClient.Do. +func Get(url string) (resp *Response, err error) { + return DefaultClient.Get(url) +} + +// Get issues a GET to the specified URL. If the response is one of the +// following redirect codes, Get follows the redirect after calling the +// Client's CheckRedirect function: +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// 308 (Permanent Redirect) +// +// An error is returned if the Client's CheckRedirect function fails +// or if there was an HTTP protocol error. A non-2xx response doesn't +// cause an error. Any returned error will be of type *url.Error. The +// url.Error value's Timeout method will report true if the request +// timed out. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +// +// To make a request with custom headers, use NewRequest and Client.Do. +func (c *Client) Get(url string) (resp *Response, err error) { + req, err := NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +func alwaysFalse() bool { return false } + +// ErrUseLastResponse can be returned by Client.CheckRedirect hooks to +// control how redirects are processed. If returned, the next request +// is not sent and the most recent response is returned with its body +// unclosed. +var ErrUseLastResponse = errors.New("net/http: use last response") + +// checkRedirect calls either the user's configured CheckRedirect +// function, or the default. +func (c *Client) checkRedirect(req *Request, via []*Request) error { + fn := c.CheckRedirect + if fn == nil { + fn = defaultCheckRedirect + } + return fn(req, via) +} + +// redirectBehavior describes what should happen when the +// client encounters a 3xx status code from the server +func redirectBehavior(reqMethod string, resp *Response, ireq *Request) (redirectMethod string, shouldRedirect, includeBody bool) { + switch resp.StatusCode { + case 301, 302, 303: + redirectMethod = reqMethod + shouldRedirect = true + includeBody = false + + // RFC 2616 allowed automatic redirection only with GET and + // HEAD requests. RFC 7231 lifts this restriction, but we still + // restrict other methods to GET to maintain compatibility. + // See Issue 18570. + if reqMethod != "GET" && reqMethod != "HEAD" { + redirectMethod = "GET" + } + case 307, 308: + redirectMethod = reqMethod + shouldRedirect = true + includeBody = true + + // Treat 307 and 308 specially, since they're new in + // Go 1.8, and they also require re-sending the request body. + if resp.Header.Get("Location") == "" { + // 308s have been observed in the wild being served + // without Location headers. Since Go 1.7 and earlier + // didn't follow these codes, just stop here instead + // of returning an error. + // See Issue 17773. + shouldRedirect = false + break + } + if ireq.GetBody == nil && ireq.outgoingLength() != 0 { + // We had a request body, and 307/308 require + // re-sending it, but GetBody is not defined. So just + // return this response to the user instead of an + // error, like we did in Go 1.7 and earlier. + shouldRedirect = false + } + } + return redirectMethod, shouldRedirect, includeBody +} + +// urlErrorOp returns the (*url.Error).Op value to use for the +// provided (*Request).Method value. +func urlErrorOp(method string) string { + if method == "" { + return "Get" + } + return method[:1] + strings.ToLower(method[1:]) +} + +// Do sends an HTTP request and returns an HTTP response, following +// policy (such as redirects, cookies, auth) as configured on the +// client. +// +// An error is returned if caused by client policy (such as +// CheckRedirect), or failure to speak HTTP (such as a network +// connectivity problem). A non-2xx status code doesn't cause an +// error. +// +// If the returned error is nil, the Response will contain a non-nil +// Body which the user is expected to close. If the Body is not both +// read to EOF and closed, the Client's underlying RoundTripper +// (typically Transport) may not be able to re-use a persistent TCP +// connection to the server for a subsequent "keep-alive" request. +// +// The request Body, if non-nil, will be closed by the underlying +// Transport, even on errors. +// +// On error, any Response can be ignored. A non-nil Response with a +// non-nil error only occurs when CheckRedirect fails, and even then +// the returned Response.Body is already closed. +// +// Generally Get, Post, or PostForm will be used instead of Do. +// +// If the server replies with a redirect, the Client first uses the +// CheckRedirect function to determine whether the redirect should be +// followed. If permitted, a 301, 302, or 303 redirect causes +// subsequent requests to use HTTP method GET +// (or HEAD if the original request was HEAD), with no body. +// A 307 or 308 redirect preserves the original HTTP method and body, +// provided that the Request.GetBody function is defined. +// The NewRequest function automatically sets GetBody for common +// standard library body types. +// +// Any returned error will be of type *url.Error. The url.Error +// value's Timeout method will report true if request timed out or was +// canceled. +func (c *Client) Do(req *Request) (*Response, error) { + return c.do(req) +} + +var testHookClientDoResult func(retres *Response, reterr error) + +func (c *Client) do(req *Request) (retres *Response, reterr error) { + if testHookClientDoResult != nil { + defer func() { testHookClientDoResult(retres, reterr) }() + } + if req.URL == nil { + req.closeBody() + return nil, &url.Error{ + Op: urlErrorOp(req.Method), + Err: errors.New("http: nil Request.URL"), + } + } + + var ( + deadline = c.deadline() + reqs []*Request + resp *Response + copyHeaders = c.makeHeadersCopier(req) + reqBodyClosed = false // have we closed the current req.Body? + + // Redirect behavior: + redirectMethod string + includeBody bool + ) + uerr := func(err error) error { + // the body may have been closed already by c.send() + if !reqBodyClosed { + req.closeBody() + } + var urlStr string + if resp != nil && resp.Request != nil { + urlStr = stripPassword(resp.Request.URL) + } else { + urlStr = stripPassword(req.URL) + } + return &url.Error{ + Op: urlErrorOp(reqs[0].Method), + URL: urlStr, + Err: err, + } + } + for { + // For all but the first request, create the next + // request hop and replace req. + if len(reqs) > 0 { + loc := resp.Header.Get("Location") + if loc == "" { + resp.closeBody() + return nil, uerr(fmt.Errorf("%d response missing Location header", resp.StatusCode)) + } + u, err := req.URL.Parse(loc) + if err != nil { + resp.closeBody() + return nil, uerr(fmt.Errorf("failed to parse Location header %q: %v", loc, err)) + } + host := "" + if req.Host != "" && req.Host != req.URL.Host { + // If the caller specified a custom Host header and the + // redirect location is relative, preserve the Host header + // through the redirect. See issue #22233. + if u, _ := url.Parse(loc); u != nil && !u.IsAbs() { + host = req.Host + } + } + ireq := reqs[0] + req = &Request{ + Method: redirectMethod, + Response: resp, + URL: u, + Header: make(Header), + Host: host, + Cancel: ireq.Cancel, + ctx: ireq.ctx, + } + if includeBody && ireq.GetBody != nil { + req.Body, err = ireq.GetBody() + if err != nil { + resp.closeBody() + return nil, uerr(err) + } + req.ContentLength = ireq.ContentLength + } + + // Copy original headers before setting the Referer, + // in case the user set Referer on their first request. + // If they really want to override, they can do it in + // their CheckRedirect func. + copyHeaders(req) + + // Add the Referer header from the most recent + // request URL to the new one, if it's not https->http: + if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL); ref != "" { + req.Header.Set("Referer", ref) + } + err = c.checkRedirect(req, reqs) + + // Sentinel error to let users select the + // previous response, without closing its + // body. See Issue 10069. + if err == ErrUseLastResponse { + return resp, nil + } + + // Close the previous response's body. But + // read at least some of the body so if it's + // small the underlying TCP connection will be + // re-used. No need to check for errors: if it + // fails, the Transport won't reuse it anyway. + const maxBodySlurpSize = 2 << 10 + if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize { + io.CopyN(io.Discard, resp.Body, maxBodySlurpSize) + } + resp.Body.Close() + + if err != nil { + // Special case for Go 1 compatibility: return both the response + // and an error if the CheckRedirect function failed. + // See https://golang.org/issue/3795 + // The resp.Body has already been closed. + ue := uerr(err) + ue.(*url.Error).URL = loc + return resp, ue + } + } + + reqs = append(reqs, req) + var err error + var didTimeout func() bool + if resp, didTimeout, err = c.send(req, deadline); err != nil { + // c.send() always closes req.Body + reqBodyClosed = true + if !deadline.IsZero() && didTimeout() { + err = &httpError{ + // TODO: early in cycle: s/Client.Timeout exceeded/timeout or context cancellation/ + err: err.Error() + " (Client.Timeout exceeded while awaiting headers)", + timeout: true, + } + } + return nil, uerr(err) + } + + var shouldRedirect bool + redirectMethod, shouldRedirect, includeBody = redirectBehavior(req.Method, resp, reqs[0]) + if !shouldRedirect { + return resp, nil + } + + req.closeBody() + } +} + +// makeHeadersCopier makes a function that copies headers from the +// initial Request, ireq. For every redirect, this function must be called +// so that it can copy headers into the upcoming Request. +func (c *Client) makeHeadersCopier(ireq *Request) func(*Request) { + // The headers to copy are from the very initial request. + // We use a closured callback to keep a reference to these original headers. + var ( + ireqhdr = cloneOrMakeHeader(ireq.Header) + icookies map[string][]*Cookie + ) + if c.Jar != nil && ireq.Header.Get("Cookie") != "" { + icookies = make(map[string][]*Cookie) + for _, c := range ireq.Cookies() { + icookies[c.Name] = append(icookies[c.Name], c) + } + } + + preq := ireq // The previous request + return func(req *Request) { + // If Jar is present and there was some initial cookies provided + // via the request header, then we may need to alter the initial + // cookies as we follow redirects since each redirect may end up + // modifying a pre-existing cookie. + // + // Since cookies already set in the request header do not contain + // information about the original domain and path, the logic below + // assumes any new set cookies override the original cookie + // regardless of domain or path. + // + // See https://golang.org/issue/17494 + if c.Jar != nil && icookies != nil { + var changed bool + resp := req.Response // The response that caused the upcoming redirect + for _, c := range resp.Cookies() { + if _, ok := icookies[c.Name]; ok { + delete(icookies, c.Name) + changed = true + } + } + if changed { + ireqhdr.Del("Cookie") + var ss []string + for _, cs := range icookies { + for _, c := range cs { + ss = append(ss, c.Name+"="+c.Value) + } + } + sort.Strings(ss) // Ensure deterministic headers + ireqhdr.Set("Cookie", strings.Join(ss, "; ")) + } + } + + // Copy the initial request's Header values + // (at least the safe ones). + for k, vv := range ireqhdr { + if shouldCopyHeaderOnRedirect(k, preq.URL, req.URL) { + req.Header[k] = vv + } + } + + preq = req // Update previous Request with the current request + } +} + +func defaultCheckRedirect(req *Request, via []*Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + return nil +} + +// Post issues a POST to the specified URL. +// +// Caller should close resp.Body when done reading from it. +// +// If the provided body is an io.Closer, it is closed after the +// request. +// +// Post is a wrapper around DefaultClient.Post. +// +// To set custom headers, use NewRequest and DefaultClient.Do. +// +// See the Client.Do method documentation for details on how redirects +// are handled. +func Post(url, contentType string, body io.Reader) (resp *Response, err error) { + return DefaultClient.Post(url, contentType, body) +} + +// Post issues a POST to the specified URL. +// +// Caller should close resp.Body when done reading from it. +// +// If the provided body is an io.Closer, it is closed after the +// request. +// +// To set custom headers, use NewRequest and Client.Do. +// +// See the Client.Do method documentation for details on how redirects +// are handled. +func (c *Client) Post(url, contentType string, body io.Reader) (resp *Response, err error) { + req, err := NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", contentType) + return c.Do(req) +} + +// PostForm issues a POST to the specified URL, with data's keys and +// values URL-encoded as the request body. +// +// The Content-Type header is set to application/x-www-form-urlencoded. +// To set other headers, use NewRequest and DefaultClient.Do. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +// +// PostForm is a wrapper around DefaultClient.PostForm. +// +// See the Client.Do method documentation for details on how redirects +// are handled. +func PostForm(url string, data url.Values) (resp *Response, err error) { + return DefaultClient.PostForm(url, data) +} + +// PostForm issues a POST to the specified URL, +// with data's keys and values URL-encoded as the request body. +// +// The Content-Type header is set to application/x-www-form-urlencoded. +// To set other headers, use NewRequest and Client.Do. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +// +// See the Client.Do method documentation for details on how redirects +// are handled. +func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) { + return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} + +// Head issues a HEAD to the specified URL. If the response is one of +// the following redirect codes, Head follows the redirect, up to a +// maximum of 10 redirects: +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// 308 (Permanent Redirect) +// +// Head is a wrapper around DefaultClient.Head +func Head(url string) (resp *Response, err error) { + return DefaultClient.Head(url) +} + +// Head issues a HEAD to the specified URL. If the response is one of the +// following redirect codes, Head follows the redirect after calling the +// Client's CheckRedirect function: +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// 308 (Permanent Redirect) +func (c *Client) Head(url string) (resp *Response, err error) { + req, err := NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +// CloseIdleConnections closes any connections on its Transport which +// were previously connected from previous requests but are now +// sitting idle in a "keep-alive" state. It does not interrupt any +// connections currently in use. +// +// If the Client's Transport does not have a CloseIdleConnections method +// then this method does nothing. +func (c *Client) CloseIdleConnections() { + type closeIdler interface { + CloseIdleConnections() + } + if tr, ok := c.transport().(closeIdler); ok { + tr.CloseIdleConnections() + } +} + +// cancelTimerBody is an io.ReadCloser that wraps rc with two features: +// 1) on Read error or close, the stop func is called. +// 2) On Read failure, if reqDidTimeout is true, the error is wrapped and +// marked as net.Error that hit its timeout. +type cancelTimerBody struct { + stop func() // stops the time.Timer waiting to cancel the request + rc io.ReadCloser + reqDidTimeout func() bool +} + +func (b *cancelTimerBody) Read(p []byte) (n int, err error) { + n, err = b.rc.Read(p) + if err == nil { + return n, nil + } + b.stop() + if err == io.EOF { + return n, err + } + if b.reqDidTimeout() { + err = &httpError{ + err: err.Error() + " (Client.Timeout or context cancellation while reading body)", + timeout: true, + } + } + return n, err +} + +func (b *cancelTimerBody) Close() error { + err := b.rc.Close() + b.stop() + return err +} + +func shouldCopyHeaderOnRedirect(headerKey string, initial, dest *url.URL) bool { + switch CanonicalHeaderKey(headerKey) { + case "Authorization", "Www-Authenticate", "Cookie", "Cookie2": + // Permit sending auth/cookie headers from "foo.com" + // to "sub.foo.com". + + // Note that we don't send all cookies to subdomains + // automatically. This function is only used for + // Cookies set explicitly on the initial outgoing + // client request. Cookies automatically added via the + // CookieJar mechanism continue to follow each + // cookie's scope as set by Set-Cookie. But for + // outgoing requests with the Cookie header set + // directly, we don't know their scope, so we assume + // it's for *.domain.com. + + ihost := canonicalAddr(initial) + dhost := canonicalAddr(dest) + return isDomainOrSubdomain(dhost, ihost) + } + // All other headers are copied: + return true +} + +// isDomainOrSubdomain reports whether sub is a subdomain (or exact +// match) of the parent domain. +// +// Both domains must already be in canonical form. +func isDomainOrSubdomain(sub, parent string) bool { + if sub == parent { + return true + } + // If sub is "foo.example.com" and parent is "example.com", + // that means sub must end in "."+parent. + // Do it without allocating. + if !strings.HasSuffix(sub, parent) { + return false + } + return sub[len(sub)-len(parent)-1] == '.' +} + +func stripPassword(u *url.URL) string { + _, passSet := u.User.Password() + if passSet { + return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1) + } + return u.String() +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/client_test.go b/vendor/github.com/lesismal/llib/std/net/http/client_test.go new file mode 100644 index 0000000..a7eae40 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/client_test.go @@ -0,0 +1,2084 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for client.go + +package http_test + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "github.com/lesismal/llib/std/net/http/cookiejar" + "github.com/lesismal/llib/std/net/http/httptest" + "io" + "log" + "net" + . "net/http" + "net/url" + "reflect" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Last-Modified", "sometime") + fmt.Fprintf(w, "User-agent: go\nDisallow: /something/") +}) + +// pedanticReadAll works like io.ReadAll but additionally +// verifies that r obeys the documented io.Reader contract. +func pedanticReadAll(r io.Reader) (b []byte, err error) { + var bufa [64]byte + buf := bufa[:] + for { + n, err := r.Read(buf) + if n == 0 && err == nil { + return nil, fmt.Errorf("Read: n=0 with err=nil") + } + b = append(b, buf[:n]...) + if err == io.EOF { + n, err := r.Read(buf) + if n != 0 || err != io.EOF { + return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err) + } + return b, nil + } + if err != nil { + return b, err + } + } +} + +type chanWriter chan string + +func (w chanWriter) Write(p []byte) (n int, err error) { + w <- string(p) + return len(p), nil +} + +func TestClient(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(robotsTxtHandler) + defer ts.Close() + + c := ts.Client() + r, err := c.Get(ts.URL) + var b []byte + if err == nil { + b, err = pedanticReadAll(r.Body) + r.Body.Close() + } + if err != nil { + t.Error(err) + } else if s := string(b); !strings.HasPrefix(s, "User-agent:") { + t.Errorf("Incorrect page body (did not begin with User-agent): %q", s) + } +} + +func TestClientHead_h1(t *testing.T) { testClientHead(t, h1Mode) } +func TestClientHead_h2(t *testing.T) { testClientHead(t, h2Mode) } + +func testClientHead(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, robotsTxtHandler) + defer cst.close() + + r, err := cst.c.Head(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + if _, ok := r.Header["Last-Modified"]; !ok { + t.Error("Last-Modified header not found.") + } +} + +type recordingTransport struct { + req *Request +} + +func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) { + t.req = req + return nil, errors.New("dummy impl") +} + +func TestGetRequestFormat(t *testing.T) { + setParallel(t) + defer afterTest(t) + tr := &recordingTransport{} + client := &Client{Transport: tr} + url := "http://dummy.faketld/" + client.Get(url) // Note: doesn't hit network + if tr.req.Method != "GET" { + t.Errorf("expected method %q; got %q", "GET", tr.req.Method) + } + if tr.req.URL.String() != url { + t.Errorf("expected URL %q; got %q", url, tr.req.URL.String()) + } + if tr.req.Header == nil { + t.Errorf("expected non-nil request Header") + } +} + +func TestPostRequestFormat(t *testing.T) { + defer afterTest(t) + tr := &recordingTransport{} + client := &Client{Transport: tr} + + url := "http://dummy.faketld/" + json := `{"key":"value"}` + b := strings.NewReader(json) + client.Post(url, "application/json", b) // Note: doesn't hit network + + if tr.req.Method != "POST" { + t.Errorf("got method %q, want %q", tr.req.Method, "POST") + } + if tr.req.URL.String() != url { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), url) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + if tr.req.Close { + t.Error("got Close true, want false") + } + if g, e := tr.req.ContentLength, int64(len(json)); g != e { + t.Errorf("got ContentLength %d, want %d", g, e) + } +} + +func TestPostFormRequestFormat(t *testing.T) { + defer afterTest(t) + tr := &recordingTransport{} + client := &Client{Transport: tr} + + urlStr := "http://dummy.faketld/" + form := make(url.Values) + form.Set("foo", "bar") + form.Add("foo", "bar2") + form.Set("bar", "baz") + client.PostForm(urlStr, form) // Note: doesn't hit network + + if tr.req.Method != "POST" { + t.Errorf("got method %q, want %q", tr.req.Method, "POST") + } + if tr.req.URL.String() != urlStr { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), urlStr) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + if g, e := tr.req.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; g != e { + t.Errorf("got Content-Type %q, want %q", g, e) + } + if tr.req.Close { + t.Error("got Close true, want false") + } + // Depending on map iteration, body can be either of these. + expectedBody := "foo=bar&foo=bar2&bar=baz" + expectedBody1 := "bar=baz&foo=bar&foo=bar2" + if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e { + t.Errorf("got ContentLength %d, want %d", g, e) + } + bodyb, err := io.ReadAll(tr.req.Body) + if err != nil { + t.Fatalf("ReadAll on req.Body: %v", err) + } + if g := string(bodyb); g != expectedBody && g != expectedBody1 { + t.Errorf("got body %q, want %q or %q", g, expectedBody, expectedBody1) + } +} + +func TestClientRedirects(t *testing.T) { + setParallel(t) + defer afterTest(t) + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + n, _ := strconv.Atoi(r.FormValue("n")) + // Test Referer header. (7 is arbitrary position to test at) + if n == 7 { + if g, e := r.Referer(), ts.URL+"/?n=6"; e != g { + t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g) + } + } + if n < 15 { + Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusTemporaryRedirect) + return + } + fmt.Fprintf(w, "n=%d", n) + })) + defer ts.Close() + + c := ts.Client() + _, err := c.Get(ts.URL) + if e, g := `Get "/?n=10": stopped after 10 redirects`, fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Get, expected error %q, got %q", e, g) + } + + // HEAD request should also have the ability to follow redirects. + _, err = c.Head(ts.URL) + if e, g := `Head "/?n=10": stopped after 10 redirects`, fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Head, expected error %q, got %q", e, g) + } + + // Do should also follow redirects. + greq, _ := NewRequest("GET", ts.URL, nil) + _, err = c.Do(greq) + if e, g := `Get "/?n=10": stopped after 10 redirects`, fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Do, expected error %q, got %q", e, g) + } + + // Requests with an empty Method should also redirect (Issue 12705) + greq.Method = "" + _, err = c.Do(greq) + if e, g := `Get "/?n=10": stopped after 10 redirects`, fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Do and empty Method, expected error %q, got %q", e, g) + } + + var checkErr error + var lastVia []*Request + var lastReq *Request + c.CheckRedirect = func(req *Request, via []*Request) error { + lastReq = req + lastVia = via + return checkErr + } + res, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + res.Body.Close() + finalUrl := res.Request.URL.String() + if e, g := "", fmt.Sprintf("%v", err); e != g { + t.Errorf("with custom client, expected error %q, got %q", e, g) + } + if !strings.HasSuffix(finalUrl, "/?n=15") { + t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl) + } + if e, g := 15, len(lastVia); e != g { + t.Errorf("expected lastVia to have contained %d elements; got %d", e, g) + } + + // Test that Request.Cancel is propagated between requests (Issue 14053) + creq, _ := NewRequest("HEAD", ts.URL, nil) + cancel := make(chan struct{}) + creq.Cancel = cancel + if _, err := c.Do(creq); err != nil { + t.Fatal(err) + } + if lastReq == nil { + t.Fatal("didn't see redirect") + } + if lastReq.Cancel != cancel { + t.Errorf("expected lastReq to have the cancel channel set on the initial req") + } + + checkErr = errors.New("no redirects allowed") + res, err = c.Get(ts.URL) + if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr { + t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err) + } + if res == nil { + t.Fatalf("Expected a non-nil Response on CheckRedirect failure (https://golang.org/issue/3795)") + } + res.Body.Close() + if res.Header.Get("Location") == "" { + t.Errorf("no Location header in Response") + } +} + +// Tests that Client redirects' contexts are derived from the original request's context. +func TestClientRedirectContext(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Redirect(w, r, "/", StatusTemporaryRedirect) + })) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + c := ts.Client() + c.CheckRedirect = func(req *Request, via []*Request) error { + cancel() + select { + case <-req.Context().Done(): + return nil + case <-time.After(5 * time.Second): + return errors.New("redirected request's context never expired after root request canceled") + } + } + req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil) + _, err := c.Do(req) + ue, ok := err.(*url.Error) + if !ok { + t.Fatalf("got error %T; want *url.Error", err) + } + if ue.Err != context.Canceled { + t.Errorf("url.Error.Err = %v; want %v", ue.Err, context.Canceled) + } +} + +type redirectTest struct { + suffix string + want int // response code + redirectBody string +} + +func TestPostRedirects(t *testing.T) { + postRedirectTests := []redirectTest{ + {"/", 200, "first"}, + {"/?code=301&next=302", 200, "c301"}, + {"/?code=302&next=302", 200, "c302"}, + {"/?code=303&next=301", 200, "c303wc301"}, // Issue 9348 + {"/?code=304", 304, "c304"}, + {"/?code=305", 305, "c305"}, + {"/?code=307&next=303,308,302", 200, "c307"}, + {"/?code=308&next=302,301", 200, "c308"}, + {"/?code=404", 404, "c404"}, + } + + wantSegments := []string{ + `POST / "first"`, + `POST /?code=301&next=302 "c301"`, + `GET /?code=302 ""`, + `GET / ""`, + `POST /?code=302&next=302 "c302"`, + `GET /?code=302 ""`, + `GET / ""`, + `POST /?code=303&next=301 "c303wc301"`, + `GET /?code=301 ""`, + `GET / ""`, + `POST /?code=304 "c304"`, + `POST /?code=305 "c305"`, + `POST /?code=307&next=303,308,302 "c307"`, + `POST /?code=303&next=308,302 "c307"`, + `GET /?code=308&next=302 ""`, + `GET /?code=302 "c307"`, + `GET / ""`, + `POST /?code=308&next=302,301 "c308"`, + `POST /?code=302&next=301 "c308"`, + `GET /?code=301 ""`, + `GET / ""`, + `POST /?code=404 "c404"`, + } + want := strings.Join(wantSegments, "\n") + testRedirectsByMethod(t, "POST", postRedirectTests, want) +} + +func TestDeleteRedirects(t *testing.T) { + deleteRedirectTests := []redirectTest{ + {"/", 200, "first"}, + {"/?code=301&next=302,308", 200, "c301"}, + {"/?code=302&next=302", 200, "c302"}, + {"/?code=303", 200, "c303"}, + {"/?code=307&next=301,308,303,302,304", 304, "c307"}, + {"/?code=308&next=307", 200, "c308"}, + {"/?code=404", 404, "c404"}, + } + + wantSegments := []string{ + `DELETE / "first"`, + `DELETE /?code=301&next=302,308 "c301"`, + `GET /?code=302&next=308 ""`, + `GET /?code=308 ""`, + `GET / "c301"`, + `DELETE /?code=302&next=302 "c302"`, + `GET /?code=302 ""`, + `GET / ""`, + `DELETE /?code=303 "c303"`, + `GET / ""`, + `DELETE /?code=307&next=301,308,303,302,304 "c307"`, + `DELETE /?code=301&next=308,303,302,304 "c307"`, + `GET /?code=308&next=303,302,304 ""`, + `GET /?code=303&next=302,304 "c307"`, + `GET /?code=302&next=304 ""`, + `GET /?code=304 ""`, + `DELETE /?code=308&next=307 "c308"`, + `DELETE /?code=307 "c308"`, + `DELETE / "c308"`, + `DELETE /?code=404 "c404"`, + } + want := strings.Join(wantSegments, "\n") + testRedirectsByMethod(t, "DELETE", deleteRedirectTests, want) +} + +func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, want string) { + defer afterTest(t) + var log struct { + sync.Mutex + bytes.Buffer + } + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + log.Lock() + slurp, _ := io.ReadAll(r.Body) + fmt.Fprintf(&log.Buffer, "%s %s %q", r.Method, r.RequestURI, slurp) + if cl := r.Header.Get("Content-Length"); r.Method == "GET" && len(slurp) == 0 && (r.ContentLength != 0 || cl != "") { + fmt.Fprintf(&log.Buffer, " (but with body=%T, content-length = %v, %q)", r.Body, r.ContentLength, cl) + } + log.WriteByte('\n') + log.Unlock() + urlQuery := r.URL.Query() + if v := urlQuery.Get("code"); v != "" { + location := ts.URL + if final := urlQuery.Get("next"); final != "" { + splits := strings.Split(final, ",") + first, rest := splits[0], splits[1:] + location = fmt.Sprintf("%s?code=%s", location, first) + if len(rest) > 0 { + location = fmt.Sprintf("%s&next=%s", location, strings.Join(rest, ",")) + } + } + code, _ := strconv.Atoi(v) + if code/100 == 3 { + w.Header().Set("Location", location) + } + w.WriteHeader(code) + } + })) + defer ts.Close() + + c := ts.Client() + for _, tt := range table { + content := tt.redirectBody + req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content)) + req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(strings.NewReader(content)), nil } + res, err := c.Do(req) + + if err != nil { + t.Fatal(err) + } + if res.StatusCode != tt.want { + t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want) + } + } + log.Lock() + got := log.String() + log.Unlock() + + got = strings.TrimSpace(got) + want = strings.TrimSpace(want) + + if got != want { + got, want, lines := removeCommonLines(got, want) + t.Errorf("Log differs after %d common lines.\n\nGot:\n%s\n\nWant:\n%s\n", lines, got, want) + } +} + +func removeCommonLines(a, b string) (asuffix, bsuffix string, commonLines int) { + for { + nl := strings.IndexByte(a, '\n') + if nl < 0 { + return a, b, commonLines + } + line := a[:nl+1] + if !strings.HasPrefix(b, line) { + return a, b, commonLines + } + commonLines++ + a = a[len(line):] + b = b[len(line):] + } +} + +func TestClientRedirectUseResponse(t *testing.T) { + setParallel(t) + defer afterTest(t) + const body = "Hello, world." + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if strings.Contains(r.URL.Path, "/other") { + io.WriteString(w, "wrong body") + } else { + w.Header().Set("Location", ts.URL+"/other") + w.WriteHeader(StatusFound) + io.WriteString(w, body) + } + })) + defer ts.Close() + + c := ts.Client() + c.CheckRedirect = func(req *Request, via []*Request) error { + if req.Response == nil { + t.Error("expected non-nil Request.Response") + } + return ErrUseLastResponse + } + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != StatusFound { + t.Errorf("status = %d; want %d", res.StatusCode, StatusFound) + } + defer res.Body.Close() + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != body { + t.Errorf("body = %q; want %q", slurp, body) + } +} + +// Issue 17773: don't follow a 308 (or 307) if the response doesn't +// have a Location header. +func TestClientRedirect308NoLocation(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Foo", "Bar") + w.WriteHeader(308) + })) + defer ts.Close() + c := ts.Client() + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != 308 { + t.Errorf("status = %d; want %d", res.StatusCode, 308) + } + if got := res.Header.Get("Foo"); got != "Bar" { + t.Errorf("Foo header = %q; want Bar", got) + } +} + +// Don't follow a 307/308 if we can't resent the request body. +func TestClientRedirect308NoGetBody(t *testing.T) { + setParallel(t) + defer afterTest(t) + const fakeURL = "https://localhost:1234/" // won't be hit + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Location", fakeURL) + w.WriteHeader(308) + })) + defer ts.Close() + req, err := NewRequest("POST", ts.URL, strings.NewReader("some body")) + if err != nil { + t.Fatal(err) + } + c := ts.Client() + req.GetBody = nil // so it can't rewind. + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != 308 { + t.Errorf("status = %d; want %d", res.StatusCode, 308) + } + if got := res.Header.Get("Location"); got != fakeURL { + t.Errorf("Location header = %q; want %q", got, fakeURL) + } +} + +var expectedCookies = []*Cookie{ + {Name: "ChocolateChip", Value: "tasty"}, + {Name: "First", Value: "Hit"}, + {Name: "Second", Value: "Hit"}, +} + +var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) { + for _, cookie := range r.Cookies() { + SetCookie(w, cookie) + } + if r.URL.Path == "/" { + SetCookie(w, expectedCookies[1]) + Redirect(w, r, "/second", StatusMovedPermanently) + } else { + SetCookie(w, expectedCookies[2]) + w.Write([]byte("hello")) + } +}) + +func TestClientSendsCookieFromJar(t *testing.T) { + defer afterTest(t) + tr := &recordingTransport{} + client := &Client{Transport: tr} + client.Jar = &TestJar{perURL: make(map[string][]*Cookie)} + us := "http://dummy.faketld/" + u, _ := url.Parse(us) + client.Jar.SetCookies(u, expectedCookies) + + client.Get(us) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + client.Head(us) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + client.Post(us, "text/plain", strings.NewReader("body")) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + client.PostForm(us, url.Values{}) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + req, _ := NewRequest("GET", us, nil) + client.Do(req) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + req, _ = NewRequest("POST", us, nil) + client.Do(req) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) +} + +// Just enough correctness for our redirect tests. Uses the URL.Host as the +// scope of all cookies. +type TestJar struct { + m sync.Mutex + perURL map[string][]*Cookie +} + +func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) { + j.m.Lock() + defer j.m.Unlock() + if j.perURL == nil { + j.perURL = make(map[string][]*Cookie) + } + j.perURL[u.Host] = cookies +} + +func (j *TestJar) Cookies(u *url.URL) []*Cookie { + j.m.Lock() + defer j.m.Unlock() + return j.perURL[u.Host] +} + +func TestRedirectCookiesJar(t *testing.T) { + setParallel(t) + defer afterTest(t) + var ts *httptest.Server + ts = httptest.NewServer(echoCookiesRedirectHandler) + defer ts.Close() + c := ts.Client() + c.Jar = new(TestJar) + u, _ := url.Parse(ts.URL) + c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]}) + resp, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + resp.Body.Close() + matchReturnedCookies(t, expectedCookies, resp.Cookies()) +} + +func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { + if len(given) != len(expected) { + t.Logf("Received cookies: %v", given) + t.Errorf("Expected %d cookies, got %d", len(expected), len(given)) + } + for _, ec := range expected { + foundC := false + for _, c := range given { + if ec.Name == c.Name && ec.Value == c.Value { + foundC = true + break + } + } + if !foundC { + t.Errorf("Missing cookie %v", ec) + } + } +} + +func TestJarCalls(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + pathSuffix := r.RequestURI[1:] + if r.RequestURI == "/nosetcookie" { + return // don't set cookies for this path + } + SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix}) + if r.RequestURI == "/" { + Redirect(w, r, "http://secondhost.fake/secondpath", 302) + } + })) + defer ts.Close() + jar := new(RecordingJar) + c := ts.Client() + c.Jar = jar + c.Transport.(*Transport).Dial = func(_ string, _ string) (net.Conn, error) { + return net.Dial("tcp", ts.Listener.Addr().String()) + } + _, err := c.Get("http://firsthost.fake/") + if err != nil { + t.Fatal(err) + } + _, err = c.Get("http://firsthost.fake/nosetcookie") + if err != nil { + t.Fatal(err) + } + got := jar.log.String() + want := `Cookies("http://firsthost.fake/") +SetCookie("http://firsthost.fake/", [name=val]) +Cookies("http://secondhost.fake/secondpath") +SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath]) +Cookies("http://firsthost.fake/nosetcookie") +` + if got != want { + t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want) + } +} + +// RecordingJar keeps a log of calls made to it, without +// tracking any cookies. +type RecordingJar struct { + mu sync.Mutex + log bytes.Buffer +} + +func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) { + j.logf("SetCookie(%q, %v)\n", u, cookies) +} + +func (j *RecordingJar) Cookies(u *url.URL) []*Cookie { + j.logf("Cookies(%q)\n", u) + return nil +} + +func (j *RecordingJar) logf(format string, args ...interface{}) { + j.mu.Lock() + defer j.mu.Unlock() + fmt.Fprintf(&j.log, format, args...) +} + +func TestStreamingGet_h1(t *testing.T) { testStreamingGet(t, h1Mode) } +func TestStreamingGet_h2(t *testing.T) { testStreamingGet(t, h2Mode) } + +func testStreamingGet(t *testing.T, h2 bool) { + defer afterTest(t) + say := make(chan string) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.(Flusher).Flush() + for str := range say { + w.Write([]byte(str)) + w.(Flusher).Flush() + } + })) + defer cst.close() + + c := cst.c + res, err := c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + var buf [10]byte + for _, str := range []string{"i", "am", "also", "known", "as", "comet"} { + say <- str + n, err := io.ReadFull(res.Body, buf[0:len(str)]) + if err != nil { + t.Fatalf("ReadFull on %q: %v", str, err) + } + if n != len(str) { + t.Fatalf("Receiving %q, only read %d bytes", str, n) + } + got := string(buf[0:n]) + if got != str { + t.Fatalf("Expected %q, got %q", str, got) + } + } + close(say) + _, err = io.ReadFull(res.Body, buf[0:1]) + if err != io.EOF { + t.Fatalf("at end expected EOF, got %v", err) + } +} + +type writeCountingConn struct { + net.Conn + count *int +} + +func (c *writeCountingConn) Write(p []byte) (int, error) { + *c.count++ + return c.Conn.Write(p) +} + +// TestClientWrites verifies that client requests are buffered and we +// don't send a TCP packet per line of the http request + body. +func TestClientWrites(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + })) + defer ts.Close() + + writes := 0 + dialer := func(netz string, addr string) (net.Conn, error) { + c, err := net.Dial(netz, addr) + if err == nil { + c = &writeCountingConn{c, &writes} + } + return c, err + } + c := ts.Client() + c.Transport.(*Transport).Dial = dialer + + _, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if writes != 1 { + t.Errorf("Get request did %d Write calls, want 1", writes) + } + + writes = 0 + _, err = c.PostForm(ts.URL, url.Values{"foo": {"bar"}}) + if err != nil { + t.Fatal(err) + } + if writes != 1 { + t.Errorf("Post request did %d Write calls, want 1", writes) + } +} + +func TestClientInsecureTransport(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("Hello")) + })) + errc := make(chanWriter, 10) // but only expecting 1 + ts.Config.ErrorLog = log.New(errc, "", 0) + defer ts.Close() + + // TODO(bradfitz): add tests for skipping hostname checks too? + // would require a new cert for testing, and probably + // redundant with these tests. + c := ts.Client() + for _, insecure := range []bool{true, false} { + c.Transport.(*Transport).TLSClientConfig = &tls.Config{ + InsecureSkipVerify: insecure, + } + res, err := c.Get(ts.URL) + if (err == nil) != insecure { + t.Errorf("insecure=%v: got unexpected err=%v", insecure, err) + } + if res != nil { + res.Body.Close() + } + } + + select { + case v := <-errc: + if !strings.Contains(v, "TLS handshake error") { + t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for logged error") + } + +} + +func TestClientErrorWithRequestURI(t *testing.T) { + defer afterTest(t) + req, _ := NewRequest("GET", "http://localhost:1234/", nil) + req.RequestURI = "/this/field/is/illegal/and/should/error/" + _, err := DefaultClient.Do(req) + if err == nil { + t.Fatalf("expected an error") + } + if !strings.Contains(err.Error(), "RequestURI") { + t.Errorf("wanted error mentioning RequestURI; got error: %v", err) + } +} + +func TestClientWithCorrectTLSServerName(t *testing.T) { + defer afterTest(t) + + const serverName = "example.com" + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.TLS.ServerName != serverName { + t.Errorf("expected client to set ServerName %q, got: %q", serverName, r.TLS.ServerName) + } + })) + defer ts.Close() + + c := ts.Client() + c.Transport.(*Transport).TLSClientConfig.ServerName = serverName + if _, err := c.Get(ts.URL); err != nil { + t.Fatalf("expected successful TLS connection, got error: %v", err) + } +} + +func TestClientWithIncorrectTLSServerName(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() + errc := make(chanWriter, 10) // but only expecting 1 + ts.Config.ErrorLog = log.New(errc, "", 0) + + c := ts.Client() + c.Transport.(*Transport).TLSClientConfig.ServerName = "badserver" + _, err := c.Get(ts.URL) + if err == nil { + t.Fatalf("expected an error") + } + if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") { + t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) + } + select { + case v := <-errc: + if !strings.Contains(v, "TLS handshake error") { + t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for logged error") + } +} + +// Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName +// when not empty. +// +// tls.Config.ServerName (non-empty, set to "example.com") takes +// precedence over "some-other-host.tld" which previously incorrectly +// took precedence. We don't actually connect to (or even resolve) +// "some-other-host.tld", though, because of the Transport.Dial hook. +// +// The httptest.Server has a cert with "example.com" as its name. +func TestTransportUsesTLSConfigServerName(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("Hello")) + })) + defer ts.Close() + + c := ts.Client() + tr := c.Transport.(*Transport) + tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names + tr.Dial = func(netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) + } + res, err := c.Get("https://some-other-host.tld/") + if err != nil { + t.Fatal(err) + } + res.Body.Close() +} + +func TestResponseSetsTLSConnectionState(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("Hello")) + })) + defer ts.Close() + + c := ts.Client() + tr := c.Transport.(*Transport) + tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA} + tr.TLSClientConfig.MaxVersion = tls.VersionTLS12 // to get to pick the cipher suite + tr.Dial = func(netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) + } + res, err := c.Get("https://example.com/") + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.TLS == nil { + t.Fatal("Response didn't set TLS Connection State.") + } + if got, want := res.TLS.CipherSuite, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA; got != want { + t.Errorf("TLS Cipher Suite = %d; want %d", got, want) + } +} + +// Check that an HTTPS client can interpret a particular TLS error +// to determine that the server is speaking HTTP. +// See golang.org/issue/11111. +func TestHTTPSClientDetectsHTTPServer(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + ts.Config.ErrorLog = quietLog + defer ts.Close() + + _, err := Get(strings.Replace(ts.URL, "http", "https", 1)) + if got := err.Error(); !strings.Contains(got, "HTTP response to HTTPS client") { + t.Fatalf("error = %q; want error indicating HTTP response to HTTPS request", got) + } +} + +// Verify Response.ContentLength is populated. https://golang.org/issue/4126 +func TestClientHeadContentLength_h1(t *testing.T) { + testClientHeadContentLength(t, h1Mode) +} + +func TestClientHeadContentLength_h2(t *testing.T) { + testClientHeadContentLength(t, h2Mode) +} + +func testClientHeadContentLength(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + if v := r.FormValue("cl"); v != "" { + w.Header().Set("Content-Length", v) + } + })) + defer cst.close() + tests := []struct { + suffix string + want int64 + }{ + {"/?cl=1234", 1234}, + {"/?cl=0", 0}, + {"", -1}, + } + for _, tt := range tests { + req, _ := NewRequest("HEAD", cst.ts.URL+tt.suffix, nil) + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.ContentLength != tt.want { + t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want) + } + bs, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if len(bs) != 0 { + t.Errorf("Unexpected content: %q", bs) + } + } +} + +func TestEmptyPasswordAuth(t *testing.T) { + setParallel(t) + defer afterTest(t) + gopher := "gopher" + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Basic ") { + encoded := auth[6:] + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + t.Fatal(err) + } + expected := gopher + ":" + s := string(decoded) + if expected != s { + t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected) + } + } else { + t.Errorf("Invalid auth %q", auth) + } + })) + defer ts.Close() + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + req.URL.User = url.User(gopher) + c := ts.Client() + resp, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() +} + +func TestBasicAuth(t *testing.T) { + defer afterTest(t) + tr := &recordingTransport{} + client := &Client{Transport: tr} + + url := "http://My%20User:My%20Pass@dummy.faketld/" + expected := "My User:My Pass" + client.Get(url) + + if tr.req.Method != "GET" { + t.Errorf("got method %q, want %q", tr.req.Method, "GET") + } + if tr.req.URL.String() != url { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), url) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + auth := tr.req.Header.Get("Authorization") + if strings.HasPrefix(auth, "Basic ") { + encoded := auth[6:] + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + t.Fatal(err) + } + s := string(decoded) + if expected != s { + t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected) + } + } else { + t.Errorf("Invalid auth %q", auth) + } +} + +func TestBasicAuthHeadersPreserved(t *testing.T) { + defer afterTest(t) + tr := &recordingTransport{} + client := &Client{Transport: tr} + + // If Authorization header is provided, username in URL should not override it + url := "http://My%20User@dummy.faketld/" + req, err := NewRequest("GET", url, nil) + if err != nil { + t.Fatal(err) + } + req.SetBasicAuth("My User", "My Pass") + expected := "My User:My Pass" + client.Do(req) + + if tr.req.Method != "GET" { + t.Errorf("got method %q, want %q", tr.req.Method, "GET") + } + if tr.req.URL.String() != url { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), url) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + auth := tr.req.Header.Get("Authorization") + if strings.HasPrefix(auth, "Basic ") { + encoded := auth[6:] + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + t.Fatal(err) + } + s := string(decoded) + if expected != s { + t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected) + } + } else { + t.Errorf("Invalid auth %q", auth) + } + +} + +func TestStripPasswordFromError(t *testing.T) { + client := &Client{Transport: &recordingTransport{}} + testCases := []struct { + desc string + in string + out string + }{ + { + desc: "Strip password from error message", + in: "http://user:password@dummy.faketld/", + out: `Get "http://user:***@dummy.faketld/": dummy impl`, + }, + { + desc: "Don't Strip password from domain name", + in: "http://user:password@password.faketld/", + out: `Get "http://user:***@password.faketld/": dummy impl`, + }, + { + desc: "Don't Strip password from path", + in: "http://user:password@dummy.faketld/password", + out: `Get "http://user:***@dummy.faketld/password": dummy impl`, + }, + { + desc: "Strip escaped password", + in: "http://user:pa%2Fssword@dummy.faketld/", + out: `Get "http://user:***@dummy.faketld/": dummy impl`, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + _, err := client.Get(tC.in) + if err.Error() != tC.out { + t.Errorf("Unexpected output for %q: expected %q, actual %q", + tC.in, tC.out, err.Error()) + } + }) + } +} + +func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) } +func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) } + +func testClientTimeout(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + testDone := make(chan struct{}) // closed in defer below + + sawRoot := make(chan bool, 1) + sawSlow := make(chan bool, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/" { + sawRoot <- true + Redirect(w, r, "/slow", StatusFound) + return + } + if r.URL.Path == "/slow" { + sawSlow <- true + w.Write([]byte("Hello")) + w.(Flusher).Flush() + <-testDone + return + } + })) + defer cst.close() + defer close(testDone) // before cst.close, to unblock /slow handler + + // 200ms should be long enough to get a normal request (the / + // handler), but not so long that it makes the test slow. + const timeout = 200 * time.Millisecond + cst.c.Timeout = timeout + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + if strings.Contains(err.Error(), "Client.Timeout") { + t.Skipf("host too slow to get fast resource in %v", timeout) + } + t.Fatal(err) + } + + select { + case <-sawRoot: + // good. + default: + t.Fatal("handler never got / request") + } + + select { + case <-sawSlow: + // good. + default: + t.Fatal("handler never got /slow request") + } + + errc := make(chan error, 1) + go func() { + _, err := io.ReadAll(res.Body) + errc <- err + res.Body.Close() + }() + + const failTime = 5 * time.Second + select { + case err := <-errc: + if err == nil { + t.Fatal("expected error from ReadAll") + } + ne, ok := err.(net.Error) + if !ok { + t.Errorf("error value from ReadAll was %T; expected some net.Error", err) + } else if !ne.Timeout() { + t.Errorf("net.Error.Timeout = false; want true") + } + if got := ne.Error(); !strings.Contains(got, "(Client.Timeout") { + t.Errorf("error string = %q; missing timeout substring", got) + } + case <-time.After(failTime): + t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout) + } +} + +func TestClientTimeout_Headers_h1(t *testing.T) { testClientTimeout_Headers(t, h1Mode) } +func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h2Mode) } + +// Client.Timeout firing before getting to the body +func testClientTimeout_Headers(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + donec := make(chan bool, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + <-donec + }), optQuietLog) + defer cst.close() + // Note that we use a channel send here and not a close. + // The race detector doesn't know that we're waiting for a timeout + // and thinks that the waitgroup inside httptest.Server is added to concurrently + // with us closing it. If we timed out immediately, we could close the testserver + // before we entered the handler. We're not timing out immediately and there's + // no way we would be done before we entered the handler, but the race detector + // doesn't know this, so synchronize explicitly. + defer func() { donec <- true }() + + cst.c.Timeout = 5 * time.Millisecond + res, err := cst.c.Get(cst.ts.URL) + if err == nil { + res.Body.Close() + t.Fatal("got response from Get; expected error") + } + if _, ok := err.(*url.Error); !ok { + t.Fatalf("Got error of type %T; want *url.Error", err) + } + ne, ok := err.(net.Error) + if !ok { + t.Fatalf("Got error of type %T; want some net.Error", err) + } + if !ne.Timeout() { + t.Error("net.Error.Timeout = false; want true") + } + if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") { + t.Errorf("error string = %q; missing timeout substring", got) + } +} + +// Issue 16094: if Client.Timeout is set but not hit, a Timeout error shouldn't be +// returned. +func TestClientTimeoutCancel(t *testing.T) { + setParallel(t) + defer afterTest(t) + + testDone := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.(Flusher).Flush() + <-testDone + })) + defer cst.close() + defer close(testDone) + + cst.c.Timeout = 1 * time.Hour + req, _ := NewRequest("GET", cst.ts.URL, nil) + req.Cancel = ctx.Done() + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + cancel() + _, err = io.Copy(io.Discard, res.Body) + if err != ExportErrRequestCanceled { + t.Fatalf("error = %v; want errRequestCanceled", err) + } +} + +func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) } +func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) } +func testClientRedirectEatsBody(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + saw := make(chan string, 2) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + saw <- r.RemoteAddr + if r.URL.Path == "/" { + Redirect(w, r, "/foo", StatusFound) // which includes a body + } + })) + defer cst.close() + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + _, err = io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + + var first string + select { + case first = <-saw: + default: + t.Fatal("server didn't see a request") + } + + var second string + select { + case second = <-saw: + default: + t.Fatal("server didn't see a second request") + } + + if first != second { + t.Fatal("server saw different client ports before & after the redirect") + } +} + +// eofReaderFunc is an io.Reader that runs itself, and then returns io.EOF. +type eofReaderFunc func() + +func (f eofReaderFunc) Read(p []byte) (n int, err error) { + f() + return 0, io.EOF +} + +func TestReferer(t *testing.T) { + tests := []struct { + lastReq, newReq string // from -> to URLs + want string + }{ + // don't send user: + {"http://gopher@test.com", "http://link.com", "http://test.com"}, + {"https://gopher@test.com", "https://link.com", "https://test.com"}, + + // don't send a user and password: + {"http://gopher:go@test.com", "http://link.com", "http://test.com"}, + {"https://gopher:go@test.com", "https://link.com", "https://test.com"}, + + // nothing to do: + {"http://test.com", "http://link.com", "http://test.com"}, + {"https://test.com", "https://link.com", "https://test.com"}, + + // https to http doesn't send a referer: + {"https://test.com", "http://link.com", ""}, + {"https://gopher:go@test.com", "http://link.com", ""}, + } + for _, tt := range tests { + l, err := url.Parse(tt.lastReq) + if err != nil { + t.Fatal(err) + } + n, err := url.Parse(tt.newReq) + if err != nil { + t.Fatal(err) + } + r := ExportRefererForURL(l, n) + if r != tt.want { + t.Errorf("refererForURL(%q, %q) = %q; want %q", tt.lastReq, tt.newReq, r, tt.want) + } + } +} + +// issue15577Tripper returns a Response with a redirect response +// header and doesn't populate its Response.Request field. +type issue15577Tripper struct{} + +func (issue15577Tripper) RoundTrip(*Request) (*Response, error) { + resp := &Response{ + StatusCode: 303, + Header: map[string][]string{"Location": {"http://www.example.com/"}}, + Body: io.NopCloser(strings.NewReader("")), + } + return resp, nil +} + +// Issue 15577: don't assume the roundtripper's response populates its Request field. +func TestClientRedirectResponseWithoutRequest(t *testing.T) { + c := &Client{ + CheckRedirect: func(*Request, []*Request) error { return fmt.Errorf("no redirects!") }, + Transport: issue15577Tripper{}, + } + // Check that this doesn't crash: + c.Get("http://dummy.tld") +} + +// Issue 4800: copy (some) headers when Client follows a redirect. +func TestClientCopyHeadersOnRedirect(t *testing.T) { + const ( + ua = "some-agent/1.2" + xfoo = "foo-val" + ) + var ts2URL string + ts1 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + want := Header{ + "User-Agent": []string{ua}, + "X-Foo": []string{xfoo}, + "Referer": []string{ts2URL}, + "Accept-Encoding": []string{"gzip"}, + } + if !reflect.DeepEqual(r.Header, want) { + t.Errorf("Request.Header = %#v; want %#v", r.Header, want) + } + if t.Failed() { + w.Header().Set("Result", "got errors") + } else { + w.Header().Set("Result", "ok") + } + })) + defer ts1.Close() + ts2 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Redirect(w, r, ts1.URL, StatusFound) + })) + defer ts2.Close() + ts2URL = ts2.URL + + c := ts1.Client() + c.CheckRedirect = func(r *Request, via []*Request) error { + want := Header{ + "User-Agent": []string{ua}, + "X-Foo": []string{xfoo}, + "Referer": []string{ts2URL}, + } + if !reflect.DeepEqual(r.Header, want) { + t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want) + } + return nil + } + + req, _ := NewRequest("GET", ts2.URL, nil) + req.Header.Add("User-Agent", ua) + req.Header.Add("X-Foo", xfoo) + req.Header.Add("Cookie", "foo=bar") + req.Header.Add("Authorization", "secretpassword") + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + t.Fatal(res.Status) + } + if got := res.Header.Get("Result"); got != "ok" { + t.Errorf("result = %q; want ok", got) + } +} + +// Issue 22233: copy host when Client follows a relative redirect. +func TestClientCopyHostOnRedirect(t *testing.T) { + // Virtual hostname: should not receive any request. + virtual := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + t.Errorf("Virtual host received request %v", r.URL) + w.WriteHeader(403) + io.WriteString(w, "should not see this response") + })) + defer virtual.Close() + virtualHost := strings.TrimPrefix(virtual.URL, "http://") + t.Logf("Virtual host is %v", virtualHost) + + // Actual hostname: should not receive any request. + const wantBody = "response body" + var tsURL string + var tsHost string + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + switch r.URL.Path { + case "/": + // Relative redirect. + if r.Host != virtualHost { + t.Errorf("Serving /: Request.Host = %#v; want %#v", r.Host, virtualHost) + w.WriteHeader(404) + return + } + w.Header().Set("Location", "/hop") + w.WriteHeader(302) + case "/hop": + // Absolute redirect. + if r.Host != virtualHost { + t.Errorf("Serving /hop: Request.Host = %#v; want %#v", r.Host, virtualHost) + w.WriteHeader(404) + return + } + w.Header().Set("Location", tsURL+"/final") + w.WriteHeader(302) + case "/final": + if r.Host != tsHost { + t.Errorf("Serving /final: Request.Host = %#v; want %#v", r.Host, tsHost) + w.WriteHeader(404) + return + } + w.WriteHeader(200) + io.WriteString(w, wantBody) + default: + t.Errorf("Serving unexpected path %q", r.URL.Path) + w.WriteHeader(404) + } + })) + defer ts.Close() + tsURL = ts.URL + tsHost = strings.TrimPrefix(ts.URL, "http://") + t.Logf("Server host is %v", tsHost) + + c := ts.Client() + req, _ := NewRequest("GET", ts.URL, nil) + req.Host = virtualHost + resp, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatal(resp.Status) + } + if got, err := io.ReadAll(resp.Body); err != nil || string(got) != wantBody { + t.Errorf("body = %q; want %q", got, wantBody) + } +} + +// Issue 17494: cookies should be altered when Client follows redirects. +func TestClientAltersCookiesOnRedirect(t *testing.T) { + cookieMap := func(cs []*Cookie) map[string][]string { + m := make(map[string][]string) + for _, c := range cs { + m[c.Name] = append(m[c.Name], c.Value) + } + return m + } + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + var want map[string][]string + got := cookieMap(r.Cookies()) + + c, _ := r.Cookie("Cycle") + switch c.Value { + case "0": + want = map[string][]string{ + "Cookie1": {"OldValue1a", "OldValue1b"}, + "Cookie2": {"OldValue2"}, + "Cookie3": {"OldValue3a", "OldValue3b"}, + "Cookie4": {"OldValue4"}, + "Cycle": {"0"}, + } + SetCookie(w, &Cookie{Name: "Cycle", Value: "1", Path: "/"}) + SetCookie(w, &Cookie{Name: "Cookie2", Path: "/", MaxAge: -1}) // Delete cookie from Header + Redirect(w, r, "/", StatusFound) + case "1": + want = map[string][]string{ + "Cookie1": {"OldValue1a", "OldValue1b"}, + "Cookie3": {"OldValue3a", "OldValue3b"}, + "Cookie4": {"OldValue4"}, + "Cycle": {"1"}, + } + SetCookie(w, &Cookie{Name: "Cycle", Value: "2", Path: "/"}) + SetCookie(w, &Cookie{Name: "Cookie3", Value: "NewValue3", Path: "/"}) // Modify cookie in Header + SetCookie(w, &Cookie{Name: "Cookie4", Value: "NewValue4", Path: "/"}) // Modify cookie in Jar + Redirect(w, r, "/", StatusFound) + case "2": + want = map[string][]string{ + "Cookie1": {"OldValue1a", "OldValue1b"}, + "Cookie3": {"NewValue3"}, + "Cookie4": {"NewValue4"}, + "Cycle": {"2"}, + } + SetCookie(w, &Cookie{Name: "Cycle", Value: "3", Path: "/"}) + SetCookie(w, &Cookie{Name: "Cookie5", Value: "NewValue5", Path: "/"}) // Insert cookie into Jar + Redirect(w, r, "/", StatusFound) + case "3": + want = map[string][]string{ + "Cookie1": {"OldValue1a", "OldValue1b"}, + "Cookie3": {"NewValue3"}, + "Cookie4": {"NewValue4"}, + "Cookie5": {"NewValue5"}, + "Cycle": {"3"}, + } + // Don't redirect to ensure the loop ends. + default: + t.Errorf("unexpected redirect cycle") + return + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("redirect %s, Cookie = %v, want %v", c.Value, got, want) + } + })) + defer ts.Close() + + jar, _ := cookiejar.New(nil) + c := ts.Client() + c.Jar = jar + + u, _ := url.Parse(ts.URL) + req, _ := NewRequest("GET", ts.URL, nil) + req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1a"}) + req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1b"}) + req.AddCookie(&Cookie{Name: "Cookie2", Value: "OldValue2"}) + req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3a"}) + req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3b"}) + jar.SetCookies(u, []*Cookie{{Name: "Cookie4", Value: "OldValue4", Path: "/"}}) + jar.SetCookies(u, []*Cookie{{Name: "Cycle", Value: "0", Path: "/"}}) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + t.Fatal(res.Status) + } +} + +// Part of Issue 4800 +func TestShouldCopyHeaderOnRedirect(t *testing.T) { + tests := []struct { + header string + initialURL string + destURL string + want bool + }{ + {"User-Agent", "http://foo.com/", "http://bar.com/", true}, + {"X-Foo", "http://foo.com/", "http://bar.com/", true}, + + // Sensitive headers: + {"cookie", "http://foo.com/", "http://bar.com/", false}, + {"cookie2", "http://foo.com/", "http://bar.com/", false}, + {"authorization", "http://foo.com/", "http://bar.com/", false}, + {"www-authenticate", "http://foo.com/", "http://bar.com/", false}, + + // But subdomains should work: + {"www-authenticate", "http://foo.com/", "http://foo.com/", true}, + {"www-authenticate", "http://foo.com/", "http://sub.foo.com/", true}, + {"www-authenticate", "http://foo.com/", "http://notfoo.com/", false}, + {"www-authenticate", "http://foo.com/", "https://foo.com/", false}, + {"www-authenticate", "http://foo.com:80/", "http://foo.com/", true}, + {"www-authenticate", "http://foo.com:80/", "http://sub.foo.com/", true}, + {"www-authenticate", "http://foo.com:443/", "https://foo.com/", true}, + {"www-authenticate", "http://foo.com:443/", "https://sub.foo.com/", true}, + {"www-authenticate", "http://foo.com:1234/", "http://foo.com/", false}, + } + for i, tt := range tests { + u0, err := url.Parse(tt.initialURL) + if err != nil { + t.Errorf("%d. initial URL %q parse error: %v", i, tt.initialURL, err) + continue + } + u1, err := url.Parse(tt.destURL) + if err != nil { + t.Errorf("%d. dest URL %q parse error: %v", i, tt.destURL, err) + continue + } + got := Export_shouldCopyHeaderOnRedirect(tt.header, u0, u1) + if got != tt.want { + t.Errorf("%d. shouldCopyHeaderOnRedirect(%q, %q => %q) = %v; want %v", + i, tt.header, tt.initialURL, tt.destURL, got, tt.want) + } + } +} + +func TestClientRedirectTypes(t *testing.T) { + setParallel(t) + defer afterTest(t) + + tests := [...]struct { + method string + serverStatus int + wantMethod string // desired subsequent client method + }{ + 0: {method: "POST", serverStatus: 301, wantMethod: "GET"}, + 1: {method: "POST", serverStatus: 302, wantMethod: "GET"}, + 2: {method: "POST", serverStatus: 303, wantMethod: "GET"}, + 3: {method: "POST", serverStatus: 307, wantMethod: "POST"}, + 4: {method: "POST", serverStatus: 308, wantMethod: "POST"}, + + 5: {method: "HEAD", serverStatus: 301, wantMethod: "HEAD"}, + 6: {method: "HEAD", serverStatus: 302, wantMethod: "HEAD"}, + 7: {method: "HEAD", serverStatus: 303, wantMethod: "HEAD"}, + 8: {method: "HEAD", serverStatus: 307, wantMethod: "HEAD"}, + 9: {method: "HEAD", serverStatus: 308, wantMethod: "HEAD"}, + + 10: {method: "GET", serverStatus: 301, wantMethod: "GET"}, + 11: {method: "GET", serverStatus: 302, wantMethod: "GET"}, + 12: {method: "GET", serverStatus: 303, wantMethod: "GET"}, + 13: {method: "GET", serverStatus: 307, wantMethod: "GET"}, + 14: {method: "GET", serverStatus: 308, wantMethod: "GET"}, + + 15: {method: "DELETE", serverStatus: 301, wantMethod: "GET"}, + 16: {method: "DELETE", serverStatus: 302, wantMethod: "GET"}, + 17: {method: "DELETE", serverStatus: 303, wantMethod: "GET"}, + 18: {method: "DELETE", serverStatus: 307, wantMethod: "DELETE"}, + 19: {method: "DELETE", serverStatus: 308, wantMethod: "DELETE"}, + + 20: {method: "PUT", serverStatus: 301, wantMethod: "GET"}, + 21: {method: "PUT", serverStatus: 302, wantMethod: "GET"}, + 22: {method: "PUT", serverStatus: 303, wantMethod: "GET"}, + 23: {method: "PUT", serverStatus: 307, wantMethod: "PUT"}, + 24: {method: "PUT", serverStatus: 308, wantMethod: "PUT"}, + + 25: {method: "MADEUPMETHOD", serverStatus: 301, wantMethod: "GET"}, + 26: {method: "MADEUPMETHOD", serverStatus: 302, wantMethod: "GET"}, + 27: {method: "MADEUPMETHOD", serverStatus: 303, wantMethod: "GET"}, + 28: {method: "MADEUPMETHOD", serverStatus: 307, wantMethod: "MADEUPMETHOD"}, + 29: {method: "MADEUPMETHOD", serverStatus: 308, wantMethod: "MADEUPMETHOD"}, + } + + handlerc := make(chan HandlerFunc, 1) + + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + h := <-handlerc + h(rw, req) + })) + defer ts.Close() + + c := ts.Client() + for i, tt := range tests { + handlerc <- func(w ResponseWriter, r *Request) { + w.Header().Set("Location", ts.URL) + w.WriteHeader(tt.serverStatus) + } + + req, err := NewRequest(tt.method, ts.URL, nil) + if err != nil { + t.Errorf("#%d: NewRequest: %v", i, err) + continue + } + + c.CheckRedirect = func(req *Request, via []*Request) error { + if got, want := req.Method, tt.wantMethod; got != want { + return fmt.Errorf("#%d: got next method %q; want %q", i, got, want) + } + handlerc <- func(rw ResponseWriter, req *Request) { + // TODO: Check that the body is valid when we do 307 and 308 support + } + return nil + } + + res, err := c.Do(req) + if err != nil { + t.Errorf("#%d: Response: %v", i, err) + continue + } + + res.Body.Close() + } +} + +// issue18239Body is an io.ReadCloser for TestTransportBodyReadError. +// Its Read returns readErr and increments *readCalls atomically. +// Its Close returns nil and increments *closeCalls atomically. +type issue18239Body struct { + readCalls *int32 + closeCalls *int32 + readErr error +} + +func (b issue18239Body) Read([]byte) (int, error) { + atomic.AddInt32(b.readCalls, 1) + return 0, b.readErr +} + +func (b issue18239Body) Close() error { + atomic.AddInt32(b.closeCalls, 1) + return nil +} + +// Issue 18239: make sure the Transport doesn't retry requests with bodies +// if Request.GetBody is not defined. +func TestTransportBodyReadError(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/ping" { + return + } + buf := make([]byte, 1) + n, err := r.Body.Read(buf) + w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err)) + })) + defer ts.Close() + c := ts.Client() + tr := c.Transport.(*Transport) + + // Do one initial successful request to create an idle TCP connection + // for the subsequent request to reuse. (The Transport only retries + // requests on reused connections.) + res, err := c.Get(ts.URL + "/ping") + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + var readCallsAtomic int32 + var closeCallsAtomic int32 // atomic + someErr := errors.New("some body read error") + body := issue18239Body{&readCallsAtomic, &closeCallsAtomic, someErr} + + req, err := NewRequest("POST", ts.URL, body) + if err != nil { + t.Fatal(err) + } + req = req.WithT(t) + _, err = tr.RoundTrip(req) + if err != someErr { + t.Errorf("Got error: %v; want Request.Body read error: %v", err, someErr) + } + + // And verify that our Body wasn't used multiple times, which + // would indicate retries. (as it buggily was during part of + // Go 1.8's dev cycle) + readCalls := atomic.LoadInt32(&readCallsAtomic) + closeCalls := atomic.LoadInt32(&closeCallsAtomic) + if readCalls != 1 { + t.Errorf("read calls = %d; want 1", readCalls) + } + if closeCalls != 1 { + t.Errorf("close calls = %d; want 1", closeCalls) + } +} + +type roundTripperWithoutCloseIdle struct{} + +func (roundTripperWithoutCloseIdle) RoundTrip(*Request) (*Response, error) { panic("unused") } + +type roundTripperWithCloseIdle func() // underlying func is CloseIdleConnections func + +func (roundTripperWithCloseIdle) RoundTrip(*Request) (*Response, error) { panic("unused") } +func (f roundTripperWithCloseIdle) CloseIdleConnections() { f() } + +func TestClientCloseIdleConnections(t *testing.T) { + c := &Client{Transport: roundTripperWithoutCloseIdle{}} + c.CloseIdleConnections() // verify we don't crash at least + + closed := false + var tr RoundTripper = roundTripperWithCloseIdle(func() { + closed = true + }) + c = &Client{Transport: tr} + c.CloseIdleConnections() + if !closed { + t.Error("not closed") + } +} + +func TestClientPropagatesTimeoutToContext(t *testing.T) { + errDial := errors.New("not actually dialing") + c := &Client{ + Timeout: 5 * time.Second, + Transport: &Transport{ + DialContext: func(ctx context.Context, netw, addr string) (net.Conn, error) { + deadline, ok := ctx.Deadline() + if !ok { + t.Error("no deadline") + } else { + t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10)) + } + return nil, errDial + }, + }, + } + c.Get("https://example.tld/") +} + +func TestClientDoCanceledVsTimeout_h1(t *testing.T) { + testClientDoCanceledVsTimeout(t, h1Mode) +} + +func TestClientDoCanceledVsTimeout_h2(t *testing.T) { + testClientDoCanceledVsTimeout(t, h2Mode) +} + +// Issue 33545: lock-in the behavior promised by Client.Do's +// docs about request cancelation vs timing out. +func testClientDoCanceledVsTimeout(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("Hello, World!")) + })) + defer cst.close() + + cases := []string{"timeout", "canceled"} + + for _, name := range cases { + t.Run(name, func(t *testing.T) { + var ctx context.Context + var cancel func() + if name == "timeout" { + ctx, cancel = context.WithTimeout(context.Background(), -time.Nanosecond) + } else { + ctx, cancel = context.WithCancel(context.Background()) + cancel() + } + defer cancel() + + req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) + _, err := cst.c.Do(req) + if err == nil { + t.Fatal("Unexpectedly got a nil error") + } + + ue := err.(*url.Error) + + var wantIsTimeout bool + var wantErr error = context.Canceled + if name == "timeout" { + wantErr = context.DeadlineExceeded + wantIsTimeout = true + } + if g, w := ue.Timeout(), wantIsTimeout; g != w { + t.Fatalf("url.Timeout() = %t, want %t", g, w) + } + if g, w := ue.Err, wantErr; g != w { + t.Errorf("url.Error.Err = %v; want %v", g, w) + } + }) + } +} + +type nilBodyRoundTripper struct{} + +func (nilBodyRoundTripper) RoundTrip(req *Request) (*Response, error) { + return &Response{ + StatusCode: StatusOK, + Status: StatusText(StatusOK), + Body: nil, + Request: req, + }, nil +} + +func TestClientPopulatesNilResponseBody(t *testing.T) { + c := &Client{Transport: nilBodyRoundTripper{}} + + resp, err := c.Get("http://localhost/anything") + if err != nil { + t.Fatalf("Client.Get rejected Response with nil Body: %v", err) + } + + if resp.Body == nil { + t.Fatalf("Client failed to provide a non-nil Body as documented") + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.Fatalf("error from Close on substitute Response.Body: %v", err) + } + }() + + if b, err := io.ReadAll(resp.Body); err != nil { + t.Errorf("read error from substitute Response.Body: %v", err) + } else if len(b) != 0 { + t.Errorf("substitute Response.Body was unexpectedly non-empty: %q", b) + } +} + +// Issue 40382: Client calls Close multiple times on Request.Body. +func TestClientCallsCloseOnlyOnce(t *testing.T) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNoContent) + })) + defer cst.close() + + // Issue occurred non-deterministically: needed to occur after a successful + // write (into TCP buffer) but before end of body. + for i := 0; i < 50 && !t.Failed(); i++ { + body := &issue40382Body{t: t, n: 300000} + req, err := NewRequest(MethodPost, cst.ts.URL, body) + if err != nil { + t.Fatal(err) + } + resp, err := cst.tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + } +} + +// issue40382Body is an io.ReadCloser for TestClientCallsCloseOnlyOnce. +// Its Read reads n bytes before returning io.EOF. +// Its Close returns nil but fails the test if called more than once. +type issue40382Body struct { + t *testing.T + n int + closeCallsAtomic int32 +} + +func (b *issue40382Body) Read(p []byte) (int, error) { + switch { + case b.n == 0: + return 0, io.EOF + case b.n < len(p): + p = p[:b.n] + fallthrough + default: + for i := range p { + p[i] = 'x' + } + b.n -= len(p) + return len(p), nil + } +} + +func (b *issue40382Body) Close() error { + if atomic.AddInt32(&b.closeCallsAtomic, 1) == 2 { + b.t.Error("Body closed more than once") + } + return nil +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/clientserver_test.go b/vendor/github.com/lesismal/llib/std/net/http/clientserver_test.go new file mode 100644 index 0000000..115bcf7 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/clientserver_test.go @@ -0,0 +1,1584 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode. + +package http_test + +import ( + "bytes" + "compress/gzip" + "crypto/rand" + "crypto/sha1" + "crypto/tls" + "fmt" + "github.com/lesismal/llib/std/net/http/httptest" + "github.com/lesismal/llib/std/net/http/httputil" + "hash" + "io" + "log" + "net" + . "net/http" + "net/url" + "os" + "reflect" + "runtime" + "sort" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +type clientServerTest struct { + t *testing.T + h2 bool + h Handler + ts *httptest.Server + tr *Transport + c *Client +} + +func (t *clientServerTest) close() { + t.tr.CloseIdleConnections() + t.ts.Close() +} + +func (t *clientServerTest) getURL(u string) string { + res, err := t.c.Get(u) + if err != nil { + t.t.Fatal(err) + } + defer res.Body.Close() + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.t.Fatal(err) + } + return string(slurp) +} + +func (t *clientServerTest) scheme() string { + if t.h2 { + return "https" + } + return "http" +} + +const ( + h1Mode = false + h2Mode = true +) + +var optQuietLog = func(ts *httptest.Server) { + ts.Config.ErrorLog = quietLog +} + +func optWithServerLog(lg *log.Logger) func(*httptest.Server) { + return func(ts *httptest.Server) { + ts.Config.ErrorLog = lg + } +} + +func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest { + if h2 { + CondSkipHTTP2(t) + } + cst := &clientServerTest{ + t: t, + h2: h2, + h: h, + tr: &Transport{}, + } + cst.c = &Client{Transport: cst.tr} + cst.ts = httptest.NewUnstartedServer(h) + + for _, opt := range opts { + switch opt := opt.(type) { + case func(*Transport): + opt(cst.tr) + case func(*httptest.Server): + opt(cst.ts) + default: + t.Fatalf("unhandled option type %T", opt) + } + } + + if !h2 { + cst.ts.Start() + return cst + } + ExportHttp2ConfigureServer(cst.ts.Config, nil) + cst.ts.TLS = cst.ts.Config.TLSConfig + cst.ts.StartTLS() + + cst.tr.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { + t.Fatal(err) + } + return cst +} + +// Testing the newClientServerTest helper itself. +func TestNewClientServerTest(t *testing.T) { + var got struct { + sync.Mutex + log []string + } + h := HandlerFunc(func(w ResponseWriter, r *Request) { + got.Lock() + defer got.Unlock() + got.log = append(got.log, r.Proto) + }) + for _, v := range [2]bool{false, true} { + cst := newClientServerTest(t, v, h) + if _, err := cst.c.Head(cst.ts.URL); err != nil { + t.Fatal(err) + } + cst.close() + } + got.Lock() // no need to unlock + if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) { + t.Errorf("got %q; want %q", got.log, want) + } +} + +func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) } +func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) } + +func testChunkedResponseHeaders(t *testing.T, h2 bool) { + defer afterTest(t) + log.SetOutput(io.Discard) // is noisy otherwise + defer log.SetOutput(os.Stderr) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted + w.(Flusher).Flush() + fmt.Fprintf(w, "I am a chunked response.") + })) + defer cst.close() + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + defer res.Body.Close() + if g, e := res.ContentLength, int64(-1); g != e { + t.Errorf("expected ContentLength of %d; got %d", e, g) + } + wantTE := []string{"chunked"} + if h2 { + wantTE = nil + } + if !reflect.DeepEqual(res.TransferEncoding, wantTE) { + t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE) + } + if got, haveCL := res.Header["Content-Length"]; haveCL { + t.Errorf("Unexpected Content-Length: %q", got) + } +} + +type reqFunc func(c *Client, url string) (*Response, error) + +// h12Compare is a test that compares HTTP/1 and HTTP/2 behavior +// against each other. +type h12Compare struct { + Handler func(ResponseWriter, *Request) // required + ReqFunc reqFunc // optional + CheckResponse func(proto string, res *Response) // optional + EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize + Opts []interface{} +} + +func (tt h12Compare) reqFunc() reqFunc { + if tt.ReqFunc == nil { + return (*Client).Get + } + return tt.ReqFunc +} + +func (tt h12Compare) run(t *testing.T) { + setParallel(t) + cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...) + defer cst1.close() + cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...) + defer cst2.close() + + res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) + if err != nil { + t.Errorf("HTTP/1 request: %v", err) + return + } + res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL) + if err != nil { + t.Errorf("HTTP/2 request: %v", err) + return + } + + if fn := tt.EarlyCheckResponse; fn != nil { + fn("HTTP/1.1", res1) + fn("HTTP/2.0", res2) + } + + tt.normalizeRes(t, res1, "HTTP/1.1") + tt.normalizeRes(t, res2, "HTTP/2.0") + res1body, res2body := res1.Body, res2.Body + + eres1 := mostlyCopy(res1) + eres2 := mostlyCopy(res2) + if !reflect.DeepEqual(eres1, eres2) { + t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v", + cst1.ts.URL, eres1, cst2.ts.URL, eres2) + } + if !reflect.DeepEqual(res1body, res2body) { + t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body) + } + if fn := tt.CheckResponse; fn != nil { + res1.Body, res2.Body = res1body, res2body + fn("HTTP/1.1", res1) + fn("HTTP/2.0", res2) + } +} + +func mostlyCopy(r *Response) *Response { + c := *r + c.Body = nil + c.TransferEncoding = nil + c.TLS = nil + c.Request = nil + return &c +} + +type slurpResult struct { + io.ReadCloser + body []byte + err error +} + +func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) } + +func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) { + if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" { + res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0 + } else { + t.Errorf("got %q response; want %q", res.Proto, wantProto) + } + slurp, err := io.ReadAll(res.Body) + + res.Body.Close() + res.Body = slurpResult{ + ReadCloser: io.NopCloser(bytes.NewReader(slurp)), + body: slurp, + err: err, + } + for i, v := range res.Header["Date"] { + res.Header["Date"][i] = strings.Repeat("x", len(v)) + } + if res.Request == nil { + t.Errorf("for %s, no request", wantProto) + } + if (res.TLS != nil) != (wantProto == "HTTP/2.0") { + t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil) + } +} + +// Issue 13532 +func TestH12_HeadContentLengthNoBody(t *testing.T) { + h12Compare{ + ReqFunc: (*Client).Head, + Handler: func(w ResponseWriter, r *Request) { + }, + }.run(t) +} + +func TestH12_HeadContentLengthSmallBody(t *testing.T) { + h12Compare{ + ReqFunc: (*Client).Head, + Handler: func(w ResponseWriter, r *Request) { + io.WriteString(w, "small") + }, + }.run(t) +} + +func TestH12_HeadContentLengthLargeBody(t *testing.T) { + h12Compare{ + ReqFunc: (*Client).Head, + Handler: func(w ResponseWriter, r *Request) { + chunk := strings.Repeat("x", 512<<10) + for i := 0; i < 10; i++ { + io.WriteString(w, chunk) + } + }, + }.run(t) +} + +func TestH12_200NoBody(t *testing.T) { + h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t) +} + +func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) } +func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) } +func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) } + +func testH12_noBody(t *testing.T, status int) { + h12Compare{Handler: func(w ResponseWriter, r *Request) { + w.WriteHeader(status) + }}.run(t) +} + +func TestH12_SmallBody(t *testing.T) { + h12Compare{Handler: func(w ResponseWriter, r *Request) { + io.WriteString(w, "small body") + }}.run(t) +} + +func TestH12_ExplicitContentLength(t *testing.T) { + h12Compare{Handler: func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "3") + io.WriteString(w, "foo") + }}.run(t) +} + +func TestH12_FlushBeforeBody(t *testing.T) { + h12Compare{Handler: func(w ResponseWriter, r *Request) { + w.(Flusher).Flush() + io.WriteString(w, "foo") + }}.run(t) +} + +func TestH12_FlushMidBody(t *testing.T) { + h12Compare{Handler: func(w ResponseWriter, r *Request) { + io.WriteString(w, "foo") + w.(Flusher).Flush() + io.WriteString(w, "bar") + }}.run(t) +} + +func TestH12_Head_ExplicitLen(t *testing.T) { + h12Compare{ + ReqFunc: (*Client).Head, + Handler: func(w ResponseWriter, r *Request) { + if r.Method != "HEAD" { + t.Errorf("unexpected method %q", r.Method) + } + w.Header().Set("Content-Length", "1235") + }, + }.run(t) +} + +func TestH12_Head_ImplicitLen(t *testing.T) { + h12Compare{ + ReqFunc: (*Client).Head, + Handler: func(w ResponseWriter, r *Request) { + if r.Method != "HEAD" { + t.Errorf("unexpected method %q", r.Method) + } + io.WriteString(w, "foo") + }, + }.run(t) +} + +func TestH12_HandlerWritesTooLittle(t *testing.T) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "3") + io.WriteString(w, "12") // one byte short + }, + CheckResponse: func(proto string, res *Response) { + sr, ok := res.Body.(slurpResult) + if !ok { + t.Errorf("%s body is %T; want slurpResult", proto, res.Body) + return + } + if sr.err != io.ErrUnexpectedEOF { + t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err) + } + if string(sr.body) != "12" { + t.Errorf("%s body = %q; want %q", proto, sr.body, "12") + } + }, + }.run(t) +} + +// Tests that the HTTP/1 and HTTP/2 servers prevent handlers from +// writing more than they declared. This test does not test whether +// the transport deals with too much data, though, since the server +// doesn't make it possible to send bogus data. For those tests, see +// transport_test.go (for HTTP/1) or x/net/http2/transport_test.go +// (for HTTP/2). +func TestH12_HandlerWritesTooMuch(t *testing.T) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "3") + w.(Flusher).Flush() + io.WriteString(w, "123") + w.(Flusher).Flush() + n, err := io.WriteString(w, "x") // too many + if n > 0 || err == nil { + t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err) + } + }, + }.run(t) +} + +// Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip. +// Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298 +func TestH12_AutoGzip(t *testing.T) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" { + t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae) + } + w.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(w) + io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.") + gz.Close() + }, + }.run(t) +} + +func TestH12_AutoGzip_Disabled(t *testing.T) { + h12Compare{ + Opts: []interface{}{ + func(tr *Transport) { tr.DisableCompression = true }, + }, + Handler: func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"]) + if ae := r.Header.Get("Accept-Encoding"); ae != "" { + t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae) + } + }, + }.run(t) +} + +// Test304Responses verifies that 304s don't declare that they're +// chunking in their response headers and aren't allowed to produce +// output. +func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) } +func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) } + +func test304Responses(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNotModified) + _, err := w.Write([]byte("illegal body")) + if err != ErrBodyNotAllowed { + t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) + } + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + if len(res.TransferEncoding) > 0 { + t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if len(body) > 0 { + t.Errorf("got unexpected body %q", string(body)) + } +} + +func TestH12_ServerEmptyContentLength(t *testing.T) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + w.Header()["Content-Type"] = []string{""} + io.WriteString(w, "hi") + }, + }.run(t) +} + +func TestH12_RequestContentLength_Known_NonZero(t *testing.T) { + h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4) +} + +func TestH12_RequestContentLength_Known_Zero(t *testing.T) { + h12requestContentLength(t, func() io.Reader { return nil }, 0) +} + +func TestH12_RequestContentLength_Unknown(t *testing.T) { + h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1) +} + +func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength)) + fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength) + }, + ReqFunc: func(c *Client, url string) (*Response, error) { + return c.Post(url, "text/plain", bodyfn()) + }, + CheckResponse: func(proto string, res *Response) { + if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want { + t.Errorf("Proto %q got length %q; want %q", proto, got, want) + } + }, + }.run(t) +} + +// Tests that closing the Request.Cancel channel also while still +// reading the response body. Issue 13159. +func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) } +func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) } +func testCancelRequestMidBody(t *testing.T, h2 bool) { + defer afterTest(t) + unblock := make(chan bool) + didFlush := make(chan bool, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, "Hello") + w.(Flusher).Flush() + didFlush <- true + <-unblock + io.WriteString(w, ", world.") + })) + defer cst.close() + defer close(unblock) + + req, _ := NewRequest("GET", cst.ts.URL, nil) + cancel := make(chan struct{}) + req.Cancel = cancel + + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + <-didFlush + + // Read a bit before we cancel. (Issue 13626) + // We should have "Hello" at least sitting there. + firstRead := make([]byte, 10) + n, err := res.Body.Read(firstRead) + if err != nil { + t.Fatal(err) + } + firstRead = firstRead[:n] + + close(cancel) + + rest, err := io.ReadAll(res.Body) + all := string(firstRead) + string(rest) + if all != "Hello" { + t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest) + } + if err != ExportErrRequestCanceled { + t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled) + } +} + +// Tests that clients can send trailers to a server and that the server can read them. +func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) } +func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) } + +func testTrailersClientToServer(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + var decl []string + for k := range r.Trailer { + decl = append(decl, k) + } + sort.Strings(decl) + + slurp, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Server reading request body: %v", err) + } + if string(slurp) != "foo" { + t.Errorf("Server read request body %q; want foo", slurp) + } + if r.Trailer == nil { + io.WriteString(w, "nil Trailer") + } else { + fmt.Fprintf(w, "decl: %v, vals: %s, %s", + decl, + r.Trailer.Get("Client-Trailer-A"), + r.Trailer.Get("Client-Trailer-B")) + } + })) + defer cst.close() + + var req *Request + req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader( + eofReaderFunc(func() { + req.Trailer["Client-Trailer-A"] = []string{"valuea"} + }), + strings.NewReader("foo"), + eofReaderFunc(func() { + req.Trailer["Client-Trailer-B"] = []string{"valueb"} + }), + )) + req.Trailer = Header{ + "Client-Trailer-A": nil, // to be set later + "Client-Trailer-B": nil, // to be set later + } + req.ContentLength = -1 + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil { + t.Error(err) + } +} + +// Tests that servers send trailers to a client and that the client can read them. +func TestTrailersServerToClient_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, false) } +func TestTrailersServerToClient_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, false) } +func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) } +func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) } + +func testTrailersServerToClient(t *testing.T, h2, flush bool) { + defer afterTest(t) + const body = "Some body" + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") + w.Header().Add("Trailer", "Server-Trailer-C") + + io.WriteString(w, body) + if flush { + w.(Flusher).Flush() + } + + // How handlers set Trailers: declare it ahead of time + // with the Trailer header, and then mutate the + // Header() of those values later, after the response + // has been written (we wrote to w above). + w.Header().Set("Server-Trailer-A", "valuea") + w.Header().Set("Server-Trailer-C", "valuec") // skipping B + w.Header().Set("Server-Trailer-NotDeclared", "should be omitted") + })) + defer cst.close() + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + + wantHeader := Header{ + "Content-Type": {"text/plain; charset=utf-8"}, + } + wantLen := -1 + if h2 && !flush { + // In HTTP/1.1, any use of trailers forces HTTP/1.1 + // chunking and a flush at the first write. That's + // unnecessary with HTTP/2's framing, so the server + // is able to calculate the length while still sending + // trailers afterwards. + wantLen = len(body) + wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)} + } + if res.ContentLength != int64(wantLen) { + t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen) + } + + delete(res.Header, "Date") // irrelevant for test + if !reflect.DeepEqual(res.Header, wantHeader) { + t.Errorf("Header = %v; want %v", res.Header, wantHeader) + } + + if got, want := res.Trailer, (Header{ + "Server-Trailer-A": nil, + "Server-Trailer-B": nil, + "Server-Trailer-C": nil, + }); !reflect.DeepEqual(got, want) { + t.Errorf("Trailer before body read = %v; want %v", got, want) + } + + if err := wantBody(res, nil, body); err != nil { + t.Fatal(err) + } + + if got, want := res.Trailer, (Header{ + "Server-Trailer-A": {"valuea"}, + "Server-Trailer-B": nil, + "Server-Trailer-C": {"valuec"}, + }); !reflect.DeepEqual(got, want) { + t.Errorf("Trailer after body read = %v; want %v", got, want) + } +} + +// Don't allow a Body.Read after Body.Close. Issue 13648. +func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) } +func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) } + +func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { + defer afterTest(t) + const body = "Some body" + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, body) + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + data, err := io.ReadAll(res.Body) + if len(data) != 0 || err == nil { + t.Fatalf("ReadAll returned %q, %v; want error", data, err) + } +} + +func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) } +func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) } +func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { + defer afterTest(t) + const reqBody = "some request body" + const resBody = "some response body" + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + var wg sync.WaitGroup + wg.Add(2) + didRead := make(chan bool, 1) + // Read in one goroutine. + go func() { + defer wg.Done() + data, err := io.ReadAll(r.Body) + if string(data) != reqBody { + t.Errorf("Handler read %q; want %q", data, reqBody) + } + if err != nil { + t.Errorf("Handler Read: %v", err) + } + didRead <- true + }() + // Write in another goroutine. + go func() { + defer wg.Done() + if !h2 { + // our HTTP/1 implementation intentionally + // doesn't permit writes during read (mostly + // due to it being undefined); if that is ever + // relaxed, change this. + <-didRead + } + io.WriteString(w, resBody) + }() + wg.Wait() + })) + defer cst.close() + req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody)) + req.Header.Add("Expect", "100-continue") // just to complicate things + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + data, err := io.ReadAll(res.Body) + defer res.Body.Close() + if err != nil { + t.Fatal(err) + } + if string(data) != resBody { + t.Errorf("read %q; want %q", data, resBody) + } +} + +func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) } +func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) } +func testConnectRequest(t *testing.T, h2 bool) { + defer afterTest(t) + gotc := make(chan *Request, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + gotc <- r + })) + defer cst.close() + + u, err := url.Parse(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + req *Request + want string + }{ + { + req: &Request{ + Method: "CONNECT", + Header: Header{}, + URL: u, + }, + want: u.Host, + }, + { + req: &Request{ + Method: "CONNECT", + Header: Header{}, + URL: u, + Host: "example.com:123", + }, + want: "example.com:123", + }, + } + + for i, tt := range tests { + res, err := cst.c.Do(tt.req) + if err != nil { + t.Errorf("%d. RoundTrip = %v", i, err) + continue + } + res.Body.Close() + req := <-gotc + if req.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", req.Method) + } + if req.Host != tt.want { + t.Errorf("Host = %q; want %q", req.Host, tt.want) + } + if req.URL.Host != tt.want { + t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want) + } + } +} + +func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) } +func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) } +func testTransportUserAgent(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%q", r.Header["User-Agent"]) + })) + defer cst.close() + + either := func(a, b string) string { + if h2 { + return b + } + return a + } + + tests := []struct { + setup func(*Request) + want string + }{ + { + func(r *Request) {}, + either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`), + }, + { + func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") }, + `["foo/1.2.3"]`, + }, + { + func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} }, + `["single"]`, + }, + { + func(r *Request) { r.Header.Set("User-Agent", "") }, + `[]`, + }, + { + func(r *Request) { r.Header["User-Agent"] = nil }, + `[]`, + }, + } + for i, tt := range tests { + req, _ := NewRequest("GET", cst.ts.URL, nil) + tt.setup(req) + res, err := cst.c.Do(req) + if err != nil { + t.Errorf("%d. RoundTrip = %v", i, err) + continue + } + slurp, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("%d. read body = %v", i, err) + continue + } + if string(slurp) != tt.want { + t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want) + } + } +} + +func TestStarRequestFoo_h1(t *testing.T) { testStarRequest(t, "FOO", h1Mode) } +func TestStarRequestFoo_h2(t *testing.T) { testStarRequest(t, "FOO", h2Mode) } +func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) } +func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) } +func testStarRequest(t *testing.T, method string, h2 bool) { + defer afterTest(t) + gotc := make(chan *Request, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("foo", "bar") + gotc <- r + w.(Flusher).Flush() + })) + defer cst.close() + + u, err := url.Parse(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + u.Path = "*" + + req := &Request{ + Method: method, + Header: Header{}, + URL: u, + } + + res, err := cst.c.Do(req) + if err != nil { + t.Fatalf("RoundTrip = %v", err) + } + res.Body.Close() + + wantFoo := "bar" + wantLen := int64(-1) + if method == "OPTIONS" { + wantFoo = "" + wantLen = 0 + } + if res.StatusCode != 200 { + t.Errorf("status code = %v; want %d", res.Status, 200) + } + if res.ContentLength != wantLen { + t.Errorf("content length = %v; want %d", res.ContentLength, wantLen) + } + if got := res.Header.Get("foo"); got != wantFoo { + t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo) + } + select { + case req = <-gotc: + default: + req = nil + } + if req == nil { + if method != "OPTIONS" { + t.Fatalf("handler never got request") + } + return + } + if req.Method != method { + t.Errorf("method = %q; want %q", req.Method, method) + } + if req.URL.Path != "*" { + t.Errorf("URL.Path = %q; want *", req.URL.Path) + } + if req.RequestURI != "*" { + t.Errorf("RequestURI = %q; want *", req.RequestURI) + } +} + +// Issue 13957 +func TestTransportDiscardsUnneededConns(t *testing.T) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) + })) + defer cst.close() + + var numOpen, numClose int32 // atomic + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + tr := &Transport{ + TLSClientConfig: tlsConfig, + DialTLS: func(_, addr string) (net.Conn, error) { + time.Sleep(10 * time.Millisecond) + rc, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + atomic.AddInt32(&numOpen, 1) + c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }} + return tls.Client(c, tlsConfig), nil + }, + } + if err := ExportHttp2ConfigureTransport(tr); err != nil { + t.Fatal(err) + } + defer tr.CloseIdleConnections() + + c := &Client{Transport: tr} + + const N = 10 + gotBody := make(chan string, N) + var wg sync.WaitGroup + for i := 0; i < N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := c.Get(cst.ts.URL) + if err != nil { + // Try to work around spurious connection reset on loaded system. + // See golang.org/issue/33585 and golang.org/issue/36797. + time.Sleep(10 * time.Millisecond) + resp, err = c.Get(cst.ts.URL) + if err != nil { + t.Errorf("Get: %v", err) + return + } + } + defer resp.Body.Close() + slurp, err := io.ReadAll(resp.Body) + if err != nil { + t.Error(err) + } + gotBody <- string(slurp) + }() + } + wg.Wait() + close(gotBody) + + var last string + for got := range gotBody { + if last == "" { + last = got + continue + } + if got != last { + t.Errorf("Response body changed: %q -> %q", last, got) + } + } + + var open, close int32 + for i := 0; i < 150; i++ { + open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose) + if open < 1 { + t.Fatalf("open = %d; want at least", open) + } + if close == open-1 { + // Success + return + } + time.Sleep(10 * time.Millisecond) + } + t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1) +} + +// tests that Transport doesn't retain a pointer to the provided request. +func TestTransportGCRequest_Body_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, true) } +func TestTransportGCRequest_Body_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, true) } +func TestTransportGCRequest_NoBody_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, false) } +func TestTransportGCRequest_NoBody_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, false) } +func testTransportGCRequest(t *testing.T, h2, body bool) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + io.ReadAll(r.Body) + if body { + io.WriteString(w, "Hello.") + } + })) + defer cst.close() + + didGC := make(chan struct{}) + (func() { + body := strings.NewReader("some body") + req, _ := NewRequest("POST", cst.ts.URL, body) + runtime.SetFinalizer(req, func(*Request) { close(didGC) }) + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + if _, err := io.ReadAll(res.Body); err != nil { + t.Fatal(err) + } + if err := res.Body.Close(); err != nil { + t.Fatal(err) + } + })() + timeout := time.NewTimer(5 * time.Second) + defer timeout.Stop() + for { + select { + case <-didGC: + return + case <-time.After(100 * time.Millisecond): + runtime.GC() + case <-timeout.C: + t.Fatal("never saw GC of request") + } + } +} + +func TestTransportRejectsInvalidHeaders_h1(t *testing.T) { + testTransportRejectsInvalidHeaders(t, h1Mode) +} +func TestTransportRejectsInvalidHeaders_h2(t *testing.T) { + testTransportRejectsInvalidHeaders(t, h2Mode) +} +func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Handler saw headers: %q", r.Header) + }), optQuietLog) + defer cst.close() + cst.tr.DisableKeepAlives = true + + tests := []struct { + key, val string + ok bool + }{ + {"Foo", "capital-key", true}, // verify h2 allows capital keys + {"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed + {"Foo", "two\nlines", false}, // \n byte in value not allowed + {"bogus\nkey", "v", false}, // \n byte also not allowed in key + {"A space", "v", false}, // spaces in keys not allowed + {"имя", "v", false}, // key must be ascii + {"name", "валю", true}, // value may be non-ascii + {"", "v", false}, // key must be non-empty + {"k", "", true}, // value may be empty + } + for _, tt := range tests { + dialedc := make(chan bool, 1) + cst.tr.Dial = func(netw, addr string) (net.Conn, error) { + dialedc <- true + return net.Dial(netw, addr) + } + req, _ := NewRequest("GET", cst.ts.URL, nil) + req.Header[tt.key] = []string{tt.val} + res, err := cst.c.Do(req) + var body []byte + if err == nil { + body, _ = io.ReadAll(res.Body) + res.Body.Close() + } + var dialed bool + select { + case <-dialedc: + dialed = true + default: + } + + if !tt.ok && dialed { + t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body) + } else if (err == nil) != tt.ok { + t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok) + } + } +} + +func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, "boom") } +func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, "boom") } +func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) } +func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) } +func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) { + testInterruptWithPanic(t, h1Mode, ErrAbortHandler) +} +func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) { + testInterruptWithPanic(t, h2Mode, ErrAbortHandler) +} +func testInterruptWithPanic(t *testing.T, h2 bool, panicValue interface{}) { + setParallel(t) + const msg = "hello" + defer afterTest(t) + + testDone := make(chan struct{}) + defer close(testDone) + + var errorLog lockedBytesBuffer + gotHeaders := make(chan bool, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, msg) + w.(Flusher).Flush() + + select { + case <-gotHeaders: + case <-testDone: + } + panic(panicValue) + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(&errorLog, "", 0) + }) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + gotHeaders <- true + defer res.Body.Close() + slurp, err := io.ReadAll(res.Body) + if string(slurp) != msg { + t.Errorf("client read %q; want %q", slurp, msg) + } + if err == nil { + t.Errorf("client read all successfully; want some error") + } + logOutput := func() string { + errorLog.Lock() + defer errorLog.Unlock() + return errorLog.String() + } + wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler + + if err := waitErrCondition(5*time.Second, 10*time.Millisecond, func() error { + gotLog := logOutput() + if !wantStackLogged { + if gotLog == "" { + return nil + } + return fmt.Errorf("want no log output; got: %s", gotLog) + } + if gotLog == "" { + return fmt.Errorf("wanted a stack trace logged; got nothing") + } + if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 { + return fmt.Errorf("output doesn't look like a panic stack trace. Got: %s", gotLog) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +type lockedBytesBuffer struct { + sync.Mutex + bytes.Buffer +} + +func (b *lockedBytesBuffer) Write(p []byte) (int, error) { + b.Lock() + defer b.Unlock() + return b.Buffer.Write(p) +} + +// Issue 15366 +func TestH12_AutoGzipWithDumpResponse(t *testing.T) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + h := w.Header() + h.Set("Content-Encoding", "gzip") + h.Set("Content-Length", "23") + io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00") + }, + EarlyCheckResponse: func(proto string, res *Response) { + if !res.Uncompressed { + t.Errorf("%s: expected Uncompressed to be set", proto) + } + dump, err := httputil.DumpResponse(res, true) + if err != nil { + t.Errorf("%s: DumpResponse: %v", proto, err) + return + } + if strings.Contains(string(dump), "Connection: close") { + t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump) + } + if !strings.Contains(string(dump), "FOO") { + t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump) + } + }, + }.run(t) +} + +// Issue 14607 +func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) } +func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) } +func testCloseIdleConnections(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("X-Addr", r.RemoteAddr) + })) + defer cst.close() + get := func() string { + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + v := res.Header.Get("X-Addr") + if v == "" { + t.Fatal("didn't get X-Addr") + } + return v + } + a1 := get() + cst.tr.CloseIdleConnections() + a2 := get() + if a1 == a2 { + t.Errorf("didn't close connection") + } +} + +type noteCloseConn struct { + net.Conn + closeFunc func() +} + +func (x noteCloseConn) Close() error { + x.closeFunc() + return x.Conn.Close() +} + +type testErrorReader struct{ t *testing.T } + +func (r testErrorReader) Read(p []byte) (n int, err error) { + r.t.Error("unexpected Read call") + return 0, io.EOF +} + +func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) } +func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) } + +func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusUnauthorized) + })) + defer cst.close() + + // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. + cst.tr.ExpectContinueTimeout = 10 * time.Second + + req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t}) + if err != nil { + t.Fatal(err) + } + req.ContentLength = 0 // so transport is tempted to sniff it + req.Header.Set("Expect", "100-continue") + res, err := cst.tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != StatusUnauthorized { + t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized) + } +} + +func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) } +func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) } +func testServerUndeclaredTrailers(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Foo", "Bar") + w.Header().Set("Trailer:Foo", "Baz") + w.(Flusher).Flush() + w.Header().Add("Trailer:Foo", "Baz2") + w.Header().Set("Trailer:Bar", "Quux") + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + if _, err := io.Copy(io.Discard, res.Body); err != nil { + t.Fatal(err) + } + res.Body.Close() + delete(res.Header, "Date") + delete(res.Header, "Content-Type") + + if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) { + t.Errorf("Header = %#v; want %#v", res.Header, want) + } + if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) { + t.Errorf("Trailer = %#v; want %#v", res.Trailer, want) + } +} + +func TestBadResponseAfterReadingBody(t *testing.T) { + defer afterTest(t) + cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := io.Copy(io.Discard, r.Body) + if err != nil { + t.Fatal(err) + } + c, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Fatal(err) + } + defer c.Close() + fmt.Fprintln(c, "some bogus crap") + })) + defer cst.close() + + closes := 0 + res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) + if err == nil { + res.Body.Close() + t.Fatal("expected an error to be returned from Post") + } + if closes != 1 { + t.Errorf("closes = %d; want 1", closes) + } +} + +func TestWriteHeader0_h1(t *testing.T) { testWriteHeader0(t, h1Mode) } +func TestWriteHeader0_h2(t *testing.T) { testWriteHeader0(t, h2Mode) } +func testWriteHeader0(t *testing.T, h2 bool) { + defer afterTest(t) + gotpanic := make(chan bool, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + defer close(gotpanic) + defer func() { + if e := recover(); e != nil { + got := fmt.Sprintf("%T, %v", e, e) + want := "string, invalid WriteHeader code 0" + if got != want { + t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want) + } + gotpanic <- true + + // Set an explicit 503. This also tests that the WriteHeader call panics + // before it recorded that an explicit value was set and that bogus + // value wasn't stuck. + w.WriteHeader(503) + } + }() + w.WriteHeader(0) + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 503 { + t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status) + } + if !<-gotpanic { + t.Error("expected panic in handler") + } +} + +// Issue 23010: don't be super strict checking WriteHeader's code if +// it's not even valid to call WriteHeader then anyway. +func TestWriteHeaderNoCodeCheck_h1(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, false) } +func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, true) } +func TestWriteHeaderNoCodeCheck_h2(t *testing.T) { testWriteHeaderAfterWrite(t, h2Mode, false) } +func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { + setParallel(t) + defer afterTest(t) + + var errorLog lockedBytesBuffer + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + if hijack { + conn, _, _ := w.(Hijacker).Hijack() + defer conn.Close() + conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo")) + w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 + conn.Write([]byte("bar")) + return + } + io.WriteString(w, "foo") + w.(Flusher).Flush() + w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 + io.WriteString(w, "bar") + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(&errorLog, "", 0) + }) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if got, want := string(body), "foobar"; got != want { + t.Errorf("got = %q; want %q", got, want) + } + + // Also check the stderr output: + if h2 { + // TODO: also emit this log message for HTTP/2? + // We historically haven't, so don't check. + return + } + gotLog := strings.TrimSpace(errorLog.String()) + wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:" + if hijack { + wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:" + } + if !strings.HasPrefix(gotLog, wantLog) { + t.Errorf("stderr output = %q; want %q", gotLog, wantLog) + } +} + +func TestBidiStreamReverseProxy(t *testing.T) { + setParallel(t) + defer afterTest(t) + backend := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if _, err := io.Copy(w, r.Body); err != nil { + log.Printf("bidi backend copy: %v", err) + } + })) + defer backend.close() + + backURL, err := url.Parse(backend.ts.URL) + if err != nil { + t.Fatal(err) + } + rp := httputil.NewSingleHostReverseProxy(backURL) + rp.Transport = backend.tr + proxy := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + rp.ServeHTTP(w, r) + })) + defer proxy.close() + + bodyRes := make(chan interface{}, 1) // error or hash.Hash + pr, pw := io.Pipe() + req, _ := NewRequest("PUT", proxy.ts.URL, pr) + const size = 4 << 20 + go func() { + h := sha1.New() + _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size) + go pw.Close() + if err != nil { + bodyRes <- err + } else { + bodyRes <- h + } + }() + res, err := backend.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + hgot := sha1.New() + n, err := io.Copy(hgot, res.Body) + if err != nil { + t.Fatal(err) + } + if n != size { + t.Fatalf("got %d bytes; want %d", n, size) + } + select { + case v := <-bodyRes: + switch v := v.(type) { + default: + t.Fatalf("body copy: %v", err) + case hash.Hash: + if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) { + t.Errorf("written bytes didn't match received bytes") + } + } + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } + +} + +// Always use HTTP/1.1 for WebSocket upgrades. +func TestH12_WebSocketUpgrade(t *testing.T) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + h := w.Header() + h.Set("Foo", "bar") + }, + ReqFunc: func(c *Client, url string) (*Response, error) { + req, _ := NewRequest("GET", url, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "WebSocket") + return c.Do(req) + }, + EarlyCheckResponse: func(proto string, res *Response) { + if res.Proto != "HTTP/1.1" { + t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto) + } + res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0 + }, + }.run(t) +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/clone.go b/vendor/github.com/lesismal/llib/std/net/http/clone.go new file mode 100644 index 0000000..3a3375b --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/clone.go @@ -0,0 +1,74 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "mime/multipart" + "net/textproto" + "net/url" +) + +func cloneURLValues(v url.Values) url.Values { + if v == nil { + return nil + } + // http.Header and url.Values have the same representation, so temporarily + // treat it like http.Header, which does have a clone: + return url.Values(Header(v).Clone()) +} + +func cloneURL(u *url.URL) *url.URL { + if u == nil { + return nil + } + u2 := new(url.URL) + *u2 = *u + if u.User != nil { + u2.User = new(url.Userinfo) + *u2.User = *u.User + } + return u2 +} + +func cloneMultipartForm(f *multipart.Form) *multipart.Form { + if f == nil { + return nil + } + f2 := &multipart.Form{ + Value: (map[string][]string)(Header(f.Value).Clone()), + } + if f.File != nil { + m := make(map[string][]*multipart.FileHeader) + for k, vv := range f.File { + vv2 := make([]*multipart.FileHeader, len(vv)) + for i, v := range vv { + vv2[i] = cloneMultipartFileHeader(v) + } + m[k] = vv2 + } + f2.File = m + } + return f2 +} + +func cloneMultipartFileHeader(fh *multipart.FileHeader) *multipart.FileHeader { + if fh == nil { + return nil + } + fh2 := new(multipart.FileHeader) + *fh2 = *fh + fh2.Header = textproto.MIMEHeader(Header(fh.Header).Clone()) + return fh2 +} + +// cloneOrMakeHeader invokes Header.Clone but if the +// result is nil, it'll instead make and return a non-nil Header. +func cloneOrMakeHeader(hdr Header) Header { + clone := hdr.Clone() + if clone == nil { + clone = make(Header) + } + return clone +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cookie.go b/vendor/github.com/lesismal/llib/std/net/http/cookie.go new file mode 100644 index 0000000..141bc94 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cookie.go @@ -0,0 +1,433 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "log" + "net" + "net/textproto" + "strconv" + "strings" + "time" +) + +// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an +// HTTP response or the Cookie header of an HTTP request. +// +// See https://tools.ietf.org/html/rfc6265 for details. +type Cookie struct { + Name string + Value string + + Path string // optional + Domain string // optional + Expires time.Time // optional + RawExpires string // for reading cookies only + + // MaxAge=0 means no 'Max-Age' attribute specified. + // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0' + // MaxAge>0 means Max-Age attribute present and given in seconds + MaxAge int + Secure bool + HttpOnly bool + SameSite SameSite + Raw string + Unparsed []string // Raw text of unparsed attribute-value pairs +} + +// SameSite allows a server to define a cookie attribute making it impossible for +// the browser to send this cookie along with cross-site requests. The main +// goal is to mitigate the risk of cross-origin information leakage, and provide +// some protection against cross-site request forgery attacks. +// +// See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details. +type SameSite int + +const ( + SameSiteDefaultMode SameSite = iota + 1 + SameSiteLaxMode + SameSiteStrictMode + SameSiteNoneMode +) + +// readSetCookies parses all "Set-Cookie" values from +// the header h and returns the successfully parsed Cookies. +func readSetCookies(h Header) []*Cookie { + cookieCount := len(h["Set-Cookie"]) + if cookieCount == 0 { + return []*Cookie{} + } + cookies := make([]*Cookie, 0, cookieCount) + for _, line := range h["Set-Cookie"] { + parts := strings.Split(textproto.TrimString(line), ";") + if len(parts) == 1 && parts[0] == "" { + continue + } + parts[0] = textproto.TrimString(parts[0]) + j := strings.Index(parts[0], "=") + if j < 0 { + continue + } + name, value := parts[0][:j], parts[0][j+1:] + if !isCookieNameValid(name) { + continue + } + value, ok := parseCookieValue(value, true) + if !ok { + continue + } + c := &Cookie{ + Name: name, + Value: value, + Raw: line, + } + for i := 1; i < len(parts); i++ { + parts[i] = textproto.TrimString(parts[i]) + if len(parts[i]) == 0 { + continue + } + + attr, val := parts[i], "" + if j := strings.Index(attr, "="); j >= 0 { + attr, val = attr[:j], attr[j+1:] + } + lowerAttr := strings.ToLower(attr) + val, ok = parseCookieValue(val, false) + if !ok { + c.Unparsed = append(c.Unparsed, parts[i]) + continue + } + switch lowerAttr { + case "samesite": + lowerVal := strings.ToLower(val) + switch lowerVal { + case "lax": + c.SameSite = SameSiteLaxMode + case "strict": + c.SameSite = SameSiteStrictMode + case "none": + c.SameSite = SameSiteNoneMode + default: + c.SameSite = SameSiteDefaultMode + } + continue + case "secure": + c.Secure = true + continue + case "httponly": + c.HttpOnly = true + continue + case "domain": + c.Domain = val + continue + case "max-age": + secs, err := strconv.Atoi(val) + if err != nil || secs != 0 && val[0] == '0' { + break + } + if secs <= 0 { + secs = -1 + } + c.MaxAge = secs + continue + case "expires": + c.RawExpires = val + exptime, err := time.Parse(time.RFC1123, val) + if err != nil { + exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", val) + if err != nil { + c.Expires = time.Time{} + break + } + } + c.Expires = exptime.UTC() + continue + case "path": + c.Path = val + continue + } + c.Unparsed = append(c.Unparsed, parts[i]) + } + cookies = append(cookies, c) + } + return cookies +} + +// SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers. +// The provided cookie must have a valid Name. Invalid cookies may be +// silently dropped. +func SetCookie(w ResponseWriter, cookie *Cookie) { + if v := cookie.String(); v != "" { + w.Header().Add("Set-Cookie", v) + } +} + +// String returns the serialization of the cookie for use in a Cookie +// header (if only Name and Value are set) or a Set-Cookie response +// header (if other fields are set). +// If c is nil or c.Name is invalid, the empty string is returned. +func (c *Cookie) String() string { + if c == nil || !isCookieNameValid(c.Name) { + return "" + } + // extraCookieLength derived from typical length of cookie attributes + // see RFC 6265 Sec 4.1. + const extraCookieLength = 110 + var b strings.Builder + b.Grow(len(c.Name) + len(c.Value) + len(c.Domain) + len(c.Path) + extraCookieLength) + b.WriteString(c.Name) + b.WriteRune('=') + b.WriteString(sanitizeCookieValue(c.Value)) + + if len(c.Path) > 0 { + b.WriteString("; Path=") + b.WriteString(sanitizeCookiePath(c.Path)) + } + if len(c.Domain) > 0 { + if validCookieDomain(c.Domain) { + // A c.Domain containing illegal characters is not + // sanitized but simply dropped which turns the cookie + // into a host-only cookie. A leading dot is okay + // but won't be sent. + d := c.Domain + if d[0] == '.' { + d = d[1:] + } + b.WriteString("; Domain=") + b.WriteString(d) + } else { + log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute", c.Domain) + } + } + var buf [len(TimeFormat)]byte + if validCookieExpires(c.Expires) { + b.WriteString("; Expires=") + b.Write(c.Expires.UTC().AppendFormat(buf[:0], TimeFormat)) + } + if c.MaxAge > 0 { + b.WriteString("; Max-Age=") + b.Write(strconv.AppendInt(buf[:0], int64(c.MaxAge), 10)) + } else if c.MaxAge < 0 { + b.WriteString("; Max-Age=0") + } + if c.HttpOnly { + b.WriteString("; HttpOnly") + } + if c.Secure { + b.WriteString("; Secure") + } + switch c.SameSite { + case SameSiteDefaultMode: + // Skip, default mode is obtained by not emitting the attribute. + case SameSiteNoneMode: + b.WriteString("; SameSite=None") + case SameSiteLaxMode: + b.WriteString("; SameSite=Lax") + case SameSiteStrictMode: + b.WriteString("; SameSite=Strict") + } + return b.String() +} + +// readCookies parses all "Cookie" values from the header h and +// returns the successfully parsed Cookies. +// +// if filter isn't empty, only cookies of that name are returned +func readCookies(h Header, filter string) []*Cookie { + lines := h["Cookie"] + if len(lines) == 0 { + return []*Cookie{} + } + + cookies := make([]*Cookie, 0, len(lines)+strings.Count(lines[0], ";")) + for _, line := range lines { + line = textproto.TrimString(line) + + var part string + for len(line) > 0 { // continue since we have rest + if splitIndex := strings.Index(line, ";"); splitIndex > 0 { + part, line = line[:splitIndex], line[splitIndex+1:] + } else { + part, line = line, "" + } + part = textproto.TrimString(part) + if len(part) == 0 { + continue + } + name, val := part, "" + if j := strings.Index(part, "="); j >= 0 { + name, val = name[:j], name[j+1:] + } + if !isCookieNameValid(name) { + continue + } + if filter != "" && filter != name { + continue + } + val, ok := parseCookieValue(val, true) + if !ok { + continue + } + cookies = append(cookies, &Cookie{Name: name, Value: val}) + } + } + return cookies +} + +// validCookieDomain reports whether v is a valid cookie domain-value. +func validCookieDomain(v string) bool { + if isCookieDomainName(v) { + return true + } + if net.ParseIP(v) != nil && !strings.Contains(v, ":") { + return true + } + return false +} + +// validCookieExpires reports whether v is a valid cookie expires-value. +func validCookieExpires(t time.Time) bool { + // IETF RFC 6265 Section 5.1.1.5, the year must not be less than 1601 + return t.Year() >= 1601 +} + +// isCookieDomainName reports whether s is a valid domain name or a valid +// domain name with a leading dot '.'. It is almost a direct copy of +// package net's isDomainName. +func isCookieDomainName(s string) bool { + if len(s) == 0 { + return false + } + if len(s) > 255 { + return false + } + + if s[0] == '.' { + // A cookie a domain attribute may start with a leading dot. + s = s[1:] + } + last := byte('.') + ok := false // Ok once we've seen a letter. + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': + // No '_' allowed here (in contrast to package net). + ok = true + partlen++ + case '0' <= c && c <= '9': + // fine + partlen++ + case c == '-': + // Byte before dash cannot be dot. + if last == '.' { + return false + } + partlen++ + case c == '.': + // Byte before dot cannot be dot, dash. + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + + return ok +} + +var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") + +func sanitizeCookieName(n string) string { + return cookieNameSanitizer.Replace(n) +} + +// sanitizeCookieValue produces a suitable cookie-value from v. +// https://tools.ietf.org/html/rfc6265#section-4.1.1 +// cookie-value = *cookie-octet / ( DQUOTE *cookie-octet DQUOTE ) +// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E +// ; US-ASCII characters excluding CTLs, +// ; whitespace DQUOTE, comma, semicolon, +// ; and backslash +// We loosen this as spaces and commas are common in cookie values +// but we produce a quoted cookie-value if and only if v contains +// commas or spaces. +// See https://golang.org/issue/7243 for the discussion. +func sanitizeCookieValue(v string) string { + v = sanitizeOrWarn("Cookie.Value", validCookieValueByte, v) + if len(v) == 0 { + return v + } + if strings.IndexByte(v, ' ') >= 0 || strings.IndexByte(v, ',') >= 0 { + return `"` + v + `"` + } + return v +} + +func validCookieValueByte(b byte) bool { + return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\' +} + +// path-av = "Path=" path-value +// path-value = +func sanitizeCookiePath(v string) string { + return sanitizeOrWarn("Cookie.Path", validCookiePathByte, v) +} + +func validCookiePathByte(b byte) bool { + return 0x20 <= b && b < 0x7f && b != ';' +} + +func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string { + ok := true + for i := 0; i < len(v); i++ { + if valid(v[i]) { + continue + } + log.Printf("net/http: invalid byte %q in %s; dropping invalid bytes", v[i], fieldName) + ok = false + break + } + if ok { + return v + } + buf := make([]byte, 0, len(v)) + for i := 0; i < len(v); i++ { + if b := v[i]; valid(b) { + buf = append(buf, b) + } + } + return string(buf) +} + +func parseCookieValue(raw string, allowDoubleQuote bool) (string, bool) { + // Strip the quotes, if present. + if allowDoubleQuote && len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' { + raw = raw[1 : len(raw)-1] + } + for i := 0; i < len(raw); i++ { + if !validCookieValueByte(raw[i]) { + return "", false + } + } + return raw, true +} + +func isCookieNameValid(raw string) bool { + if raw == "" { + return false + } + return strings.IndexFunc(raw, isNotToken) < 0 +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cookie_test.go b/vendor/github.com/lesismal/llib/std/net/http/cookie_test.go new file mode 100644 index 0000000..959713a --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cookie_test.go @@ -0,0 +1,619 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "os" + "reflect" + "strings" + "testing" + "time" +) + +var writeSetCookiesTests = []struct { + Cookie *Cookie + Raw string +}{ + { + &Cookie{Name: "cookie-1", Value: "v$1"}, + "cookie-1=v$1", + }, + { + &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}, + "cookie-2=two; Max-Age=3600", + }, + { + &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"}, + "cookie-3=three; Domain=example.com", + }, + { + &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"}, + "cookie-4=four; Path=/restricted/", + }, + { + &Cookie{Name: "cookie-5", Value: "five", Domain: "wrong;bad.abc"}, + "cookie-5=five", + }, + { + &Cookie{Name: "cookie-6", Value: "six", Domain: "bad-.abc"}, + "cookie-6=six", + }, + { + &Cookie{Name: "cookie-7", Value: "seven", Domain: "127.0.0.1"}, + "cookie-7=seven; Domain=127.0.0.1", + }, + { + &Cookie{Name: "cookie-8", Value: "eight", Domain: "::1"}, + "cookie-8=eight", + }, + { + &Cookie{Name: "cookie-9", Value: "expiring", Expires: time.Unix(1257894000, 0)}, + "cookie-9=expiring; Expires=Tue, 10 Nov 2009 23:00:00 GMT", + }, + // According to IETF 6265 Section 5.1.1.5, the year cannot be less than 1601 + { + &Cookie{Name: "cookie-10", Value: "expiring-1601", Expires: time.Date(1601, 1, 1, 1, 1, 1, 1, time.UTC)}, + "cookie-10=expiring-1601; Expires=Mon, 01 Jan 1601 01:01:01 GMT", + }, + { + &Cookie{Name: "cookie-11", Value: "invalid-expiry", Expires: time.Date(1600, 1, 1, 1, 1, 1, 1, time.UTC)}, + "cookie-11=invalid-expiry", + }, + { + &Cookie{Name: "cookie-12", Value: "samesite-default", SameSite: SameSiteDefaultMode}, + "cookie-12=samesite-default", + }, + { + &Cookie{Name: "cookie-13", Value: "samesite-lax", SameSite: SameSiteLaxMode}, + "cookie-13=samesite-lax; SameSite=Lax", + }, + { + &Cookie{Name: "cookie-14", Value: "samesite-strict", SameSite: SameSiteStrictMode}, + "cookie-14=samesite-strict; SameSite=Strict", + }, + { + &Cookie{Name: "cookie-15", Value: "samesite-none", SameSite: SameSiteNoneMode}, + "cookie-15=samesite-none; SameSite=None", + }, + // The "special" cookies have values containing commas or spaces which + // are disallowed by RFC 6265 but are common in the wild. + { + &Cookie{Name: "special-1", Value: "a z"}, + `special-1="a z"`, + }, + { + &Cookie{Name: "special-2", Value: " z"}, + `special-2=" z"`, + }, + { + &Cookie{Name: "special-3", Value: "a "}, + `special-3="a "`, + }, + { + &Cookie{Name: "special-4", Value: " "}, + `special-4=" "`, + }, + { + &Cookie{Name: "special-5", Value: "a,z"}, + `special-5="a,z"`, + }, + { + &Cookie{Name: "special-6", Value: ",z"}, + `special-6=",z"`, + }, + { + &Cookie{Name: "special-7", Value: "a,"}, + `special-7="a,"`, + }, + { + &Cookie{Name: "special-8", Value: ","}, + `special-8=","`, + }, + { + &Cookie{Name: "empty-value", Value: ""}, + `empty-value=`, + }, + { + nil, + ``, + }, + { + &Cookie{Name: ""}, + ``, + }, + { + &Cookie{Name: "\t"}, + ``, + }, + { + &Cookie{Name: "\r"}, + ``, + }, + { + &Cookie{Name: "a\nb", Value: "v"}, + ``, + }, + { + &Cookie{Name: "a\nb", Value: "v"}, + ``, + }, + { + &Cookie{Name: "a\rb", Value: "v"}, + ``, + }, +} + +func TestWriteSetCookies(t *testing.T) { + defer log.SetOutput(os.Stderr) + var logbuf bytes.Buffer + log.SetOutput(&logbuf) + + for i, tt := range writeSetCookiesTests { + if g, e := tt.Cookie.String(), tt.Raw; g != e { + t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, e, g) + continue + } + } + + if got, sub := logbuf.String(), "dropping domain attribute"; !strings.Contains(got, sub) { + t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) + } +} + +type headerOnlyResponseWriter Header + +func (ho headerOnlyResponseWriter) Header() Header { + return Header(ho) +} + +func (ho headerOnlyResponseWriter) Write([]byte) (int, error) { + panic("NOIMPL") +} + +func (ho headerOnlyResponseWriter) WriteHeader(int) { + panic("NOIMPL") +} + +func TestSetCookie(t *testing.T) { + m := make(Header) + SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-1", Value: "one", Path: "/restricted/"}) + SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}) + if l := len(m["Set-Cookie"]); l != 2 { + t.Fatalf("expected %d cookies, got %d", 2, l) + } + if g, e := m["Set-Cookie"][0], "cookie-1=one; Path=/restricted/"; g != e { + t.Errorf("cookie #1: want %q, got %q", e, g) + } + if g, e := m["Set-Cookie"][1], "cookie-2=two; Max-Age=3600"; g != e { + t.Errorf("cookie #2: want %q, got %q", e, g) + } +} + +var addCookieTests = []struct { + Cookies []*Cookie + Raw string +}{ + { + []*Cookie{}, + "", + }, + { + []*Cookie{{Name: "cookie-1", Value: "v$1"}}, + "cookie-1=v$1", + }, + { + []*Cookie{ + {Name: "cookie-1", Value: "v$1"}, + {Name: "cookie-2", Value: "v$2"}, + {Name: "cookie-3", Value: "v$3"}, + }, + "cookie-1=v$1; cookie-2=v$2; cookie-3=v$3", + }, +} + +func TestAddCookie(t *testing.T) { + for i, tt := range addCookieTests { + req, _ := NewRequest("GET", "http://example.com/", nil) + for _, c := range tt.Cookies { + req.AddCookie(c) + } + if g := req.Header.Get("Cookie"); g != tt.Raw { + t.Errorf("Test %d:\nwant: %s\n got: %s\n", i, tt.Raw, g) + continue + } + } +} + +var readSetCookiesTests = []struct { + Header Header + Cookies []*Cookie +}{ + { + Header{"Set-Cookie": {"Cookie-1=v$1"}}, + []*Cookie{{Name: "Cookie-1", Value: "v$1", Raw: "Cookie-1=v$1"}}, + }, + { + Header{"Set-Cookie": {"NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly"}}, + []*Cookie{{ + Name: "NID", + Value: "99=YsDT5i3E-CXax-", + Path: "/", + Domain: ".google.ch", + HttpOnly: true, + Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC), + RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT", + Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly", + }}, + }, + { + Header{"Set-Cookie": {".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly"}}, + []*Cookie{{ + Name: ".ASPXAUTH", + Value: "7E3AA", + Path: "/", + Expires: time.Date(2012, 3, 7, 14, 25, 6, 0, time.UTC), + RawExpires: "Wed, 07-Mar-2012 14:25:06 GMT", + HttpOnly: true, + Raw: ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly", + }}, + }, + { + Header{"Set-Cookie": {"ASP.NET_SessionId=foo; path=/; HttpOnly"}}, + []*Cookie{{ + Name: "ASP.NET_SessionId", + Value: "foo", + Path: "/", + HttpOnly: true, + Raw: "ASP.NET_SessionId=foo; path=/; HttpOnly", + }}, + }, + { + Header{"Set-Cookie": {"samesitedefault=foo; SameSite"}}, + []*Cookie{{ + Name: "samesitedefault", + Value: "foo", + SameSite: SameSiteDefaultMode, + Raw: "samesitedefault=foo; SameSite", + }}, + }, + { + Header{"Set-Cookie": {"samesiteinvalidisdefault=foo; SameSite=invalid"}}, + []*Cookie{{ + Name: "samesiteinvalidisdefault", + Value: "foo", + SameSite: SameSiteDefaultMode, + Raw: "samesiteinvalidisdefault=foo; SameSite=invalid", + }}, + }, + { + Header{"Set-Cookie": {"samesitelax=foo; SameSite=Lax"}}, + []*Cookie{{ + Name: "samesitelax", + Value: "foo", + SameSite: SameSiteLaxMode, + Raw: "samesitelax=foo; SameSite=Lax", + }}, + }, + { + Header{"Set-Cookie": {"samesitestrict=foo; SameSite=Strict"}}, + []*Cookie{{ + Name: "samesitestrict", + Value: "foo", + SameSite: SameSiteStrictMode, + Raw: "samesitestrict=foo; SameSite=Strict", + }}, + }, + { + Header{"Set-Cookie": {"samesitenone=foo; SameSite=None"}}, + []*Cookie{{ + Name: "samesitenone", + Value: "foo", + SameSite: SameSiteNoneMode, + Raw: "samesitenone=foo; SameSite=None", + }}, + }, + // Make sure we can properly read back the Set-Cookie headers we create + // for values containing spaces or commas: + { + Header{"Set-Cookie": {`special-1=a z`}}, + []*Cookie{{Name: "special-1", Value: "a z", Raw: `special-1=a z`}}, + }, + { + Header{"Set-Cookie": {`special-2=" z"`}}, + []*Cookie{{Name: "special-2", Value: " z", Raw: `special-2=" z"`}}, + }, + { + Header{"Set-Cookie": {`special-3="a "`}}, + []*Cookie{{Name: "special-3", Value: "a ", Raw: `special-3="a "`}}, + }, + { + Header{"Set-Cookie": {`special-4=" "`}}, + []*Cookie{{Name: "special-4", Value: " ", Raw: `special-4=" "`}}, + }, + { + Header{"Set-Cookie": {`special-5=a,z`}}, + []*Cookie{{Name: "special-5", Value: "a,z", Raw: `special-5=a,z`}}, + }, + { + Header{"Set-Cookie": {`special-6=",z"`}}, + []*Cookie{{Name: "special-6", Value: ",z", Raw: `special-6=",z"`}}, + }, + { + Header{"Set-Cookie": {`special-7=a,`}}, + []*Cookie{{Name: "special-7", Value: "a,", Raw: `special-7=a,`}}, + }, + { + Header{"Set-Cookie": {`special-8=","`}}, + []*Cookie{{Name: "special-8", Value: ",", Raw: `special-8=","`}}, + }, + + // TODO(bradfitz): users have reported seeing this in the + // wild, but do browsers handle it? RFC 6265 just says "don't + // do that" (section 3) and then never mentions header folding + // again. + // Header{"Set-Cookie": {"ASP.NET_SessionId=foo; path=/; HttpOnly, .ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly"}}, +} + +func toJSON(v interface{}) string { + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%#v", v) + } + return string(b) +} + +func TestReadSetCookies(t *testing.T) { + for i, tt := range readSetCookiesTests { + for n := 0; n < 2; n++ { // to verify readSetCookies doesn't mutate its input + c := readSetCookies(tt.Header) + if !reflect.DeepEqual(c, tt.Cookies) { + t.Errorf("#%d readSetCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies)) + continue + } + } + } +} + +var readCookiesTests = []struct { + Header Header + Filter string + Cookies []*Cookie +}{ + { + Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}}, + "", + []*Cookie{ + {Name: "Cookie-1", Value: "v$1"}, + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}}, + "c2", + []*Cookie{ + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {"Cookie-1=v$1; c2=v2"}}, + "", + []*Cookie{ + {Name: "Cookie-1", Value: "v$1"}, + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {"Cookie-1=v$1; c2=v2"}}, + "c2", + []*Cookie{ + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {`Cookie-1="v$1"; c2="v2"`}}, + "", + []*Cookie{ + {Name: "Cookie-1", Value: "v$1"}, + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {`Cookie-1="v$1"; c2=v2;`}}, + "", + []*Cookie{ + {Name: "Cookie-1", Value: "v$1"}, + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {``}}, + "", + []*Cookie{}, + }, +} + +func TestReadCookies(t *testing.T) { + for i, tt := range readCookiesTests { + for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input + c := readCookies(tt.Header, tt.Filter) + if !reflect.DeepEqual(c, tt.Cookies) { + t.Errorf("#%d readCookies:\nhave: %s\nwant: %s\n", i, toJSON(c), toJSON(tt.Cookies)) + continue + } + } + } +} + +func TestSetCookieDoubleQuotes(t *testing.T) { + res := &Response{Header: Header{}} + res.Header.Add("Set-Cookie", `quoted0=none; max-age=30`) + res.Header.Add("Set-Cookie", `quoted1="cookieValue"; max-age=31`) + res.Header.Add("Set-Cookie", `quoted2=cookieAV; max-age="32"`) + res.Header.Add("Set-Cookie", `quoted3="both"; max-age="33"`) + got := res.Cookies() + want := []*Cookie{ + {Name: "quoted0", Value: "none", MaxAge: 30}, + {Name: "quoted1", Value: "cookieValue", MaxAge: 31}, + {Name: "quoted2", Value: "cookieAV"}, + {Name: "quoted3", Value: "both"}, + } + if len(got) != len(want) { + t.Fatalf("got %d cookies, want %d", len(got), len(want)) + } + for i, w := range want { + g := got[i] + if g.Name != w.Name || g.Value != w.Value || g.MaxAge != w.MaxAge { + t.Errorf("cookie #%d:\ngot %v\nwant %v", i, g, w) + } + } +} + +func TestCookieSanitizeValue(t *testing.T) { + defer log.SetOutput(os.Stderr) + var logbuf bytes.Buffer + log.SetOutput(&logbuf) + + tests := []struct { + in, want string + }{ + {"foo", "foo"}, + {"foo;bar", "foobar"}, + {"foo\\bar", "foobar"}, + {"foo\"bar", "foobar"}, + {"\x00\x7e\x7f\x80", "\x7e"}, + {`"withquotes"`, "withquotes"}, + {"a z", `"a z"`}, + {" z", `" z"`}, + {"a ", `"a "`}, + {"a,z", `"a,z"`}, + {",z", `",z"`}, + {"a,", `"a,"`}, + } + for _, tt := range tests { + if got := sanitizeCookieValue(tt.in); got != tt.want { + t.Errorf("sanitizeCookieValue(%q) = %q; want %q", tt.in, got, tt.want) + } + } + + if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) { + t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) + } +} + +func TestCookieSanitizePath(t *testing.T) { + defer log.SetOutput(os.Stderr) + var logbuf bytes.Buffer + log.SetOutput(&logbuf) + + tests := []struct { + in, want string + }{ + {"/path", "/path"}, + {"/path with space/", "/path with space/"}, + {"/just;no;semicolon\x00orstuff/", "/justnosemicolonorstuff/"}, + } + for _, tt := range tests { + if got := sanitizeCookiePath(tt.in); got != tt.want { + t.Errorf("sanitizeCookiePath(%q) = %q; want %q", tt.in, got, tt.want) + } + } + + if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) { + t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) + } +} + +func BenchmarkCookieString(b *testing.B) { + const wantCookieString = `cookie-9=i3e01nf61b6t23bvfmplnanol3; Path=/restricted/; Domain=example.com; Expires=Tue, 10 Nov 2009 23:00:00 GMT; Max-Age=3600` + c := &Cookie{ + Name: "cookie-9", + Value: "i3e01nf61b6t23bvfmplnanol3", + Expires: time.Unix(1257894000, 0), + Path: "/restricted/", + Domain: ".example.com", + MaxAge: 3600, + } + var benchmarkCookieString string + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkCookieString = c.String() + } + if have, want := benchmarkCookieString, wantCookieString; have != want { + b.Fatalf("Have: %v Want: %v", have, want) + } +} + +func BenchmarkReadSetCookies(b *testing.B) { + header := Header{ + "Set-Cookie": { + "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly", + ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly", + }, + } + wantCookies := []*Cookie{ + { + Name: "NID", + Value: "99=YsDT5i3E-CXax-", + Path: "/", + Domain: ".google.ch", + HttpOnly: true, + Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC), + RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT", + Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly", + }, + { + Name: ".ASPXAUTH", + Value: "7E3AA", + Path: "/", + Expires: time.Date(2012, 3, 7, 14, 25, 6, 0, time.UTC), + RawExpires: "Wed, 07-Mar-2012 14:25:06 GMT", + HttpOnly: true, + Raw: ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly", + }, + } + var c []*Cookie + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + c = readSetCookies(header) + } + if !reflect.DeepEqual(c, wantCookies) { + b.Fatalf("readSetCookies:\nhave: %s\nwant: %s\n", toJSON(c), toJSON(wantCookies)) + } +} + +func BenchmarkReadCookies(b *testing.B) { + header := Header{ + "Cookie": { + `de=; client_region=0; rpld1=0:hispeed.ch|20:che|21:zh|22:zurich|23:47.36|24:8.53|; rpld0=1:08|; backplane-channel=newspaper.com:1471; devicetype=0; osfam=0; rplmct=2; s_pers=%20s_vmonthnum%3D1472680800496%2526vn%253D1%7C1472680800496%3B%20s_nr%3D1471686767664-New%7C1474278767664%3B%20s_lv%3D1471686767669%7C1566294767669%3B%20s_lv_s%3DFirst%2520Visit%7C1471688567669%3B%20s_monthinvisit%3Dtrue%7C1471688567677%3B%20gvp_p5%3Dsports%253Ablog%253Aearly-lead%2520-%2520184693%2520-%252020160820%2520-%2520u-s%7C1471688567681%3B%20gvp_p51%3Dwp%2520-%2520sports%7C1471688567684%3B; s_sess=%20s_wp_ep%3Dhomepage%3B%20s._ref%3Dhttps%253A%252F%252Fwww.google.ch%252F%3B%20s_cc%3Dtrue%3B%20s_ppvl%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_ppv%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-s-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_dslv%3DFirst%2520Visit%3B%20s_sq%3Dwpninewspapercom%253D%252526pid%25253Dsports%2525253Ablog%2525253Aearly-lead%25252520-%25252520184693%25252520-%2525252020160820%25252520-%25252520u-s%252526pidt%25253D1%252526oid%25253Dhttps%2525253A%2525252F%2525252Fwww.newspaper.com%2525252F%2525253Fnid%2525253Dmenu_nav_homepage%252526ot%25253DA%3B`, + }, + } + wantCookies := []*Cookie{ + {Name: "de", Value: ""}, + {Name: "client_region", Value: "0"}, + {Name: "rpld1", Value: "0:hispeed.ch|20:che|21:zh|22:zurich|23:47.36|24:8.53|"}, + {Name: "rpld0", Value: "1:08|"}, + {Name: "backplane-channel", Value: "newspaper.com:1471"}, + {Name: "devicetype", Value: "0"}, + {Name: "osfam", Value: "0"}, + {Name: "rplmct", Value: "2"}, + {Name: "s_pers", Value: "%20s_vmonthnum%3D1472680800496%2526vn%253D1%7C1472680800496%3B%20s_nr%3D1471686767664-New%7C1474278767664%3B%20s_lv%3D1471686767669%7C1566294767669%3B%20s_lv_s%3DFirst%2520Visit%7C1471688567669%3B%20s_monthinvisit%3Dtrue%7C1471688567677%3B%20gvp_p5%3Dsports%253Ablog%253Aearly-lead%2520-%2520184693%2520-%252020160820%2520-%2520u-s%7C1471688567681%3B%20gvp_p51%3Dwp%2520-%2520sports%7C1471688567684%3B"}, + {Name: "s_sess", Value: "%20s_wp_ep%3Dhomepage%3B%20s._ref%3Dhttps%253A%252F%252Fwww.google.ch%252F%3B%20s_cc%3Dtrue%3B%20s_ppvl%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_ppv%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-s-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_dslv%3DFirst%2520Visit%3B%20s_sq%3Dwpninewspapercom%253D%252526pid%25253Dsports%2525253Ablog%2525253Aearly-lead%25252520-%25252520184693%25252520-%2525252020160820%25252520-%25252520u-s%252526pidt%25253D1%252526oid%25253Dhttps%2525253A%2525252F%2525252Fwww.newspaper.com%2525252F%2525253Fnid%2525253Dmenu_nav_homepage%252526ot%25253DA%3B"}, + } + var c []*Cookie + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + c = readCookies(header, "") + } + if !reflect.DeepEqual(c, wantCookies) { + b.Fatalf("readCookies:\nhave: %s\nwant: %s\n", toJSON(c), toJSON(wantCookies)) + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cookiejar/dummy_publicsuffix_test.go b/vendor/github.com/lesismal/llib/std/net/http/cookiejar/dummy_publicsuffix_test.go new file mode 100644 index 0000000..639d4eb --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cookiejar/dummy_publicsuffix_test.go @@ -0,0 +1,21 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar_test + +import "github.com/lesismal/llib/std/net/http/cookiejar" + +type dummypsl struct { + List cookiejar.PublicSuffixList +} + +func (dummypsl) PublicSuffix(domain string) string { + return domain +} + +func (dummypsl) String() string { + return "dummy" +} + +var publicsuffix = dummypsl{} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cookiejar/example_test.go b/vendor/github.com/lesismal/llib/std/net/http/cookiejar/example_test.go new file mode 100644 index 0000000..d5b8151 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cookiejar/example_test.go @@ -0,0 +1,65 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar_test + +import ( + "fmt" + "github.com/lesismal/llib/std/net/http/cookiejar" + "github.com/lesismal/llib/std/net/http/httptest" + "log" + "net/http" + "net/url" +) + +func ExampleNew() { + // Start a server to give us cookies. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if cookie, err := r.Cookie("Flavor"); err != nil { + http.SetCookie(w, &http.Cookie{Name: "Flavor", Value: "Chocolate Chip"}) + } else { + cookie.Value = "Oatmeal Raisin" + http.SetCookie(w, cookie) + } + })) + defer ts.Close() + + u, err := url.Parse(ts.URL) + if err != nil { + log.Fatal(err) + } + + // All users of cookiejar should import "golang.org/x/net/publicsuffix" + jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + if err != nil { + log.Fatal(err) + } + + client := &http.Client{ + Jar: jar, + } + + if _, err = client.Get(u.String()); err != nil { + log.Fatal(err) + } + + fmt.Println("After 1st request:") + for _, cookie := range jar.Cookies(u) { + fmt.Printf(" %s: %s\n", cookie.Name, cookie.Value) + } + + if _, err = client.Get(u.String()); err != nil { + log.Fatal(err) + } + + fmt.Println("After 2nd request:") + for _, cookie := range jar.Cookies(u) { + fmt.Printf(" %s: %s\n", cookie.Name, cookie.Value) + } + // Output: + // After 1st request: + // Flavor: Chocolate Chip + // After 2nd request: + // Flavor: Oatmeal Raisin +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cookiejar/jar.go b/vendor/github.com/lesismal/llib/std/net/http/cookiejar/jar.go new file mode 100644 index 0000000..9f19917 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cookiejar/jar.go @@ -0,0 +1,503 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cookiejar implements an in-memory RFC 6265-compliant http.CookieJar. +package cookiejar + +import ( + "errors" + "fmt" + "net" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "time" +) + +// PublicSuffixList provides the public suffix of a domain. For example: +// - the public suffix of "example.com" is "com", +// - the public suffix of "foo1.foo2.foo3.co.uk" is "co.uk", and +// - the public suffix of "bar.pvt.k12.ma.us" is "pvt.k12.ma.us". +// +// Implementations of PublicSuffixList must be safe for concurrent use by +// multiple goroutines. +// +// An implementation that always returns "" is valid and may be useful for +// testing but it is not secure: it means that the HTTP server for foo.com can +// set a cookie for bar.com. +// +// A public suffix list implementation is in the package +// golang.org/x/net/publicsuffix. +type PublicSuffixList interface { + // PublicSuffix returns the public suffix of domain. + // + // TODO: specify which of the caller and callee is responsible for IP + // addresses, for leading and trailing dots, for case sensitivity, and + // for IDN/Punycode. + PublicSuffix(domain string) string + + // String returns a description of the source of this public suffix + // list. The description will typically contain something like a time + // stamp or version number. + String() string +} + +// Options are the options for creating a new Jar. +type Options struct { + // PublicSuffixList is the public suffix list that determines whether + // an HTTP server can set a cookie for a domain. + // + // A nil value is valid and may be useful for testing but it is not + // secure: it means that the HTTP server for foo.co.uk can set a cookie + // for bar.co.uk. + PublicSuffixList PublicSuffixList +} + +// Jar implements the http.CookieJar interface from the net/http package. +type Jar struct { + psList PublicSuffixList + + // mu locks the remaining fields. + mu sync.Mutex + + // entries is a set of entries, keyed by their eTLD+1 and subkeyed by + // their name/domain/path. + entries map[string]map[string]entry + + // nextSeqNum is the next sequence number assigned to a new cookie + // created SetCookies. + nextSeqNum uint64 +} + +// New returns a new cookie jar. A nil *Options is equivalent to a zero +// Options. +func New(o *Options) (*Jar, error) { + jar := &Jar{ + entries: make(map[string]map[string]entry), + } + if o != nil { + jar.psList = o.PublicSuffixList + } + return jar, nil +} + +// entry is the internal representation of a cookie. +// +// This struct type is not used outside of this package per se, but the exported +// fields are those of RFC 6265. +type entry struct { + Name string + Value string + Domain string + Path string + SameSite string + Secure bool + HttpOnly bool + Persistent bool + HostOnly bool + Expires time.Time + Creation time.Time + LastAccess time.Time + + // seqNum is a sequence number so that Cookies returns cookies in a + // deterministic order, even for cookies that have equal Path length and + // equal Creation time. This simplifies testing. + seqNum uint64 +} + +// id returns the domain;path;name triple of e as an id. +func (e *entry) id() string { + return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name) +} + +// shouldSend determines whether e's cookie qualifies to be included in a +// request to host/path. It is the caller's responsibility to check if the +// cookie is expired. +func (e *entry) shouldSend(https bool, host, path string) bool { + return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure) +} + +// domainMatch implements "domain-match" of RFC 6265 section 5.1.3. +func (e *entry) domainMatch(host string) bool { + if e.Domain == host { + return true + } + return !e.HostOnly && hasDotSuffix(host, e.Domain) +} + +// pathMatch implements "path-match" according to RFC 6265 section 5.1.4. +func (e *entry) pathMatch(requestPath string) bool { + if requestPath == e.Path { + return true + } + if strings.HasPrefix(requestPath, e.Path) { + if e.Path[len(e.Path)-1] == '/' { + return true // The "/any/" matches "/any/path" case. + } else if requestPath[len(e.Path)] == '/' { + return true // The "/any" matches "/any/path" case. + } + } + return false +} + +// hasDotSuffix reports whether s ends in "."+suffix. +func hasDotSuffix(s, suffix string) bool { + return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix +} + +// Cookies implements the Cookies method of the http.CookieJar interface. +// +// It returns an empty slice if the URL's scheme is not HTTP or HTTPS. +func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) { + return j.cookies(u, time.Now()) +} + +// cookies is like Cookies but takes the current time as a parameter. +func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { + if u.Scheme != "http" && u.Scheme != "https" { + return cookies + } + host, err := canonicalHost(u.Host) + if err != nil { + return cookies + } + key := jarKey(host, j.psList) + + j.mu.Lock() + defer j.mu.Unlock() + + submap := j.entries[key] + if submap == nil { + return cookies + } + + https := u.Scheme == "https" + path := u.Path + if path == "" { + path = "/" + } + + modified := false + var selected []entry + for id, e := range submap { + if e.Persistent && !e.Expires.After(now) { + delete(submap, id) + modified = true + continue + } + if !e.shouldSend(https, host, path) { + continue + } + e.LastAccess = now + submap[id] = e + selected = append(selected, e) + modified = true + } + if modified { + if len(submap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = submap + } + } + + // sort according to RFC 6265 section 5.4 point 2: by longest + // path and then by earliest creation time. + sort.Slice(selected, func(i, j int) bool { + s := selected + if len(s[i].Path) != len(s[j].Path) { + return len(s[i].Path) > len(s[j].Path) + } + if !s[i].Creation.Equal(s[j].Creation) { + return s[i].Creation.Before(s[j].Creation) + } + return s[i].seqNum < s[j].seqNum + }) + for _, e := range selected { + cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value}) + } + + return cookies +} + +// SetCookies implements the SetCookies method of the http.CookieJar interface. +// +// It does nothing if the URL's scheme is not HTTP or HTTPS. +func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) { + j.setCookies(u, cookies, time.Now()) +} + +// setCookies is like SetCookies but takes the current time as parameter. +func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) { + if len(cookies) == 0 { + return + } + if u.Scheme != "http" && u.Scheme != "https" { + return + } + host, err := canonicalHost(u.Host) + if err != nil { + return + } + key := jarKey(host, j.psList) + defPath := defaultPath(u.Path) + + j.mu.Lock() + defer j.mu.Unlock() + + submap := j.entries[key] + + modified := false + for _, cookie := range cookies { + e, remove, err := j.newEntry(cookie, now, defPath, host) + if err != nil { + continue + } + id := e.id() + if remove { + if submap != nil { + if _, ok := submap[id]; ok { + delete(submap, id) + modified = true + } + } + continue + } + if submap == nil { + submap = make(map[string]entry) + } + + if old, ok := submap[id]; ok { + e.Creation = old.Creation + e.seqNum = old.seqNum + } else { + e.Creation = now + e.seqNum = j.nextSeqNum + j.nextSeqNum++ + } + e.LastAccess = now + submap[id] = e + modified = true + } + + if modified { + if len(submap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = submap + } + } +} + +// canonicalHost strips port from host if present and returns the canonicalized +// host name. +func canonicalHost(host string) (string, error) { + var err error + host = strings.ToLower(host) + if hasPort(host) { + host, _, err = net.SplitHostPort(host) + if err != nil { + return "", err + } + } + if strings.HasSuffix(host, ".") { + // Strip trailing dot from fully qualified domain names. + host = host[:len(host)-1] + } + return toASCII(host) +} + +// hasPort reports whether host contains a port number. host may be a host +// name, an IPv4 or an IPv6 address. +func hasPort(host string) bool { + colons := strings.Count(host, ":") + if colons == 0 { + return false + } + if colons == 1 { + return true + } + return host[0] == '[' && strings.Contains(host, "]:") +} + +// jarKey returns the key to use for a jar. +func jarKey(host string, psl PublicSuffixList) string { + if isIP(host) { + return host + } + + var i int + if psl == nil { + i = strings.LastIndex(host, ".") + if i <= 0 { + return host + } + } else { + suffix := psl.PublicSuffix(host) + if suffix == host { + return host + } + i = len(host) - len(suffix) + if i <= 0 || host[i-1] != '.' { + // The provided public suffix list psl is broken. + // Storing cookies under host is a safe stopgap. + return host + } + // Only len(suffix) is used to determine the jar key from + // here on, so it is okay if psl.PublicSuffix("www.buggy.psl") + // returns "com" as the jar key is generated from host. + } + prevDot := strings.LastIndex(host[:i-1], ".") + return host[prevDot+1:] +} + +// isIP reports whether host is an IP address. +func isIP(host string) bool { + return net.ParseIP(host) != nil +} + +// defaultPath returns the directory part of an URL's path according to +// RFC 6265 section 5.1.4. +func defaultPath(path string) string { + if len(path) == 0 || path[0] != '/' { + return "/" // Path is empty or malformed. + } + + i := strings.LastIndex(path, "/") // Path starts with "/", so i != -1. + if i == 0 { + return "/" // Path has the form "/abc". + } + return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/". +} + +// newEntry creates an entry from a http.Cookie c. now is the current time and +// is compared to c.Expires to determine deletion of c. defPath and host are the +// default-path and the canonical host name of the URL c was received from. +// +// remove records whether the jar should delete this cookie, as it has already +// expired with respect to now. In this case, e may be incomplete, but it will +// be valid to call e.id (which depends on e's Name, Domain and Path). +// +// A malformed c.Domain will result in an error. +func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) { + e.Name = c.Name + + if c.Path == "" || c.Path[0] != '/' { + e.Path = defPath + } else { + e.Path = c.Path + } + + e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain) + if err != nil { + return e, false, err + } + + // MaxAge takes precedence over Expires. + if c.MaxAge < 0 { + return e, true, nil + } else if c.MaxAge > 0 { + e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second) + e.Persistent = true + } else { + if c.Expires.IsZero() { + e.Expires = endOfTime + e.Persistent = false + } else { + if !c.Expires.After(now) { + return e, true, nil + } + e.Expires = c.Expires + e.Persistent = true + } + } + + e.Value = c.Value + e.Secure = c.Secure + e.HttpOnly = c.HttpOnly + + switch c.SameSite { + case http.SameSiteDefaultMode: + e.SameSite = "SameSite" + case http.SameSiteStrictMode: + e.SameSite = "SameSite=Strict" + case http.SameSiteLaxMode: + e.SameSite = "SameSite=Lax" + } + + return e, false, nil +} + +var ( + errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute") + errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute") + errNoHostname = errors.New("cookiejar: no host name available (IP only)") +) + +// endOfTime is the time when session (non-persistent) cookies expire. +// This instant is representable in most date/time formats (not just +// Go's time.Time) and should be far enough in the future. +var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) + +// domainAndType determines the cookie's domain and hostOnly attribute. +func (j *Jar) domainAndType(host, domain string) (string, bool, error) { + if domain == "" { + // No domain attribute in the SetCookie header indicates a + // host cookie. + return host, true, nil + } + + if isIP(host) { + // According to RFC 6265 domain-matching includes not being + // an IP address. + // TODO: This might be relaxed as in common browsers. + return "", false, errNoHostname + } + + // From here on: If the cookie is valid, it is a domain cookie (with + // the one exception of a public suffix below). + // See RFC 6265 section 5.2.3. + if domain[0] == '.' { + domain = domain[1:] + } + + if len(domain) == 0 || domain[0] == '.' { + // Received either "Domain=." or "Domain=..some.thing", + // both are illegal. + return "", false, errMalformedDomain + } + domain = strings.ToLower(domain) + + if domain[len(domain)-1] == '.' { + // We received stuff like "Domain=www.example.com.". + // Browsers do handle such stuff (actually differently) but + // RFC 6265 seems to be clear here (e.g. section 4.1.2.3) in + // requiring a reject. 4.1.2.3 is not normative, but + // "Domain Matching" (5.1.3) and "Canonicalized Host Names" + // (5.1.2) are. + return "", false, errMalformedDomain + } + + // See RFC 6265 section 5.3 #5. + if j.psList != nil { + if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) { + if host == domain { + // This is the one exception in which a cookie + // with a domain attribute is a host cookie. + return host, true, nil + } + return "", false, errIllegalDomain + } + } + + // The domain must domain-match host: www.mycompany.com cannot + // set cookies for .ourcompetitors.com. + if host != domain && !hasDotSuffix(host, domain) { + return "", false, errIllegalDomain + } + + return domain, false, nil +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cookiejar/jar_test.go b/vendor/github.com/lesismal/llib/std/net/http/cookiejar/jar_test.go new file mode 100644 index 0000000..47fb1ab --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cookiejar/jar_test.go @@ -0,0 +1,1322 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +import ( + "fmt" + "net/http" + "net/url" + "sort" + "strings" + "testing" + "time" +) + +// tNow is the synthetic current time used as now during testing. +var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC) + +// testPSL implements PublicSuffixList with just two rules: "co.uk" +// and the default rule "*". +// The implementation has two intentional bugs: +// PublicSuffix("www.buggy.psl") == "xy" +// PublicSuffix("www2.buggy.psl") == "com" +type testPSL struct{} + +func (testPSL) String() string { + return "testPSL" +} +func (testPSL) PublicSuffix(d string) string { + if d == "co.uk" || strings.HasSuffix(d, ".co.uk") { + return "co.uk" + } + if d == "www.buggy.psl" { + return "xy" + } + if d == "www2.buggy.psl" { + return "com" + } + return d[strings.LastIndex(d, ".")+1:] +} + +// newTestJar creates an empty Jar with testPSL as the public suffix list. +func newTestJar() *Jar { + jar, err := New(&Options{PublicSuffixList: testPSL{}}) + if err != nil { + panic(err) + } + return jar +} + +var hasDotSuffixTests = [...]struct { + s, suffix string +}{ + {"", ""}, + {"", "."}, + {"", "x"}, + {".", ""}, + {".", "."}, + {".", ".."}, + {".", "x"}, + {".", "x."}, + {".", ".x"}, + {".", ".x."}, + {"x", ""}, + {"x", "."}, + {"x", ".."}, + {"x", "x"}, + {"x", "x."}, + {"x", ".x"}, + {"x", ".x."}, + {".x", ""}, + {".x", "."}, + {".x", ".."}, + {".x", "x"}, + {".x", "x."}, + {".x", ".x"}, + {".x", ".x."}, + {"x.", ""}, + {"x.", "."}, + {"x.", ".."}, + {"x.", "x"}, + {"x.", "x."}, + {"x.", ".x"}, + {"x.", ".x."}, + {"com", ""}, + {"com", "m"}, + {"com", "om"}, + {"com", "com"}, + {"com", ".com"}, + {"com", "x.com"}, + {"com", "xcom"}, + {"com", "xorg"}, + {"com", "org"}, + {"com", "rg"}, + {"foo.com", ""}, + {"foo.com", "m"}, + {"foo.com", "om"}, + {"foo.com", "com"}, + {"foo.com", ".com"}, + {"foo.com", "o.com"}, + {"foo.com", "oo.com"}, + {"foo.com", "foo.com"}, + {"foo.com", ".foo.com"}, + {"foo.com", "x.foo.com"}, + {"foo.com", "xfoo.com"}, + {"foo.com", "xfoo.org"}, + {"foo.com", "foo.org"}, + {"foo.com", "oo.org"}, + {"foo.com", "o.org"}, + {"foo.com", ".org"}, + {"foo.com", "org"}, + {"foo.com", "rg"}, +} + +func TestHasDotSuffix(t *testing.T) { + for _, tc := range hasDotSuffixTests { + got := hasDotSuffix(tc.s, tc.suffix) + want := strings.HasSuffix(tc.s, "."+tc.suffix) + if got != want { + t.Errorf("s=%q, suffix=%q: got %v, want %v", tc.s, tc.suffix, got, want) + } + } +} + +var canonicalHostTests = map[string]string{ + "www.example.com": "www.example.com", + "WWW.EXAMPLE.COM": "www.example.com", + "wWw.eXAmple.CoM": "www.example.com", + "www.example.com:80": "www.example.com", + "192.168.0.10": "192.168.0.10", + "192.168.0.5:8080": "192.168.0.5", + "2001:4860:0:2001::68": "2001:4860:0:2001::68", + "[2001:4860:0:::68]:8080": "2001:4860:0:::68", + "www.bücher.de": "www.xn--bcher-kva.de", + "www.example.com.": "www.example.com", + // TODO: Fix canonicalHost so that all of the following malformed + // domain names trigger an error. (This list is not exhaustive, e.g. + // malformed internationalized domain names are missing.) + ".": "", + "..": ".", + "...": "..", + ".net": ".net", + ".net.": ".net", + "a..": "a.", + "b.a..": "b.a.", + "weird.stuff...": "weird.stuff..", + "[bad.unmatched.bracket:": "error", +} + +func TestCanonicalHost(t *testing.T) { + for h, want := range canonicalHostTests { + got, err := canonicalHost(h) + if want == "error" { + if err == nil { + t.Errorf("%q: got %q and nil error, want non-nil", h, got) + } + continue + } + if err != nil { + t.Errorf("%q: %v", h, err) + continue + } + if got != want { + t.Errorf("%q: got %q, want %q", h, got, want) + continue + } + } +} + +var hasPortTests = map[string]bool{ + "www.example.com": false, + "www.example.com:80": true, + "127.0.0.1": false, + "127.0.0.1:8080": true, + "2001:4860:0:2001::68": false, + "[2001::0:::68]:80": true, +} + +func TestHasPort(t *testing.T) { + for host, want := range hasPortTests { + if got := hasPort(host); got != want { + t.Errorf("%q: got %t, want %t", host, got, want) + } + } +} + +var jarKeyTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "bbc.co.uk", + "www.bbc.co.uk": "bbc.co.uk", + "bbc.co.uk": "bbc.co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", + "www.buggy.psl": "www.buggy.psl", + "www2.buggy.psl": "buggy.psl", + // The following are actual outputs of canonicalHost for + // malformed inputs to canonicalHost (see above). + "": "", + ".": ".", + "..": ".", + ".net": ".net", + "a.": "a.", + "b.a.": "a.", + "weird.stuff..": ".", +} + +func TestJarKey(t *testing.T) { + for host, want := range jarKeyTests { + if got := jarKey(host, testPSL{}); got != want { + t.Errorf("%q: got %q, want %q", host, got, want) + } + } +} + +var jarKeyNilPSLTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "co.uk", + "www.bbc.co.uk": "co.uk", + "bbc.co.uk": "co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", + // The following are actual outputs of canonicalHost for + // malformed inputs to canonicalHost. + "": "", + ".": ".", + "..": "..", + ".net": ".net", + "a.": "a.", + "b.a.": "a.", + "weird.stuff..": "stuff..", +} + +func TestJarKeyNilPSL(t *testing.T) { + for host, want := range jarKeyNilPSLTests { + if got := jarKey(host, nil); got != want { + t.Errorf("%q: got %q, want %q", host, got, want) + } + } +} + +var isIPTests = map[string]bool{ + "127.0.0.1": true, + "1.2.3.4": true, + "2001:4860:0:2001::68": true, + "example.com": false, + "1.1.1.300": false, + "www.foo.bar.net": false, + "123.foo.bar.net": false, +} + +func TestIsIP(t *testing.T) { + for host, want := range isIPTests { + if got := isIP(host); got != want { + t.Errorf("%q: got %t, want %t", host, got, want) + } + } +} + +var defaultPathTests = map[string]string{ + "/": "/", + "/abc": "/", + "/abc/": "/abc", + "/abc/xyz": "/abc", + "/abc/xyz/": "/abc/xyz", + "/a/b/c.html": "/a/b", + "": "/", + "strange": "/", + "//": "/", + "/a//b": "/a/", + "/a/./b": "/a/.", + "/a/../b": "/a/..", +} + +func TestDefaultPath(t *testing.T) { + for path, want := range defaultPathTests { + if got := defaultPath(path); got != want { + t.Errorf("%q: got %q, want %q", path, got, want) + } + } +} + +var domainAndTypeTests = [...]struct { + host string // host Set-Cookie header was received from + domain string // domain attribute in Set-Cookie header + wantDomain string // expected domain of cookie + wantHostOnly bool // expected host-cookie flag + wantErr error // expected error +}{ + {"www.example.com", "", "www.example.com", true, nil}, + {"127.0.0.1", "", "127.0.0.1", true, nil}, + {"2001:4860:0:2001::68", "", "2001:4860:0:2001::68", true, nil}, + {"www.example.com", "example.com", "example.com", false, nil}, + {"www.example.com", ".example.com", "example.com", false, nil}, + {"www.example.com", "www.example.com", "www.example.com", false, nil}, + {"www.example.com", ".www.example.com", "www.example.com", false, nil}, + {"foo.sso.example.com", "sso.example.com", "sso.example.com", false, nil}, + {"bar.co.uk", "bar.co.uk", "bar.co.uk", false, nil}, + {"foo.bar.co.uk", ".bar.co.uk", "bar.co.uk", false, nil}, + {"127.0.0.1", "127.0.0.1", "", false, errNoHostname}, + {"2001:4860:0:2001::68", "2001:4860:0:2001::68", "2001:4860:0:2001::68", false, errNoHostname}, + {"www.example.com", ".", "", false, errMalformedDomain}, + {"www.example.com", "..", "", false, errMalformedDomain}, + {"www.example.com", "other.com", "", false, errIllegalDomain}, + {"www.example.com", "com", "", false, errIllegalDomain}, + {"www.example.com", ".com", "", false, errIllegalDomain}, + {"foo.bar.co.uk", ".co.uk", "", false, errIllegalDomain}, + {"127.www.0.0.1", "127.0.0.1", "", false, errIllegalDomain}, + {"com", "", "com", true, nil}, + {"com", "com", "com", true, nil}, + {"com", ".com", "com", true, nil}, + {"co.uk", "", "co.uk", true, nil}, + {"co.uk", "co.uk", "co.uk", true, nil}, + {"co.uk", ".co.uk", "co.uk", true, nil}, +} + +func TestDomainAndType(t *testing.T) { + jar := newTestJar() + for _, tc := range domainAndTypeTests { + domain, hostOnly, err := jar.domainAndType(tc.host, tc.domain) + if err != tc.wantErr { + t.Errorf("%q/%q: got %q error, want %q", + tc.host, tc.domain, err, tc.wantErr) + continue + } + if err != nil { + continue + } + if domain != tc.wantDomain || hostOnly != tc.wantHostOnly { + t.Errorf("%q/%q: got %q/%t want %q/%t", + tc.host, tc.domain, domain, hostOnly, + tc.wantDomain, tc.wantHostOnly) + } + } +} + +// expiresIn creates an expires attribute delta seconds from tNow. +func expiresIn(delta int) string { + t := tNow.Add(time.Duration(delta) * time.Second) + return "expires=" + t.Format(time.RFC1123) +} + +// mustParseURL parses s to an URL and panics on error. +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil || u.Scheme == "" || u.Host == "" { + panic(fmt.Sprintf("Unable to parse URL %s.", s)) + } + return u +} + +// jarTest encapsulates the following actions on a jar: +// 1. Perform SetCookies with fromURL and the cookies from setCookies. +// (Done at time tNow + 0 ms.) +// 2. Check that the entries in the jar matches content. +// (Done at time tNow + 1001 ms.) +// 3. For each query in tests: Check that Cookies with toURL yields the +// cookies in want. +// (Query n done at tNow + (n+2)*1001 ms.) +type jarTest struct { + description string // The description of what this test is supposed to test + fromURL string // The full URL of the request from which Set-Cookie headers where received + setCookies []string // All the cookies received from fromURL + content string // The whole (non-expired) content of the jar + queries []query // Queries to test the Jar.Cookies method +} + +// query contains one test of the cookies returned from Jar.Cookies. +type query struct { + toURL string // the URL in the Cookies call + want string // the expected list of cookies (order matters) +} + +// run runs the jarTest. +func (test jarTest) run(t *testing.T, jar *Jar) { + now := tNow + + // Populate jar with cookies. + setCookies := make([]*http.Cookie, len(test.setCookies)) + for i, cs := range test.setCookies { + cookies := (&http.Response{Header: http.Header{"Set-Cookie": {cs}}}).Cookies() + if len(cookies) != 1 { + panic(fmt.Sprintf("Wrong cookie line %q: %#v", cs, cookies)) + } + setCookies[i] = cookies[0] + } + jar.setCookies(mustParseURL(test.fromURL), setCookies, now) + now = now.Add(1001 * time.Millisecond) + + // Serialize non-expired entries in the form "name1=val1 name2=val2". + var cs []string + for _, submap := range jar.entries { + for _, cookie := range submap { + if !cookie.Expires.After(now) { + continue + } + cs = append(cs, cookie.Name+"="+cookie.Value) + } + } + sort.Strings(cs) + got := strings.Join(cs, " ") + + // Make sure jar content matches our expectations. + if got != test.content { + t.Errorf("Test %q Content\ngot %q\nwant %q", + test.description, got, test.content) + } + + // Test different calls to Cookies. + for i, query := range test.queries { + now = now.Add(1001 * time.Millisecond) + var s []string + for _, c := range jar.cookies(mustParseURL(query.toURL), now) { + s = append(s, c.Name+"="+c.Value) + } + if got := strings.Join(s, " "); got != query.want { + t.Errorf("Test %q #%d\ngot %q\nwant %q", test.description, i, got, query.want) + } + } +} + +// basicsTests contains fundamental tests. Each jarTest has to be performed on +// a fresh, empty Jar. +var basicsTests = [...]jarTest{ + { + "Retrieval of a plain host cookie.", + "http://www.host.test/", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", "A=a"}, + {"http://www.host.test/", "A=a"}, + {"http://www.host.test/some/path", "A=a"}, + {"https://www.host.test", "A=a"}, + {"https://www.host.test/", "A=a"}, + {"https://www.host.test/some/path", "A=a"}, + {"ftp://www.host.test", ""}, + {"ftp://www.host.test/", ""}, + {"ftp://www.host.test/some/path", ""}, + {"http://www.other.org", ""}, + {"http://sibling.host.test", ""}, + {"http://deep.www.host.test", ""}, + }, + }, + { + "Secure cookies are not returned to http.", + "http://www.host.test/", + []string{"A=a; secure"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some/path", ""}, + {"https://www.host.test", "A=a"}, + {"https://www.host.test/", "A=a"}, + {"https://www.host.test/some/path", "A=a"}, + }, + }, + { + "Explicit path.", + "http://www.host.test/", + []string{"A=a; path=/some/path"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #1: path is a directory.", + "http://www.host.test/some/path/", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #2: path is not a directory.", + "http://www.host.test/some/path/index.html", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #3: no path in URL at all.", + "http://www.host.test", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", "A=a"}, + {"http://www.host.test/", "A=a"}, + {"http://www.host.test/some/path", "A=a"}, + }, + }, + { + "Cookies are sorted by path length.", + "http://www.host.test/", + []string{ + "A=a; path=/foo/bar", + "B=b; path=/foo/bar/baz/qux", + "C=c; path=/foo/bar/baz", + "D=d; path=/foo"}, + "A=a B=b C=c D=d", + []query{ + {"http://www.host.test/foo/bar/baz/qux", "B=b C=c A=a D=d"}, + {"http://www.host.test/foo/bar/baz/", "C=c A=a D=d"}, + {"http://www.host.test/foo/bar", "A=a D=d"}, + }, + }, + { + "Creation time determines sorting on same length paths.", + "http://www.host.test/", + []string{ + "A=a; path=/foo/bar", + "X=x; path=/foo/bar", + "Y=y; path=/foo/bar/baz/qux", + "B=b; path=/foo/bar/baz/qux", + "C=c; path=/foo/bar/baz", + "W=w; path=/foo/bar/baz", + "Z=z; path=/foo", + "D=d; path=/foo"}, + "A=a B=b C=c D=d W=w X=x Y=y Z=z", + []query{ + {"http://www.host.test/foo/bar/baz/qux", "Y=y B=b C=c W=w A=a X=x Z=z D=d"}, + {"http://www.host.test/foo/bar/baz/", "C=c W=w A=a X=x Z=z D=d"}, + {"http://www.host.test/foo/bar", "A=a X=x Z=z D=d"}, + }, + }, + { + "Sorting of same-name cookies.", + "http://www.host.test/", + []string{ + "A=1; path=/", + "A=2; path=/path", + "A=3; path=/quux", + "A=4; path=/path/foo", + "A=5; domain=.host.test; path=/path", + "A=6; domain=.host.test; path=/quux", + "A=7; domain=.host.test; path=/path/foo", + }, + "A=1 A=2 A=3 A=4 A=5 A=6 A=7", + []query{ + {"http://www.host.test/path", "A=2 A=5 A=1"}, + {"http://www.host.test/path/foo", "A=4 A=7 A=2 A=5 A=1"}, + }, + }, + { + "Disallow domain cookie on public suffix.", + "http://www.bbc.co.uk", + []string{ + "a=1", + "b=2; domain=co.uk", + }, + "a=1", + []query{{"http://www.bbc.co.uk", "a=1"}}, + }, + { + "Host cookie on IP.", + "http://192.168.0.10", + []string{"a=1"}, + "a=1", + []query{{"http://192.168.0.10", "a=1"}}, + }, + { + "Port is ignored #1.", + "http://www.host.test/", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://www.host.test:8080/", "a=1"}, + }, + }, + { + "Port is ignored #2.", + "http://www.host.test:8080/", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://www.host.test:8080/", "a=1"}, + {"http://www.host.test:1234/", "a=1"}, + }, + }, +} + +func TestBasics(t *testing.T) { + for _, test := range basicsTests { + jar := newTestJar() + test.run(t, jar) + } +} + +// updateAndDeleteTests contains jarTests which must be performed on the same +// Jar. +var updateAndDeleteTests = [...]jarTest{ + { + "Set initial cookies.", + "http://www.host.test", + []string{ + "a=1", + "b=2; secure", + "c=3; httponly", + "d=4; secure; httponly"}, + "a=1 b=2 c=3 d=4", + []query{ + {"http://www.host.test", "a=1 c=3"}, + {"https://www.host.test", "a=1 b=2 c=3 d=4"}, + }, + }, + { + "Update value via http.", + "http://www.host.test", + []string{ + "a=w", + "b=x; secure", + "c=y; httponly", + "d=z; secure; httponly"}, + "a=w b=x c=y d=z", + []query{ + {"http://www.host.test", "a=w c=y"}, + {"https://www.host.test", "a=w b=x c=y d=z"}, + }, + }, + { + "Clear Secure flag from a http.", + "http://www.host.test/", + []string{ + "b=xx", + "d=zz; httponly"}, + "a=w b=xx c=y d=zz", + []query{{"http://www.host.test", "a=w b=xx c=y d=zz"}}, + }, + { + "Delete all.", + "http://www.host.test/", + []string{ + "a=1; max-Age=-1", // delete via MaxAge + "b=2; " + expiresIn(-10), // delete via Expires + "c=2; max-age=-1; " + expiresIn(-10), // delete via both + "d=4; max-age=-1; " + expiresIn(10)}, // MaxAge takes precedence + "", + []query{{"http://www.host.test", ""}}, + }, + { + "Refill #1.", + "http://www.host.test", + []string{ + "A=1", + "A=2; path=/foo", + "A=3; domain=.host.test", + "A=4; path=/foo; domain=.host.test"}, + "A=1 A=2 A=3 A=4", + []query{{"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}}, + }, + { + "Refill #2.", + "http://www.google.com", + []string{ + "A=6", + "A=7; path=/foo", + "A=8; domain=.google.com", + "A=9; path=/foo; domain=.google.com"}, + "A=1 A=2 A=3 A=4 A=6 A=7 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}, + {"http://www.google.com/foo", "A=7 A=9 A=6 A=8"}, + }, + }, + { + "Delete A7.", + "http://www.google.com", + []string{"A=; path=/foo; max-age=-1"}, + "A=1 A=2 A=3 A=4 A=6 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=6 A=8"}, + }, + }, + { + "Delete A4.", + "http://www.host.test", + []string{"A=; path=/foo; domain=host.test; max-age=-1"}, + "A=1 A=2 A=3 A=6 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=6 A=8"}, + }, + }, + { + "Delete A6.", + "http://www.google.com", + []string{"A=; max-age=-1"}, + "A=1 A=2 A=3 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "Delete A3.", + "http://www.host.test", + []string{"A=; domain=host.test; max-age=-1"}, + "A=1 A=2 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "No cross-domain delete.", + "http://www.host.test", + []string{ + "A=; domain=google.com; max-age=-1", + "A=; path=/foo; domain=google.com; max-age=-1"}, + "A=1 A=2 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "Delete A8 and A9.", + "http://www.google.com", + []string{ + "A=; domain=google.com; max-age=-1", + "A=; path=/foo; domain=google.com; max-age=-1"}, + "A=1 A=2", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", ""}, + }, + }, +} + +func TestUpdateAndDelete(t *testing.T) { + jar := newTestJar() + for _, test := range updateAndDeleteTests { + test.run(t, jar) + } +} + +func TestExpiration(t *testing.T) { + jar := newTestJar() + jarTest{ + "Expiration.", + "http://www.host.test", + []string{ + "a=1", + "b=2; max-age=3", + "c=3; " + expiresIn(3), + "d=4; max-age=5", + "e=5; " + expiresIn(5), + "f=6; max-age=100", + }, + "a=1 b=2 c=3 d=4 e=5 f=6", // executed at t0 + 1001 ms + []query{ + {"http://www.host.test", "a=1 b=2 c=3 d=4 e=5 f=6"}, // t0 + 2002 ms + {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 3003 ms + {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 4004 ms + {"http://www.host.test", "a=1 f=6"}, // t0 + 5005 ms + {"http://www.host.test", "a=1 f=6"}, // t0 + 6006 ms + }, + }.run(t, jar) +} + +// +// Tests derived from Chromium's cookie_store_unittest.h. +// + +// See http://src.chromium.org/viewvc/chrome/trunk/src/net/cookies/cookie_store_unittest.h?revision=159685&content-type=text/plain +// Some of the original tests are in a bad condition (e.g. +// DomainWithTrailingDotTest) or are not RFC 6265 conforming (e.g. +// TestNonDottedAndTLD #1 and #6) and have not been ported. + +// chromiumBasicsTests contains fundamental tests. Each jarTest has to be +// performed on a fresh, empty Jar. +var chromiumBasicsTests = [...]jarTest{ + { + "DomainWithTrailingDotTest.", + "http://www.google.com/", + []string{ + "a=1; domain=.www.google.com.", + "b=2; domain=.www.google.com.."}, + "", + []query{ + {"http://www.google.com", ""}, + }, + }, + { + "ValidSubdomainTest #1.", + "http://a.b.c.d.com", + []string{ + "a=1; domain=.a.b.c.d.com", + "b=2; domain=.b.c.d.com", + "c=3; domain=.c.d.com", + "d=4; domain=.d.com"}, + "a=1 b=2 c=3 d=4", + []query{ + {"http://a.b.c.d.com", "a=1 b=2 c=3 d=4"}, + {"http://b.c.d.com", "b=2 c=3 d=4"}, + {"http://c.d.com", "c=3 d=4"}, + {"http://d.com", "d=4"}, + }, + }, + { + "ValidSubdomainTest #2.", + "http://a.b.c.d.com", + []string{ + "a=1; domain=.a.b.c.d.com", + "b=2; domain=.b.c.d.com", + "c=3; domain=.c.d.com", + "d=4; domain=.d.com", + "X=bcd; domain=.b.c.d.com", + "X=cd; domain=.c.d.com"}, + "X=bcd X=cd a=1 b=2 c=3 d=4", + []query{ + {"http://b.c.d.com", "b=2 c=3 d=4 X=bcd X=cd"}, + {"http://c.d.com", "c=3 d=4 X=cd"}, + }, + }, + { + "InvalidDomainTest #1.", + "http://foo.bar.com", + []string{ + "a=1; domain=.yo.foo.bar.com", + "b=2; domain=.foo.com", + "c=3; domain=.bar.foo.com", + "d=4; domain=.foo.bar.com.net", + "e=5; domain=ar.com", + "f=6; domain=.", + "g=7; domain=/", + "h=8; domain=http://foo.bar.com", + "i=9; domain=..foo.bar.com", + "j=10; domain=..bar.com", + "k=11; domain=.foo.bar.com?blah", + "l=12; domain=.foo.bar.com/blah", + "m=12; domain=.foo.bar.com:80", + "n=14; domain=.foo.bar.com:", + "o=15; domain=.foo.bar.com#sup", + }, + "", // Jar is empty. + []query{{"http://foo.bar.com", ""}}, + }, + { + "InvalidDomainTest #2.", + "http://foo.com.com", + []string{"a=1; domain=.foo.com.com.com"}, + "", + []query{{"http://foo.bar.com", ""}}, + }, + { + "DomainWithoutLeadingDotTest #1.", + "http://manage.hosted.filefront.com", + []string{"a=1; domain=filefront.com"}, + "a=1", + []query{{"http://www.filefront.com", "a=1"}}, + }, + { + "DomainWithoutLeadingDotTest #2.", + "http://www.google.com", + []string{"a=1; domain=www.google.com"}, + "a=1", + []query{ + {"http://www.google.com", "a=1"}, + {"http://sub.www.google.com", "a=1"}, + {"http://something-else.com", ""}, + }, + }, + { + "CaseInsensitiveDomainTest.", + "http://www.google.com", + []string{ + "a=1; domain=.GOOGLE.COM", + "b=2; domain=.www.gOOgLE.coM"}, + "a=1 b=2", + []query{{"http://www.google.com", "a=1 b=2"}}, + }, + { + "TestIpAddress #1.", + "http://1.2.3.4/foo", + []string{"a=1; path=/"}, + "a=1", + []query{{"http://1.2.3.4/foo", "a=1"}}, + }, + { + "TestIpAddress #2.", + "http://1.2.3.4/foo", + []string{ + "a=1; domain=.1.2.3.4", + "b=2; domain=.3.4"}, + "", + []query{{"http://1.2.3.4/foo", ""}}, + }, + { + "TestIpAddress #3.", + "http://1.2.3.4/foo", + []string{"a=1; domain=1.2.3.4"}, + "", + []query{{"http://1.2.3.4/foo", ""}}, + }, + { + "TestNonDottedAndTLD #2.", + "http://com./index.html", + []string{"a=1"}, + "a=1", + []query{ + {"http://com./index.html", "a=1"}, + {"http://no-cookies.com./index.html", ""}, + }, + }, + { + "TestNonDottedAndTLD #3.", + "http://a.b", + []string{ + "a=1; domain=.b", + "b=2; domain=b"}, + "", + []query{{"http://bar.foo", ""}}, + }, + { + "TestNonDottedAndTLD #4.", + "http://google.com", + []string{ + "a=1; domain=.com", + "b=2; domain=com"}, + "", + []query{{"http://google.com", ""}}, + }, + { + "TestNonDottedAndTLD #5.", + "http://google.co.uk", + []string{ + "a=1; domain=.co.uk", + "b=2; domain=.uk"}, + "", + []query{ + {"http://google.co.uk", ""}, + {"http://else.co.com", ""}, + {"http://else.uk", ""}, + }, + }, + { + "TestHostEndsWithDot.", + "http://www.google.com", + []string{ + "a=1", + "b=2; domain=.www.google.com."}, + "a=1", + []query{{"http://www.google.com", "a=1"}}, + }, + { + "PathTest", + "http://www.google.izzle", + []string{"a=1; path=/wee"}, + "a=1", + []query{ + {"http://www.google.izzle/wee", "a=1"}, + {"http://www.google.izzle/wee/", "a=1"}, + {"http://www.google.izzle/wee/war", "a=1"}, + {"http://www.google.izzle/wee/war/more/more", "a=1"}, + {"http://www.google.izzle/weehee", ""}, + {"http://www.google.izzle/", ""}, + }, + }, +} + +func TestChromiumBasics(t *testing.T) { + for _, test := range chromiumBasicsTests { + jar := newTestJar() + test.run(t, jar) + } +} + +// chromiumDomainTests contains jarTests which must be executed all on the +// same Jar. +var chromiumDomainTests = [...]jarTest{ + { + "Fill #1.", + "http://www.google.izzle", + []string{"A=B"}, + "A=B", + []query{{"http://www.google.izzle", "A=B"}}, + }, + { + "Fill #2.", + "http://www.google.izzle", + []string{"C=D; domain=.google.izzle"}, + "A=B C=D", + []query{{"http://www.google.izzle", "A=B C=D"}}, + }, + { + "Verify A is a host cookie and not accessible from subdomain.", + "http://unused.nil", + []string{}, + "A=B C=D", + []query{{"http://foo.www.google.izzle", "C=D"}}, + }, + { + "Verify domain cookies are found on proper domain.", + "http://www.google.izzle", + []string{"E=F; domain=.www.google.izzle"}, + "A=B C=D E=F", + []query{{"http://www.google.izzle", "A=B C=D E=F"}}, + }, + { + "Leading dots in domain attributes are optional.", + "http://www.google.izzle", + []string{"G=H; domain=www.google.izzle"}, + "A=B C=D E=F G=H", + []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}}, + }, + { + "Verify domain enforcement works #1.", + "http://www.google.izzle", + []string{"K=L; domain=.bar.www.google.izzle"}, + "A=B C=D E=F G=H", + []query{{"http://bar.www.google.izzle", "C=D E=F G=H"}}, + }, + { + "Verify domain enforcement works #2.", + "http://unused.nil", + []string{}, + "A=B C=D E=F G=H", + []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}}, + }, +} + +func TestChromiumDomain(t *testing.T) { + jar := newTestJar() + for _, test := range chromiumDomainTests { + test.run(t, jar) + } + +} + +// chromiumDeletionTests must be performed all on the same Jar. +var chromiumDeletionTests = [...]jarTest{ + { + "Create session cookie a1.", + "http://www.google.com", + []string{"a=1"}, + "a=1", + []query{{"http://www.google.com", "a=1"}}, + }, + { + "Delete sc a1 via MaxAge.", + "http://www.google.com", + []string{"a=1; max-age=-1"}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create session cookie b2.", + "http://www.google.com", + []string{"b=2"}, + "b=2", + []query{{"http://www.google.com", "b=2"}}, + }, + { + "Delete sc b2 via Expires.", + "http://www.google.com", + []string{"b=2; " + expiresIn(-10)}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create persistent cookie c3.", + "http://www.google.com", + []string{"c=3; max-age=3600"}, + "c=3", + []query{{"http://www.google.com", "c=3"}}, + }, + { + "Delete pc c3 via MaxAge.", + "http://www.google.com", + []string{"c=3; max-age=-1"}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create persistent cookie d4.", + "http://www.google.com", + []string{"d=4; max-age=3600"}, + "d=4", + []query{{"http://www.google.com", "d=4"}}, + }, + { + "Delete pc d4 via Expires.", + "http://www.google.com", + []string{"d=4; " + expiresIn(-10)}, + "", + []query{{"http://www.google.com", ""}}, + }, +} + +func TestChromiumDeletion(t *testing.T) { + jar := newTestJar() + for _, test := range chromiumDeletionTests { + test.run(t, jar) + } +} + +// domainHandlingTests tests and documents the rules for domain handling. +// Each test must be performed on an empty new Jar. +var domainHandlingTests = [...]jarTest{ + { + "Host cookie", + "http://www.host.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", ""}, + {"http://bar.host.test", ""}, + {"http://foo.www.host.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie #1", + "http://www.host.test", + []string{"a=1; domain=host.test"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", "a=1"}, + {"http://bar.host.test", "a=1"}, + {"http://foo.www.host.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie #2", + "http://www.host.test", + []string{"a=1; domain=.host.test"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", "a=1"}, + {"http://bar.host.test", "a=1"}, + {"http://foo.www.host.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on IDNA domain #1", + "http://www.bücher.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", ""}, + {"http://xn--bcher-kva.test", ""}, + {"http://bar.bücher.test", ""}, + {"http://bar.xn--bcher-kva.test", ""}, + {"http://foo.www.bücher.test", ""}, + {"http://foo.www.xn--bcher-kva.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on IDNA domain #2", + "http://www.xn--bcher-kva.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", ""}, + {"http://xn--bcher-kva.test", ""}, + {"http://bar.bücher.test", ""}, + {"http://bar.xn--bcher-kva.test", ""}, + {"http://foo.www.bücher.test", ""}, + {"http://foo.www.xn--bcher-kva.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie on IDNA domain #1", + "http://www.bücher.test", + []string{"a=1; domain=xn--bcher-kva.test"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", "a=1"}, + {"http://xn--bcher-kva.test", "a=1"}, + {"http://bar.bücher.test", "a=1"}, + {"http://bar.xn--bcher-kva.test", "a=1"}, + {"http://foo.www.bücher.test", "a=1"}, + {"http://foo.www.xn--bcher-kva.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie on IDNA domain #2", + "http://www.xn--bcher-kva.test", + []string{"a=1; domain=xn--bcher-kva.test"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", "a=1"}, + {"http://xn--bcher-kva.test", "a=1"}, + {"http://bar.bücher.test", "a=1"}, + {"http://bar.xn--bcher-kva.test", "a=1"}, + {"http://foo.www.bücher.test", "a=1"}, + {"http://foo.www.xn--bcher-kva.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on TLD.", + "http://com", + []string{"a=1"}, + "a=1", + []query{ + {"http://com", "a=1"}, + {"http://any.com", ""}, + {"http://any.test", ""}, + }, + }, + { + "Domain cookie on TLD becomes a host cookie.", + "http://com", + []string{"a=1; domain=com"}, + "a=1", + []query{ + {"http://com", "a=1"}, + {"http://any.com", ""}, + {"http://any.test", ""}, + }, + }, + { + "Host cookie on public suffix.", + "http://co.uk", + []string{"a=1"}, + "a=1", + []query{ + {"http://co.uk", "a=1"}, + {"http://uk", ""}, + {"http://some.co.uk", ""}, + {"http://foo.some.co.uk", ""}, + {"http://any.uk", ""}, + }, + }, + { + "Domain cookie on public suffix is ignored.", + "http://some.co.uk", + []string{"a=1; domain=co.uk"}, + "", + []query{ + {"http://co.uk", ""}, + {"http://uk", ""}, + {"http://some.co.uk", ""}, + {"http://foo.some.co.uk", ""}, + {"http://any.uk", ""}, + }, + }, +} + +func TestDomainHandling(t *testing.T) { + for _, test := range domainHandlingTests { + jar := newTestJar() + test.run(t, jar) + } +} + +func TestIssue19384(t *testing.T) { + cookies := []*http.Cookie{{Name: "name", Value: "value"}} + for _, host := range []string{"", ".", "..", "..."} { + jar, _ := New(nil) + u := &url.URL{Scheme: "http", Host: host, Path: "/"} + if got := jar.Cookies(u); len(got) != 0 { + t.Errorf("host %q, got %v", host, got) + } + jar.SetCookies(u, cookies) + if got := jar.Cookies(u); len(got) != 1 || got[0].Value != "value" { + t.Errorf("host %q, got %v", host, got) + } + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cookiejar/punycode.go b/vendor/github.com/lesismal/llib/std/net/http/cookiejar/punycode.go new file mode 100644 index 0000000..a9cc666 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cookiejar/punycode.go @@ -0,0 +1,159 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +// This file implements the Punycode algorithm from RFC 3492. + +import ( + "fmt" + "strings" + "unicode/utf8" +) + +// These parameter values are specified in section 5. +// +// All computation is done with int32s, so that overflow behavior is identical +// regardless of whether int is 32-bit or 64-bit. +const ( + base int32 = 36 + damp int32 = 700 + initialBias int32 = 72 + initialN int32 = 128 + skew int32 = 38 + tmax int32 = 26 + tmin int32 = 1 +) + +// encode encodes a string as specified in section 6.3 and prepends prefix to +// the result. +// +// The "while h < length(input)" line in the specification becomes "for +// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes. +func encode(prefix, s string) (string, error) { + output := make([]byte, len(prefix), len(prefix)+1+2*len(s)) + copy(output, prefix) + delta, n, bias := int32(0), initialN, initialBias + b, remaining := int32(0), int32(0) + for _, r := range s { + if r < utf8.RuneSelf { + b++ + output = append(output, byte(r)) + } else { + remaining++ + } + } + h := b + if b > 0 { + output = append(output, '-') + } + for remaining != 0 { + m := int32(0x7fffffff) + for _, r := range s { + if m > r && r >= n { + m = r + } + } + delta += (m - n) * (h + 1) + if delta < 0 { + return "", fmt.Errorf("cookiejar: invalid label %q", s) + } + n = m + for _, r := range s { + if r < n { + delta++ + if delta < 0 { + return "", fmt.Errorf("cookiejar: invalid label %q", s) + } + continue + } + if r > n { + continue + } + q := delta + for k := base; ; k += base { + t := k - bias + if t < tmin { + t = tmin + } else if t > tmax { + t = tmax + } + if q < t { + break + } + output = append(output, encodeDigit(t+(q-t)%(base-t))) + q = (q - t) / (base - t) + } + output = append(output, encodeDigit(q)) + bias = adapt(delta, h+1, h == b) + delta = 0 + h++ + remaining-- + } + delta++ + n++ + } + return string(output), nil +} + +func encodeDigit(digit int32) byte { + switch { + case 0 <= digit && digit < 26: + return byte(digit + 'a') + case 26 <= digit && digit < 36: + return byte(digit + ('0' - 26)) + } + panic("cookiejar: internal error in punycode encoding") +} + +// adapt is the bias adaptation function specified in section 6.1. +func adapt(delta, numPoints int32, firstTime bool) int32 { + if firstTime { + delta /= damp + } else { + delta /= 2 + } + delta += delta / numPoints + k := int32(0) + for delta > ((base-tmin)*tmax)/2 { + delta /= base - tmin + k += base + } + return k + (base-tmin+1)*delta/(delta+skew) +} + +// Strictly speaking, the remaining code below deals with IDNA (RFC 5890 and +// friends) and not Punycode (RFC 3492) per se. + +// acePrefix is the ASCII Compatible Encoding prefix. +const acePrefix = "xn--" + +// toASCII converts a domain or domain label to its ASCII form. For example, +// toASCII("bücher.example.com") is "xn--bcher-kva.example.com", and +// toASCII("golang") is "golang". +func toASCII(s string) (string, error) { + if ascii(s) { + return s, nil + } + labels := strings.Split(s, ".") + for i, label := range labels { + if !ascii(label) { + a, err := encode(acePrefix, label) + if err != nil { + return "", err + } + labels[i] = a + } + } + return strings.Join(labels, "."), nil +} + +func ascii(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + return false + } + } + return true +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/cookiejar/punycode_test.go b/vendor/github.com/lesismal/llib/std/net/http/cookiejar/punycode_test.go new file mode 100644 index 0000000..0301de1 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/cookiejar/punycode_test.go @@ -0,0 +1,161 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +import ( + "testing" +) + +var punycodeTestCases = [...]struct { + s, encoded string +}{ + {"", ""}, + {"-", "--"}, + {"-a", "-a-"}, + {"-a-", "-a--"}, + {"a", "a-"}, + {"a-", "a--"}, + {"a-b", "a-b-"}, + {"books", "books-"}, + {"bücher", "bcher-kva"}, + {"Hello世界", "Hello-ck1hg65u"}, + {"ü", "tda"}, + {"üý", "tdac"}, + + // The test cases below come from RFC 3492 section 7.1 with Errata 3026. + { + // (A) Arabic (Egyptian). + "\u0644\u064A\u0647\u0645\u0627\u0628\u062A\u0643\u0644" + + "\u0645\u0648\u0634\u0639\u0631\u0628\u064A\u061F", + "egbpdaj6bu4bxfgehfvwxn", + }, + { + // (B) Chinese (simplified). + "\u4ED6\u4EEC\u4E3A\u4EC0\u4E48\u4E0D\u8BF4\u4E2D\u6587", + "ihqwcrb4cv8a8dqg056pqjye", + }, + { + // (C) Chinese (traditional). + "\u4ED6\u5011\u7232\u4EC0\u9EBD\u4E0D\u8AAA\u4E2D\u6587", + "ihqwctvzc91f659drss3x8bo0yb", + }, + { + // (D) Czech. + "\u0050\u0072\u006F\u010D\u0070\u0072\u006F\u0073\u0074" + + "\u011B\u006E\u0065\u006D\u006C\u0075\u0076\u00ED\u010D" + + "\u0065\u0073\u006B\u0079", + "Proprostnemluvesky-uyb24dma41a", + }, + { + // (E) Hebrew. + "\u05DC\u05DE\u05D4\u05D4\u05DD\u05E4\u05E9\u05D5\u05D8" + + "\u05DC\u05D0\u05DE\u05D3\u05D1\u05E8\u05D9\u05DD\u05E2" + + "\u05D1\u05E8\u05D9\u05EA", + "4dbcagdahymbxekheh6e0a7fei0b", + }, + { + // (F) Hindi (Devanagari). + "\u092F\u0939\u0932\u094B\u0917\u0939\u093F\u0928\u094D" + + "\u0926\u0940\u0915\u094D\u092F\u094B\u0902\u0928\u0939" + + "\u0940\u0902\u092C\u094B\u0932\u0938\u0915\u0924\u0947" + + "\u0939\u0948\u0902", + "i1baa7eci9glrd9b2ae1bj0hfcgg6iyaf8o0a1dig0cd", + }, + { + // (G) Japanese (kanji and hiragana). + "\u306A\u305C\u307F\u3093\u306A\u65E5\u672C\u8A9E\u3092" + + "\u8A71\u3057\u3066\u304F\u308C\u306A\u3044\u306E\u304B", + "n8jok5ay5dzabd5bym9f0cm5685rrjetr6pdxa", + }, + { + // (H) Korean (Hangul syllables). + "\uC138\uACC4\uC758\uBAA8\uB4E0\uC0AC\uB78C\uB4E4\uC774" + + "\uD55C\uAD6D\uC5B4\uB97C\uC774\uD574\uD55C\uB2E4\uBA74" + + "\uC5BC\uB9C8\uB098\uC88B\uC744\uAE4C", + "989aomsvi5e83db1d2a355cv1e0vak1dwrv93d5xbh15a0dt30a5j" + + "psd879ccm6fea98c", + }, + { + // (I) Russian (Cyrillic). + "\u043F\u043E\u0447\u0435\u043C\u0443\u0436\u0435\u043E" + + "\u043D\u0438\u043D\u0435\u0433\u043E\u0432\u043E\u0440" + + "\u044F\u0442\u043F\u043E\u0440\u0443\u0441\u0441\u043A" + + "\u0438", + "b1abfaaepdrnnbgefbadotcwatmq2g4l", + }, + { + // (J) Spanish. + "\u0050\u006F\u0072\u0071\u0075\u00E9\u006E\u006F\u0070" + + "\u0075\u0065\u0064\u0065\u006E\u0073\u0069\u006D\u0070" + + "\u006C\u0065\u006D\u0065\u006E\u0074\u0065\u0068\u0061" + + "\u0062\u006C\u0061\u0072\u0065\u006E\u0045\u0073\u0070" + + "\u0061\u00F1\u006F\u006C", + "PorqunopuedensimplementehablarenEspaol-fmd56a", + }, + { + // (K) Vietnamese. + "\u0054\u1EA1\u0069\u0073\u0061\u006F\u0068\u1ECD\u006B" + + "\u0068\u00F4\u006E\u0067\u0074\u0068\u1EC3\u0063\u0068" + + "\u1EC9\u006E\u00F3\u0069\u0074\u0069\u1EBF\u006E\u0067" + + "\u0056\u0069\u1EC7\u0074", + "TisaohkhngthchnitingVit-kjcr8268qyxafd2f1b9g", + }, + { + // (L) 3B. + "\u0033\u5E74\u0042\u7D44\u91D1\u516B\u5148\u751F", + "3B-ww4c5e180e575a65lsy2b", + }, + { + // (M) -with-SUPER-MONKEYS. + "\u5B89\u5BA4\u5948\u7F8E\u6075\u002D\u0077\u0069\u0074" + + "\u0068\u002D\u0053\u0055\u0050\u0045\u0052\u002D\u004D" + + "\u004F\u004E\u004B\u0045\u0059\u0053", + "-with-SUPER-MONKEYS-pc58ag80a8qai00g7n9n", + }, + { + // (N) Hello-Another-Way-. + "\u0048\u0065\u006C\u006C\u006F\u002D\u0041\u006E\u006F" + + "\u0074\u0068\u0065\u0072\u002D\u0057\u0061\u0079\u002D" + + "\u305D\u308C\u305E\u308C\u306E\u5834\u6240", + "Hello-Another-Way--fc4qua05auwb3674vfr0b", + }, + { + // (O) 2. + "\u3072\u3068\u3064\u5C4B\u6839\u306E\u4E0B\u0032", + "2-u9tlzr9756bt3uc0v", + }, + { + // (P) MajiKoi5 + "\u004D\u0061\u006A\u0069\u3067\u004B\u006F\u0069\u3059" + + "\u308B\u0035\u79D2\u524D", + "MajiKoi5-783gue6qz075azm5e", + }, + { + // (Q) de + "\u30D1\u30D5\u30A3\u30FC\u0064\u0065\u30EB\u30F3\u30D0", + "de-jg4avhby1noc0d", + }, + { + // (R) + "\u305D\u306E\u30B9\u30D4\u30FC\u30C9\u3067", + "d9juau41awczczp", + }, + { + // (S) -> $1.00 <- + "\u002D\u003E\u0020\u0024\u0031\u002E\u0030\u0030\u0020" + + "\u003C\u002D", + "-> $1.00 <--", + }, +} + +func TestPunycode(t *testing.T) { + for _, tc := range punycodeTestCases { + if got, err := encode("", tc.s); err != nil { + t.Errorf(`encode("", %q): %v`, tc.s, err) + } else if got != tc.encoded { + t.Errorf(`encode("", %q): got %q, want %q`, tc.s, got, tc.encoded) + } + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/doc.go b/vendor/github.com/lesismal/llib/std/net/http/doc.go new file mode 100644 index 0000000..ae9b708 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/doc.go @@ -0,0 +1,107 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package http provides HTTP client and server implementations. + +Get, Head, Post, and PostForm make HTTP (or HTTPS) requests: + + resp, err := http.Get("http://example.com/") + ... + resp, err := http.Post("http://example.com/upload", "image/jpeg", &buf) + ... + resp, err := http.PostForm("http://example.com/form", + url.Values{"key": {"Value"}, "id": {"123"}}) + +The client must close the response body when finished with it: + + resp, err := http.Get("http://example.com/") + if err != nil { + // handle error + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + // ... + +For control over HTTP client headers, redirect policy, and other +settings, create a Client: + + client := &http.Client{ + CheckRedirect: redirectPolicyFunc, + } + + resp, err := client.Get("http://example.com") + // ... + + req, err := http.NewRequest("GET", "http://example.com", nil) + // ... + req.Header.Add("If-None-Match", `W/"wyzzy"`) + resp, err := client.Do(req) + // ... + +For control over proxies, TLS configuration, keep-alives, +compression, and other settings, create a Transport: + + tr := &http.Transport{ + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + DisableCompression: true, + } + client := &http.Client{Transport: tr} + resp, err := client.Get("https://example.com") + +Clients and Transports are safe for concurrent use by multiple +goroutines and for efficiency should only be created once and re-used. + +ListenAndServe starts an HTTP server with a given address and handler. +The handler is usually nil, which means to use DefaultServeMux. +Handle and HandleFunc add handlers to DefaultServeMux: + + http.Handle("/foo", fooHandler) + + http.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) + }) + + log.Fatal(http.ListenAndServe(":8080", nil)) + +More control over the server's behavior is available by creating a +custom Server: + + s := &http.Server{ + Addr: ":8080", + Handler: myHandler, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + MaxHeaderBytes: 1 << 20, + } + log.Fatal(s.ListenAndServe()) + +Starting with Go 1.6, the http package has transparent support for the +HTTP/2 protocol when using HTTPS. Programs that must disable HTTP/2 +can do so by setting Transport.TLSNextProto (for clients) or +Server.TLSNextProto (for servers) to a non-nil, empty +map. Alternatively, the following GODEBUG environment variables are +currently supported: + + GODEBUG=http2client=0 # disable HTTP/2 client support + GODEBUG=http2server=0 # disable HTTP/2 server support + GODEBUG=http2debug=1 # enable verbose HTTP/2 debug logs + GODEBUG=http2debug=2 # ... even more verbose, with frame dumps + +The GODEBUG variables are not covered by Go's API compatibility +promise. Please report any issues before disabling HTTP/2 +support: https://golang.org/s/http2bug + +The http package's Transport and Server both automatically enable +HTTP/2 support for simple configurations. To enable HTTP/2 for more +complex configurations, to use lower-level HTTP/2 features, or to use +a newer version of Go's http2 package, import "golang.org/x/net/http2" +directly and use its ConfigureTransport and/or ConfigureServer +functions. Manually configuring HTTP/2 via the golang.org/x/net/http2 +package takes precedence over the net/http package's built-in HTTP/2 +support. + +*/ +package http diff --git a/vendor/github.com/lesismal/llib/std/net/http/example_filesystem_test.go b/vendor/github.com/lesismal/llib/std/net/http/example_filesystem_test.go new file mode 100644 index 0000000..0e81458 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/example_filesystem_test.go @@ -0,0 +1,71 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "io/fs" + "log" + "net/http" + "strings" +) + +// containsDotFile reports whether name contains a path element starting with a period. +// The name is assumed to be a delimited by forward slashes, as guaranteed +// by the http.FileSystem interface. +func containsDotFile(name string) bool { + parts := strings.Split(name, "/") + for _, part := range parts { + if strings.HasPrefix(part, ".") { + return true + } + } + return false +} + +// dotFileHidingFile is the http.File use in dotFileHidingFileSystem. +// It is used to wrap the Readdir method of http.File so that we can +// remove files and directories that start with a period from its output. +type dotFileHidingFile struct { + http.File +} + +// Readdir is a wrapper around the Readdir method of the embedded File +// that filters out all files that start with a period in their name. +func (f dotFileHidingFile) Readdir(n int) (fis []fs.FileInfo, err error) { + files, err := f.File.Readdir(n) + for _, file := range files { // Filters out the dot files + if !strings.HasPrefix(file.Name(), ".") { + fis = append(fis, file) + } + } + return +} + +// dotFileHidingFileSystem is an http.FileSystem that hides +// hidden "dot files" from being served. +type dotFileHidingFileSystem struct { + http.FileSystem +} + +// Open is a wrapper around the Open method of the embedded FileSystem +// that serves a 403 permission error when name has a file or directory +// with whose name starts with a period in its path. +func (fsys dotFileHidingFileSystem) Open(name string) (http.File, error) { + if containsDotFile(name) { // If dot file, return 403 response + return nil, fs.ErrPermission + } + + file, err := fsys.FileSystem.Open(name) + if err != nil { + return nil, err + } + return dotFileHidingFile{file}, err +} + +func ExampleFileServer_dotFileHiding() { + fsys := dotFileHidingFileSystem{http.Dir(".")} + http.Handle("/", http.FileServer(fsys)) + log.Fatal(http.ListenAndServe(":8080", nil)) +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/example_handle_test.go b/vendor/github.com/lesismal/llib/std/net/http/example_handle_test.go new file mode 100644 index 0000000..10a62f6 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/example_handle_test.go @@ -0,0 +1,29 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "fmt" + "log" + "net/http" + "sync" +) + +type countHandler struct { + mu sync.Mutex // guards n + n int +} + +func (h *countHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.mu.Lock() + defer h.mu.Unlock() + h.n++ + fmt.Fprintf(w, "count is %d\n", h.n) +} + +func ExampleHandle() { + http.Handle("/count", new(countHandler)) + log.Fatal(http.ListenAndServe(":8080", nil)) +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/example_test.go b/vendor/github.com/lesismal/llib/std/net/http/example_test.go new file mode 100644 index 0000000..c677d52 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/example_test.go @@ -0,0 +1,192 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "os" + "os/signal" +) + +func ExampleHijacker() { + http.HandleFunc("/hijack", func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError) + return + } + conn, bufrw, err := hj.Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + // Don't forget to close the connection: + defer conn.Close() + bufrw.WriteString("Now we're speaking raw TCP. Say hi: ") + bufrw.Flush() + s, err := bufrw.ReadString('\n') + if err != nil { + log.Printf("error reading string: %v", err) + return + } + fmt.Fprintf(bufrw, "You said: %q\nBye.\n", s) + bufrw.Flush() + }) +} + +func ExampleGet() { + res, err := http.Get("http://www.google.com/robots.txt") + if err != nil { + log.Fatal(err) + } + robots, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + fmt.Printf("%s", robots) +} + +func ExampleFileServer() { + // Simple static webserver: + log.Fatal(http.ListenAndServe(":8080", http.FileServer(http.Dir("/usr/share/doc")))) +} + +func ExampleFileServer_stripPrefix() { + // To serve a directory on disk (/tmp) under an alternate URL + // path (/tmpfiles/), use StripPrefix to modify the request + // URL's path before the FileServer sees it: + http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp")))) +} + +func ExampleStripPrefix() { + // To serve a directory on disk (/tmp) under an alternate URL + // path (/tmpfiles/), use StripPrefix to modify the request + // URL's path before the FileServer sees it: + http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp")))) +} + +type apiHandler struct{} + +func (apiHandler) ServeHTTP(http.ResponseWriter, *http.Request) {} + +func ExampleServeMux_Handle() { + mux := http.NewServeMux() + mux.Handle("/api/", apiHandler{}) + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + // The "/" pattern matches everything, so we need to check + // that we're at the root here. + if req.URL.Path != "/" { + http.NotFound(w, req) + return + } + fmt.Fprintf(w, "Welcome to the home page!") + }) +} + +// HTTP Trailers are a set of key/value pairs like headers that come +// after the HTTP response, instead of before. +func ExampleResponseWriter_trailers() { + mux := http.NewServeMux() + mux.HandleFunc("/sendstrailers", func(w http.ResponseWriter, req *http.Request) { + // Before any call to WriteHeader or Write, declare + // the trailers you will set during the HTTP + // response. These three headers are actually sent in + // the trailer. + w.Header().Set("Trailer", "AtEnd1, AtEnd2") + w.Header().Add("Trailer", "AtEnd3") + + w.Header().Set("Content-Type", "text/plain; charset=utf-8") // normal header + w.WriteHeader(http.StatusOK) + + w.Header().Set("AtEnd1", "value 1") + io.WriteString(w, "This HTTP response has both headers before this text and trailers at the end.\n") + w.Header().Set("AtEnd2", "value 2") + w.Header().Set("AtEnd3", "value 3") // These will appear as trailers. + }) +} + +func ExampleServer_Shutdown() { + var srv http.Server + + idleConnsClosed := make(chan struct{}) + go func() { + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, os.Interrupt) + <-sigint + + // We received an interrupt signal, shut down. + if err := srv.Shutdown(context.Background()); err != nil { + // Error from closing listeners, or context timeout: + log.Printf("HTTP server Shutdown: %v", err) + } + close(idleConnsClosed) + }() + + if err := srv.ListenAndServe(); err != http.ErrServerClosed { + // Error starting or closing listener: + log.Fatalf("HTTP server ListenAndServe: %v", err) + } + + <-idleConnsClosed +} + +func ExampleListenAndServeTLS() { + http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + io.WriteString(w, "Hello, TLS!\n") + }) + + // One can use generate_cert.go in crypto/tls to generate cert.pem and key.pem. + log.Printf("About to listen on 8443. Go to https://127.0.0.1:8443/") + err := http.ListenAndServeTLS(":8443", "cert.pem", "key.pem", nil) + log.Fatal(err) +} + +func ExampleListenAndServe() { + // Hello world, the web server + + helloHandler := func(w http.ResponseWriter, req *http.Request) { + io.WriteString(w, "Hello, world!\n") + } + + http.HandleFunc("/hello", helloHandler) + log.Fatal(http.ListenAndServe(":8080", nil)) +} + +func ExampleHandleFunc() { + h1 := func(w http.ResponseWriter, _ *http.Request) { + io.WriteString(w, "Hello from a HandleFunc #1!\n") + } + h2 := func(w http.ResponseWriter, _ *http.Request) { + io.WriteString(w, "Hello from a HandleFunc #2!\n") + } + + http.HandleFunc("/", h1) + http.HandleFunc("/endpoint", h2) + + log.Fatal(http.ListenAndServe(":8080", nil)) +} + +func newPeopleHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "This is the people handler.") + }) +} + +func ExampleNotFoundHandler() { + mux := http.NewServeMux() + + // Create sample handler to returns 404 + mux.Handle("/resources", http.NotFoundHandler()) + + // Create sample handler that returns 200 + mux.Handle("/resources/people/", newPeopleHandler()) + + log.Fatal(http.ListenAndServe(":8080", mux)) +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/export_test.go b/vendor/github.com/lesismal/llib/std/net/http/export_test.go new file mode 100644 index 0000000..096a6d3 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/export_test.go @@ -0,0 +1,313 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Bridge package to expose http internals to tests in the http_test +// package. + +package http + +import ( + "context" + "fmt" + "net" + "net/url" + "sort" + "sync" + "testing" + "time" +) + +var ( + DefaultUserAgent = defaultUserAgent + NewLoggingConn = newLoggingConn + ExportAppendTime = appendTime + ExportRefererForURL = refererForURL + ExportServerNewConn = (*Server).newConn + ExportCloseWriteAndWait = (*conn).closeWriteAndWait + ExportErrRequestCanceled = errRequestCanceled + ExportErrRequestCanceledConn = errRequestCanceledConn + ExportErrServerClosedIdle = errServerClosedIdle + ExportServeFile = serveFile + ExportScanETag = scanETag + ExportHttp2ConfigureServer = http2ConfigureServer + Export_shouldCopyHeaderOnRedirect = shouldCopyHeaderOnRedirect + Export_writeStatusLine = writeStatusLine + Export_is408Message = is408Message +) + +const MaxWriteWaitBeforeConnReuse = maxWriteWaitBeforeConnReuse + +func init() { + // We only want to pay for this cost during testing. + // When not under test, these values are always nil + // and never assigned to. + testHookMu = new(sync.Mutex) + + testHookClientDoResult = func(res *Response, err error) { + if err != nil { + if _, ok := err.(*url.Error); !ok { + panic(fmt.Sprintf("unexpected Client.Do error of type %T; want *url.Error", err)) + } + } else { + if res == nil { + panic("Client.Do returned nil, nil") + } + if res.Body == nil { + panic("Client.Do returned nil res.Body and no error") + } + } + } +} + +func CondSkipHTTP2(t *testing.T) { + if omitBundledHTTP2 { + t.Skip("skipping HTTP/2 test when nethttpomithttp2 build tag in use") + } +} + +var ( + SetEnterRoundTripHook = hookSetter(&testHookEnterRoundTrip) + SetRoundTripRetried = hookSetter(&testHookRoundTripRetried) +) + +func SetReadLoopBeforeNextReadHook(f func()) { + testHookMu.Lock() + defer testHookMu.Unlock() + unnilTestHook(&f) + testHookReadLoopBeforeNextRead = f +} + +// SetPendingDialHooks sets the hooks that run before and after handling +// pending dials. +func SetPendingDialHooks(before, after func()) { + unnilTestHook(&before) + unnilTestHook(&after) + testHookPrePendingDial, testHookPostPendingDial = before, after +} + +func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn } + +func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-ch + cancel() + }() + return &timeoutHandler{ + handler: handler, + testContext: ctx, + // (no body) + } +} + +func ResetCachedEnvironment() { + resetProxyConfig() +} + +func (t *Transport) NumPendingRequestsForTesting() int { + t.reqMu.Lock() + defer t.reqMu.Unlock() + return len(t.reqCanceler) +} + +func (t *Transport) IdleConnKeysForTesting() (keys []string) { + keys = make([]string, 0) + t.idleMu.Lock() + defer t.idleMu.Unlock() + for key := range t.idleConn { + keys = append(keys, key.String()) + } + sort.Strings(keys) + return +} + +func (t *Transport) IdleConnKeyCountForTesting() int { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return len(t.idleConn) +} + +func (t *Transport) IdleConnStrsForTesting() []string { + var ret []string + t.idleMu.Lock() + defer t.idleMu.Unlock() + for _, conns := range t.idleConn { + for _, pc := range conns { + ret = append(ret, pc.conn.LocalAddr().String()+"/"+pc.conn.RemoteAddr().String()) + } + } + sort.Strings(ret) + return ret +} + +func (t *Transport) IdleConnStrsForTesting_h2() []string { + var ret []string + noDialPool := t.h2transport.(*http2Transport).ConnPool.(http2noDialClientConnPool) + pool := noDialPool.http2clientConnPool + + pool.mu.Lock() + defer pool.mu.Unlock() + + for k, cc := range pool.conns { + for range cc { + ret = append(ret, k) + } + } + + sort.Strings(ret) + return ret +} + +func (t *Transport) IdleConnCountForTesting(scheme, addr string) int { + t.idleMu.Lock() + defer t.idleMu.Unlock() + key := connectMethodKey{"", scheme, addr, false} + cacheKey := key.String() + for k, conns := range t.idleConn { + if k.String() == cacheKey { + return len(conns) + } + } + return 0 +} + +func (t *Transport) IdleConnWaitMapSizeForTesting() int { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return len(t.idleConnWait) +} + +func (t *Transport) IsIdleForTesting() bool { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return t.closeIdle +} + +func (t *Transport) QueueForIdleConnForTesting() { + t.queueForIdleConn(nil) +} + +// PutIdleTestConn reports whether it was able to insert a fresh +// persistConn for scheme, addr into the idle connection pool. +func (t *Transport) PutIdleTestConn(scheme, addr string) bool { + c, _ := net.Pipe() + key := connectMethodKey{"", scheme, addr, false} + + if t.MaxConnsPerHost > 0 { + // Transport is tracking conns-per-host. + // Increment connection count to account + // for new persistConn created below. + t.connsPerHostMu.Lock() + if t.connsPerHost == nil { + t.connsPerHost = make(map[connectMethodKey]int) + } + t.connsPerHost[key]++ + t.connsPerHostMu.Unlock() + } + + return t.tryPutIdleConn(&persistConn{ + t: t, + conn: c, // dummy + closech: make(chan struct{}), // so it can be closed + cacheKey: key, + }) == nil +} + +// PutIdleTestConnH2 reports whether it was able to insert a fresh +// HTTP/2 persistConn for scheme, addr into the idle connection pool. +func (t *Transport) PutIdleTestConnH2(scheme, addr string, alt RoundTripper) bool { + key := connectMethodKey{"", scheme, addr, false} + + if t.MaxConnsPerHost > 0 { + // Transport is tracking conns-per-host. + // Increment connection count to account + // for new persistConn created below. + t.connsPerHostMu.Lock() + if t.connsPerHost == nil { + t.connsPerHost = make(map[connectMethodKey]int) + } + t.connsPerHost[key]++ + t.connsPerHostMu.Unlock() + } + + return t.tryPutIdleConn(&persistConn{ + t: t, + alt: alt, + cacheKey: key, + }) == nil +} + +// All test hooks must be non-nil so they can be called directly, +// but the tests use nil to mean hook disabled. +func unnilTestHook(f *func()) { + if *f == nil { + *f = nop + } +} + +func hookSetter(dst *func()) func(func()) { + return func(fn func()) { + unnilTestHook(&fn) + *dst = fn + } +} + +func ExportHttp2ConfigureTransport(t *Transport) error { + t2, err := http2configureTransports(t) + if err != nil { + return err + } + t.h2transport = t2 + return nil +} + +func (s *Server) ExportAllConnsIdle() bool { + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.activeConn { + st, unixSec := c.getState() + if unixSec == 0 || st != StateIdle { + return false + } + } + return true +} + +func (s *Server) ExportAllConnsByState() map[ConnState]int { + states := map[ConnState]int{} + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.activeConn { + st, _ := c.getState() + states[st] += 1 + } + return states +} + +func (r *Request) WithT(t *testing.T) *Request { + return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf)) +} + +func ExportSetH2GoawayTimeout(d time.Duration) (restore func()) { + old := http2goAwayTimeout + http2goAwayTimeout = d + return func() { http2goAwayTimeout = old } +} + +func (r *Request) ExportIsReplayable() bool { return r.isReplayable() } + +// ExportCloseTransportConnsAbruptly closes all idle connections from +// tr in an abrupt way, just reaching into the underlying Conns and +// closing them, without telling the Transport or its persistConns +// that it's doing so. This is to simulate the server closing connections +// on the Transport. +func ExportCloseTransportConnsAbruptly(tr *Transport) { + tr.idleMu.Lock() + for _, pcs := range tr.idleConn { + for _, pc := range pcs { + pc.conn.Close() + } + } + tr.idleMu.Unlock() +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/fcgi/child.go b/vendor/github.com/lesismal/llib/std/net/http/fcgi/child.go new file mode 100644 index 0000000..5bcf5ad --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/fcgi/child.go @@ -0,0 +1,405 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +// This file implements FastCGI from the perspective of a child process. + +import ( + "context" + "errors" + "fmt" + "github.com/lesismal/llib/std/net/http/cgi" + "io" + "net" + "net/http" + "os" + "strings" + "sync" + "time" +) + +// request holds the state for an in-progress request. As soon as it's complete, +// it's converted to an http.Request. +type request struct { + pw *io.PipeWriter + reqId uint16 + params map[string]string + buf [1024]byte + rawParams []byte + keepConn bool +} + +// envVarsContextKey uniquely identifies a mapping of CGI +// environment variables to their values in a request context +type envVarsContextKey struct{} + +func newRequest(reqId uint16, flags uint8) *request { + r := &request{ + reqId: reqId, + params: map[string]string{}, + keepConn: flags&flagKeepConn != 0, + } + r.rawParams = r.buf[:0] + return r +} + +// parseParams reads an encoded []byte into Params. +func (r *request) parseParams() { + text := r.rawParams + r.rawParams = nil + for len(text) > 0 { + keyLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + valLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + if int(keyLen)+int(valLen) > len(text) { + return + } + key := readString(text, keyLen) + text = text[keyLen:] + val := readString(text, valLen) + text = text[valLen:] + r.params[key] = val + } +} + +// response implements http.ResponseWriter. +type response struct { + req *request + header http.Header + code int + wroteHeader bool + wroteCGIHeader bool + w *bufWriter +} + +func newResponse(c *child, req *request) *response { + return &response{ + req: req, + header: http.Header{}, + w: newWriter(c.conn, typeStdout, req.reqId), + } +} + +func (r *response) Header() http.Header { + return r.header +} + +func (r *response) Write(p []byte) (n int, err error) { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + if !r.wroteCGIHeader { + r.writeCGIHeader(p) + } + return r.w.Write(p) +} + +func (r *response) WriteHeader(code int) { + if r.wroteHeader { + return + } + r.wroteHeader = true + r.code = code + if code == http.StatusNotModified { + // Must not have body. + r.header.Del("Content-Type") + r.header.Del("Content-Length") + r.header.Del("Transfer-Encoding") + } + if r.header.Get("Date") == "" { + r.header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) + } +} + +// writeCGIHeader finalizes the header sent to the client and writes it to the output. +// p is not written by writeHeader, but is the first chunk of the body +// that will be written. It is sniffed for a Content-Type if none is +// set explicitly. +func (r *response) writeCGIHeader(p []byte) { + if r.wroteCGIHeader { + return + } + r.wroteCGIHeader = true + fmt.Fprintf(r.w, "Status: %d %s\r\n", r.code, http.StatusText(r.code)) + if _, hasType := r.header["Content-Type"]; r.code != http.StatusNotModified && !hasType { + r.header.Set("Content-Type", http.DetectContentType(p)) + } + r.header.Write(r.w) + r.w.WriteString("\r\n") + r.w.Flush() +} + +func (r *response) Flush() { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + r.w.Flush() +} + +func (r *response) Close() error { + r.Flush() + return r.w.Close() +} + +type child struct { + conn *conn + handler http.Handler + + mu sync.Mutex // protects requests: + requests map[uint16]*request // keyed by request ID +} + +func newChild(rwc io.ReadWriteCloser, handler http.Handler) *child { + return &child{ + conn: newConn(rwc), + handler: handler, + requests: make(map[uint16]*request), + } +} + +func (c *child) serve() { + defer c.conn.Close() + defer c.cleanUp() + var rec record + for { + if err := rec.read(c.conn.rwc); err != nil { + return + } + if err := c.handleRecord(&rec); err != nil { + return + } + } +} + +var errCloseConn = errors.New("fcgi: connection should be closed") + +var emptyBody = io.NopCloser(strings.NewReader("")) + +// ErrRequestAborted is returned by Read when a handler attempts to read the +// body of a request that has been aborted by the web server. +var ErrRequestAborted = errors.New("fcgi: request aborted by web server") + +// ErrConnClosed is returned by Read when a handler attempts to read the body of +// a request after the connection to the web server has been closed. +var ErrConnClosed = errors.New("fcgi: connection to web server closed") + +func (c *child) handleRecord(rec *record) error { + c.mu.Lock() + req, ok := c.requests[rec.h.Id] + c.mu.Unlock() + if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { + // The spec says to ignore unknown request IDs. + return nil + } + + switch rec.h.Type { + case typeBeginRequest: + if req != nil { + // The server is trying to begin a request with the same ID + // as an in-progress request. This is an error. + return errors.New("fcgi: received ID that is already in-flight") + } + + var br beginRequest + if err := br.read(rec.content()); err != nil { + return err + } + if br.role != roleResponder { + c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole) + return nil + } + req = newRequest(rec.h.Id, br.flags) + c.mu.Lock() + c.requests[rec.h.Id] = req + c.mu.Unlock() + return nil + case typeParams: + // NOTE(eds): Technically a key-value pair can straddle the boundary + // between two packets. We buffer until we've received all parameters. + if len(rec.content()) > 0 { + req.rawParams = append(req.rawParams, rec.content()...) + return nil + } + req.parseParams() + return nil + case typeStdin: + content := rec.content() + if req.pw == nil { + var body io.ReadCloser + if len(content) > 0 { + // body could be an io.LimitReader, but it shouldn't matter + // as long as both sides are behaving. + body, req.pw = io.Pipe() + } else { + body = emptyBody + } + go c.serveRequest(req, body) + } + if len(content) > 0 { + // TODO(eds): This blocks until the handler reads from the pipe. + // If the handler takes a long time, it might be a problem. + req.pw.Write(content) + } else if req.pw != nil { + req.pw.Close() + } + return nil + case typeGetValues: + values := map[string]string{"FCGI_MPXS_CONNS": "1"} + c.conn.writePairs(typeGetValuesResult, 0, values) + return nil + case typeData: + // If the filter role is implemented, read the data stream here. + return nil + case typeAbortRequest: + c.mu.Lock() + delete(c.requests, rec.h.Id) + c.mu.Unlock() + c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) + if req.pw != nil { + req.pw.CloseWithError(ErrRequestAborted) + } + if !req.keepConn { + // connection will close upon return + return errCloseConn + } + return nil + default: + b := make([]byte, 8) + b[0] = byte(rec.h.Type) + c.conn.writeRecord(typeUnknownType, 0, b) + return nil + } +} + +// filterOutUsedEnvVars returns a new map of env vars without the +// variables in the given envVars map that are read for creating each http.Request +func filterOutUsedEnvVars(envVars map[string]string) map[string]string { + withoutUsedEnvVars := make(map[string]string) + for k, v := range envVars { + if addFastCGIEnvToContext(k) { + withoutUsedEnvVars[k] = v + } + } + return withoutUsedEnvVars +} + +func (c *child) serveRequest(req *request, body io.ReadCloser) { + r := newResponse(c, req) + httpReq, err := cgi.RequestFromMap(req.params) + if err != nil { + // there was an error reading the request + r.WriteHeader(http.StatusInternalServerError) + c.conn.writeRecord(typeStderr, req.reqId, []byte(err.Error())) + } else { + httpReq.Body = body + withoutUsedEnvVars := filterOutUsedEnvVars(req.params) + envVarCtx := context.WithValue(httpReq.Context(), envVarsContextKey{}, withoutUsedEnvVars) + httpReq = httpReq.WithContext(envVarCtx) + c.handler.ServeHTTP(r, httpReq) + } + // Make sure we serve something even if nothing was written to r + r.Write(nil) + r.Close() + c.mu.Lock() + delete(c.requests, req.reqId) + c.mu.Unlock() + c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete) + + // Consume the entire body, so the host isn't still writing to + // us when we close the socket below in the !keepConn case, + // otherwise we'd send a RST. (golang.org/issue/4183) + // TODO(bradfitz): also bound this copy in time. Or send + // some sort of abort request to the host, so the host + // can properly cut off the client sending all the data. + // For now just bound it a little and + io.CopyN(io.Discard, body, 100<<20) + body.Close() + + if !req.keepConn { + c.conn.Close() + } +} + +func (c *child) cleanUp() { + c.mu.Lock() + defer c.mu.Unlock() + for _, req := range c.requests { + if req.pw != nil { + // race with call to Close in c.serveRequest doesn't matter because + // Pipe(Reader|Writer).Close are idempotent + req.pw.CloseWithError(ErrConnClosed) + } + } +} + +// Serve accepts incoming FastCGI connections on the listener l, creating a new +// goroutine for each. The goroutine reads requests and then calls handler +// to reply to them. +// If l is nil, Serve accepts connections from os.Stdin. +// If handler is nil, http.DefaultServeMux is used. +func Serve(l net.Listener, handler http.Handler) error { + if l == nil { + var err error + l, err = net.FileListener(os.Stdin) + if err != nil { + return err + } + defer l.Close() + } + if handler == nil { + handler = http.DefaultServeMux + } + for { + rw, err := l.Accept() + if err != nil { + return err + } + c := newChild(rw, handler) + go c.serve() + } +} + +// ProcessEnv returns FastCGI environment variables associated with the request r +// for which no effort was made to be included in the request itself - the data +// is hidden in the request's context. As an example, if REMOTE_USER is set for a +// request, it will not be found anywhere in r, but it will be included in +// ProcessEnv's response (via r's context). +func ProcessEnv(r *http.Request) map[string]string { + env, _ := r.Context().Value(envVarsContextKey{}).(map[string]string) + return env +} + +// addFastCGIEnvToContext reports whether to include the FastCGI environment variable s +// in the http.Request.Context, accessible via ProcessEnv. +func addFastCGIEnvToContext(s string) bool { + // Exclude things supported by net/http natively: + switch s { + case "CONTENT_LENGTH", "CONTENT_TYPE", "HTTPS", + "PATH_INFO", "QUERY_STRING", "REMOTE_ADDR", + "REMOTE_HOST", "REMOTE_PORT", "REQUEST_METHOD", + "REQUEST_URI", "SCRIPT_NAME", "SERVER_PROTOCOL": + return false + } + if strings.HasPrefix(s, "HTTP_") { + return false + } + // Explicitly include FastCGI-specific things. + // This list is redundant with the default "return true" below. + // Consider this documentation of the sorts of things we expect + // to maybe see. + switch s { + case "REMOTE_USER": + return true + } + // Unknown, so include it to be safe. + return true +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/fcgi/fcgi.go b/vendor/github.com/lesismal/llib/std/net/http/fcgi/fcgi.go new file mode 100644 index 0000000..fb822f8 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/fcgi/fcgi.go @@ -0,0 +1,270 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fcgi implements the FastCGI protocol. +// +// See https://fast-cgi.github.io/ for an unofficial mirror of the +// original documentation. +// +// Currently only the responder role is supported. +package fcgi + +// This file defines the raw protocol and some utilities used by the child and +// the host. + +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "io" + "sync" +) + +// recType is a record type, as defined by +// https://web.archive.org/web/20150420080736/http://www.fastcgi.com/drupal/node/6?q=node/22#S8 +type recType uint8 + +const ( + typeBeginRequest recType = 1 + typeAbortRequest recType = 2 + typeEndRequest recType = 3 + typeParams recType = 4 + typeStdin recType = 5 + typeStdout recType = 6 + typeStderr recType = 7 + typeData recType = 8 + typeGetValues recType = 9 + typeGetValuesResult recType = 10 + typeUnknownType recType = 11 +) + +// keep the connection between web-server and responder open after request +const flagKeepConn = 1 + +const ( + maxWrite = 65535 // maximum record body + maxPad = 255 +) + +const ( + roleResponder = iota + 1 // only Responders are implemented. + roleAuthorizer + roleFilter +) + +const ( + statusRequestComplete = iota + statusCantMultiplex + statusOverloaded + statusUnknownRole +) + +type header struct { + Version uint8 + Type recType + Id uint16 + ContentLength uint16 + PaddingLength uint8 + Reserved uint8 +} + +type beginRequest struct { + role uint16 + flags uint8 + reserved [5]uint8 +} + +func (br *beginRequest) read(content []byte) error { + if len(content) != 8 { + return errors.New("fcgi: invalid begin request record") + } + br.role = binary.BigEndian.Uint16(content) + br.flags = content[2] + return nil +} + +// for padding so we don't have to allocate all the time +// not synchronized because we don't care what the contents are +var pad [maxPad]byte + +func (h *header) init(recType recType, reqId uint16, contentLength int) { + h.Version = 1 + h.Type = recType + h.Id = reqId + h.ContentLength = uint16(contentLength) + h.PaddingLength = uint8(-contentLength & 7) +} + +// conn sends records over rwc +type conn struct { + mutex sync.Mutex + rwc io.ReadWriteCloser + + // to avoid allocations + buf bytes.Buffer + h header +} + +func newConn(rwc io.ReadWriteCloser) *conn { + return &conn{rwc: rwc} +} + +func (c *conn) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.rwc.Close() +} + +type record struct { + h header + buf [maxWrite + maxPad]byte +} + +func (rec *record) read(r io.Reader) (err error) { + if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil { + return err + } + if rec.h.Version != 1 { + return errors.New("fcgi: invalid header version") + } + n := int(rec.h.ContentLength) + int(rec.h.PaddingLength) + if _, err = io.ReadFull(r, rec.buf[:n]); err != nil { + return err + } + return nil +} + +func (r *record) content() []byte { + return r.buf[:r.h.ContentLength] +} + +// writeRecord writes and sends a single record. +func (c *conn) writeRecord(recType recType, reqId uint16, b []byte) error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.buf.Reset() + c.h.init(recType, reqId, len(b)) + if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { + return err + } + if _, err := c.buf.Write(b); err != nil { + return err + } + if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { + return err + } + _, err := c.rwc.Write(c.buf.Bytes()) + return err +} + +func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) error { + b := make([]byte, 8) + binary.BigEndian.PutUint32(b, uint32(appStatus)) + b[4] = protocolStatus + return c.writeRecord(typeEndRequest, reqId, b) +} + +func (c *conn) writePairs(recType recType, reqId uint16, pairs map[string]string) error { + w := newWriter(c, recType, reqId) + b := make([]byte, 8) + for k, v := range pairs { + n := encodeSize(b, uint32(len(k))) + n += encodeSize(b[n:], uint32(len(v))) + if _, err := w.Write(b[:n]); err != nil { + return err + } + if _, err := w.WriteString(k); err != nil { + return err + } + if _, err := w.WriteString(v); err != nil { + return err + } + } + w.Close() + return nil +} + +func readSize(s []byte) (uint32, int) { + if len(s) == 0 { + return 0, 0 + } + size, n := uint32(s[0]), 1 + if size&(1<<7) != 0 { + if len(s) < 4 { + return 0, 0 + } + n = 4 + size = binary.BigEndian.Uint32(s) + size &^= 1 << 31 + } + return size, n +} + +func readString(s []byte, size uint32) string { + if size > uint32(len(s)) { + return "" + } + return string(s[:size]) +} + +func encodeSize(b []byte, size uint32) int { + if size > 127 { + size |= 1 << 31 + binary.BigEndian.PutUint32(b, size) + return 4 + } + b[0] = byte(size) + return 1 +} + +// bufWriter encapsulates bufio.Writer but also closes the underlying stream when +// Closed. +type bufWriter struct { + closer io.Closer + *bufio.Writer +} + +func (w *bufWriter) Close() error { + if err := w.Writer.Flush(); err != nil { + w.closer.Close() + return err + } + return w.closer.Close() +} + +func newWriter(c *conn, recType recType, reqId uint16) *bufWriter { + s := &streamWriter{c: c, recType: recType, reqId: reqId} + w := bufio.NewWriterSize(s, maxWrite) + return &bufWriter{s, w} +} + +// streamWriter abstracts out the separation of a stream into discrete records. +// It only writes maxWrite bytes at a time. +type streamWriter struct { + c *conn + recType recType + reqId uint16 +} + +func (w *streamWriter) Write(p []byte) (int, error) { + nn := 0 + for len(p) > 0 { + n := len(p) + if n > maxWrite { + n = maxWrite + } + if err := w.c.writeRecord(w.recType, w.reqId, p[:n]); err != nil { + return nn, err + } + nn += n + p = p[n:] + } + return nn, nil +} + +func (w *streamWriter) Close() error { + // send empty record to close the stream + return w.c.writeRecord(w.recType, w.reqId, nil) +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/fcgi/fcgi_test.go b/vendor/github.com/lesismal/llib/std/net/http/fcgi/fcgi_test.go new file mode 100644 index 0000000..b58111d --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/fcgi/fcgi_test.go @@ -0,0 +1,401 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +import ( + "bytes" + "errors" + "io" + "net/http" + "strings" + "testing" +) + +var sizeTests = []struct { + size uint32 + bytes []byte +}{ + {0, []byte{0x00}}, + {127, []byte{0x7F}}, + {128, []byte{0x80, 0x00, 0x00, 0x80}}, + {1000, []byte{0x80, 0x00, 0x03, 0xE8}}, + {33554431, []byte{0x81, 0xFF, 0xFF, 0xFF}}, +} + +func TestSize(t *testing.T) { + b := make([]byte, 4) + for i, test := range sizeTests { + n := encodeSize(b, test.size) + if !bytes.Equal(b[:n], test.bytes) { + t.Errorf("%d expected %x, encoded %x", i, test.bytes, b) + } + size, n := readSize(test.bytes) + if size != test.size { + t.Errorf("%d expected %d, read %d", i, test.size, size) + } + if len(test.bytes) != n { + t.Errorf("%d did not consume all the bytes", i) + } + } +} + +var streamTests = []struct { + desc string + recType recType + reqId uint16 + content []byte + raw []byte +}{ + {"single record", typeStdout, 1, nil, + []byte{1, byte(typeStdout), 0, 1, 0, 0, 0, 0}, + }, + // this data will have to be split into two records + {"two records", typeStdin, 300, make([]byte, 66000), + bytes.Join([][]byte{ + // header for the first record + {1, byte(typeStdin), 0x01, 0x2C, 0xFF, 0xFF, 1, 0}, + make([]byte, 65536), + // header for the second + {1, byte(typeStdin), 0x01, 0x2C, 0x01, 0xD1, 7, 0}, + make([]byte, 472), + // header for the empty record + {1, byte(typeStdin), 0x01, 0x2C, 0, 0, 0, 0}, + }, + nil), + }, +} + +type nilCloser struct { + io.ReadWriter +} + +func (c *nilCloser) Close() error { return nil } + +func TestStreams(t *testing.T) { + var rec record +outer: + for _, test := range streamTests { + buf := bytes.NewBuffer(test.raw) + var content []byte + for buf.Len() > 0 { + if err := rec.read(buf); err != nil { + t.Errorf("%s: error reading record: %v", test.desc, err) + continue outer + } + content = append(content, rec.content()...) + } + if rec.h.Type != test.recType { + t.Errorf("%s: got type %d expected %d", test.desc, rec.h.Type, test.recType) + continue + } + if rec.h.Id != test.reqId { + t.Errorf("%s: got request ID %d expected %d", test.desc, rec.h.Id, test.reqId) + continue + } + if !bytes.Equal(content, test.content) { + t.Errorf("%s: read wrong content", test.desc) + continue + } + buf.Reset() + c := newConn(&nilCloser{buf}) + w := newWriter(c, test.recType, test.reqId) + if _, err := w.Write(test.content); err != nil { + t.Errorf("%s: error writing record: %v", test.desc, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: error closing stream: %v", test.desc, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.raw) { + t.Errorf("%s: wrote wrong content", test.desc) + } + } +} + +type writeOnlyConn struct { + buf []byte +} + +func (c *writeOnlyConn) Write(p []byte) (int, error) { + c.buf = append(c.buf, p...) + return len(p), nil +} + +func (c *writeOnlyConn) Read(p []byte) (int, error) { + return 0, errors.New("conn is write-only") +} + +func (c *writeOnlyConn) Close() error { + return nil +} + +func TestGetValues(t *testing.T) { + var rec record + rec.h.Type = typeGetValues + + wc := new(writeOnlyConn) + c := newChild(wc, nil) + err := c.handleRecord(&rec) + if err != nil { + t.Fatalf("handleRecord: %v", err) + } + + const want = "\x01\n\x00\x00\x00\x12\x06\x00" + + "\x0f\x01FCGI_MPXS_CONNS1" + + "\x00\x00\x00\x00\x00\x00\x01\n\x00\x00\x00\x00\x00\x00" + if got := string(wc.buf); got != want { + t.Errorf(" got: %q\nwant: %q\n", got, want) + } +} + +func nameValuePair11(nameData, valueData string) []byte { + return bytes.Join( + [][]byte{ + {byte(len(nameData)), byte(len(valueData))}, + []byte(nameData), + []byte(valueData), + }, + nil, + ) +} + +func makeRecord( + recordType recType, + requestId uint16, + contentData []byte, +) []byte { + requestIdB1 := byte(requestId >> 8) + requestIdB0 := byte(requestId) + + contentLength := len(contentData) + contentLengthB1 := byte(contentLength >> 8) + contentLengthB0 := byte(contentLength) + return bytes.Join([][]byte{ + {1, byte(recordType), requestIdB1, requestIdB0, contentLengthB1, + contentLengthB0, 0, 0}, + contentData, + }, + nil) +} + +// a series of FastCGI records that start a request and begin sending the +// request body +var streamBeginTypeStdin = bytes.Join([][]byte{ + // set up request 1 + makeRecord(typeBeginRequest, 1, + []byte{0, byte(roleResponder), 0, 0, 0, 0, 0, 0}), + // add required parameters to request 1 + makeRecord(typeParams, 1, nameValuePair11("REQUEST_METHOD", "GET")), + makeRecord(typeParams, 1, nameValuePair11("SERVER_PROTOCOL", "HTTP/1.1")), + makeRecord(typeParams, 1, nil), + // begin sending body of request 1 + makeRecord(typeStdin, 1, []byte("0123456789abcdef")), +}, + nil) + +var cleanUpTests = []struct { + input []byte + err error +}{ + // confirm that child.handleRecord closes req.pw after aborting req + { + bytes.Join([][]byte{ + streamBeginTypeStdin, + makeRecord(typeAbortRequest, 1, nil), + }, + nil), + ErrRequestAborted, + }, + // confirm that child.serve closes all pipes after error reading record + { + bytes.Join([][]byte{ + streamBeginTypeStdin, + nil, + }, + nil), + ErrConnClosed, + }, +} + +type nopWriteCloser struct { + io.Reader +} + +func (nopWriteCloser) Write(buf []byte) (int, error) { + return len(buf), nil +} + +func (nopWriteCloser) Close() error { + return nil +} + +// Test that child.serve closes the bodies of aborted requests and closes the +// bodies of all requests before returning. Causes deadlock if either condition +// isn't met. See issue 6934. +func TestChildServeCleansUp(t *testing.T) { + for _, tt := range cleanUpTests { + input := make([]byte, len(tt.input)) + copy(input, tt.input) + rc := nopWriteCloser{bytes.NewReader(input)} + done := make(chan bool) + c := newChild(rc, http.HandlerFunc(func( + w http.ResponseWriter, + r *http.Request, + ) { + // block on reading body of request + _, err := io.Copy(io.Discard, r.Body) + if err != tt.err { + t.Errorf("Expected %#v, got %#v", tt.err, err) + } + // not reached if body of request isn't closed + done <- true + })) + go c.serve() + // wait for body of request to be closed or all goroutines to block + <-done + } +} + +type rwNopCloser struct { + io.Reader + io.Writer +} + +func (rwNopCloser) Close() error { + return nil +} + +// Verifies it doesn't crash. Issue 11824. +func TestMalformedParams(t *testing.T) { + input := []byte{ + // beginRequest, requestId=1, contentLength=8, role=1, keepConn=1 + 1, 1, 0, 1, 0, 8, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, + // params, requestId=1, contentLength=10, k1Len=50, v1Len=50 (malformed, wrong length) + 1, 4, 0, 1, 0, 10, 0, 0, 50, 50, 3, 4, 5, 6, 7, 8, 9, 10, + // end of params + 1, 4, 0, 1, 0, 0, 0, 0, + } + rw := rwNopCloser{bytes.NewReader(input), io.Discard} + c := newChild(rw, http.DefaultServeMux) + c.serve() +} + +// a series of FastCGI records that start and end a request +var streamFullRequestStdin = bytes.Join([][]byte{ + // set up request + makeRecord(typeBeginRequest, 1, + []byte{0, byte(roleResponder), 0, 0, 0, 0, 0, 0}), + // add required parameters + makeRecord(typeParams, 1, nameValuePair11("REQUEST_METHOD", "GET")), + makeRecord(typeParams, 1, nameValuePair11("SERVER_PROTOCOL", "HTTP/1.1")), + // set optional parameters + makeRecord(typeParams, 1, nameValuePair11("REMOTE_USER", "jane.doe")), + makeRecord(typeParams, 1, nameValuePair11("QUERY_STRING", "/foo/bar")), + makeRecord(typeParams, 1, nil), + // begin sending body of request + makeRecord(typeStdin, 1, []byte("0123456789abcdef")), + // end request + makeRecord(typeEndRequest, 1, nil), +}, + nil) + +var envVarTests = []struct { + input []byte + envVar string + expectedVal string + expectedFilteredOut bool +}{ + { + streamFullRequestStdin, + "REMOTE_USER", + "jane.doe", + false, + }, + { + streamFullRequestStdin, + "QUERY_STRING", + "", + true, + }, +} + +// Test that environment variables set for a request can be +// read by a handler. Ensures that variables not set will not +// be exposed to a handler. +func TestChildServeReadsEnvVars(t *testing.T) { + for _, tt := range envVarTests { + input := make([]byte, len(tt.input)) + copy(input, tt.input) + rc := nopWriteCloser{bytes.NewReader(input)} + done := make(chan bool) + c := newChild(rc, http.HandlerFunc(func( + w http.ResponseWriter, + r *http.Request, + ) { + env := ProcessEnv(r) + if _, ok := env[tt.envVar]; ok && tt.expectedFilteredOut { + t.Errorf("Expected environment variable %s to not be set, but set to %s", + tt.envVar, env[tt.envVar]) + } else if env[tt.envVar] != tt.expectedVal { + t.Errorf("Expected %s, got %s", tt.expectedVal, env[tt.envVar]) + } + done <- true + })) + go c.serve() + <-done + } +} + +func TestResponseWriterSniffsContentType(t *testing.T) { + var tests = []struct { + name string + body string + wantCT string + }{ + { + name: "no body", + wantCT: "text/plain; charset=utf-8", + }, + { + name: "html", + body: "test pageThis is a body", + wantCT: "text/html; charset=utf-8", + }, + { + name: "text", + body: strings.Repeat("gopher", 86), + wantCT: "text/plain; charset=utf-8", + }, + { + name: "jpg", + body: "\xFF\xD8\xFF" + strings.Repeat("B", 1024), + wantCT: "image/jpeg", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input := make([]byte, len(streamFullRequestStdin)) + copy(input, streamFullRequestStdin) + rc := nopWriteCloser{bytes.NewReader(input)} + done := make(chan bool) + var resp *response + c := newChild(rc, http.HandlerFunc(func( + w http.ResponseWriter, + r *http.Request, + ) { + io.WriteString(w, tt.body) + resp = w.(*response) + done <- true + })) + defer c.cleanUp() + go c.serve() + <-done + if got := resp.Header().Get("Content-Type"); got != tt.wantCT { + t.Errorf("got a Content-Type of %q; expected it to start with %q", got, tt.wantCT) + } + }) + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/filetransport.go b/vendor/github.com/lesismal/llib/std/net/http/filetransport.go new file mode 100644 index 0000000..32126d7 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/filetransport.go @@ -0,0 +1,123 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "fmt" + "io" +) + +// fileTransport implements RoundTripper for the 'file' protocol. +type fileTransport struct { + fh fileHandler +} + +// NewFileTransport returns a new RoundTripper, serving the provided +// FileSystem. The returned RoundTripper ignores the URL host in its +// incoming requests, as well as most other properties of the +// request. +// +// The typical use case for NewFileTransport is to register the "file" +// protocol with a Transport, as in: +// +// t := &http.Transport{} +// t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/"))) +// c := &http.Client{Transport: t} +// res, err := c.Get("file:///etc/passwd") +// ... +func NewFileTransport(fs FileSystem) RoundTripper { + return fileTransport{fileHandler{fs}} +} + +func (t fileTransport) RoundTrip(req *Request) (resp *Response, err error) { + // We start ServeHTTP in a goroutine, which may take a long + // time if the file is large. The newPopulateResponseWriter + // call returns a channel which either ServeHTTP or finish() + // sends our *Response on, once the *Response itself has been + // populated (even if the body itself is still being + // written to the res.Body, a pipe) + rw, resc := newPopulateResponseWriter() + go func() { + t.fh.ServeHTTP(rw, req) + rw.finish() + }() + return <-resc, nil +} + +func newPopulateResponseWriter() (*populateResponse, <-chan *Response) { + pr, pw := io.Pipe() + rw := &populateResponse{ + ch: make(chan *Response), + pw: pw, + res: &Response{ + Proto: "HTTP/1.0", + ProtoMajor: 1, + Header: make(Header), + Close: true, + Body: pr, + }, + } + return rw, rw.ch +} + +// populateResponse is a ResponseWriter that populates the *Response +// in res, and writes its body to a pipe connected to the response +// body. Once writes begin or finish() is called, the response is sent +// on ch. +type populateResponse struct { + res *Response + ch chan *Response + wroteHeader bool + hasContent bool + sentResponse bool + pw *io.PipeWriter +} + +func (pr *populateResponse) finish() { + if !pr.wroteHeader { + pr.WriteHeader(500) + } + if !pr.sentResponse { + pr.sendResponse() + } + pr.pw.Close() +} + +func (pr *populateResponse) sendResponse() { + if pr.sentResponse { + return + } + pr.sentResponse = true + + if pr.hasContent { + pr.res.ContentLength = -1 + } + pr.ch <- pr.res +} + +func (pr *populateResponse) Header() Header { + return pr.res.Header +} + +func (pr *populateResponse) WriteHeader(code int) { + if pr.wroteHeader { + return + } + pr.wroteHeader = true + + pr.res.StatusCode = code + pr.res.Status = fmt.Sprintf("%d %s", code, StatusText(code)) +} + +func (pr *populateResponse) Write(p []byte) (n int, err error) { + if !pr.wroteHeader { + pr.WriteHeader(StatusOK) + } + pr.hasContent = true + if !pr.sentResponse { + pr.sendResponse() + } + return pr.pw.Write(p) +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/filetransport_test.go b/vendor/github.com/lesismal/llib/std/net/http/filetransport_test.go new file mode 100644 index 0000000..b58888d --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/filetransport_test.go @@ -0,0 +1,66 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "io" + "os" + "path/filepath" + "testing" +) + +func checker(t *testing.T) func(string, error) { + return func(call string, err error) { + if err == nil { + return + } + t.Fatalf("%s: %v", call, err) + } +} + +func TestFileTransport(t *testing.T) { + check := checker(t) + + dname, err := os.MkdirTemp("", "") + check("TempDir", err) + fname := filepath.Join(dname, "foo.txt") + err = os.WriteFile(fname, []byte("Bar"), 0644) + check("WriteFile", err) + defer os.Remove(dname) + defer os.Remove(fname) + + tr := &Transport{} + tr.RegisterProtocol("file", NewFileTransport(Dir(dname))) + c := &Client{Transport: tr} + + fooURLs := []string{"file:///foo.txt", "file://../foo.txt"} + for _, urlstr := range fooURLs { + res, err := c.Get(urlstr) + check("Get "+urlstr, err) + if res.StatusCode != 200 { + t.Errorf("for %s, StatusCode = %d, want 200", urlstr, res.StatusCode) + } + if res.ContentLength != -1 { + t.Errorf("for %s, ContentLength = %d, want -1", urlstr, res.ContentLength) + } + if res.Body == nil { + t.Fatalf("for %s, nil Body", urlstr) + } + slurp, err := io.ReadAll(res.Body) + res.Body.Close() + check("ReadAll "+urlstr, err) + if string(slurp) != "Bar" { + t.Errorf("for %s, got content %q, want %q", urlstr, string(slurp), "Bar") + } + } + + const badURL = "file://../no-exist.txt" + res, err := c.Get(badURL) + check("Get "+badURL, err) + if res.StatusCode != 404 { + t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode) + } + res.Body.Close() +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/fs.go b/vendor/github.com/lesismal/llib/std/net/http/fs.go new file mode 100644 index 0000000..a28ae85 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/fs.go @@ -0,0 +1,970 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP file system request handler + +package http + +import ( + "errors" + "fmt" + "io" + "io/fs" + "mime" + "mime/multipart" + "net/textproto" + "net/url" + "os" + "path" + "path/filepath" + "sort" + "strconv" + "strings" + "time" +) + +// A Dir implements FileSystem using the native file system restricted to a +// specific directory tree. +// +// While the FileSystem.Open method takes '/'-separated paths, a Dir's string +// value is a filename on the native file system, not a URL, so it is separated +// by filepath.Separator, which isn't necessarily '/'. +// +// Note that Dir could expose sensitive files and directories. Dir will follow +// symlinks pointing out of the directory tree, which can be especially dangerous +// if serving from a directory in which users are able to create arbitrary symlinks. +// Dir will also allow access to files and directories starting with a period, +// which could expose sensitive directories like .git or sensitive files like +// .htpasswd. To exclude files with a leading period, remove the files/directories +// from the server or create a custom FileSystem implementation. +// +// An empty Dir is treated as ".". +type Dir string + +// mapDirOpenError maps the provided non-nil error from opening name +// to a possibly better non-nil error. In particular, it turns OS-specific errors +// about opening files in non-directories into fs.ErrNotExist. See Issue 18984. +func mapDirOpenError(originalErr error, name string) error { + if os.IsNotExist(originalErr) || os.IsPermission(originalErr) { + return originalErr + } + + parts := strings.Split(name, string(filepath.Separator)) + for i := range parts { + if parts[i] == "" { + continue + } + fi, err := os.Stat(strings.Join(parts[:i+1], string(filepath.Separator))) + if err != nil { + return originalErr + } + if !fi.IsDir() { + return fs.ErrNotExist + } + } + return originalErr +} + +// Open implements FileSystem using os.Open, opening files for reading rooted +// and relative to the directory d. +func (d Dir) Open(name string) (File, error) { + if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) { + return nil, errors.New("http: invalid character in file path") + } + dir := string(d) + if dir == "" { + dir = "." + } + fullName := filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name))) + f, err := os.Open(fullName) + if err != nil { + return nil, mapDirOpenError(err, fullName) + } + return f, nil +} + +// A FileSystem implements access to a collection of named files. +// The elements in a file path are separated by slash ('/', U+002F) +// characters, regardless of host operating system convention. +// See the FileServer function to convert a FileSystem to a Handler. +// +// This interface predates the fs.FS interface, which can be used instead: +// the FS adapter function converts an fs.FS to a FileSystem. +type FileSystem interface { + Open(name string) (File, error) +} + +// A File is returned by a FileSystem's Open method and can be +// served by the FileServer implementation. +// +// The methods should behave the same as those on an *os.File. +type File interface { + io.Closer + io.Reader + io.Seeker + Readdir(count int) ([]fs.FileInfo, error) + Stat() (fs.FileInfo, error) +} + +type anyDirs interface { + len() int + name(i int) string + isDir(i int) bool +} + +type fileInfoDirs []fs.FileInfo + +func (d fileInfoDirs) len() int { return len(d) } +func (d fileInfoDirs) isDir(i int) bool { return d[i].IsDir() } +func (d fileInfoDirs) name(i int) string { return d[i].Name() } + +type dirEntryDirs []fs.DirEntry + +func (d dirEntryDirs) len() int { return len(d) } +func (d dirEntryDirs) isDir(i int) bool { return d[i].IsDir() } +func (d dirEntryDirs) name(i int) string { return d[i].Name() } + +func dirList(w ResponseWriter, r *Request, f File) { + // Prefer to use ReadDir instead of Readdir, + // because the former doesn't require calling + // Stat on every entry of a directory on Unix. + var dirs anyDirs + var err error + if d, ok := f.(fs.ReadDirFile); ok { + var list dirEntryDirs + list, err = d.ReadDir(-1) + dirs = list + } else { + var list fileInfoDirs + list, err = f.Readdir(-1) + dirs = list + } + + if err != nil { + logf(r, "http: error reading directory: %v", err) + Error(w, "Error reading directory", StatusInternalServerError) + return + } + sort.Slice(dirs, func(i, j int) bool { return dirs.name(i) < dirs.name(j) }) + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + fmt.Fprintf(w, "
\n")
+	for i, n := 0, dirs.len(); i < n; i++ {
+		name := dirs.name(i)
+		if dirs.isDir(i) {
+			name += "/"
+		}
+		// name may contain '?' or '#', which must be escaped to remain
+		// part of the URL path, and not indicate the start of a query
+		// string or fragment.
+		url := url.URL{Path: name}
+		fmt.Fprintf(w, "%s\n", url.String(), htmlReplacer.Replace(name))
+	}
+	fmt.Fprintf(w, "
\n") +} + +// ServeContent replies to the request using the content in the +// provided ReadSeeker. The main benefit of ServeContent over io.Copy +// is that it handles Range requests properly, sets the MIME type, and +// handles If-Match, If-Unmodified-Since, If-None-Match, If-Modified-Since, +// and If-Range requests. +// +// If the response's Content-Type header is not set, ServeContent +// first tries to deduce the type from name's file extension and, +// if that fails, falls back to reading the first block of the content +// and passing it to DetectContentType. +// The name is otherwise unused; in particular it can be empty and is +// never sent in the response. +// +// If modtime is not the zero time or Unix epoch, ServeContent +// includes it in a Last-Modified header in the response. If the +// request includes an If-Modified-Since header, ServeContent uses +// modtime to decide whether the content needs to be sent at all. +// +// The content's Seek method must work: ServeContent uses +// a seek to the end of the content to determine its size. +// +// If the caller has set w's ETag header formatted per RFC 7232, section 2.3, +// ServeContent uses it to handle requests using If-Match, If-None-Match, or If-Range. +// +// Note that *os.File implements the io.ReadSeeker interface. +func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) { + sizeFunc := func() (int64, error) { + size, err := content.Seek(0, io.SeekEnd) + if err != nil { + return 0, errSeeker + } + _, err = content.Seek(0, io.SeekStart) + if err != nil { + return 0, errSeeker + } + return size, nil + } + serveContent(w, req, name, modtime, sizeFunc, content) +} + +// errSeeker is returned by ServeContent's sizeFunc when the content +// doesn't seek properly. The underlying Seeker's error text isn't +// included in the sizeFunc reply so it's not sent over HTTP to end +// users. +var errSeeker = errors.New("seeker can't seek") + +// errNoOverlap is returned by serveContent's parseRange if first-byte-pos of +// all of the byte-range-spec values is greater than the content size. +var errNoOverlap = errors.New("invalid range: failed to overlap") + +// if name is empty, filename is unknown. (used for mime type, before sniffing) +// if modtime.IsZero(), modtime is unknown. +// content must be seeked to the beginning of the file. +// The sizeFunc is called at most once. Its error, if any, is sent in the HTTP response. +func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, sizeFunc func() (int64, error), content io.ReadSeeker) { + setLastModified(w, modtime) + done, rangeReq := checkPreconditions(w, r, modtime) + if done { + return + } + + code := StatusOK + + // If Content-Type isn't set, use the file's extension to find it, but + // if the Content-Type is unset explicitly, do not sniff the type. + ctypes, haveType := w.Header()["Content-Type"] + var ctype string + if !haveType { + ctype = mime.TypeByExtension(filepath.Ext(name)) + if ctype == "" { + // read a chunk to decide between utf-8 text and binary + var buf [sniffLen]byte + n, _ := io.ReadFull(content, buf[:]) + ctype = DetectContentType(buf[:n]) + _, err := content.Seek(0, io.SeekStart) // rewind to output whole file + if err != nil { + Error(w, "seeker can't seek", StatusInternalServerError) + return + } + } + w.Header().Set("Content-Type", ctype) + } else if len(ctypes) > 0 { + ctype = ctypes[0] + } + + size, err := sizeFunc() + if err != nil { + Error(w, err.Error(), StatusInternalServerError) + return + } + + // handle Content-Range header. + sendSize := size + var sendContent io.Reader = content + if size >= 0 { + ranges, err := parseRange(rangeReq, size) + if err != nil { + if err == errNoOverlap { + w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + } + Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + return + } + if sumRangesSize(ranges) > size { + // The total number of bytes in all the ranges + // is larger than the size of the file by + // itself, so this is probably an attack, or a + // dumb client. Ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 1: + // RFC 7233, Section 4.1: + // "If a single part is being transferred, the server + // generating the 206 response MUST generate a + // Content-Range header field, describing what range + // of the selected representation is enclosed, and a + // payload consisting of the range. + // ... + // A server MUST NOT generate a multipart response to + // a request for a single range, since a client that + // does not request multiple parts might not support + // multipart responses." + ra := ranges[0] + if _, err := content.Seek(ra.start, io.SeekStart); err != nil { + Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + return + } + sendSize = ra.length + code = StatusPartialContent + w.Header().Set("Content-Range", ra.contentRange(size)) + case len(ranges) > 1: + sendSize = rangesMIMESize(ranges, ctype, size) + code = StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) + if err != nil { + pw.CloseWithError(err) + return + } + if _, err := content.Seek(ra.start, io.SeekStart); err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.CopyN(part, content, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } + mw.Close() + pw.Close() + }() + } + + w.Header().Set("Accept-Ranges", "bytes") + if w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) + } + } + + w.WriteHeader(code) + + if r.Method != "HEAD" { + io.CopyN(w, sendContent, sendSize) + } +} + +// scanETag determines if a syntactically valid ETag is present at s. If so, +// the ETag and remaining text after consuming ETag is returned. Otherwise, +// it returns "", "". +func scanETag(s string) (etag string, remain string) { + s = textproto.TrimString(s) + start := 0 + if strings.HasPrefix(s, "W/") { + start = 2 + } + if len(s[start:]) < 2 || s[start] != '"' { + return "", "" + } + // ETag is either W/"text" or "text". + // See RFC 7232 2.3. + for i := start + 1; i < len(s); i++ { + c := s[i] + switch { + // Character values allowed in ETags. + case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: + case c == '"': + return s[:i+1], s[i+1:] + default: + return "", "" + } + } + return "", "" +} + +// etagStrongMatch reports whether a and b match using strong ETag comparison. +// Assumes a and b are valid ETags. +func etagStrongMatch(a, b string) bool { + return a == b && a != "" && a[0] == '"' +} + +// etagWeakMatch reports whether a and b match using weak ETag comparison. +// Assumes a and b are valid ETags. +func etagWeakMatch(a, b string) bool { + return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/") +} + +// condResult is the result of an HTTP request precondition check. +// See https://tools.ietf.org/html/rfc7232 section 3. +type condResult int + +const ( + condNone condResult = iota + condTrue + condFalse +) + +func checkIfMatch(w ResponseWriter, r *Request) condResult { + im := r.Header.Get("If-Match") + if im == "" { + return condNone + } + for { + im = textproto.TrimString(im) + if len(im) == 0 { + break + } + if im[0] == ',' { + im = im[1:] + continue + } + if im[0] == '*' { + return condTrue + } + etag, remain := scanETag(im) + if etag == "" { + break + } + if etagStrongMatch(etag, w.Header().get("Etag")) { + return condTrue + } + im = remain + } + + return condFalse +} + +func checkIfUnmodifiedSince(r *Request, modtime time.Time) condResult { + ius := r.Header.Get("If-Unmodified-Since") + if ius == "" || isZeroTime(modtime) { + return condNone + } + t, err := ParseTime(ius) + if err != nil { + return condNone + } + + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime = modtime.Truncate(time.Second) + if modtime.Before(t) || modtime.Equal(t) { + return condTrue + } + return condFalse +} + +func checkIfNoneMatch(w ResponseWriter, r *Request) condResult { + inm := r.Header.get("If-None-Match") + if inm == "" { + return condNone + } + buf := inm + for { + buf = textproto.TrimString(buf) + if len(buf) == 0 { + break + } + if buf[0] == ',' { + buf = buf[1:] + continue + } + if buf[0] == '*' { + return condFalse + } + etag, remain := scanETag(buf) + if etag == "" { + break + } + if etagWeakMatch(etag, w.Header().get("Etag")) { + return condFalse + } + buf = remain + } + return condTrue +} + +func checkIfModifiedSince(r *Request, modtime time.Time) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ims := r.Header.Get("If-Modified-Since") + if ims == "" || isZeroTime(modtime) { + return condNone + } + t, err := ParseTime(ims) + if err != nil { + return condNone + } + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime = modtime.Truncate(time.Second) + if modtime.Before(t) || modtime.Equal(t) { + return condFalse + } + return condTrue +} + +func checkIfRange(w ResponseWriter, r *Request, modtime time.Time) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ir := r.Header.get("If-Range") + if ir == "" { + return condNone + } + etag, _ := scanETag(ir) + if etag != "" { + if etagStrongMatch(etag, w.Header().Get("Etag")) { + return condTrue + } else { + return condFalse + } + } + // The If-Range value is typically the ETag value, but it may also be + // the modtime date. See golang.org/issue/8367. + if modtime.IsZero() { + return condFalse + } + t, err := ParseTime(ir) + if err != nil { + return condFalse + } + if t.Unix() == modtime.Unix() { + return condTrue + } + return condFalse +} + +var unixEpochTime = time.Unix(0, 0) + +// isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0). +func isZeroTime(t time.Time) bool { + return t.IsZero() || t.Equal(unixEpochTime) +} + +func setLastModified(w ResponseWriter, modtime time.Time) { + if !isZeroTime(modtime) { + w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat)) + } +} + +func writeNotModified(w ResponseWriter) { + // RFC 7232 section 4.1: + // a sender SHOULD NOT generate representation metadata other than the + // above listed fields unless said metadata exists for the purpose of + // guiding cache updates (e.g., Last-Modified might be useful if the + // response does not have an ETag field). + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + if h.Get("Etag") != "" { + delete(h, "Last-Modified") + } + w.WriteHeader(StatusNotModified) +} + +// checkPreconditions evaluates request preconditions and reports whether a precondition +// resulted in sending StatusNotModified or StatusPreconditionFailed. +func checkPreconditions(w ResponseWriter, r *Request, modtime time.Time) (done bool, rangeHeader string) { + // This function carefully follows RFC 7232 section 6. + ch := checkIfMatch(w, r) + if ch == condNone { + ch = checkIfUnmodifiedSince(r, modtime) + } + if ch == condFalse { + w.WriteHeader(StatusPreconditionFailed) + return true, "" + } + switch checkIfNoneMatch(w, r) { + case condFalse: + if r.Method == "GET" || r.Method == "HEAD" { + writeNotModified(w) + return true, "" + } else { + w.WriteHeader(StatusPreconditionFailed) + return true, "" + } + case condNone: + if checkIfModifiedSince(r, modtime) == condFalse { + writeNotModified(w) + return true, "" + } + } + + rangeHeader = r.Header.get("Range") + if rangeHeader != "" && checkIfRange(w, r, modtime) == condFalse { + rangeHeader = "" + } + return false, rangeHeader +} + +// name is '/'-separated, not filepath.Separator. +func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) { + const indexPage = "/index.html" + + // redirect .../index.html to .../ + // can't use Redirect() because that would make the path absolute, + // which would be a problem running under StripPrefix + if strings.HasSuffix(r.URL.Path, indexPage) { + localRedirect(w, r, "./") + return + } + + f, err := fs.Open(name) + if err != nil { + msg, code := toHTTPError(err) + Error(w, msg, code) + return + } + defer f.Close() + + d, err := f.Stat() + if err != nil { + msg, code := toHTTPError(err) + Error(w, msg, code) + return + } + + if redirect { + // redirect to canonical path: / at end of directory url + // r.URL.Path always begins with / + url := r.URL.Path + if d.IsDir() { + if url[len(url)-1] != '/' { + localRedirect(w, r, path.Base(url)+"/") + return + } + } else { + if url[len(url)-1] == '/' { + localRedirect(w, r, "../"+path.Base(url)) + return + } + } + } + + if d.IsDir() { + url := r.URL.Path + // redirect if the directory name doesn't end in a slash + if url == "" || url[len(url)-1] != '/' { + localRedirect(w, r, path.Base(url)+"/") + return + } + + // use contents of index.html for directory, if present + index := strings.TrimSuffix(name, "/") + indexPage + ff, err := fs.Open(index) + if err == nil { + defer ff.Close() + dd, err := ff.Stat() + if err == nil { + name = index + d = dd + f = ff + } + } + } + + // Still a directory? (we didn't find an index.html file) + if d.IsDir() { + if checkIfModifiedSince(r, d.ModTime()) == condFalse { + writeNotModified(w) + return + } + setLastModified(w, d.ModTime()) + dirList(w, r, f) + return + } + + // serveContent will check modification time + sizeFunc := func() (int64, error) { return d.Size(), nil } + serveContent(w, r, d.Name(), d.ModTime(), sizeFunc, f) +} + +// toHTTPError returns a non-specific HTTP error message and status code +// for a given non-nil error value. It's important that toHTTPError does not +// actually return err.Error(), since msg and httpStatus are returned to users, +// and historically Go's ServeContent always returned just "404 Not Found" for +// all errors. We don't want to start leaking information in error messages. +func toHTTPError(err error) (msg string, httpStatus int) { + if os.IsNotExist(err) { + return "404 page not found", StatusNotFound + } + if os.IsPermission(err) { + return "403 Forbidden", StatusForbidden + } + // Default: + return "500 Internal Server Error", StatusInternalServerError +} + +// localRedirect gives a Moved Permanently response. +// It does not convert relative paths to absolute paths like Redirect does. +func localRedirect(w ResponseWriter, r *Request, newPath string) { + if q := r.URL.RawQuery; q != "" { + newPath += "?" + q + } + w.Header().Set("Location", newPath) + w.WriteHeader(StatusMovedPermanently) +} + +// ServeFile replies to the request with the contents of the named +// file or directory. +// +// If the provided file or directory name is a relative path, it is +// interpreted relative to the current directory and may ascend to +// parent directories. If the provided name is constructed from user +// input, it should be sanitized before calling ServeFile. +// +// As a precaution, ServeFile will reject requests where r.URL.Path +// contains a ".." path element; this protects against callers who +// might unsafely use filepath.Join on r.URL.Path without sanitizing +// it and then use that filepath.Join result as the name argument. +// +// As another special case, ServeFile redirects any request where r.URL.Path +// ends in "/index.html" to the same path, without the final +// "index.html". To avoid such redirects either modify the path or +// use ServeContent. +// +// Outside of those two special cases, ServeFile does not use +// r.URL.Path for selecting the file or directory to serve; only the +// file or directory provided in the name argument is used. +func ServeFile(w ResponseWriter, r *Request, name string) { + if containsDotDot(r.URL.Path) { + // Too many programs use r.URL.Path to construct the argument to + // serveFile. Reject the request under the assumption that happened + // here and ".." may not be wanted. + // Note that name might not contain "..", for example if code (still + // incorrectly) used filepath.Join(myDir, r.URL.Path). + Error(w, "invalid URL path", StatusBadRequest) + return + } + dir, file := filepath.Split(name) + serveFile(w, r, Dir(dir), file, false) +} + +func containsDotDot(v string) bool { + if !strings.Contains(v, "..") { + return false + } + for _, ent := range strings.FieldsFunc(v, isSlashRune) { + if ent == ".." { + return true + } + } + return false +} + +func isSlashRune(r rune) bool { return r == '/' || r == '\\' } + +type fileHandler struct { + root FileSystem +} + +type ioFS struct { + fsys fs.FS +} + +type ioFile struct { + file fs.File +} + +func (f ioFS) Open(name string) (File, error) { + if name == "/" { + name = "." + } else { + name = strings.TrimPrefix(name, "/") + } + file, err := f.fsys.Open(name) + if err != nil { + return nil, err + } + return ioFile{file}, nil +} + +func (f ioFile) Close() error { return f.file.Close() } +func (f ioFile) Read(b []byte) (int, error) { return f.file.Read(b) } +func (f ioFile) Stat() (fs.FileInfo, error) { return f.file.Stat() } + +var errMissingSeek = errors.New("io.File missing Seek method") +var errMissingReadDir = errors.New("io.File directory missing ReadDir method") + +func (f ioFile) Seek(offset int64, whence int) (int64, error) { + s, ok := f.file.(io.Seeker) + if !ok { + return 0, errMissingSeek + } + return s.Seek(offset, whence) +} + +func (f ioFile) ReadDir(count int) ([]fs.DirEntry, error) { + d, ok := f.file.(fs.ReadDirFile) + if !ok { + return nil, errMissingReadDir + } + return d.ReadDir(count) +} + +func (f ioFile) Readdir(count int) ([]fs.FileInfo, error) { + d, ok := f.file.(fs.ReadDirFile) + if !ok { + return nil, errMissingReadDir + } + var list []fs.FileInfo + for { + dirs, err := d.ReadDir(count - len(list)) + for _, dir := range dirs { + info, err := dir.Info() + if err != nil { + // Pretend it doesn't exist, like (*os.File).Readdir does. + continue + } + list = append(list, info) + } + if err != nil { + return list, err + } + if count < 0 || len(list) >= count { + break + } + } + return list, nil +} + +// FS converts fsys to a FileSystem implementation, +// for use with FileServer and NewFileTransport. +func FS(fsys fs.FS) FileSystem { + return ioFS{fsys} +} + +// FileServer returns a handler that serves HTTP requests +// with the contents of the file system rooted at root. +// +// As a special case, the returned file server redirects any request +// ending in "/index.html" to the same path, without the final +// "index.html". +// +// To use the operating system's file system implementation, +// use http.Dir: +// +// http.Handle("/", http.FileServer(http.Dir("/tmp"))) +// +// To use an fs.FS implementation, use http.FS to convert it: +// +// http.Handle("/", http.FileServer(http.FS(fsys))) +// +func FileServer(root FileSystem) Handler { + return &fileHandler{root} +} + +func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) { + upath := r.URL.Path + if !strings.HasPrefix(upath, "/") { + upath = "/" + upath + r.URL.Path = upath + } + serveFile(w, r, f.root, path.Clean(upath), true) +} + +// httpRange specifies the byte range to be sent to the client. +type httpRange struct { + start, length int64 +} + +func (r httpRange) contentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) +} + +func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.contentRange(size)}, + "Content-Type": {contentType}, + } +} + +// parseRange parses a Range header string as per RFC 7233. +// errNoOverlap is returned if none of the ranges overlap. +func parseRange(s string, size int64) ([]httpRange, error) { + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, errors.New("invalid range") + } + var ranges []httpRange + noOverlap := false + for _, ra := range strings.Split(s[len(b):], ",") { + ra = textproto.TrimString(ra) + if ra == "" { + continue + } + i := strings.Index(ra, "-") + if i < 0 { + return nil, errors.New("invalid range") + } + start, end := textproto.TrimString(ra[:i]), textproto.TrimString(ra[i+1:]) + var r httpRange + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file, + // and we are dealing with + // which has to be a non-negative integer as per + // RFC 7233 Section 2.1 "Byte-Ranges". + if end == "" || end[0] == '-' { + return nil, errors.New("invalid range") + } + i, err := strconv.ParseInt(end, 10, 64) + if i < 0 || err != nil { + return nil, errors.New("invalid range") + } + if i > size { + i = size + } + r.start = size - i + r.length = size - r.start + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i < 0 { + return nil, errors.New("invalid range") + } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + noOverlap = true + continue + } + r.start = i + if end == "" { + // If no end is specified, range extends to end of the file. + r.length = size - r.start + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r.start > i { + return nil, errors.New("invalid range") + } + if i >= size { + i = size - 1 + } + r.length = i - r.start + 1 + } + } + ranges = append(ranges, r) + } + if noOverlap && len(ranges) == 0 { + // The specified ranges did not overlap with the content. + return nil, errNoOverlap + } + return ranges, nil +} + +// countingWriter counts how many bytes have been written to it. +type countingWriter int64 + +func (w *countingWriter) Write(p []byte) (n int, err error) { + *w += countingWriter(len(p)) + return len(p), nil +} + +// rangesMIMESize returns the number of bytes it takes to encode the +// provided ranges as a multipart response. +func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { + var w countingWriter + mw := multipart.NewWriter(&w) + for _, ra := range ranges { + mw.CreatePart(ra.mimeHeader(contentType, contentSize)) + encSize += ra.length + } + mw.Close() + encSize += int64(w) + return +} + +func sumRangesSize(ranges []httpRange) (size int64) { + for _, ra := range ranges { + size += ra.length + } + return +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/fs_test.go b/vendor/github.com/lesismal/llib/std/net/http/fs_test.go new file mode 100644 index 0000000..52f234f --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/fs_test.go @@ -0,0 +1,1414 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "github.com/lesismal/llib/std/net/http/httptest" + "io" + "io/fs" + "io/ioutil" + "mime" + "mime/multipart" + "net" + . "net/http" + "net/url" + "os" + "os/exec" + "path" + "path/filepath" + "reflect" + "regexp" + "runtime" + "strings" + "testing" + "time" +) + +const ( + testFile = "testdata/file" + testFileLen = 11 +) + +type wantRange struct { + start, end int64 // range [start,end) +} + +var ServeFileRangeTests = []struct { + r string + code int + ranges []wantRange +}{ + {r: "", code: StatusOK}, + {r: "bytes=0-4", code: StatusPartialContent, ranges: []wantRange{{0, 5}}}, + {r: "bytes=2-", code: StatusPartialContent, ranges: []wantRange{{2, testFileLen}}}, + {r: "bytes=-5", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 5, testFileLen}}}, + {r: "bytes=3-7", code: StatusPartialContent, ranges: []wantRange{{3, 8}}}, + {r: "bytes=0-0,-2", code: StatusPartialContent, ranges: []wantRange{{0, 1}, {testFileLen - 2, testFileLen}}}, + {r: "bytes=0-1,5-8", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, 9}}}, + {r: "bytes=0-1,5-", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, testFileLen}}}, + {r: "bytes=5-1000", code: StatusPartialContent, ranges: []wantRange{{5, testFileLen}}}, + {r: "bytes=0-,1-,2-,3-,4-", code: StatusOK}, // ignore wasteful range request + {r: "bytes=0-9", code: StatusPartialContent, ranges: []wantRange{{0, testFileLen - 1}}}, + {r: "bytes=0-10", code: StatusPartialContent, ranges: []wantRange{{0, testFileLen}}}, + {r: "bytes=0-11", code: StatusPartialContent, ranges: []wantRange{{0, testFileLen}}}, + {r: "bytes=10-11", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 1, testFileLen}}}, + {r: "bytes=10-", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 1, testFileLen}}}, + {r: "bytes=11-", code: StatusRequestedRangeNotSatisfiable}, + {r: "bytes=11-12", code: StatusRequestedRangeNotSatisfiable}, + {r: "bytes=12-12", code: StatusRequestedRangeNotSatisfiable}, + {r: "bytes=11-100", code: StatusRequestedRangeNotSatisfiable}, + {r: "bytes=12-100", code: StatusRequestedRangeNotSatisfiable}, + {r: "bytes=100-", code: StatusRequestedRangeNotSatisfiable}, + {r: "bytes=100-1000", code: StatusRequestedRangeNotSatisfiable}, +} + +func TestServeFile(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "testdata/file") + })) + defer ts.Close() + c := ts.Client() + + var err error + + file, err := os.ReadFile(testFile) + if err != nil { + t.Fatal("reading file:", err) + } + + // set up the Request (re-used for all tests) + var req Request + req.Header = make(Header) + if req.URL, err = url.Parse(ts.URL); err != nil { + t.Fatal("ParseURL:", err) + } + req.Method = "GET" + + // straight GET + _, body := getBody(t, "straight get", req, c) + if !bytes.Equal(body, file) { + t.Fatalf("body mismatch: got %q, want %q", body, file) + } + + // Range tests +Cases: + for _, rt := range ServeFileRangeTests { + if rt.r != "" { + req.Header.Set("Range", rt.r) + } + resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req, c) + if resp.StatusCode != rt.code { + t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code) + } + if rt.code == StatusRequestedRangeNotSatisfiable { + continue + } + wantContentRange := "" + if len(rt.ranges) == 1 { + rng := rt.ranges[0] + wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen) + } + cr := resp.Header.Get("Content-Range") + if cr != wantContentRange { + t.Errorf("range=%q: Content-Range = %q, want %q", rt.r, cr, wantContentRange) + } + ct := resp.Header.Get("Content-Type") + if len(rt.ranges) == 1 { + rng := rt.ranges[0] + wantBody := file[rng.start:rng.end] + if !bytes.Equal(body, wantBody) { + t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody) + } + if strings.HasPrefix(ct, "multipart/byteranges") { + t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r, ct) + } + } + if len(rt.ranges) > 1 { + typ, params, err := mime.ParseMediaType(ct) + if err != nil { + t.Errorf("range=%q content-type = %q; %v", rt.r, ct, err) + continue + } + if typ != "multipart/byteranges" { + t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r, typ) + continue + } + if params["boundary"] == "" { + t.Errorf("range=%q content-type = %q; lacks boundary", rt.r, ct) + continue + } + if g, w := resp.ContentLength, int64(len(body)); g != w { + t.Errorf("range=%q Content-Length = %d; want %d", rt.r, g, w) + continue + } + mr := multipart.NewReader(bytes.NewReader(body), params["boundary"]) + for ri, rng := range rt.ranges { + part, err := mr.NextPart() + if err != nil { + t.Errorf("range=%q, reading part index %d: %v", rt.r, ri, err) + continue Cases + } + wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen) + if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w { + t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w) + } + body, err := io.ReadAll(part) + if err != nil { + t.Errorf("range=%q, reading part index %d body: %v", rt.r, ri, err) + continue Cases + } + wantBody := file[rng.start:rng.end] + if !bytes.Equal(body, wantBody) { + t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody) + } + } + _, err = mr.NextPart() + if err != io.EOF { + t.Errorf("range=%q; expected final error io.EOF; got %v", rt.r, err) + } + } + } +} + +func TestServeFile_DotDot(t *testing.T) { + tests := []struct { + req string + wantStatus int + }{ + {"/testdata/file", 200}, + {"/../file", 400}, + {"/..", 400}, + {"/../", 400}, + {"/../foo", 400}, + {"/..\\foo", 400}, + {"/file/a", 200}, + {"/file/a..", 200}, + {"/file/a/..", 400}, + {"/file/a\\..", 400}, + } + for _, tt := range tests { + req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + tt.req + " HTTP/1.1\r\nHost: foo\r\n\r\n"))) + if err != nil { + t.Errorf("bad request %q: %v", tt.req, err) + continue + } + rec := httptest.NewRecorder() + ServeFile(rec, req, "testdata/file") + if rec.Code != tt.wantStatus { + t.Errorf("for request %q, status = %d; want %d", tt.req, rec.Code, tt.wantStatus) + } + } +} + +// Tests that this doesn't panic. (Issue 30165) +func TestServeFileDirPanicEmptyPath(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + req.URL.Path = "" + ServeFile(rec, req, "testdata") + res := rec.Result() + if res.StatusCode != 301 { + t.Errorf("code = %v; want 301", res.Status) + } +} + +var fsRedirectTestData = []struct { + original, redirect string +}{ + {"/test/index.html", "/test/"}, + {"/test/testdata", "/test/testdata/"}, + {"/test/testdata/file/", "/test/testdata/file"}, +} + +func TestFSRedirect(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir(".")))) + defer ts.Close() + + for _, data := range fsRedirectTestData { + res, err := Get(ts.URL + data.original) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if g, e := res.Request.URL.Path, data.redirect; g != e { + t.Errorf("redirect from %s: got %s, want %s", data.original, g, e) + } + } +} + +type testFileSystem struct { + open func(name string) (File, error) +} + +func (fs *testFileSystem) Open(name string) (File, error) { + return fs.open(name) +} + +func TestFileServerCleans(t *testing.T) { + defer afterTest(t) + ch := make(chan string, 1) + fs := FileServer(&testFileSystem{func(name string) (File, error) { + ch <- name + return nil, errors.New("file does not exist") + }}) + tests := []struct { + reqPath, openArg string + }{ + {"/foo.txt", "/foo.txt"}, + {"//foo.txt", "/foo.txt"}, + {"/../foo.txt", "/foo.txt"}, + } + req, _ := NewRequest("GET", "http://example.com", nil) + for n, test := range tests { + rec := httptest.NewRecorder() + req.URL.Path = test.reqPath + fs.ServeHTTP(rec, req) + if got := <-ch; got != test.openArg { + t.Errorf("test %d: got %q, want %q", n, got, test.openArg) + } + } +} + +func TestFileServerEscapesNames(t *testing.T) { + defer afterTest(t) + const dirListPrefix = "
\n"
+	const dirListSuffix = "\n
\n" + tests := []struct { + name, escaped string + }{ + {`simple_name`, `simple_name`}, + {`"'<>&`, `"'<>&`}, + {`?foo=bar#baz`, `?foo=bar#baz`}, + {`?foo`, `<combo>?foo`}, + {`foo:bar`, `foo:bar`}, + } + + // We put each test file in its own directory in the fakeFS so we can look at it in isolation. + fs := make(fakeFS) + for i, test := range tests { + testFile := &fakeFileInfo{basename: test.name} + fs[fmt.Sprintf("/%d", i)] = &fakeFileInfo{ + dir: true, + modtime: time.Unix(1000000000, 0).UTC(), + ents: []*fakeFileInfo{testFile}, + } + fs[fmt.Sprintf("/%d/%s", i, test.name)] = testFile + } + + ts := httptest.NewServer(FileServer(&fs)) + defer ts.Close() + for i, test := range tests { + url := fmt.Sprintf("%s/%d", ts.URL, i) + res, err := Get(url) + if err != nil { + t.Fatalf("test %q: Get: %v", test.name, err) + } + b, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("test %q: read Body: %v", test.name, err) + } + s := string(b) + if !strings.HasPrefix(s, dirListPrefix) || !strings.HasSuffix(s, dirListSuffix) { + t.Errorf("test %q: listing dir, full output is %q, want prefix %q and suffix %q", test.name, s, dirListPrefix, dirListSuffix) + } + if trimmed := strings.TrimSuffix(strings.TrimPrefix(s, dirListPrefix), dirListSuffix); trimmed != test.escaped { + t.Errorf("test %q: listing dir, filename escaped to %q, want %q", test.name, trimmed, test.escaped) + } + res.Body.Close() + } +} + +func TestFileServerSortsNames(t *testing.T) { + defer afterTest(t) + const contents = "I am a fake file" + dirMod := time.Unix(123, 0).UTC() + fileMod := time.Unix(1000000000, 0).UTC() + fs := fakeFS{ + "/": &fakeFileInfo{ + dir: true, + modtime: dirMod, + ents: []*fakeFileInfo{ + { + basename: "b", + modtime: fileMod, + contents: contents, + }, + { + basename: "a", + modtime: fileMod, + contents: contents, + }, + }, + }, + } + + ts := httptest.NewServer(FileServer(&fs)) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + + b, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("read Body: %v", err) + } + s := string(b) + if !strings.Contains(s, "a\nb") { + t.Errorf("output appears to be unsorted:\n%s", s) + } +} + +func mustRemoveAll(dir string) { + err := os.RemoveAll(dir) + if err != nil { + panic(err) + } +} + +func TestFileServerImplicitLeadingSlash(t *testing.T) { + defer afterTest(t) + tempDir, err := os.MkdirTemp("", "") + if err != nil { + t.Fatalf("TempDir: %v", err) + } + defer mustRemoveAll(tempDir) + if err := os.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + ts := httptest.NewServer(StripPrefix("/bar/", FileServer(Dir(tempDir)))) + defer ts.Close() + get := func(suffix string) string { + res, err := Get(ts.URL + suffix) + if err != nil { + t.Fatalf("Get %s: %v", suffix, err) + } + b, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("ReadAll %s: %v", suffix, err) + } + res.Body.Close() + return string(b) + } + if s := get("/bar/"); !strings.Contains(s, ">foo.txt<") { + t.Logf("expected a directory listing with foo.txt, got %q", s) + } + if s := get("/bar/foo.txt"); s != "Hello world" { + t.Logf("expected %q, got %q", "Hello world", s) + } +} + +func TestDirJoin(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping test on windows") + } + wfi, err := os.Stat("/etc/hosts") + if err != nil { + t.Skip("skipping test; no /etc/hosts file") + } + test := func(d Dir, name string) { + f, err := d.Open(name) + if err != nil { + t.Fatalf("open of %s: %v", name, err) + } + defer f.Close() + gfi, err := f.Stat() + if err != nil { + t.Fatalf("stat of %s: %v", name, err) + } + if !os.SameFile(gfi, wfi) { + t.Errorf("%s got different file", name) + } + } + test(Dir("/etc/"), "/hosts") + test(Dir("/etc/"), "hosts") + test(Dir("/etc/"), "../../../../hosts") + test(Dir("/etc"), "/hosts") + test(Dir("/etc"), "hosts") + test(Dir("/etc"), "../../../../hosts") + + // Not really directories, but since we use this trick in + // ServeFile, test it: + test(Dir("/etc/hosts"), "") + test(Dir("/etc/hosts"), "/") + test(Dir("/etc/hosts"), "../") +} + +func TestEmptyDirOpenCWD(t *testing.T) { + test := func(d Dir) { + name := "fs_test.go" + f, err := d.Open(name) + if err != nil { + t.Fatalf("open of %s: %v", name, err) + } + defer f.Close() + } + test(Dir("")) + test(Dir(".")) + test(Dir("./")) +} + +func TestServeFileContentType(t *testing.T) { + defer afterTest(t) + const ctype = "icecream/chocolate" + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + switch r.FormValue("override") { + case "1": + w.Header().Set("Content-Type", ctype) + case "2": + // Explicitly inhibit sniffing. + w.Header()["Content-Type"] = []string{} + } + ServeFile(w, r, "testdata/file") + })) + defer ts.Close() + get := func(override string, want []string) { + resp, err := Get(ts.URL + "?override=" + override) + if err != nil { + t.Fatal(err) + } + if h := resp.Header["Content-Type"]; !reflect.DeepEqual(h, want) { + t.Errorf("Content-Type mismatch: got %v, want %v", h, want) + } + resp.Body.Close() + } + get("0", []string{"text/plain; charset=utf-8"}) + get("1", []string{ctype}) + get("2", nil) +} + +func TestServeFileMimeType(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "testdata/style.css") + })) + defer ts.Close() + resp, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + want := "text/css; charset=utf-8" + if h := resp.Header.Get("Content-Type"); h != want { + t.Errorf("Content-Type mismatch: got %q, want %q", h, want) + } +} + +func TestServeFileFromCWD(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "fs_test.go") + })) + defer ts.Close() + r, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + r.Body.Close() + if r.StatusCode != 200 { + t.Fatalf("expected 200 OK, got %s", r.Status) + } +} + +// Issue 13996 +func TestServeDirWithoutTrailingSlash(t *testing.T) { + e := "/testdata/" + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, ".") + })) + defer ts.Close() + r, err := Get(ts.URL + "/testdata") + if err != nil { + t.Fatal(err) + } + r.Body.Close() + if g := r.Request.URL.Path; g != e { + t.Errorf("got %s, want %s", g, e) + } +} + +// Tests that ServeFile doesn't add a Content-Length if a Content-Encoding is +// specified. +func TestServeFileWithContentEncoding_h1(t *testing.T) { testServeFileWithContentEncoding(t, h1Mode) } +func TestServeFileWithContentEncoding_h2(t *testing.T) { testServeFileWithContentEncoding(t, h2Mode) } +func testServeFileWithContentEncoding(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "foo") + ServeFile(w, r, "testdata/file") + + // Because the testdata is so small, it would fit in + // both the h1 and h2 Server's write buffers. For h1, + // sendfile is used, though, forcing a header flush at + // the io.Copy. http2 doesn't do a header flush so + // buffers all 11 bytes and then adds its own + // Content-Length. To prevent the Server's + // Content-Length and test ServeFile only, flush here. + w.(Flusher).Flush() + })) + defer cst.close() + resp, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if g, e := resp.ContentLength, int64(-1); g != e { + t.Errorf("Content-Length mismatch: got %d, want %d", g, e) + } +} + +func TestServeIndexHtml(t *testing.T) { + defer afterTest(t) + + for i := 0; i < 2; i++ { + var h Handler + var name string + switch i { + case 0: + h = FileServer(Dir(".")) + name = "Dir" + case 1: + h = FileServer(FS(os.DirFS("."))) + name = "DirFS" + } + t.Run(name, func(t *testing.T) { + const want = "index.html says hello\n" + ts := httptest.NewServer(h) + defer ts.Close() + + for _, path := range []string{"/testdata/", "/testdata/index.html"} { + res, err := Get(ts.URL + path) + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if s := string(b); s != want { + t.Errorf("for path %q got %q, want %q", path, s, want) + } + res.Body.Close() + } + }) + } +} + +func TestServeIndexHtmlFS(t *testing.T) { + defer afterTest(t) + const want = "index.html says hello\n" + ts := httptest.NewServer(FileServer(Dir("."))) + defer ts.Close() + + for _, path := range []string{"/testdata/", "/testdata/index.html"} { + res, err := Get(ts.URL + path) + if err != nil { + t.Fatal(err) + } + b, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if s := string(b); s != want { + t.Errorf("for path %q got %q, want %q", path, s, want) + } + res.Body.Close() + } +} + +func TestFileServerZeroByte(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(FileServer(Dir("."))) + defer ts.Close() + + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + _, err = fmt.Fprintf(c, "GET /..\x00 HTTP/1.0\r\n\r\n") + if err != nil { + t.Fatal(err) + } + var got bytes.Buffer + bufr := bufio.NewReader(io.TeeReader(c, &got)) + res, err := ReadResponse(bufr, nil) + if err != nil { + t.Fatal("ReadResponse: ", err) + } + if res.StatusCode == 200 { + t.Errorf("got status 200; want an error. Body is:\n%s", got.Bytes()) + } +} + +type fakeFileInfo struct { + dir bool + basename string + modtime time.Time + ents []*fakeFileInfo + contents string + err error +} + +func (f *fakeFileInfo) Name() string { return f.basename } +func (f *fakeFileInfo) Sys() interface{} { return nil } +func (f *fakeFileInfo) ModTime() time.Time { return f.modtime } +func (f *fakeFileInfo) IsDir() bool { return f.dir } +func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) } +func (f *fakeFileInfo) Mode() fs.FileMode { + if f.dir { + return 0755 | fs.ModeDir + } + return 0644 +} + +type fakeFile struct { + io.ReadSeeker + fi *fakeFileInfo + path string // as opened + entpos int +} + +func (f *fakeFile) Close() error { return nil } +func (f *fakeFile) Stat() (fs.FileInfo, error) { return f.fi, nil } +func (f *fakeFile) Readdir(count int) ([]fs.FileInfo, error) { + if !f.fi.dir { + return nil, fs.ErrInvalid + } + var fis []fs.FileInfo + + limit := f.entpos + count + if count <= 0 || limit > len(f.fi.ents) { + limit = len(f.fi.ents) + } + for ; f.entpos < limit; f.entpos++ { + fis = append(fis, f.fi.ents[f.entpos]) + } + + if len(fis) == 0 && count > 0 { + return fis, io.EOF + } else { + return fis, nil + } +} + +type fakeFS map[string]*fakeFileInfo + +func (fsys fakeFS) Open(name string) (File, error) { + name = path.Clean(name) + f, ok := fsys[name] + if !ok { + return nil, fs.ErrNotExist + } + if f.err != nil { + return nil, f.err + } + return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil +} + +func TestDirectoryIfNotModified(t *testing.T) { + defer afterTest(t) + const indexContents = "I am a fake index.html file" + fileMod := time.Unix(1000000000, 0).UTC() + fileModStr := fileMod.Format(TimeFormat) + dirMod := time.Unix(123, 0).UTC() + indexFile := &fakeFileInfo{ + basename: "index.html", + modtime: fileMod, + contents: indexContents, + } + fs := fakeFS{ + "/": &fakeFileInfo{ + dir: true, + modtime: dirMod, + ents: []*fakeFileInfo{indexFile}, + }, + "/index.html": indexFile, + } + + ts := httptest.NewServer(FileServer(fs)) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + b, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(b) != indexContents { + t.Fatalf("Got body %q; want %q", b, indexContents) + } + res.Body.Close() + + lastMod := res.Header.Get("Last-Modified") + if lastMod != fileModStr { + t.Fatalf("initial Last-Modified = %q; want %q", lastMod, fileModStr) + } + + req, _ := NewRequest("GET", ts.URL, nil) + req.Header.Set("If-Modified-Since", lastMod) + + c := ts.Client() + res, err = c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 304 { + t.Fatalf("Code after If-Modified-Since request = %v; want 304", res.StatusCode) + } + res.Body.Close() + + // Advance the index.html file's modtime, but not the directory's. + indexFile.modtime = indexFile.modtime.Add(1 * time.Hour) + + res, err = c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 200 { + t.Fatalf("Code after second If-Modified-Since request = %v; want 200; res is %#v", res.StatusCode, res) + } + res.Body.Close() +} + +func mustStat(t *testing.T, fileName string) fs.FileInfo { + fi, err := os.Stat(fileName) + if err != nil { + t.Fatal(err) + } + return fi +} + +func TestServeContent(t *testing.T) { + defer afterTest(t) + type serveParam struct { + name string + modtime time.Time + content io.ReadSeeker + contentType string + etag string + } + servec := make(chan serveParam, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + p := <-servec + if p.etag != "" { + w.Header().Set("ETag", p.etag) + } + if p.contentType != "" { + w.Header().Set("Content-Type", p.contentType) + } + ServeContent(w, r, p.name, p.modtime, p.content) + })) + defer ts.Close() + + type testCase struct { + // One of file or content must be set: + file string + content io.ReadSeeker + + modtime time.Time + serveETag string // optional + serveContentType string // optional + reqHeader map[string]string + wantLastMod string + wantContentType string + wantContentRange string + wantStatus int + } + htmlModTime := mustStat(t, "testdata/index.html").ModTime() + tests := map[string]testCase{ + "no_last_modified": { + file: "testdata/style.css", + wantContentType: "text/css; charset=utf-8", + wantStatus: 200, + }, + "with_last_modified": { + file: "testdata/index.html", + wantContentType: "text/html; charset=utf-8", + modtime: htmlModTime, + wantLastMod: htmlModTime.UTC().Format(TimeFormat), + wantStatus: 200, + }, + "not_modified_modtime": { + file: "testdata/style.css", + serveETag: `"foo"`, // Last-Modified sent only when no ETag + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 304, + }, + "not_modified_modtime_with_contenttype": { + file: "testdata/style.css", + serveContentType: "text/css", // explicit content type + serveETag: `"foo"`, // Last-Modified sent only when no ETag + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 304, + }, + "not_modified_etag": { + file: "testdata/style.css", + serveETag: `"foo"`, + reqHeader: map[string]string{ + "If-None-Match": `"foo"`, + }, + wantStatus: 304, + }, + "not_modified_etag_no_seek": { + content: panicOnSeek{nil}, // should never be called + serveETag: `W/"foo"`, // If-None-Match uses weak ETag comparison + reqHeader: map[string]string{ + "If-None-Match": `"baz", W/"foo"`, + }, + wantStatus: 304, + }, + "if_none_match_mismatch": { + file: "testdata/style.css", + serveETag: `"foo"`, + reqHeader: map[string]string{ + "If-None-Match": `"Foo"`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + "if_none_match_malformed": { + file: "testdata/style.css", + serveETag: `"foo"`, + reqHeader: map[string]string{ + "If-None-Match": `,`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + "range_good": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + }, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantContentRange: "bytes 0-4/8", + }, + "range_match": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": `"A"`, + }, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantContentRange: "bytes 0-4/8", + }, + "range_match_weak_etag": { + file: "testdata/style.css", + serveETag: `W/"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": `W/"A"`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + "range_no_overlap": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=10-20", + }, + wantStatus: StatusRequestedRangeNotSatisfiable, + wantContentType: "text/plain; charset=utf-8", + wantContentRange: "bytes */8", + }, + // An If-Range resource for entity "A", but entity "B" is now current. + // The Range request should be ignored. + "range_no_match": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": `"B"`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + "range_with_modtime": { + file: "testdata/style.css", + modtime: time.Date(2014, 6, 25, 17, 12, 18, 0 /* nanos */, time.UTC), + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": "Wed, 25 Jun 2014 17:12:18 GMT", + }, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantContentRange: "bytes 0-4/8", + wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", + }, + "range_with_modtime_mismatch": { + file: "testdata/style.css", + modtime: time.Date(2014, 6, 25, 17, 12, 18, 0 /* nanos */, time.UTC), + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": "Wed, 25 Jun 2014 17:12:19 GMT", + }, + wantStatus: StatusOK, + wantContentType: "text/css; charset=utf-8", + wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", + }, + "range_with_modtime_nanos": { + file: "testdata/style.css", + modtime: time.Date(2014, 6, 25, 17, 12, 18, 123 /* nanos */, time.UTC), + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": "Wed, 25 Jun 2014 17:12:18 GMT", + }, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantContentRange: "bytes 0-4/8", + wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", + }, + "unix_zero_modtime": { + content: strings.NewReader("foo"), + modtime: time.Unix(0, 0), + wantStatus: StatusOK, + wantContentType: "text/html; charset=utf-8", + }, + "ifmatch_matches": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "If-Match": `"Z", "A"`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + "ifmatch_star": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "If-Match": `*`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + "ifmatch_failed": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "If-Match": `"B"`, + }, + wantStatus: 412, + }, + "ifmatch_fails_on_weak_etag": { + file: "testdata/style.css", + serveETag: `W/"A"`, + reqHeader: map[string]string{ + "If-Match": `W/"A"`, + }, + wantStatus: 412, + }, + "if_unmodified_since_true": { + file: "testdata/style.css", + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Unmodified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + wantLastMod: htmlModTime.UTC().Format(TimeFormat), + }, + "if_unmodified_since_false": { + file: "testdata/style.css", + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Unmodified-Since": htmlModTime.Add(-2 * time.Second).UTC().Format(TimeFormat), + }, + wantStatus: 412, + wantLastMod: htmlModTime.UTC().Format(TimeFormat), + }, + } + for testName, tt := range tests { + var content io.ReadSeeker + if tt.file != "" { + f, err := os.Open(tt.file) + if err != nil { + t.Fatalf("test %q: %v", testName, err) + } + defer f.Close() + content = f + } else { + content = tt.content + } + for _, method := range []string{"GET", "HEAD"} { + //restore content in case it is consumed by previous method + if content, ok := content.(*strings.Reader); ok { + content.Seek(0, io.SeekStart) + } + + servec <- serveParam{ + name: filepath.Base(tt.file), + content: content, + modtime: tt.modtime, + etag: tt.serveETag, + contentType: tt.serveContentType, + } + req, err := NewRequest(method, ts.URL, nil) + if err != nil { + t.Fatal(err) + } + for k, v := range tt.reqHeader { + req.Header.Set(k, v) + } + + c := ts.Client() + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + io.Copy(io.Discard, res.Body) + res.Body.Close() + if res.StatusCode != tt.wantStatus { + t.Errorf("test %q using %q: got status = %d; want %d", testName, method, res.StatusCode, tt.wantStatus) + } + if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e { + t.Errorf("test %q using %q: got content-type = %q, want %q", testName, method, g, e) + } + if g, e := res.Header.Get("Content-Range"), tt.wantContentRange; g != e { + t.Errorf("test %q using %q: got content-range = %q, want %q", testName, method, g, e) + } + if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e { + t.Errorf("test %q using %q: got last-modified = %q, want %q", testName, method, g, e) + } + } + } +} + +// Issue 12991 +func TestServerFileStatError(t *testing.T) { + rec := httptest.NewRecorder() + r, _ := NewRequest("GET", "http://foo/", nil) + redirect := false + name := "file.txt" + fs := issue12991FS{} + ExportServeFile(rec, r, fs, name, redirect) + if body := rec.Body.String(); !strings.Contains(body, "403") || !strings.Contains(body, "Forbidden") { + t.Errorf("wanted 403 forbidden message; got: %s", body) + } +} + +type issue12991FS struct{} + +func (issue12991FS) Open(string) (File, error) { return issue12991File{}, nil } + +type issue12991File struct{ File } + +func (issue12991File) Stat() (fs.FileInfo, error) { return nil, fs.ErrPermission } +func (issue12991File) Close() error { return nil } + +func TestServeContentErrorMessages(t *testing.T) { + defer afterTest(t) + fs := fakeFS{ + "/500": &fakeFileInfo{ + err: errors.New("random error"), + }, + "/403": &fakeFileInfo{ + err: &fs.PathError{Err: fs.ErrPermission}, + }, + } + ts := httptest.NewServer(FileServer(fs)) + defer ts.Close() + c := ts.Client() + for _, code := range []int{403, 404, 500} { + res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, code)) + if err != nil { + t.Errorf("Error fetching /%d: %v", code, err) + continue + } + if res.StatusCode != code { + t.Errorf("For /%d, status code = %d; want %d", code, res.StatusCode, code) + } + res.Body.Close() + } +} + +// verifies that sendfile is being used on Linux +func TestLinuxSendfile(t *testing.T) { + setParallel(t) + defer afterTest(t) + if runtime.GOOS != "linux" { + t.Skip("skipping; linux-only test") + } + if _, err := exec.LookPath("strace"); err != nil { + t.Skip("skipping; strace not found in path") + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + lnf, err := ln.(*net.TCPListener).File() + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + // Attempt to run strace, and skip on failure - this test requires SYS_PTRACE. + if err := exec.Command("strace", "-f", "-q", os.Args[0], "-test.run=^$").Run(); err != nil { + t.Skipf("skipping; failed to run strace: %v", err) + } + + filename := fmt.Sprintf("1kb-%d", os.Getpid()) + filepath := path.Join(os.TempDir(), filename) + + if err := os.WriteFile(filepath, bytes.Repeat([]byte{'a'}, 1<<10), 0755); err != nil { + t.Fatal(err) + } + defer os.Remove(filepath) + + var buf bytes.Buffer + child := exec.Command("strace", "-f", "-q", os.Args[0], "-test.run=TestLinuxSendfileChild") + child.ExtraFiles = append(child.ExtraFiles, lnf) + child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...) + child.Stdout = &buf + child.Stderr = &buf + if err := child.Start(); err != nil { + t.Skipf("skipping; failed to start straced child: %v", err) + } + + res, err := Get(fmt.Sprintf("http://%s/%s", ln.Addr(), filename)) + if err != nil { + t.Fatalf("http client error: %v", err) + } + _, err = io.Copy(io.Discard, res.Body) + if err != nil { + t.Fatalf("client body read error: %v", err) + } + res.Body.Close() + + // Force child to exit cleanly. + Post(fmt.Sprintf("http://%s/quit", ln.Addr()), "", nil) + child.Wait() + + rx := regexp.MustCompile(`\b(n64:)?sendfile(64)?\(`) + out := buf.String() + if !rx.MatchString(out) { + t.Errorf("no sendfile system call found in:\n%s", out) + } +} + +func getBody(t *testing.T, testName string, req Request, client *Client) (*Response, []byte) { + r, err := client.Do(&req) + if err != nil { + t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err) + } + b, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("%s: for URL %q, reading body: %v", testName, req.URL.String(), err) + } + return r, b +} + +// TestLinuxSendfileChild isn't a real test. It's used as a helper process +// for TestLinuxSendfile. +func TestLinuxSendfileChild(*testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + defer os.Exit(0) + fd3 := os.NewFile(3, "ephemeral-port-listener") + ln, err := net.FileListener(fd3) + if err != nil { + panic(err) + } + mux := NewServeMux() + mux.Handle("/", FileServer(Dir(os.TempDir()))) + mux.HandleFunc("/quit", func(ResponseWriter, *Request) { + os.Exit(0) + }) + s := &Server{Handler: mux} + err = s.Serve(ln) + if err != nil { + panic(err) + } +} + +// Issue 18984: tests that requests for paths beyond files return not-found errors +func TestFileServerNotDirError(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(FileServer(Dir("testdata"))) + defer ts.Close() + + res, err := Get(ts.URL + "/index.html/not-a-file") + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != 404 { + t.Errorf("StatusCode = %v; want 404", res.StatusCode) + } + + test := func(name string, dir Dir) { + t.Run(name, func(t *testing.T) { + _, err = dir.Open("/index.html/not-a-file") + if err == nil { + t.Fatal("err == nil; want != nil") + } + if !os.IsNotExist(err) { + t.Errorf("err = %v; os.IsNotExist(err) = %v; want true", err, os.IsNotExist(err)) + } + + _, err = dir.Open("/index.html/not-a-dir/not-a-file") + if err == nil { + t.Fatal("err == nil; want != nil") + } + if !os.IsNotExist(err) { + t.Errorf("err = %v; os.IsNotExist(err) = %v; want true", err, os.IsNotExist(err)) + } + }) + } + + absPath, err := filepath.Abs("testdata") + if err != nil { + t.Fatal("get abs path:", err) + } + + test("RelativePath", Dir("testdata")) + test("AbsolutePath", Dir(absPath)) +} + +func TestFileServerCleanPath(t *testing.T) { + tests := []struct { + path string + wantCode int + wantOpen []string + }{ + {"/", 200, []string{"/", "/index.html"}}, + {"/dir", 301, []string{"/dir"}}, + {"/dir/", 200, []string{"/dir", "/dir/index.html"}}, + } + for _, tt := range tests { + var log []string + rr := httptest.NewRecorder() + req, _ := NewRequest("GET", "http://foo.localhost"+tt.path, nil) + FileServer(fileServerCleanPathDir{&log}).ServeHTTP(rr, req) + if !reflect.DeepEqual(log, tt.wantOpen) { + t.Logf("For %s: Opens = %q; want %q", tt.path, log, tt.wantOpen) + } + if rr.Code != tt.wantCode { + t.Logf("For %s: Response code = %d; want %d", tt.path, rr.Code, tt.wantCode) + } + } +} + +type fileServerCleanPathDir struct { + log *[]string +} + +func (d fileServerCleanPathDir) Open(path string) (File, error) { + *(d.log) = append(*(d.log), path) + if path == "/" || path == "/dir" || path == "/dir/" { + // Just return back something that's a directory. + return Dir(".").Open(".") + } + return nil, fs.ErrNotExist +} + +type panicOnSeek struct{ io.ReadSeeker } + +func Test_scanETag(t *testing.T) { + tests := []struct { + in string + wantETag string + wantRemain string + }{ + {`W/"etag-1"`, `W/"etag-1"`, ""}, + {`"etag-2"`, `"etag-2"`, ""}, + {`"etag-1", "etag-2"`, `"etag-1"`, `, "etag-2"`}, + {"", "", ""}, + {"W/", "", ""}, + {`W/"truc`, "", ""}, + {`w/"case-sensitive"`, "", ""}, + {`"spaced etag"`, "", ""}, + } + for _, test := range tests { + etag, remain := ExportScanETag(test.in) + if etag != test.wantETag || remain != test.wantRemain { + t.Errorf("scanETag(%q)=%q %q, want %q %q", test.in, etag, remain, test.wantETag, test.wantRemain) + } + } +} + +// Issue 40940: Ensure that we only accept non-negative suffix-lengths +// in "Range": "bytes=-N", and should reject "bytes=--2". +func TestServeFileRejectsInvalidSuffixLengths_h1(t *testing.T) { + testServeFileRejectsInvalidSuffixLengths(t, h1Mode) +} +func TestServeFileRejectsInvalidSuffixLengths_h2(t *testing.T) { + testServeFileRejectsInvalidSuffixLengths(t, h2Mode) +} + +func testServeFileRejectsInvalidSuffixLengths(t *testing.T, h2 bool) { + defer afterTest(t) + cst := httptest.NewUnstartedServer(FileServer(Dir("testdata"))) + cst.EnableHTTP2 = h2 + cst.StartTLS() + defer cst.Close() + + tests := []struct { + r string + wantCode int + wantBody string + }{ + {"bytes=--6", 416, "invalid range\n"}, + {"bytes=--0", 416, "invalid range\n"}, + {"bytes=---0", 416, "invalid range\n"}, + {"bytes=-6", 206, "hello\n"}, + {"bytes=6-", 206, "html says hello\n"}, + {"bytes=-6-", 416, "invalid range\n"}, + {"bytes=-0", 206, ""}, + {"bytes=", 200, "index.html says hello\n"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.r, func(t *testing.T) { + req, err := NewRequest("GET", cst.URL+"/index.html", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Range", tt.r) + res, err := cst.Client().Do(req) + if err != nil { + t.Fatal(err) + } + if g, w := res.StatusCode, tt.wantCode; g != w { + t.Errorf("StatusCode mismatch: got %d want %d", g, w) + } + slurp, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if g, w := string(slurp), tt.wantBody; g != w { + t.Fatalf("Content mismatch:\nGot: %q\nWant: %q", g, w) + } + }) + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/h2_bundle.go b/vendor/github.com/lesismal/llib/std/net/http/h2_bundle.go new file mode 100644 index 0000000..6c067aa --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/h2_bundle.go @@ -0,0 +1,10371 @@ +// +build !nethttpomithttp2 + +// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. +// $ bundle -o=h2_bundle.go -prefix=http2 -tags=!nethttpomithttp2 golang.org/x/net/http2 + +// Package http2 implements the HTTP/2 protocol. +// +// This package is low-level and intended to be used directly by very +// few people. Most users will use it indirectly through the automatic +// use by the net/http package (from Go 1.6 and later). +// For use in earlier Go versions see ConfigureServer. (Transport support +// requires Go 1.6 or later) +// +// See https://http2.github.io/ for more information on HTTP/2. +// +// See https://http2.golang.org/ for a test server running this code. +// + +package http + +import ( + "bufio" + "bytes" + "compress/gzip" + "context" + "crypto/rand" + "crypto/tls" + "encoding/binary" + "errors" + "fmt" + "github.com/lesismal/llib/std/net/http/httptrace" + "io" + "io/ioutil" + "log" + "math" + mathrand "math/rand" + "net" + "net/textproto" + "net/url" + "os" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http2/hpack" + "golang.org/x/net/idna" +) + +// A list of the possible cipher suite ids. Taken from +// https://www.iana.org/assignments/tls-parameters/tls-parameters.txt + +const ( + http2cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000 + http2cipher_TLS_RSA_WITH_NULL_MD5 uint16 = 0x0001 + http2cipher_TLS_RSA_WITH_NULL_SHA uint16 = 0x0002 + http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0003 + http2cipher_TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004 + http2cipher_TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x0006 + http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA uint16 = 0x0007 + http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0008 + http2cipher_TLS_RSA_WITH_DES_CBC_SHA uint16 = 0x0009 + http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000A + http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000B + http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA uint16 = 0x000C + http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x000D + http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000E + http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA uint16 = 0x000F + http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0010 + http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0011 + http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA uint16 = 0x0012 + http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x0013 + http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0014 + http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA uint16 = 0x0015 + http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016 + http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0017 + http2cipher_TLS_DH_anon_WITH_RC4_128_MD5 uint16 = 0x0018 + http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0019 + http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA uint16 = 0x001A + http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0x001B + // Reserved uint16 = 0x001C-1D + http2cipher_TLS_KRB5_WITH_DES_CBC_SHA uint16 = 0x001E + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA uint16 = 0x001F + http2cipher_TLS_KRB5_WITH_RC4_128_SHA uint16 = 0x0020 + http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA uint16 = 0x0021 + http2cipher_TLS_KRB5_WITH_DES_CBC_MD5 uint16 = 0x0022 + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5 uint16 = 0x0023 + http2cipher_TLS_KRB5_WITH_RC4_128_MD5 uint16 = 0x0024 + http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5 uint16 = 0x0025 + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA uint16 = 0x0026 + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA uint16 = 0x0027 + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA uint16 = 0x0028 + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5 uint16 = 0x0029 + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x002A + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5 uint16 = 0x002B + http2cipher_TLS_PSK_WITH_NULL_SHA uint16 = 0x002C + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA uint16 = 0x002D + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA uint16 = 0x002E + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002F + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0030 + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0031 + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0032 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033 + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA uint16 = 0x0034 + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0036 + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0037 + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0038 + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039 + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA uint16 = 0x003A + http2cipher_TLS_RSA_WITH_NULL_SHA256 uint16 = 0x003B + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003C + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003D + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x003E + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003F + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x0040 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0041 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0042 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0043 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0044 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0045 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0046 + // Reserved uint16 = 0x0047-4F + // Reserved uint16 = 0x0050-58 + // Reserved uint16 = 0x0059-5C + // Unassigned uint16 = 0x005D-5F + // Reserved uint16 = 0x0060-66 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067 + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x0068 + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x0069 + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x006A + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006B + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256 uint16 = 0x006C + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256 uint16 = 0x006D + // Unassigned uint16 = 0x006E-83 + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0084 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0085 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0086 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0087 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0088 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0089 + http2cipher_TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008A + http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008B + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008C + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008D + http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA uint16 = 0x008E + http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008F + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0090 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0091 + http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA uint16 = 0x0092 + http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x0093 + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0094 + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0095 + http2cipher_TLS_RSA_WITH_SEED_CBC_SHA uint16 = 0x0096 + http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA uint16 = 0x0097 + http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA uint16 = 0x0098 + http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA uint16 = 0x0099 + http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA uint16 = 0x009A + http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA uint16 = 0x009B + http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009C + http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009D + http2cipher_TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009E + http2cipher_TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009F + http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x00A0 + http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x00A1 + http2cipher_TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A2 + http2cipher_TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A3 + http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A4 + http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A5 + http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256 uint16 = 0x00A6 + http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384 uint16 = 0x00A7 + http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00A8 + http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00A9 + http2cipher_TLS_DHE_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AA + http2cipher_TLS_DHE_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AB + http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AC + http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AD + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00AE + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00AF + http2cipher_TLS_PSK_WITH_NULL_SHA256 uint16 = 0x00B0 + http2cipher_TLS_PSK_WITH_NULL_SHA384 uint16 = 0x00B1 + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B2 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B3 + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256 uint16 = 0x00B4 + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384 uint16 = 0x00B5 + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B6 + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B7 + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256 uint16 = 0x00B8 + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384 uint16 = 0x00B9 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BA + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BB + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BC + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BD + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BE + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BF + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C0 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C1 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C2 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C3 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C4 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C5 + // Unassigned uint16 = 0x00C6-FE + http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV uint16 = 0x00FF + // Unassigned uint16 = 0x01-55,* + http2cipher_TLS_FALLBACK_SCSV uint16 = 0x5600 + // Unassigned uint16 = 0x5601 - 0xC000 + http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA uint16 = 0xC001 + http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA uint16 = 0xC002 + http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC003 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC004 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC005 + http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA uint16 = 0xC006 + http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xC007 + http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC008 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC009 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC00A + http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA uint16 = 0xC00B + http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA uint16 = 0xC00C + http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC00D + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC00E + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC00F + http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA uint16 = 0xC010 + http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xC011 + http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC012 + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC013 + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC014 + http2cipher_TLS_ECDH_anon_WITH_NULL_SHA uint16 = 0xC015 + http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA uint16 = 0xC016 + http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0xC017 + http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA uint16 = 0xC018 + http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA uint16 = 0xC019 + http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01A + http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01B + http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01C + http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA uint16 = 0xC01D + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC01E + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA uint16 = 0xC01F + http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA uint16 = 0xC020 + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC021 + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA uint16 = 0xC022 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC023 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC024 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC025 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC026 + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC027 + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC028 + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC029 + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC02A + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02B + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02C + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02D + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02E + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02F + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC030 + http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC031 + http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC032 + http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA uint16 = 0xC033 + http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0xC034 + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0xC035 + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0xC036 + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0xC037 + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0xC038 + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA uint16 = 0xC039 + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256 uint16 = 0xC03A + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384 uint16 = 0xC03B + http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03C + http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03D + http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03E + http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03F + http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC040 + http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC041 + http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC042 + http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC043 + http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC044 + http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC045 + http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC046 + http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC047 + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC048 + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC049 + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04A + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04B + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04C + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04D + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04E + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04F + http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC050 + http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC051 + http2cipher_TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC052 + http2cipher_TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC053 + http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC054 + http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC055 + http2cipher_TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC056 + http2cipher_TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC057 + http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC058 + http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC059 + http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05A + http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05B + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05C + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05D + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05E + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05F + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC060 + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC061 + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC062 + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC063 + http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC064 + http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC065 + http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC066 + http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC067 + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC068 + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC069 + http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06A + http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06B + http2cipher_TLS_DHE_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06C + http2cipher_TLS_DHE_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06D + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06E + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06F + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC070 + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC071 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC072 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC073 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC074 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC075 + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC076 + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC077 + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC078 + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC079 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07A + http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07B + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07C + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07D + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07E + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07F + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC080 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC081 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC082 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC083 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC084 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC085 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC086 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC087 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC088 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC089 + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08A + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08B + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08C + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08D + http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08E + http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08F + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC090 + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC091 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC092 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC093 + http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC094 + http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC095 + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC096 + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC097 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC098 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC099 + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC09A + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC09B + http2cipher_TLS_RSA_WITH_AES_128_CCM uint16 = 0xC09C + http2cipher_TLS_RSA_WITH_AES_256_CCM uint16 = 0xC09D + http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM uint16 = 0xC09E + http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM uint16 = 0xC09F + http2cipher_TLS_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A0 + http2cipher_TLS_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A1 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A2 + http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A3 + http2cipher_TLS_PSK_WITH_AES_128_CCM uint16 = 0xC0A4 + http2cipher_TLS_PSK_WITH_AES_256_CCM uint16 = 0xC0A5 + http2cipher_TLS_DHE_PSK_WITH_AES_128_CCM uint16 = 0xC0A6 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CCM uint16 = 0xC0A7 + http2cipher_TLS_PSK_WITH_AES_128_CCM_8 uint16 = 0xC0A8 + http2cipher_TLS_PSK_WITH_AES_256_CCM_8 uint16 = 0xC0A9 + http2cipher_TLS_PSK_DHE_WITH_AES_128_CCM_8 uint16 = 0xC0AA + http2cipher_TLS_PSK_DHE_WITH_AES_256_CCM_8 uint16 = 0xC0AB + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM uint16 = 0xC0AC + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM uint16 = 0xC0AD + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 uint16 = 0xC0AE + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8 uint16 = 0xC0AF + // Unassigned uint16 = 0xC0B0-FF + // Unassigned uint16 = 0xC1-CB,* + // Unassigned uint16 = 0xCC00-A7 + http2cipher_TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA8 + http2cipher_TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA9 + http2cipher_TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAA + http2cipher_TLS_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAB + http2cipher_TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAC + http2cipher_TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAD + http2cipher_TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAE +) + +// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. +// References: +// https://tools.ietf.org/html/rfc7540#appendix-A +// Reject cipher suites from Appendix A. +// "This list includes those cipher suites that do not +// offer an ephemeral key exchange and those that are +// based on the TLS null, stream or block cipher type" +func http2isBadCipher(cipher uint16) bool { + switch cipher { + case http2cipher_TLS_NULL_WITH_NULL_NULL, + http2cipher_TLS_RSA_WITH_NULL_MD5, + http2cipher_TLS_RSA_WITH_NULL_SHA, + http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5, + http2cipher_TLS_RSA_WITH_RC4_128_MD5, + http2cipher_TLS_RSA_WITH_RC4_128_SHA, + http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5, + http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA, + http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_RSA_WITH_DES_CBC_SHA, + http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5, + http2cipher_TLS_DH_anon_WITH_RC4_128_MD5, + http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_KRB5_WITH_DES_CBC_SHA, + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_KRB5_WITH_RC4_128_SHA, + http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA, + http2cipher_TLS_KRB5_WITH_DES_CBC_MD5, + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5, + http2cipher_TLS_KRB5_WITH_RC4_128_MD5, + http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5, + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA, + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA, + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA, + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5, + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5, + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5, + http2cipher_TLS_PSK_WITH_NULL_SHA, + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA, + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA, + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA, + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA, + http2cipher_TLS_RSA_WITH_NULL_SHA256, + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_RSA_WITH_SEED_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA, + http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_PSK_WITH_NULL_SHA256, + http2cipher_TLS_PSK_WITH_NULL_SHA384, + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256, + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384, + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256, + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384, + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV, + http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA, + http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDH_anon_WITH_NULL_SHA, + http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA, + http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA, + http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA, + http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384, + http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_RSA_WITH_AES_128_CCM, + http2cipher_TLS_RSA_WITH_AES_256_CCM, + http2cipher_TLS_RSA_WITH_AES_128_CCM_8, + http2cipher_TLS_RSA_WITH_AES_256_CCM_8, + http2cipher_TLS_PSK_WITH_AES_128_CCM, + http2cipher_TLS_PSK_WITH_AES_256_CCM, + http2cipher_TLS_PSK_WITH_AES_128_CCM_8, + http2cipher_TLS_PSK_WITH_AES_256_CCM_8: + return true + default: + return false + } +} + +// ClientConnPool manages a pool of HTTP/2 client connections. +type http2ClientConnPool interface { + GetClientConn(req *Request, addr string) (*http2ClientConn, error) + MarkDead(*http2ClientConn) +} + +// clientConnPoolIdleCloser is the interface implemented by ClientConnPool +// implementations which can close their idle connections. +type http2clientConnPoolIdleCloser interface { + http2ClientConnPool + closeIdleConnections() +} + +var ( + _ http2clientConnPoolIdleCloser = (*http2clientConnPool)(nil) + _ http2clientConnPoolIdleCloser = http2noDialClientConnPool{} +) + +// TODO: use singleflight for dialing and addConnCalls? +type http2clientConnPool struct { + t *http2Transport + + mu sync.Mutex // TODO: maybe switch to RWMutex + // TODO: add support for sharing conns based on cert names + // (e.g. share conn for googleapis.com and appspot.com) + conns map[string][]*http2ClientConn // key is host:port + dialing map[string]*http2dialCall // currently in-flight dials + keys map[*http2ClientConn][]string + addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeede calls +} + +func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { + return p.getClientConn(req, addr, http2dialOnMiss) +} + +const ( + http2dialOnMiss = true + http2noDialOnMiss = false +) + +// shouldTraceGetConn reports whether getClientConn should call any +// ClientTrace.GetConn hook associated with the http.Request. +// +// This complexity is needed to avoid double calls of the GetConn hook +// during the back-and-forth between net/http and x/net/http2 (when the +// net/http.Transport is upgraded to also speak http2), as well as support +// the case where x/net/http2 is being used directly. +func (p *http2clientConnPool) shouldTraceGetConn(st http2clientConnIdleState) bool { + // If our Transport wasn't made via ConfigureTransport, always + // trace the GetConn hook if provided, because that means the + // http2 package is being used directly and it's the one + // dialing, as opposed to net/http. + if _, ok := p.t.ConnPool.(http2noDialClientConnPool); !ok { + return true + } + // Otherwise, only use the GetConn hook if this connection has + // been used previously for other requests. For fresh + // connections, the net/http package does the dialing. + return !st.freshConn +} + +func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { + if http2isConnectionCloseRequest(req) && dialOnMiss { + // It gets its own connection. + http2traceGetConn(req, addr) + const singleUse = true + cc, err := p.t.dialClientConn(addr, singleUse) + if err != nil { + return nil, err + } + return cc, nil + } + p.mu.Lock() + for _, cc := range p.conns[addr] { + if st := cc.idleState(); st.canTakeNewRequest { + if p.shouldTraceGetConn(st) { + http2traceGetConn(req, addr) + } + p.mu.Unlock() + return cc, nil + } + } + if !dialOnMiss { + p.mu.Unlock() + return nil, http2ErrNoCachedConn + } + http2traceGetConn(req, addr) + call := p.getStartDialLocked(addr) + p.mu.Unlock() + <-call.done + return call.res, call.err +} + +// dialCall is an in-flight Transport dial call to a host. +type http2dialCall struct { + _ http2incomparable + p *http2clientConnPool + done chan struct{} // closed when done + res *http2ClientConn // valid after done is closed + err error // valid after done is closed +} + +// requires p.mu is held. +func (p *http2clientConnPool) getStartDialLocked(addr string) *http2dialCall { + if call, ok := p.dialing[addr]; ok { + // A dial is already in-flight. Don't start another. + return call + } + call := &http2dialCall{p: p, done: make(chan struct{})} + if p.dialing == nil { + p.dialing = make(map[string]*http2dialCall) + } + p.dialing[addr] = call + go call.dial(addr) + return call +} + +// run in its own goroutine. +func (c *http2dialCall) dial(addr string) { + const singleUse = false // shared conn + c.res, c.err = c.p.t.dialClientConn(addr, singleUse) + close(c.done) + + c.p.mu.Lock() + delete(c.p.dialing, addr) + if c.err == nil { + c.p.addConnLocked(addr, c.res) + } + c.p.mu.Unlock() +} + +// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't +// already exist. It coalesces concurrent calls with the same key. +// This is used by the http1 Transport code when it creates a new connection. Because +// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know +// the protocol), it can get into a situation where it has multiple TLS connections. +// This code decides which ones live or die. +// The return value used is whether c was used. +// c is never closed. +func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c *tls.Conn) (used bool, err error) { + p.mu.Lock() + for _, cc := range p.conns[key] { + if cc.CanTakeNewRequest() { + p.mu.Unlock() + return false, nil + } + } + call, dup := p.addConnCalls[key] + if !dup { + if p.addConnCalls == nil { + p.addConnCalls = make(map[string]*http2addConnCall) + } + call = &http2addConnCall{ + p: p, + done: make(chan struct{}), + } + p.addConnCalls[key] = call + go call.run(t, key, c) + } + p.mu.Unlock() + + <-call.done + if call.err != nil { + return false, call.err + } + return !dup, nil +} + +type http2addConnCall struct { + _ http2incomparable + p *http2clientConnPool + done chan struct{} // closed when done + err error +} + +func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) { + cc, err := t.NewClientConn(tc) + + p := c.p + p.mu.Lock() + if err != nil { + c.err = err + } else { + p.addConnLocked(key, cc) + } + delete(p.addConnCalls, key) + p.mu.Unlock() + close(c.done) +} + +// p.mu must be held +func (p *http2clientConnPool) addConnLocked(key string, cc *http2ClientConn) { + for _, v := range p.conns[key] { + if v == cc { + return + } + } + if p.conns == nil { + p.conns = make(map[string][]*http2ClientConn) + } + if p.keys == nil { + p.keys = make(map[*http2ClientConn][]string) + } + p.conns[key] = append(p.conns[key], cc) + p.keys[cc] = append(p.keys[cc], key) +} + +func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) { + p.mu.Lock() + defer p.mu.Unlock() + for _, key := range p.keys[cc] { + vv, ok := p.conns[key] + if !ok { + continue + } + newList := http2filterOutClientConn(vv, cc) + if len(newList) > 0 { + p.conns[key] = newList + } else { + delete(p.conns, key) + } + } + delete(p.keys, cc) +} + +func (p *http2clientConnPool) closeIdleConnections() { + p.mu.Lock() + defer p.mu.Unlock() + // TODO: don't close a cc if it was just added to the pool + // milliseconds ago and has never been used. There's currently + // a small race window with the HTTP/1 Transport's integration + // where it can add an idle conn just before using it, and + // somebody else can concurrently call CloseIdleConns and + // break some caller's RoundTrip. + for _, vv := range p.conns { + for _, cc := range vv { + cc.closeIfIdle() + } + } +} + +func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn { + out := in[:0] + for _, v := range in { + if v != exclude { + out = append(out, v) + } + } + // If we filtered it out, zero out the last item to prevent + // the GC from seeing it. + if len(in) != len(out) { + in[len(in)-1] = nil + } + return out +} + +// noDialClientConnPool is an implementation of http2.ClientConnPool +// which never dials. We let the HTTP/1.1 client dial and use its TLS +// connection instead. +type http2noDialClientConnPool struct{ *http2clientConnPool } + +func (p http2noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { + return p.getClientConn(req, addr, http2noDialOnMiss) +} + +// Buffer chunks are allocated from a pool to reduce pressure on GC. +// The maximum wasted space per dataBuffer is 2x the largest size class, +// which happens when the dataBuffer has multiple chunks and there is +// one unread byte in both the first and last chunks. We use a few size +// classes to minimize overheads for servers that typically receive very +// small request bodies. +// +// TODO: Benchmark to determine if the pools are necessary. The GC may have +// improved enough that we can instead allocate chunks like this: +// make([]byte, max(16<<10, expectedBytesRemaining)) +var ( + http2dataChunkSizeClasses = []int{ + 1 << 10, + 2 << 10, + 4 << 10, + 8 << 10, + 16 << 10, + } + http2dataChunkPools = [...]sync.Pool{ + {New: func() interface{} { return make([]byte, 1<<10) }}, + {New: func() interface{} { return make([]byte, 2<<10) }}, + {New: func() interface{} { return make([]byte, 4<<10) }}, + {New: func() interface{} { return make([]byte, 8<<10) }}, + {New: func() interface{} { return make([]byte, 16<<10) }}, + } +) + +func http2getDataBufferChunk(size int64) []byte { + i := 0 + for ; i < len(http2dataChunkSizeClasses)-1; i++ { + if size <= int64(http2dataChunkSizeClasses[i]) { + break + } + } + return http2dataChunkPools[i].Get().([]byte) +} + +func http2putDataBufferChunk(p []byte) { + for i, n := range http2dataChunkSizeClasses { + if len(p) == n { + http2dataChunkPools[i].Put(p) + return + } + } + panic(fmt.Sprintf("unexpected buffer len=%v", len(p))) +} + +// dataBuffer is an io.ReadWriter backed by a list of data chunks. +// Each dataBuffer is used to read DATA frames on a single stream. +// The buffer is divided into chunks so the server can limit the +// total memory used by a single connection without limiting the +// request body size on any single stream. +type http2dataBuffer struct { + chunks [][]byte + r int // next byte to read is chunks[0][r] + w int // next byte to write is chunks[len(chunks)-1][w] + size int // total buffered bytes + expected int64 // we expect at least this many bytes in future Write calls (ignored if <= 0) +} + +var http2errReadEmpty = errors.New("read from empty dataBuffer") + +// Read copies bytes from the buffer into p. +// It is an error to read when no data is available. +func (b *http2dataBuffer) Read(p []byte) (int, error) { + if b.size == 0 { + return 0, http2errReadEmpty + } + var ntotal int + for len(p) > 0 && b.size > 0 { + readFrom := b.bytesFromFirstChunk() + n := copy(p, readFrom) + p = p[n:] + ntotal += n + b.r += n + b.size -= n + // If the first chunk has been consumed, advance to the next chunk. + if b.r == len(b.chunks[0]) { + http2putDataBufferChunk(b.chunks[0]) + end := len(b.chunks) - 1 + copy(b.chunks[:end], b.chunks[1:]) + b.chunks[end] = nil + b.chunks = b.chunks[:end] + b.r = 0 + } + } + return ntotal, nil +} + +func (b *http2dataBuffer) bytesFromFirstChunk() []byte { + if len(b.chunks) == 1 { + return b.chunks[0][b.r:b.w] + } + return b.chunks[0][b.r:] +} + +// Len returns the number of bytes of the unread portion of the buffer. +func (b *http2dataBuffer) Len() int { + return b.size +} + +// Write appends p to the buffer. +func (b *http2dataBuffer) Write(p []byte) (int, error) { + ntotal := len(p) + for len(p) > 0 { + // If the last chunk is empty, allocate a new chunk. Try to allocate + // enough to fully copy p plus any additional bytes we expect to + // receive. However, this may allocate less than len(p). + want := int64(len(p)) + if b.expected > want { + want = b.expected + } + chunk := b.lastChunkOrAlloc(want) + n := copy(chunk[b.w:], p) + p = p[n:] + b.w += n + b.size += n + b.expected -= int64(n) + } + return ntotal, nil +} + +func (b *http2dataBuffer) lastChunkOrAlloc(want int64) []byte { + if len(b.chunks) != 0 { + last := b.chunks[len(b.chunks)-1] + if b.w < len(last) { + return last + } + } + chunk := http2getDataBufferChunk(want) + b.chunks = append(b.chunks, chunk) + b.w = 0 + return chunk +} + +// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. +type http2ErrCode uint32 + +const ( + http2ErrCodeNo http2ErrCode = 0x0 + http2ErrCodeProtocol http2ErrCode = 0x1 + http2ErrCodeInternal http2ErrCode = 0x2 + http2ErrCodeFlowControl http2ErrCode = 0x3 + http2ErrCodeSettingsTimeout http2ErrCode = 0x4 + http2ErrCodeStreamClosed http2ErrCode = 0x5 + http2ErrCodeFrameSize http2ErrCode = 0x6 + http2ErrCodeRefusedStream http2ErrCode = 0x7 + http2ErrCodeCancel http2ErrCode = 0x8 + http2ErrCodeCompression http2ErrCode = 0x9 + http2ErrCodeConnect http2ErrCode = 0xa + http2ErrCodeEnhanceYourCalm http2ErrCode = 0xb + http2ErrCodeInadequateSecurity http2ErrCode = 0xc + http2ErrCodeHTTP11Required http2ErrCode = 0xd +) + +var http2errCodeName = map[http2ErrCode]string{ + http2ErrCodeNo: "NO_ERROR", + http2ErrCodeProtocol: "PROTOCOL_ERROR", + http2ErrCodeInternal: "INTERNAL_ERROR", + http2ErrCodeFlowControl: "FLOW_CONTROL_ERROR", + http2ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT", + http2ErrCodeStreamClosed: "STREAM_CLOSED", + http2ErrCodeFrameSize: "FRAME_SIZE_ERROR", + http2ErrCodeRefusedStream: "REFUSED_STREAM", + http2ErrCodeCancel: "CANCEL", + http2ErrCodeCompression: "COMPRESSION_ERROR", + http2ErrCodeConnect: "CONNECT_ERROR", + http2ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM", + http2ErrCodeInadequateSecurity: "INADEQUATE_SECURITY", + http2ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED", +} + +func (e http2ErrCode) String() string { + if s, ok := http2errCodeName[e]; ok { + return s + } + return fmt.Sprintf("unknown error code 0x%x", uint32(e)) +} + +// ConnectionError is an error that results in the termination of the +// entire connection. +type http2ConnectionError http2ErrCode + +func (e http2ConnectionError) Error() string { + return fmt.Sprintf("connection error: %s", http2ErrCode(e)) +} + +// StreamError is an error that only affects one stream within an +// HTTP/2 connection. +type http2StreamError struct { + StreamID uint32 + Code http2ErrCode + Cause error // optional additional detail +} + +func http2streamError(id uint32, code http2ErrCode) http2StreamError { + return http2StreamError{StreamID: id, Code: code} +} + +func (e http2StreamError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause) + } + return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code) +} + +// 6.9.1 The Flow Control Window +// "If a sender receives a WINDOW_UPDATE that causes a flow control +// window to exceed this maximum it MUST terminate either the stream +// or the connection, as appropriate. For streams, [...]; for the +// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code." +type http2goAwayFlowError struct{} + +func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" } + +// connError represents an HTTP/2 ConnectionError error code, along +// with a string (for debugging) explaining why. +// +// Errors of this type are only returned by the frame parser functions +// and converted into ConnectionError(Code), after stashing away +// the Reason into the Framer's errDetail field, accessible via +// the (*Framer).ErrorDetail method. +type http2connError struct { + Code http2ErrCode // the ConnectionError error code + Reason string // additional reason +} + +func (e http2connError) Error() string { + return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason) +} + +type http2pseudoHeaderError string + +func (e http2pseudoHeaderError) Error() string { + return fmt.Sprintf("invalid pseudo-header %q", string(e)) +} + +type http2duplicatePseudoHeaderError string + +func (e http2duplicatePseudoHeaderError) Error() string { + return fmt.Sprintf("duplicate pseudo-header %q", string(e)) +} + +type http2headerFieldNameError string + +func (e http2headerFieldNameError) Error() string { + return fmt.Sprintf("invalid header field name %q", string(e)) +} + +type http2headerFieldValueError string + +func (e http2headerFieldValueError) Error() string { + return fmt.Sprintf("invalid header field value %q", string(e)) +} + +var ( + http2errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers") + http2errPseudoAfterRegular = errors.New("pseudo header field after regular") +) + +// flow is the flow control window's size. +type http2flow struct { + _ http2incomparable + + // n is the number of DATA bytes we're allowed to send. + // A flow is kept both on a conn and a per-stream. + n int32 + + // conn points to the shared connection-level flow that is + // shared by all streams on that conn. It is nil for the flow + // that's on the conn directly. + conn *http2flow +} + +func (f *http2flow) setConnFlow(cf *http2flow) { f.conn = cf } + +func (f *http2flow) available() int32 { + n := f.n + if f.conn != nil && f.conn.n < n { + n = f.conn.n + } + return n +} + +func (f *http2flow) take(n int32) { + if n > f.available() { + panic("internal error: took too much") + } + f.n -= n + if f.conn != nil { + f.conn.n -= n + } +} + +// add adds n bytes (positive or negative) to the flow control window. +// It returns false if the sum would exceed 2^31-1. +func (f *http2flow) add(n int32) bool { + sum := f.n + n + if (sum > n) == (f.n > 0) { + f.n = sum + return true + } + return false +} + +const http2frameHeaderLen = 9 + +var http2padZeros = make([]byte, 255) // zeros for padding + +// A FrameType is a registered frame type as defined in +// http://http2.github.io/http2-spec/#rfc.section.11.2 +type http2FrameType uint8 + +const ( + http2FrameData http2FrameType = 0x0 + http2FrameHeaders http2FrameType = 0x1 + http2FramePriority http2FrameType = 0x2 + http2FrameRSTStream http2FrameType = 0x3 + http2FrameSettings http2FrameType = 0x4 + http2FramePushPromise http2FrameType = 0x5 + http2FramePing http2FrameType = 0x6 + http2FrameGoAway http2FrameType = 0x7 + http2FrameWindowUpdate http2FrameType = 0x8 + http2FrameContinuation http2FrameType = 0x9 +) + +var http2frameName = map[http2FrameType]string{ + http2FrameData: "DATA", + http2FrameHeaders: "HEADERS", + http2FramePriority: "PRIORITY", + http2FrameRSTStream: "RST_STREAM", + http2FrameSettings: "SETTINGS", + http2FramePushPromise: "PUSH_PROMISE", + http2FramePing: "PING", + http2FrameGoAway: "GOAWAY", + http2FrameWindowUpdate: "WINDOW_UPDATE", + http2FrameContinuation: "CONTINUATION", +} + +func (t http2FrameType) String() string { + if s, ok := http2frameName[t]; ok { + return s + } + return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t)) +} + +// Flags is a bitmask of HTTP/2 flags. +// The meaning of flags varies depending on the frame type. +type http2Flags uint8 + +// Has reports whether f contains all (0 or more) flags in v. +func (f http2Flags) Has(v http2Flags) bool { + return (f & v) == v +} + +// Frame-specific FrameHeader flag bits. +const ( + // Data Frame + http2FlagDataEndStream http2Flags = 0x1 + http2FlagDataPadded http2Flags = 0x8 + + // Headers Frame + http2FlagHeadersEndStream http2Flags = 0x1 + http2FlagHeadersEndHeaders http2Flags = 0x4 + http2FlagHeadersPadded http2Flags = 0x8 + http2FlagHeadersPriority http2Flags = 0x20 + + // Settings Frame + http2FlagSettingsAck http2Flags = 0x1 + + // Ping Frame + http2FlagPingAck http2Flags = 0x1 + + // Continuation Frame + http2FlagContinuationEndHeaders http2Flags = 0x4 + + http2FlagPushPromiseEndHeaders http2Flags = 0x4 + http2FlagPushPromisePadded http2Flags = 0x8 +) + +var http2flagName = map[http2FrameType]map[http2Flags]string{ + http2FrameData: { + http2FlagDataEndStream: "END_STREAM", + http2FlagDataPadded: "PADDED", + }, + http2FrameHeaders: { + http2FlagHeadersEndStream: "END_STREAM", + http2FlagHeadersEndHeaders: "END_HEADERS", + http2FlagHeadersPadded: "PADDED", + http2FlagHeadersPriority: "PRIORITY", + }, + http2FrameSettings: { + http2FlagSettingsAck: "ACK", + }, + http2FramePing: { + http2FlagPingAck: "ACK", + }, + http2FrameContinuation: { + http2FlagContinuationEndHeaders: "END_HEADERS", + }, + http2FramePushPromise: { + http2FlagPushPromiseEndHeaders: "END_HEADERS", + http2FlagPushPromisePadded: "PADDED", + }, +} + +// a frameParser parses a frame given its FrameHeader and payload +// bytes. The length of payload will always equal fh.Length (which +// might be 0). +type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) + +var http2frameParsers = map[http2FrameType]http2frameParser{ + http2FrameData: http2parseDataFrame, + http2FrameHeaders: http2parseHeadersFrame, + http2FramePriority: http2parsePriorityFrame, + http2FrameRSTStream: http2parseRSTStreamFrame, + http2FrameSettings: http2parseSettingsFrame, + http2FramePushPromise: http2parsePushPromise, + http2FramePing: http2parsePingFrame, + http2FrameGoAway: http2parseGoAwayFrame, + http2FrameWindowUpdate: http2parseWindowUpdateFrame, + http2FrameContinuation: http2parseContinuationFrame, +} + +func http2typeFrameParser(t http2FrameType) http2frameParser { + if f := http2frameParsers[t]; f != nil { + return f + } + return http2parseUnknownFrame +} + +// A FrameHeader is the 9 byte header of all HTTP/2 frames. +// +// See http://http2.github.io/http2-spec/#FrameHeader +type http2FrameHeader struct { + valid bool // caller can access []byte fields in the Frame + + // Type is the 1 byte frame type. There are ten standard frame + // types, but extension frame types may be written by WriteRawFrame + // and will be returned by ReadFrame (as UnknownFrame). + Type http2FrameType + + // Flags are the 1 byte of 8 potential bit flags per frame. + // They are specific to the frame type. + Flags http2Flags + + // Length is the length of the frame, not including the 9 byte header. + // The maximum size is one byte less than 16MB (uint24), but only + // frames up to 16KB are allowed without peer agreement. + Length uint32 + + // StreamID is which stream this frame is for. Certain frames + // are not stream-specific, in which case this field is 0. + StreamID uint32 +} + +// Header returns h. It exists so FrameHeaders can be embedded in other +// specific frame types and implement the Frame interface. +func (h http2FrameHeader) Header() http2FrameHeader { return h } + +func (h http2FrameHeader) String() string { + var buf bytes.Buffer + buf.WriteString("[FrameHeader ") + h.writeDebug(&buf) + buf.WriteByte(']') + return buf.String() +} + +func (h http2FrameHeader) writeDebug(buf *bytes.Buffer) { + buf.WriteString(h.Type.String()) + if h.Flags != 0 { + buf.WriteString(" flags=") + set := 0 + for i := uint8(0); i < 8; i++ { + if h.Flags&(1< 1 { + buf.WriteByte('|') + } + name := http2flagName[h.Type][http2Flags(1<>24), + byte(streamID>>16), + byte(streamID>>8), + byte(streamID)) +} + +func (f *http2Framer) endWrite() error { + // Now that we know the final size, fill in the FrameHeader in + // the space previously reserved for it. Abuse append. + length := len(f.wbuf) - http2frameHeaderLen + if length >= (1 << 24) { + return http2ErrFrameTooLarge + } + _ = append(f.wbuf[:0], + byte(length>>16), + byte(length>>8), + byte(length)) + if f.logWrites { + f.logWrite() + } + + n, err := f.w.Write(f.wbuf) + if err == nil && n != len(f.wbuf) { + err = io.ErrShortWrite + } + return err +} + +func (f *http2Framer) logWrite() { + if f.debugFramer == nil { + f.debugFramerBuf = new(bytes.Buffer) + f.debugFramer = http2NewFramer(nil, f.debugFramerBuf) + f.debugFramer.logReads = false // we log it ourselves, saying "wrote" below + // Let us read anything, even if we accidentally wrote it + // in the wrong order: + f.debugFramer.AllowIllegalReads = true + } + f.debugFramerBuf.Write(f.wbuf) + fr, err := f.debugFramer.ReadFrame() + if err != nil { + f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", f) + return + } + f.debugWriteLoggerf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr)) +} + +func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) } + +func (f *http2Framer) writeBytes(v []byte) { f.wbuf = append(f.wbuf, v...) } + +func (f *http2Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) } + +func (f *http2Framer) writeUint32(v uint32) { + f.wbuf = append(f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +const ( + http2minMaxFrameSize = 1 << 14 + http2maxFrameSize = 1<<24 - 1 +) + +// SetReuseFrames allows the Framer to reuse Frames. +// If called on a Framer, Frames returned by calls to ReadFrame are only +// valid until the next call to ReadFrame. +func (fr *http2Framer) SetReuseFrames() { + if fr.frameCache != nil { + return + } + fr.frameCache = &http2frameCache{} +} + +type http2frameCache struct { + dataFrame http2DataFrame +} + +func (fc *http2frameCache) getDataFrame() *http2DataFrame { + if fc == nil { + return &http2DataFrame{} + } + return &fc.dataFrame +} + +// NewFramer returns a Framer that writes frames to w and reads them from r. +func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { + fr := &http2Framer{ + w: w, + r: r, + logReads: http2logFrameReads, + logWrites: http2logFrameWrites, + debugReadLoggerf: log.Printf, + debugWriteLoggerf: log.Printf, + } + fr.getReadBuf = func(size uint32) []byte { + if cap(fr.readBuf) >= int(size) { + return fr.readBuf[:size] + } + fr.readBuf = make([]byte, size) + return fr.readBuf + } + fr.SetMaxReadFrameSize(http2maxFrameSize) + return fr +} + +// SetMaxReadFrameSize sets the maximum size of a frame +// that will be read by a subsequent call to ReadFrame. +// It is the caller's responsibility to advertise this +// limit with a SETTINGS frame. +func (fr *http2Framer) SetMaxReadFrameSize(v uint32) { + if v > http2maxFrameSize { + v = http2maxFrameSize + } + fr.maxReadSize = v +} + +// ErrorDetail returns a more detailed error of the last error +// returned by Framer.ReadFrame. For instance, if ReadFrame +// returns a StreamError with code PROTOCOL_ERROR, ErrorDetail +// will say exactly what was invalid. ErrorDetail is not guaranteed +// to return a non-nil value and like the rest of the http2 package, +// its return value is not protected by an API compatibility promise. +// ErrorDetail is reset after the next call to ReadFrame. +func (fr *http2Framer) ErrorDetail() error { + return fr.errDetail +} + +// ErrFrameTooLarge is returned from Framer.ReadFrame when the peer +// sends a frame that is larger than declared with SetMaxReadFrameSize. +var http2ErrFrameTooLarge = errors.New("http2: frame too large") + +// terminalReadFrameError reports whether err is an unrecoverable +// error from ReadFrame and no other frames should be read. +func http2terminalReadFrameError(err error) bool { + if _, ok := err.(http2StreamError); ok { + return false + } + return err != nil +} + +// ReadFrame reads a single frame. The returned Frame is only valid +// until the next call to ReadFrame. +// +// If the frame is larger than previously set with SetMaxReadFrameSize, the +// returned error is ErrFrameTooLarge. Other errors may be of type +// ConnectionError, StreamError, or anything else from the underlying +// reader. +func (fr *http2Framer) ReadFrame() (http2Frame, error) { + fr.errDetail = nil + if fr.lastFrame != nil { + fr.lastFrame.invalidate() + } + fh, err := http2readFrameHeader(fr.headerBuf[:], fr.r) + if err != nil { + return nil, err + } + if fh.Length > fr.maxReadSize { + return nil, http2ErrFrameTooLarge + } + payload := fr.getReadBuf(fh.Length) + if _, err := io.ReadFull(fr.r, payload); err != nil { + return nil, err + } + f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, payload) + if err != nil { + if ce, ok := err.(http2connError); ok { + return nil, fr.connError(ce.Code, ce.Reason) + } + return nil, err + } + if err := fr.checkFrameOrder(f); err != nil { + return nil, err + } + if fr.logReads { + fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) + } + if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { + return fr.readMetaFrame(f.(*http2HeadersFrame)) + } + return f, nil +} + +// connError returns ConnectionError(code) but first +// stashes away a public reason to the caller can optionally relay it +// to the peer before hanging up on them. This might help others debug +// their implementations. +func (fr *http2Framer) connError(code http2ErrCode, reason string) error { + fr.errDetail = errors.New(reason) + return http2ConnectionError(code) +} + +// checkFrameOrder reports an error if f is an invalid frame to return +// next from ReadFrame. Mostly it checks whether HEADERS and +// CONTINUATION frames are contiguous. +func (fr *http2Framer) checkFrameOrder(f http2Frame) error { + last := fr.lastFrame + fr.lastFrame = f + if fr.AllowIllegalReads { + return nil + } + + fh := f.Header() + if fr.lastHeaderStream != 0 { + if fh.Type != http2FrameContinuation { + return fr.connError(http2ErrCodeProtocol, + fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d", + fh.Type, fh.StreamID, + last.Header().Type, fr.lastHeaderStream)) + } + if fh.StreamID != fr.lastHeaderStream { + return fr.connError(http2ErrCodeProtocol, + fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d", + fh.StreamID, fr.lastHeaderStream)) + } + } else if fh.Type == http2FrameContinuation { + return fr.connError(http2ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID)) + } + + switch fh.Type { + case http2FrameHeaders, http2FrameContinuation: + if fh.Flags.Has(http2FlagHeadersEndHeaders) { + fr.lastHeaderStream = 0 + } else { + fr.lastHeaderStream = fh.StreamID + } + } + + return nil +} + +// A DataFrame conveys arbitrary, variable-length sequences of octets +// associated with a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.1 +type http2DataFrame struct { + http2FrameHeader + data []byte +} + +func (f *http2DataFrame) StreamEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagDataEndStream) +} + +// Data returns the frame's data octets, not including any padding +// size byte or padding suffix bytes. +// The caller must not retain the returned memory past the next +// call to ReadFrame. +func (f *http2DataFrame) Data() []byte { + f.checkValid() + return f.data +} + +func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) { + if fh.StreamID == 0 { + // DATA frames MUST be associated with a stream. If a + // DATA frame is received whose stream identifier + // field is 0x0, the recipient MUST respond with a + // connection error (Section 5.4.1) of type + // PROTOCOL_ERROR. + return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"} + } + f := fc.getDataFrame() + f.http2FrameHeader = fh + + var padSize byte + if fh.Flags.Has(http2FlagDataPadded) { + var err error + payload, padSize, err = http2readByte(payload) + if err != nil { + return nil, err + } + } + if int(padSize) > len(payload) { + // If the length of the padding is greater than the + // length of the frame payload, the recipient MUST + // treat this as a connection error. + // Filed: https://github.com/http2/http2-spec/issues/610 + return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"} + } + f.data = payload[:len(payload)-int(padSize)] + return f, nil +} + +var ( + http2errStreamID = errors.New("invalid stream ID") + http2errDepStreamID = errors.New("invalid dependent stream ID") + http2errPadLength = errors.New("pad length too large") + http2errPadBytes = errors.New("padding bytes must all be zeros unless AllowIllegalWrites is enabled") +) + +func http2validStreamIDOrZero(streamID uint32) bool { + return streamID&(1<<31) == 0 +} + +func http2validStreamID(streamID uint32) bool { + return streamID != 0 && streamID&(1<<31) == 0 +} + +// WriteData writes a DATA frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility not to violate the maximum frame size +// and to not call other Write methods concurrently. +func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error { + return f.WriteDataPadded(streamID, endStream, data, nil) +} + +// WriteDataPadded writes a DATA frame with optional padding. +// +// If pad is nil, the padding bit is not sent. +// The length of pad must not exceed 255 bytes. +// The bytes of pad must all be zero, unless f.AllowIllegalWrites is set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility not to violate the maximum frame size +// and to not call other Write methods concurrently. +func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + if len(pad) > 0 { + if len(pad) > 255 { + return http2errPadLength + } + if !f.AllowIllegalWrites { + for _, b := range pad { + if b != 0 { + // "Padding octets MUST be set to zero when sending." + return http2errPadBytes + } + } + } + } + var flags http2Flags + if endStream { + flags |= http2FlagDataEndStream + } + if pad != nil { + flags |= http2FlagDataPadded + } + f.startWrite(http2FrameData, flags, streamID) + if pad != nil { + f.wbuf = append(f.wbuf, byte(len(pad))) + } + f.wbuf = append(f.wbuf, data...) + f.wbuf = append(f.wbuf, pad...) + return f.endWrite() +} + +// A SettingsFrame conveys configuration parameters that affect how +// endpoints communicate, such as preferences and constraints on peer +// behavior. +// +// See http://http2.github.io/http2-spec/#SETTINGS +type http2SettingsFrame struct { + http2FrameHeader + p []byte +} + +func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { + if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 { + // When this (ACK 0x1) bit is set, the payload of the + // SETTINGS frame MUST be empty. Receipt of a + // SETTINGS frame with the ACK flag set and a length + // field value other than 0 MUST be treated as a + // connection error (Section 5.4.1) of type + // FRAME_SIZE_ERROR. + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + if fh.StreamID != 0 { + // SETTINGS frames always apply to a connection, + // never a single stream. The stream identifier for a + // SETTINGS frame MUST be zero (0x0). If an endpoint + // receives a SETTINGS frame whose stream identifier + // field is anything other than 0x0, the endpoint MUST + // respond with a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR. + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + if len(p)%6 != 0 { + // Expecting even number of 6 byte settings. + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + f := &http2SettingsFrame{http2FrameHeader: fh, p: p} + if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 { + // Values above the maximum flow control window size of 2^31 - 1 MUST + // be treated as a connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR. + return nil, http2ConnectionError(http2ErrCodeFlowControl) + } + return f, nil +} + +func (f *http2SettingsFrame) IsAck() bool { + return f.http2FrameHeader.Flags.Has(http2FlagSettingsAck) +} + +func (f *http2SettingsFrame) Value(id http2SettingID) (v uint32, ok bool) { + f.checkValid() + for i := 0; i < f.NumSettings(); i++ { + if s := f.Setting(i); s.ID == id { + return s.Val, true + } + } + return 0, false +} + +// Setting returns the setting from the frame at the given 0-based index. +// The index must be >= 0 and less than f.NumSettings(). +func (f *http2SettingsFrame) Setting(i int) http2Setting { + buf := f.p + return http2Setting{ + ID: http2SettingID(binary.BigEndian.Uint16(buf[i*6 : i*6+2])), + Val: binary.BigEndian.Uint32(buf[i*6+2 : i*6+6]), + } +} + +func (f *http2SettingsFrame) NumSettings() int { return len(f.p) / 6 } + +// HasDuplicates reports whether f contains any duplicate setting IDs. +func (f *http2SettingsFrame) HasDuplicates() bool { + num := f.NumSettings() + if num == 0 { + return false + } + // If it's small enough (the common case), just do the n^2 + // thing and avoid a map allocation. + if num < 10 { + for i := 0; i < num; i++ { + idi := f.Setting(i).ID + for j := i + 1; j < num; j++ { + idj := f.Setting(j).ID + if idi == idj { + return true + } + } + } + return false + } + seen := map[http2SettingID]bool{} + for i := 0; i < num; i++ { + id := f.Setting(i).ID + if seen[id] { + return true + } + seen[id] = true + } + return false +} + +// ForeachSetting runs fn for each setting. +// It stops and returns the first error. +func (f *http2SettingsFrame) ForeachSetting(fn func(http2Setting) error) error { + f.checkValid() + for i := 0; i < f.NumSettings(); i++ { + if err := fn(f.Setting(i)); err != nil { + return err + } + } + return nil +} + +// WriteSettings writes a SETTINGS frame with zero or more settings +// specified and the ACK bit not set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteSettings(settings ...http2Setting) error { + f.startWrite(http2FrameSettings, 0, 0) + for _, s := range settings { + f.writeUint16(uint16(s.ID)) + f.writeUint32(s.Val) + } + return f.endWrite() +} + +// WriteSettingsAck writes an empty SETTINGS frame with the ACK bit set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteSettingsAck() error { + f.startWrite(http2FrameSettings, http2FlagSettingsAck, 0) + return f.endWrite() +} + +// A PingFrame is a mechanism for measuring a minimal round trip time +// from the sender, as well as determining whether an idle connection +// is still functional. +// See http://http2.github.io/http2-spec/#rfc.section.6.7 +type http2PingFrame struct { + http2FrameHeader + Data [8]byte +} + +func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) } + +func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) { + if len(payload) != 8 { + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + if fh.StreamID != 0 { + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + f := &http2PingFrame{http2FrameHeader: fh} + copy(f.Data[:], payload) + return f, nil +} + +func (f *http2Framer) WritePing(ack bool, data [8]byte) error { + var flags http2Flags + if ack { + flags = http2FlagPingAck + } + f.startWrite(http2FramePing, flags, 0) + f.writeBytes(data[:]) + return f.endWrite() +} + +// A GoAwayFrame informs the remote peer to stop creating streams on this connection. +// See http://http2.github.io/http2-spec/#rfc.section.6.8 +type http2GoAwayFrame struct { + http2FrameHeader + LastStreamID uint32 + ErrCode http2ErrCode + debugData []byte +} + +// DebugData returns any debug data in the GOAWAY frame. Its contents +// are not defined. +// The caller must not retain the returned memory past the next +// call to ReadFrame. +func (f *http2GoAwayFrame) DebugData() []byte { + f.checkValid() + return f.debugData +} + +func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { + if fh.StreamID != 0 { + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + if len(p) < 8 { + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + return &http2GoAwayFrame{ + http2FrameHeader: fh, + LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1), + ErrCode: http2ErrCode(binary.BigEndian.Uint32(p[4:8])), + debugData: p[8:], + }, nil +} + +func (f *http2Framer) WriteGoAway(maxStreamID uint32, code http2ErrCode, debugData []byte) error { + f.startWrite(http2FrameGoAway, 0, 0) + f.writeUint32(maxStreamID & (1<<31 - 1)) + f.writeUint32(uint32(code)) + f.writeBytes(debugData) + return f.endWrite() +} + +// An UnknownFrame is the frame type returned when the frame type is unknown +// or no specific frame type parser exists. +type http2UnknownFrame struct { + http2FrameHeader + p []byte +} + +// Payload returns the frame's payload (after the header). It is not +// valid to call this method after a subsequent call to +// Framer.ReadFrame, nor is it valid to retain the returned slice. +// The memory is owned by the Framer and is invalidated when the next +// frame is read. +func (f *http2UnknownFrame) Payload() []byte { + f.checkValid() + return f.p +} + +func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { + return &http2UnknownFrame{fh, p}, nil +} + +// A WindowUpdateFrame is used to implement flow control. +// See http://http2.github.io/http2-spec/#rfc.section.6.9 +type http2WindowUpdateFrame struct { + http2FrameHeader + Increment uint32 // never read with high bit set +} + +func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { + if len(p) != 4 { + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit + if inc == 0 { + // A receiver MUST treat the receipt of a + // WINDOW_UPDATE frame with an flow control window + // increment of 0 as a stream error (Section 5.4.2) of + // type PROTOCOL_ERROR; errors on the connection flow + // control window MUST be treated as a connection + // error (Section 5.4.1). + if fh.StreamID == 0 { + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) + } + return &http2WindowUpdateFrame{ + http2FrameHeader: fh, + Increment: inc, + }, nil +} + +// WriteWindowUpdate writes a WINDOW_UPDATE frame. +// The increment value must be between 1 and 2,147,483,647, inclusive. +// If the Stream ID is zero, the window update applies to the +// connection as a whole. +func (f *http2Framer) WriteWindowUpdate(streamID, incr uint32) error { + // "The legal range for the increment to the flow control window is 1 to 2^31-1 (2,147,483,647) octets." + if (incr < 1 || incr > 2147483647) && !f.AllowIllegalWrites { + return errors.New("illegal window increment value") + } + f.startWrite(http2FrameWindowUpdate, 0, streamID) + f.writeUint32(incr) + return f.endWrite() +} + +// A HeadersFrame is used to open a stream and additionally carries a +// header block fragment. +type http2HeadersFrame struct { + http2FrameHeader + + // Priority is set if FlagHeadersPriority is set in the FrameHeader. + Priority http2PriorityParam + + headerFragBuf []byte // not owned +} + +func (f *http2HeadersFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *http2HeadersFrame) HeadersEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndHeaders) +} + +func (f *http2HeadersFrame) StreamEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndStream) +} + +func (f *http2HeadersFrame) HasPriority() bool { + return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority) +} + +func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ http2Frame, err error) { + hf := &http2HeadersFrame{ + http2FrameHeader: fh, + } + if fh.StreamID == 0 { + // HEADERS frames MUST be associated with a stream. If a HEADERS frame + // is received whose stream identifier field is 0x0, the recipient MUST + // respond with a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR. + return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"} + } + var padLength uint8 + if fh.Flags.Has(http2FlagHeadersPadded) { + if p, padLength, err = http2readByte(p); err != nil { + return + } + } + if fh.Flags.Has(http2FlagHeadersPriority) { + var v uint32 + p, v, err = http2readUint32(p) + if err != nil { + return nil, err + } + hf.Priority.StreamDep = v & 0x7fffffff + hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set + p, hf.Priority.Weight, err = http2readByte(p) + if err != nil { + return nil, err + } + } + if len(p)-int(padLength) <= 0 { + return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) + } + hf.headerFragBuf = p[:len(p)-int(padLength)] + return hf, nil +} + +// HeadersFrameParam are the parameters for writing a HEADERS frame. +type http2HeadersFrameParam struct { + // StreamID is the required Stream ID to initiate. + StreamID uint32 + // BlockFragment is part (or all) of a Header Block. + BlockFragment []byte + + // EndStream indicates that the header block is the last that + // the endpoint will send for the identified stream. Setting + // this flag causes the stream to enter one of "half closed" + // states. + EndStream bool + + // EndHeaders indicates that this frame contains an entire + // header block and is not followed by any + // CONTINUATION frames. + EndHeaders bool + + // PadLength is the optional number of bytes of zeros to add + // to this frame. + PadLength uint8 + + // Priority, if non-zero, includes stream priority information + // in the HEADER frame. + Priority http2PriorityParam +} + +// WriteHeaders writes a single HEADERS frame. +// +// This is a low-level header writing method. Encoding headers and +// splitting them into any necessary CONTINUATION frames is handled +// elsewhere. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { + if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if p.PadLength != 0 { + flags |= http2FlagHeadersPadded + } + if p.EndStream { + flags |= http2FlagHeadersEndStream + } + if p.EndHeaders { + flags |= http2FlagHeadersEndHeaders + } + if !p.Priority.IsZero() { + flags |= http2FlagHeadersPriority + } + f.startWrite(http2FrameHeaders, flags, p.StreamID) + if p.PadLength != 0 { + f.writeByte(p.PadLength) + } + if !p.Priority.IsZero() { + v := p.Priority.StreamDep + if !http2validStreamIDOrZero(v) && !f.AllowIllegalWrites { + return http2errDepStreamID + } + if p.Priority.Exclusive { + v |= 1 << 31 + } + f.writeUint32(v) + f.writeByte(p.Priority.Weight) + } + f.wbuf = append(f.wbuf, p.BlockFragment...) + f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...) + return f.endWrite() +} + +// A PriorityFrame specifies the sender-advised priority of a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.3 +type http2PriorityFrame struct { + http2FrameHeader + http2PriorityParam +} + +// PriorityParam are the stream prioritzation parameters. +type http2PriorityParam struct { + // StreamDep is a 31-bit stream identifier for the + // stream that this stream depends on. Zero means no + // dependency. + StreamDep uint32 + + // Exclusive is whether the dependency is exclusive. + Exclusive bool + + // Weight is the stream's zero-indexed weight. It should be + // set together with StreamDep, or neither should be set. Per + // the spec, "Add one to the value to obtain a weight between + // 1 and 256." + Weight uint8 +} + +func (p http2PriorityParam) IsZero() bool { + return p == http2PriorityParam{} +} + +func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) { + if fh.StreamID == 0 { + return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"} + } + if len(payload) != 5 { + return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))} + } + v := binary.BigEndian.Uint32(payload[:4]) + streamID := v & 0x7fffffff // mask off high bit + return &http2PriorityFrame{ + http2FrameHeader: fh, + http2PriorityParam: http2PriorityParam{ + Weight: payload[4], + StreamDep: streamID, + Exclusive: streamID != v, // was high bit set? + }, + }, nil +} + +// WritePriority writes a PRIORITY frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + if !http2validStreamIDOrZero(p.StreamDep) { + return http2errDepStreamID + } + f.startWrite(http2FramePriority, 0, streamID) + v := p.StreamDep + if p.Exclusive { + v |= 1 << 31 + } + f.writeUint32(v) + f.writeByte(p.Weight) + return f.endWrite() +} + +// A RSTStreamFrame allows for abnormal termination of a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.4 +type http2RSTStreamFrame struct { + http2FrameHeader + ErrCode http2ErrCode +} + +func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { + if len(p) != 4 { + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + if fh.StreamID == 0 { + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + return &http2RSTStreamFrame{fh, http2ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil +} + +// WriteRSTStream writes a RST_STREAM frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + f.startWrite(http2FrameRSTStream, 0, streamID) + f.writeUint32(uint32(code)) + return f.endWrite() +} + +// A ContinuationFrame is used to continue a sequence of header block fragments. +// See http://http2.github.io/http2-spec/#rfc.section.6.10 +type http2ContinuationFrame struct { + http2FrameHeader + headerFragBuf []byte +} + +func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { + if fh.StreamID == 0 { + return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"} + } + return &http2ContinuationFrame{fh, p}, nil +} + +func (f *http2ContinuationFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *http2ContinuationFrame) HeadersEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagContinuationEndHeaders) +} + +// WriteContinuation writes a CONTINUATION frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if endHeaders { + flags |= http2FlagContinuationEndHeaders + } + f.startWrite(http2FrameContinuation, flags, streamID) + f.wbuf = append(f.wbuf, headerBlockFragment...) + return f.endWrite() +} + +// A PushPromiseFrame is used to initiate a server stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.6 +type http2PushPromiseFrame struct { + http2FrameHeader + PromiseID uint32 + headerFragBuf []byte // not owned +} + +func (f *http2PushPromiseFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *http2PushPromiseFrame) HeadersEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders) +} + +func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ http2Frame, err error) { + pp := &http2PushPromiseFrame{ + http2FrameHeader: fh, + } + if pp.StreamID == 0 { + // PUSH_PROMISE frames MUST be associated with an existing, + // peer-initiated stream. The stream identifier of a + // PUSH_PROMISE frame indicates the stream it is associated + // with. If the stream identifier field specifies the value + // 0x0, a recipient MUST respond with a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + // The PUSH_PROMISE frame includes optional padding. + // Padding fields and flags are identical to those defined for DATA frames + var padLength uint8 + if fh.Flags.Has(http2FlagPushPromisePadded) { + if p, padLength, err = http2readByte(p); err != nil { + return + } + } + + p, pp.PromiseID, err = http2readUint32(p) + if err != nil { + return + } + pp.PromiseID = pp.PromiseID & (1<<31 - 1) + + if int(padLength) > len(p) { + // like the DATA frame, error out if padding is longer than the body. + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + pp.headerFragBuf = p[:len(p)-int(padLength)] + return pp, nil +} + +// PushPromiseParam are the parameters for writing a PUSH_PROMISE frame. +type http2PushPromiseParam struct { + // StreamID is the required Stream ID to initiate. + StreamID uint32 + + // PromiseID is the required Stream ID which this + // Push Promises + PromiseID uint32 + + // BlockFragment is part (or all) of a Header Block. + BlockFragment []byte + + // EndHeaders indicates that this frame contains an entire + // header block and is not followed by any + // CONTINUATION frames. + EndHeaders bool + + // PadLength is the optional number of bytes of zeros to add + // to this frame. + PadLength uint8 +} + +// WritePushPromise writes a single PushPromise Frame. +// +// As with Header Frames, This is the low level call for writing +// individual frames. Continuation frames are handled elsewhere. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WritePushPromise(p http2PushPromiseParam) error { + if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if p.PadLength != 0 { + flags |= http2FlagPushPromisePadded + } + if p.EndHeaders { + flags |= http2FlagPushPromiseEndHeaders + } + f.startWrite(http2FramePushPromise, flags, p.StreamID) + if p.PadLength != 0 { + f.writeByte(p.PadLength) + } + if !http2validStreamID(p.PromiseID) && !f.AllowIllegalWrites { + return http2errStreamID + } + f.writeUint32(p.PromiseID) + f.wbuf = append(f.wbuf, p.BlockFragment...) + f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...) + return f.endWrite() +} + +// WriteRawFrame writes a raw frame. This can be used to write +// extension frames unknown to this package. +func (f *http2Framer) WriteRawFrame(t http2FrameType, flags http2Flags, streamID uint32, payload []byte) error { + f.startWrite(t, flags, streamID) + f.writeBytes(payload) + return f.endWrite() +} + +func http2readByte(p []byte) (remain []byte, b byte, err error) { + if len(p) == 0 { + return nil, 0, io.ErrUnexpectedEOF + } + return p[1:], p[0], nil +} + +func http2readUint32(p []byte) (remain []byte, v uint32, err error) { + if len(p) < 4 { + return nil, 0, io.ErrUnexpectedEOF + } + return p[4:], binary.BigEndian.Uint32(p[:4]), nil +} + +type http2streamEnder interface { + StreamEnded() bool +} + +type http2headersEnder interface { + HeadersEnded() bool +} + +type http2headersOrContinuation interface { + http2headersEnder + HeaderBlockFragment() []byte +} + +// A MetaHeadersFrame is the representation of one HEADERS frame and +// zero or more contiguous CONTINUATION frames and the decoding of +// their HPACK-encoded contents. +// +// This type of frame does not appear on the wire and is only returned +// by the Framer when Framer.ReadMetaHeaders is set. +type http2MetaHeadersFrame struct { + *http2HeadersFrame + + // Fields are the fields contained in the HEADERS and + // CONTINUATION frames. The underlying slice is owned by the + // Framer and must not be retained after the next call to + // ReadFrame. + // + // Fields are guaranteed to be in the correct http2 order and + // not have unknown pseudo header fields or invalid header + // field names or values. Required pseudo header fields may be + // missing, however. Use the MetaHeadersFrame.Pseudo accessor + // method access pseudo headers. + Fields []hpack.HeaderField + + // Truncated is whether the max header list size limit was hit + // and Fields is incomplete. The hpack decoder state is still + // valid, however. + Truncated bool +} + +// PseudoValue returns the given pseudo header field's value. +// The provided pseudo field should not contain the leading colon. +func (mh *http2MetaHeadersFrame) PseudoValue(pseudo string) string { + for _, hf := range mh.Fields { + if !hf.IsPseudo() { + return "" + } + if hf.Name[1:] == pseudo { + return hf.Value + } + } + return "" +} + +// RegularFields returns the regular (non-pseudo) header fields of mh. +// The caller does not own the returned slice. +func (mh *http2MetaHeadersFrame) RegularFields() []hpack.HeaderField { + for i, hf := range mh.Fields { + if !hf.IsPseudo() { + return mh.Fields[i:] + } + } + return nil +} + +// PseudoFields returns the pseudo header fields of mh. +// The caller does not own the returned slice. +func (mh *http2MetaHeadersFrame) PseudoFields() []hpack.HeaderField { + for i, hf := range mh.Fields { + if !hf.IsPseudo() { + return mh.Fields[:i] + } + } + return mh.Fields +} + +func (mh *http2MetaHeadersFrame) checkPseudos() error { + var isRequest, isResponse bool + pf := mh.PseudoFields() + for i, hf := range pf { + switch hf.Name { + case ":method", ":path", ":scheme", ":authority": + isRequest = true + case ":status": + isResponse = true + default: + return http2pseudoHeaderError(hf.Name) + } + // Check for duplicates. + // This would be a bad algorithm, but N is 4. + // And this doesn't allocate. + for _, hf2 := range pf[:i] { + if hf.Name == hf2.Name { + return http2duplicatePseudoHeaderError(hf.Name) + } + } + } + if isRequest && isResponse { + return http2errMixPseudoHeaderTypes + } + return nil +} + +func (fr *http2Framer) maxHeaderStringLen() int { + v := fr.maxHeaderListSize() + if uint32(int(v)) == v { + return int(v) + } + // They had a crazy big number for MaxHeaderBytes anyway, + // so give them unlimited header lengths: + return 0 +} + +// readMetaFrame returns 0 or more CONTINUATION frames from fr and +// merge them into the provided hf and returns a MetaHeadersFrame +// with the decoded hpack values. +func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFrame, error) { + if fr.AllowIllegalReads { + return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders") + } + mh := &http2MetaHeadersFrame{ + http2HeadersFrame: hf, + } + var remainSize = fr.maxHeaderListSize() + var sawRegular bool + + var invalid error // pseudo header field errors + hdec := fr.ReadMetaHeaders + hdec.SetEmitEnabled(true) + hdec.SetMaxStringLength(fr.maxHeaderStringLen()) + hdec.SetEmitFunc(func(hf hpack.HeaderField) { + if http2VerboseLogs && fr.logReads { + fr.debugReadLoggerf("http2: decoded hpack field %+v", hf) + } + if !httpguts.ValidHeaderFieldValue(hf.Value) { + invalid = http2headerFieldValueError(hf.Value) + } + isPseudo := strings.HasPrefix(hf.Name, ":") + if isPseudo { + if sawRegular { + invalid = http2errPseudoAfterRegular + } + } else { + sawRegular = true + if !http2validWireHeaderFieldName(hf.Name) { + invalid = http2headerFieldNameError(hf.Name) + } + } + + if invalid != nil { + hdec.SetEmitEnabled(false) + return + } + + size := hf.Size() + if size > remainSize { + hdec.SetEmitEnabled(false) + mh.Truncated = true + return + } + remainSize -= size + + mh.Fields = append(mh.Fields, hf) + }) + // Lose reference to MetaHeadersFrame: + defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {}) + + var hc http2headersOrContinuation = hf + for { + frag := hc.HeaderBlockFragment() + if _, err := hdec.Write(frag); err != nil { + return nil, http2ConnectionError(http2ErrCodeCompression) + } + + if hc.HeadersEnded() { + break + } + if f, err := fr.ReadFrame(); err != nil { + return nil, err + } else { + hc = f.(*http2ContinuationFrame) // guaranteed by checkFrameOrder + } + } + + mh.http2HeadersFrame.headerFragBuf = nil + mh.http2HeadersFrame.invalidate() + + if err := hdec.Close(); err != nil { + return nil, http2ConnectionError(http2ErrCodeCompression) + } + if invalid != nil { + fr.errDetail = invalid + if http2VerboseLogs { + log.Printf("http2: invalid header: %v", invalid) + } + return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, invalid} + } + if err := mh.checkPseudos(); err != nil { + fr.errDetail = err + if http2VerboseLogs { + log.Printf("http2: invalid pseudo headers: %v", err) + } + return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, err} + } + return mh, nil +} + +func http2summarizeFrame(f http2Frame) string { + var buf bytes.Buffer + f.Header().writeDebug(&buf) + switch f := f.(type) { + case *http2SettingsFrame: + n := 0 + f.ForeachSetting(func(s http2Setting) error { + n++ + if n == 1 { + buf.WriteString(", settings:") + } + fmt.Fprintf(&buf, " %v=%v,", s.ID, s.Val) + return nil + }) + if n > 0 { + buf.Truncate(buf.Len() - 1) // remove trailing comma + } + case *http2DataFrame: + data := f.Data() + const max = 256 + if len(data) > max { + data = data[:max] + } + fmt.Fprintf(&buf, " data=%q", data) + if len(f.Data()) > max { + fmt.Fprintf(&buf, " (%d bytes omitted)", len(f.Data())-max) + } + case *http2WindowUpdateFrame: + if f.StreamID == 0 { + buf.WriteString(" (conn)") + } + fmt.Fprintf(&buf, " incr=%v", f.Increment) + case *http2PingFrame: + fmt.Fprintf(&buf, " ping=%q", f.Data[:]) + case *http2GoAwayFrame: + fmt.Fprintf(&buf, " LastStreamID=%v ErrCode=%v Debug=%q", + f.LastStreamID, f.ErrCode, f.debugData) + case *http2RSTStreamFrame: + fmt.Fprintf(&buf, " ErrCode=%v", f.ErrCode) + } + return buf.String() +} + +func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { + return trace != nil && trace.WroteHeaderField != nil +} + +func http2traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) { + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField(k, []string{v}) + } +} + +func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error { + if trace != nil { + return trace.Got1xxResponse + } + return nil +} + +var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" + +type http2goroutineLock uint64 + +func http2newGoroutineLock() http2goroutineLock { + if !http2DebugGoroutines { + return 0 + } + return http2goroutineLock(http2curGoroutineID()) +} + +func (g http2goroutineLock) check() { + if !http2DebugGoroutines { + return + } + if http2curGoroutineID() != uint64(g) { + panic("running on the wrong goroutine") + } +} + +func (g http2goroutineLock) checkNotOn() { + if !http2DebugGoroutines { + return + } + if http2curGoroutineID() == uint64(g) { + panic("running on the wrong goroutine") + } +} + +var http2goroutineSpace = []byte("goroutine ") + +func http2curGoroutineID() uint64 { + bp := http2littleBuf.Get().(*[]byte) + defer http2littleBuf.Put(bp) + b := *bp + b = b[:runtime.Stack(b, false)] + // Parse the 4707 out of "goroutine 4707 [" + b = bytes.TrimPrefix(b, http2goroutineSpace) + i := bytes.IndexByte(b, ' ') + if i < 0 { + panic(fmt.Sprintf("No space found in %q", b)) + } + b = b[:i] + n, err := http2parseUintBytes(b, 10, 64) + if err != nil { + panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err)) + } + return n +} + +var http2littleBuf = sync.Pool{ + New: func() interface{} { + buf := make([]byte, 64) + return &buf + }, +} + +// parseUintBytes is like strconv.ParseUint, but using a []byte. +func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) { + var cutoff, maxVal uint64 + + if bitSize == 0 { + bitSize = int(strconv.IntSize) + } + + s0 := s + switch { + case len(s) < 1: + err = strconv.ErrSyntax + goto Error + + case 2 <= base && base <= 36: + // valid base; nothing to do + + case base == 0: + // Look for octal, hex prefix. + switch { + case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'): + base = 16 + s = s[2:] + if len(s) < 1 { + err = strconv.ErrSyntax + goto Error + } + case s[0] == '0': + base = 8 + default: + base = 10 + } + + default: + err = errors.New("invalid base " + strconv.Itoa(base)) + goto Error + } + + n = 0 + cutoff = http2cutoff64(base) + maxVal = 1<= base { + n = 0 + err = strconv.ErrSyntax + goto Error + } + + if n >= cutoff { + // n*base overflows + n = 1<<64 - 1 + err = strconv.ErrRange + goto Error + } + n *= uint64(base) + + n1 := n + uint64(v) + if n1 < n || n1 > maxVal { + // n+v overflows + n = 1<<64 - 1 + err = strconv.ErrRange + goto Error + } + n = n1 + } + + return n, nil + +Error: + return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err} +} + +// Return the first number n such that n*base >= 1<<64. +func http2cutoff64(base int) uint64 { + if base < 2 { + return 0 + } + return (1<<64-1)/uint64(base) + 1 +} + +var ( + http2commonBuildOnce sync.Once + http2commonLowerHeader map[string]string // Go-Canonical-Case -> lower-case + http2commonCanonHeader map[string]string // lower-case -> Go-Canonical-Case +) + +func http2buildCommonHeaderMapsOnce() { + http2commonBuildOnce.Do(http2buildCommonHeaderMaps) +} + +func http2buildCommonHeaderMaps() { + common := []string{ + "accept", + "accept-charset", + "accept-encoding", + "accept-language", + "accept-ranges", + "age", + "access-control-allow-origin", + "allow", + "authorization", + "cache-control", + "content-disposition", + "content-encoding", + "content-language", + "content-length", + "content-location", + "content-range", + "content-type", + "cookie", + "date", + "etag", + "expect", + "expires", + "from", + "host", + "if-match", + "if-modified-since", + "if-none-match", + "if-unmodified-since", + "last-modified", + "link", + "location", + "max-forwards", + "proxy-authenticate", + "proxy-authorization", + "range", + "referer", + "refresh", + "retry-after", + "server", + "set-cookie", + "strict-transport-security", + "trailer", + "transfer-encoding", + "user-agent", + "vary", + "via", + "www-authenticate", + } + http2commonLowerHeader = make(map[string]string, len(common)) + http2commonCanonHeader = make(map[string]string, len(common)) + for _, v := range common { + chk := CanonicalHeaderKey(v) + http2commonLowerHeader[chk] = v + http2commonCanonHeader[v] = chk + } +} + +func http2lowerHeader(v string) string { + http2buildCommonHeaderMapsOnce() + if s, ok := http2commonLowerHeader[v]; ok { + return s + } + return strings.ToLower(v) +} + +var ( + http2VerboseLogs bool + http2logFrameWrites bool + http2logFrameReads bool + http2inTests bool +) + +func init() { + e := os.Getenv("GODEBUG") + if strings.Contains(e, "http2debug=1") { + http2VerboseLogs = true + } + if strings.Contains(e, "http2debug=2") { + http2VerboseLogs = true + http2logFrameWrites = true + http2logFrameReads = true + } +} + +const ( + // ClientPreface is the string that must be sent by new + // connections from clients. + http2ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + + // SETTINGS_MAX_FRAME_SIZE default + // http://http2.github.io/http2-spec/#rfc.section.6.5.2 + http2initialMaxFrameSize = 16384 + + // NextProtoTLS is the NPN/ALPN protocol negotiated during + // HTTP/2's TLS setup. + http2NextProtoTLS = "h2" + + // http://http2.github.io/http2-spec/#SettingValues + http2initialHeaderTableSize = 4096 + + http2initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size + + http2defaultMaxReadFrameSize = 1 << 20 +) + +var ( + http2clientPreface = []byte(http2ClientPreface) +) + +type http2streamState int + +// HTTP/2 stream states. +// +// See http://tools.ietf.org/html/rfc7540#section-5.1. +// +// For simplicity, the server code merges "reserved (local)" into +// "half-closed (remote)". This is one less state transition to track. +// The only downside is that we send PUSH_PROMISEs slightly less +// liberally than allowable. More discussion here: +// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html +// +// "reserved (remote)" is omitted since the client code does not +// support server push. +const ( + http2stateIdle http2streamState = iota + http2stateOpen + http2stateHalfClosedLocal + http2stateHalfClosedRemote + http2stateClosed +) + +var http2stateName = [...]string{ + http2stateIdle: "Idle", + http2stateOpen: "Open", + http2stateHalfClosedLocal: "HalfClosedLocal", + http2stateHalfClosedRemote: "HalfClosedRemote", + http2stateClosed: "Closed", +} + +func (st http2streamState) String() string { + return http2stateName[st] +} + +// Setting is a setting parameter: which setting it is, and its value. +type http2Setting struct { + // ID is which setting is being set. + // See http://http2.github.io/http2-spec/#SettingValues + ID http2SettingID + + // Val is the value. + Val uint32 +} + +func (s http2Setting) String() string { + return fmt.Sprintf("[%v = %d]", s.ID, s.Val) +} + +// Valid reports whether the setting is valid. +func (s http2Setting) Valid() error { + // Limits and error codes from 6.5.2 Defined SETTINGS Parameters + switch s.ID { + case http2SettingEnablePush: + if s.Val != 1 && s.Val != 0 { + return http2ConnectionError(http2ErrCodeProtocol) + } + case http2SettingInitialWindowSize: + if s.Val > 1<<31-1 { + return http2ConnectionError(http2ErrCodeFlowControl) + } + case http2SettingMaxFrameSize: + if s.Val < 16384 || s.Val > 1<<24-1 { + return http2ConnectionError(http2ErrCodeProtocol) + } + } + return nil +} + +// A SettingID is an HTTP/2 setting as defined in +// http://http2.github.io/http2-spec/#iana-settings +type http2SettingID uint16 + +const ( + http2SettingHeaderTableSize http2SettingID = 0x1 + http2SettingEnablePush http2SettingID = 0x2 + http2SettingMaxConcurrentStreams http2SettingID = 0x3 + http2SettingInitialWindowSize http2SettingID = 0x4 + http2SettingMaxFrameSize http2SettingID = 0x5 + http2SettingMaxHeaderListSize http2SettingID = 0x6 +) + +var http2settingName = map[http2SettingID]string{ + http2SettingHeaderTableSize: "HEADER_TABLE_SIZE", + http2SettingEnablePush: "ENABLE_PUSH", + http2SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", + http2SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", + http2SettingMaxFrameSize: "MAX_FRAME_SIZE", + http2SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", +} + +func (s http2SettingID) String() string { + if v, ok := http2settingName[s]; ok { + return v + } + return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s)) +} + +// validWireHeaderFieldName reports whether v is a valid header field +// name (key). See httpguts.ValidHeaderName for the base rules. +// +// Further, http2 says: +// "Just as in HTTP/1.x, header field names are strings of ASCII +// characters that are compared in a case-insensitive +// fashion. However, header field names MUST be converted to +// lowercase prior to their encoding in HTTP/2. " +func http2validWireHeaderFieldName(v string) bool { + if len(v) == 0 { + return false + } + for _, r := range v { + if !httpguts.IsTokenRune(r) { + return false + } + if 'A' <= r && r <= 'Z' { + return false + } + } + return true +} + +func http2httpCodeString(code int) string { + switch code { + case 200: + return "200" + case 404: + return "404" + } + return strconv.Itoa(code) +} + +// from pkg io +type http2stringWriter interface { + WriteString(s string) (n int, err error) +} + +// A gate lets two goroutines coordinate their activities. +type http2gate chan struct{} + +func (g http2gate) Done() { g <- struct{}{} } + +func (g http2gate) Wait() { <-g } + +// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed). +type http2closeWaiter chan struct{} + +// Init makes a closeWaiter usable. +// It exists because so a closeWaiter value can be placed inside a +// larger struct and have the Mutex and Cond's memory in the same +// allocation. +func (cw *http2closeWaiter) Init() { + *cw = make(chan struct{}) +} + +// Close marks the closeWaiter as closed and unblocks any waiters. +func (cw http2closeWaiter) Close() { + close(cw) +} + +// Wait waits for the closeWaiter to become closed. +func (cw http2closeWaiter) Wait() { + <-cw +} + +// bufferedWriter is a buffered writer that writes to w. +// Its buffered writer is lazily allocated as needed, to minimize +// idle memory usage with many connections. +type http2bufferedWriter struct { + _ http2incomparable + w io.Writer // immutable + bw *bufio.Writer // non-nil when data is buffered +} + +func http2newBufferedWriter(w io.Writer) *http2bufferedWriter { + return &http2bufferedWriter{w: w} +} + +// bufWriterPoolBufferSize is the size of bufio.Writer's +// buffers created using bufWriterPool. +// +// TODO: pick a less arbitrary value? this is a bit under +// (3 x typical 1500 byte MTU) at least. Other than that, +// not much thought went into it. +const http2bufWriterPoolBufferSize = 4 << 10 + +var http2bufWriterPool = sync.Pool{ + New: func() interface{} { + return bufio.NewWriterSize(nil, http2bufWriterPoolBufferSize) + }, +} + +func (w *http2bufferedWriter) Available() int { + if w.bw == nil { + return http2bufWriterPoolBufferSize + } + return w.bw.Available() +} + +func (w *http2bufferedWriter) Write(p []byte) (n int, err error) { + if w.bw == nil { + bw := http2bufWriterPool.Get().(*bufio.Writer) + bw.Reset(w.w) + w.bw = bw + } + return w.bw.Write(p) +} + +func (w *http2bufferedWriter) Flush() error { + bw := w.bw + if bw == nil { + return nil + } + err := bw.Flush() + bw.Reset(nil) + http2bufWriterPool.Put(bw) + w.bw = nil + return err +} + +func http2mustUint31(v int32) uint32 { + if v < 0 || v > 2147483647 { + panic("out of range") + } + return uint32(v) +} + +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC 7230, section 3.3. +func http2bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +type http2httpError struct { + _ http2incomparable + msg string + timeout bool +} + +func (e *http2httpError) Error() string { return e.msg } + +func (e *http2httpError) Timeout() bool { return e.timeout } + +func (e *http2httpError) Temporary() bool { return true } + +var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true} + +type http2connectionStater interface { + ConnectionState() tls.ConnectionState +} + +var http2sorterPool = sync.Pool{New: func() interface{} { return new(http2sorter) }} + +type http2sorter struct { + v []string // owned by sorter +} + +func (s *http2sorter) Len() int { return len(s.v) } + +func (s *http2sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] } + +func (s *http2sorter) Less(i, j int) bool { return s.v[i] < s.v[j] } + +// Keys returns the sorted keys of h. +// +// The returned slice is only valid until s used again or returned to +// its pool. +func (s *http2sorter) Keys(h Header) []string { + keys := s.v[:0] + for k := range h { + keys = append(keys, k) + } + s.v = keys + sort.Sort(s) + return keys +} + +func (s *http2sorter) SortStrings(ss []string) { + // Our sorter works on s.v, which sorter owns, so + // stash it away while we sort the user's buffer. + save := s.v + s.v = ss + sort.Sort(s) + s.v = save +} + +// validPseudoPath reports whether v is a valid :path pseudo-header +// value. It must be either: +// +// *) a non-empty string starting with '/' +// *) the string '*', for OPTIONS requests. +// +// For now this is only used a quick check for deciding when to clean +// up Opaque URLs before sending requests from the Transport. +// See golang.org/issue/16847 +// +// We used to enforce that the path also didn't start with "//", but +// Google's GFE accepts such paths and Chrome sends them, so ignore +// that part of the spec. See golang.org/issue/19103. +func http2validPseudoPath(v string) bool { + return (len(v) > 0 && v[0] == '/') || v == "*" +} + +// incomparable is a zero-width, non-comparable type. Adding it to a struct +// makes that struct also non-comparable, and generally doesn't add +// any size (as long as it's first). +type http2incomparable [0]func() + +// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like +// io.Pipe except there are no PipeReader/PipeWriter halves, and the +// underlying buffer is an interface. (io.Pipe is always unbuffered) +type http2pipe struct { + mu sync.Mutex + c sync.Cond // c.L lazily initialized to &p.mu + b http2pipeBuffer // nil when done reading + unread int // bytes unread when done + err error // read error once empty. non-nil means closed. + breakErr error // immediate read error (caller doesn't see rest of b) + donec chan struct{} // closed on error + readFn func() // optional code to run in Read before error +} + +type http2pipeBuffer interface { + Len() int + io.Writer + io.Reader +} + +func (p *http2pipe) Len() int { + p.mu.Lock() + defer p.mu.Unlock() + if p.b == nil { + return p.unread + } + return p.b.Len() +} + +// Read waits until data is available and copies bytes +// from the buffer into p. +func (p *http2pipe) Read(d []byte) (n int, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + for { + if p.breakErr != nil { + return 0, p.breakErr + } + if p.b != nil && p.b.Len() > 0 { + return p.b.Read(d) + } + if p.err != nil { + if p.readFn != nil { + p.readFn() // e.g. copy trailers + p.readFn = nil // not sticky like p.err + } + p.b = nil + return 0, p.err + } + p.c.Wait() + } +} + +var http2errClosedPipeWrite = errors.New("write on closed buffer") + +// Write copies bytes from p into the buffer and wakes a reader. +// It is an error to write more data than the buffer can hold. +func (p *http2pipe) Write(d []byte) (n int, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + defer p.c.Signal() + if p.err != nil { + return 0, http2errClosedPipeWrite + } + if p.breakErr != nil { + p.unread += len(d) + return len(d), nil // discard when there is no reader + } + return p.b.Write(d) +} + +// CloseWithError causes the next Read (waking up a current blocked +// Read if needed) to return the provided err after all data has been +// read. +// +// The error must be non-nil. +func (p *http2pipe) CloseWithError(err error) { p.closeWithError(&p.err, err, nil) } + +// BreakWithError causes the next Read (waking up a current blocked +// Read if needed) to return the provided err immediately, without +// waiting for unread data. +func (p *http2pipe) BreakWithError(err error) { p.closeWithError(&p.breakErr, err, nil) } + +// closeWithErrorAndCode is like CloseWithError but also sets some code to run +// in the caller's goroutine before returning the error. +func (p *http2pipe) closeWithErrorAndCode(err error, fn func()) { p.closeWithError(&p.err, err, fn) } + +func (p *http2pipe) closeWithError(dst *error, err error, fn func()) { + if err == nil { + panic("err must be non-nil") + } + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + defer p.c.Signal() + if *dst != nil { + // Already been done. + return + } + p.readFn = fn + if dst == &p.breakErr { + if p.b != nil { + p.unread += p.b.Len() + } + p.b = nil + } + *dst = err + p.closeDoneLocked() +} + +// requires p.mu be held. +func (p *http2pipe) closeDoneLocked() { + if p.donec == nil { + return + } + // Close if unclosed. This isn't racy since we always + // hold p.mu while closing. + select { + case <-p.donec: + default: + close(p.donec) + } +} + +// Err returns the error (if any) first set by BreakWithError or CloseWithError. +func (p *http2pipe) Err() error { + p.mu.Lock() + defer p.mu.Unlock() + if p.breakErr != nil { + return p.breakErr + } + return p.err +} + +// Done returns a channel which is closed if and when this pipe is closed +// with CloseWithError. +func (p *http2pipe) Done() <-chan struct{} { + p.mu.Lock() + defer p.mu.Unlock() + if p.donec == nil { + p.donec = make(chan struct{}) + if p.err != nil || p.breakErr != nil { + // Already hit an error. + p.closeDoneLocked() + } + } + return p.donec +} + +const ( + http2prefaceTimeout = 10 * time.Second + http2firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway + http2handlerChunkWriteSize = 4 << 10 + http2defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? + http2maxQueuedControlFrames = 10000 +) + +var ( + http2errClientDisconnected = errors.New("client disconnected") + http2errClosedBody = errors.New("body closed by handler") + http2errHandlerComplete = errors.New("http2: request body closed due to handler exiting") + http2errStreamClosed = errors.New("http2: stream closed") +) + +var http2responseWriterStatePool = sync.Pool{ + New: func() interface{} { + rws := &http2responseWriterState{} + rws.bw = bufio.NewWriterSize(http2chunkWriter{rws}, http2handlerChunkWriteSize) + return rws + }, +} + +// Test hooks. +var ( + http2testHookOnConn func() + http2testHookGetServerConn func(*http2serverConn) + http2testHookOnPanicMu *sync.Mutex // nil except in tests + http2testHookOnPanic func(sc *http2serverConn, panicVal interface{}) (rePanic bool) +) + +// Server is an HTTP/2 server. +type http2Server struct { + // MaxHandlers limits the number of http.Handler ServeHTTP goroutines + // which may run at a time over all connections. + // Negative or zero no limit. + // TODO: implement + MaxHandlers int + + // MaxConcurrentStreams optionally specifies the number of + // concurrent streams that each client may have open at a + // time. This is unrelated to the number of http.Handler goroutines + // which may be active globally, which is MaxHandlers. + // If zero, MaxConcurrentStreams defaults to at least 100, per + // the HTTP/2 spec's recommendations. + MaxConcurrentStreams uint32 + + // MaxReadFrameSize optionally specifies the largest frame + // this server is willing to read. A valid value is between + // 16k and 16M, inclusive. If zero or otherwise invalid, a + // default value is used. + MaxReadFrameSize uint32 + + // PermitProhibitedCipherSuites, if true, permits the use of + // cipher suites prohibited by the HTTP/2 spec. + PermitProhibitedCipherSuites bool + + // IdleTimeout specifies how long until idle clients should be + // closed with a GOAWAY frame. PING frames are not considered + // activity for the purposes of IdleTimeout. + IdleTimeout time.Duration + + // MaxUploadBufferPerConnection is the size of the initial flow + // control window for each connections. The HTTP/2 spec does not + // allow this to be smaller than 65535 or larger than 2^32-1. + // If the value is outside this range, a default value will be + // used instead. + MaxUploadBufferPerConnection int32 + + // MaxUploadBufferPerStream is the size of the initial flow control + // window for each stream. The HTTP/2 spec does not allow this to + // be larger than 2^32-1. If the value is zero or larger than the + // maximum, a default value will be used instead. + MaxUploadBufferPerStream int32 + + // NewWriteScheduler constructs a write scheduler for a connection. + // If nil, a default scheduler is chosen. + NewWriteScheduler func() http2WriteScheduler + + // Internal state. This is a pointer (rather than embedded directly) + // so that we don't embed a Mutex in this struct, which will make the + // struct non-copyable, which might break some callers. + state *http2serverInternalState +} + +func (s *http2Server) initialConnRecvWindowSize() int32 { + if s.MaxUploadBufferPerConnection > http2initialWindowSize { + return s.MaxUploadBufferPerConnection + } + return 1 << 20 +} + +func (s *http2Server) initialStreamRecvWindowSize() int32 { + if s.MaxUploadBufferPerStream > 0 { + return s.MaxUploadBufferPerStream + } + return 1 << 20 +} + +func (s *http2Server) maxReadFrameSize() uint32 { + if v := s.MaxReadFrameSize; v >= http2minMaxFrameSize && v <= http2maxFrameSize { + return v + } + return http2defaultMaxReadFrameSize +} + +func (s *http2Server) maxConcurrentStreams() uint32 { + if v := s.MaxConcurrentStreams; v > 0 { + return v + } + return http2defaultMaxStreams +} + +// maxQueuedControlFrames is the maximum number of control frames like +// SETTINGS, PING and RST_STREAM that will be queued for writing before +// the connection is closed to prevent memory exhaustion attacks. +func (s *http2Server) maxQueuedControlFrames() int { + // TODO: if anybody asks, add a Server field, and remember to define the + // behavior of negative values. + return http2maxQueuedControlFrames +} + +type http2serverInternalState struct { + mu sync.Mutex + activeConns map[*http2serverConn]struct{} +} + +func (s *http2serverInternalState) registerConn(sc *http2serverConn) { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + s.activeConns[sc] = struct{}{} + s.mu.Unlock() +} + +func (s *http2serverInternalState) unregisterConn(sc *http2serverConn) { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + delete(s.activeConns, sc) + s.mu.Unlock() +} + +func (s *http2serverInternalState) startGracefulShutdown() { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + for sc := range s.activeConns { + sc.startGracefulShutdown() + } + s.mu.Unlock() +} + +// ConfigureServer adds HTTP/2 support to a net/http Server. +// +// The configuration conf may be nil. +// +// ConfigureServer must be called before s begins serving. +func http2ConfigureServer(s *Server, conf *http2Server) error { + if s == nil { + panic("nil *http.Server") + } + if conf == nil { + conf = new(http2Server) + } + conf.state = &http2serverInternalState{activeConns: make(map[*http2serverConn]struct{})} + if h1, h2 := s, conf; h2.IdleTimeout == 0 { + if h1.IdleTimeout != 0 { + h2.IdleTimeout = h1.IdleTimeout + } else { + h2.IdleTimeout = h1.ReadTimeout + } + } + s.RegisterOnShutdown(conf.state.startGracefulShutdown) + + if s.TLSConfig == nil { + s.TLSConfig = new(tls.Config) + } else if s.TLSConfig.CipherSuites != nil { + // If they already provided a CipherSuite list, return + // an error if it has a bad order or is missing + // ECDHE_RSA_WITH_AES_128_GCM_SHA256 or ECDHE_ECDSA_WITH_AES_128_GCM_SHA256. + haveRequired := false + sawBad := false + for i, cs := range s.TLSConfig.CipherSuites { + switch cs { + case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + // Alternative MTI cipher to not discourage ECDSA-only servers. + // See http://golang.org/cl/30721 for further information. + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: + haveRequired = true + } + if http2isBadCipher(cs) { + sawBad = true + } else if sawBad { + return fmt.Errorf("http2: TLSConfig.CipherSuites index %d contains an HTTP/2-approved cipher suite (%#04x), but it comes after unapproved cipher suites. With this configuration, clients that don't support previous, approved cipher suites may be given an unapproved one and reject the connection.", i, cs) + } + } + if !haveRequired { + return fmt.Errorf("http2: TLSConfig.CipherSuites is missing an HTTP/2-required AES_128_GCM_SHA256 cipher (need at least one of TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 or TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256).") + } + } + + // Note: not setting MinVersion to tls.VersionTLS12, + // as we don't want to interfere with HTTP/1.1 traffic + // on the user's server. We enforce TLS 1.2 later once + // we accept a connection. Ideally this should be done + // during next-proto selection, but using TLS <1.2 with + // HTTP/2 is still the client's bug. + + s.TLSConfig.PreferServerCipherSuites = true + + haveNPN := false + for _, p := range s.TLSConfig.NextProtos { + if p == http2NextProtoTLS { + haveNPN = true + break + } + } + if !haveNPN { + s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS) + } + + if s.TLSNextProto == nil { + s.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){} + } + protoHandler := func(hs *Server, c *tls.Conn, h Handler) { + if http2testHookOnConn != nil { + http2testHookOnConn() + } + // The TLSNextProto interface predates contexts, so + // the net/http package passes down its per-connection + // base context via an exported but unadvertised + // method on the Handler. This is for internal + // net/http<=>http2 use only. + var ctx context.Context + type baseContexter interface { + BaseContext() context.Context + } + if bc, ok := h.(baseContexter); ok { + ctx = bc.BaseContext() + } + conf.ServeConn(c, &http2ServeConnOpts{ + Context: ctx, + Handler: h, + BaseConfig: hs, + }) + } + s.TLSNextProto[http2NextProtoTLS] = protoHandler + return nil +} + +// ServeConnOpts are options for the Server.ServeConn method. +type http2ServeConnOpts struct { + // Context is the base context to use. + // If nil, context.Background is used. + Context context.Context + + // BaseConfig optionally sets the base configuration + // for values. If nil, defaults are used. + BaseConfig *Server + + // Handler specifies which handler to use for processing + // requests. If nil, BaseConfig.Handler is used. If BaseConfig + // or BaseConfig.Handler is nil, http.DefaultServeMux is used. + Handler Handler +} + +func (o *http2ServeConnOpts) context() context.Context { + if o != nil && o.Context != nil { + return o.Context + } + return context.Background() +} + +func (o *http2ServeConnOpts) baseConfig() *Server { + if o != nil && o.BaseConfig != nil { + return o.BaseConfig + } + return new(Server) +} + +func (o *http2ServeConnOpts) handler() Handler { + if o != nil { + if o.Handler != nil { + return o.Handler + } + if o.BaseConfig != nil && o.BaseConfig.Handler != nil { + return o.BaseConfig.Handler + } + } + return DefaultServeMux +} + +// ServeConn serves HTTP/2 requests on the provided connection and +// blocks until the connection is no longer readable. +// +// ServeConn starts speaking HTTP/2 assuming that c has not had any +// reads or writes. It writes its initial settings frame and expects +// to be able to read the preface and settings frame from the +// client. If c has a ConnectionState method like a *tls.Conn, the +// ConnectionState is used to verify the TLS ciphersuite and to set +// the Request.TLS field in Handlers. +// +// ServeConn does not support h2c by itself. Any h2c support must be +// implemented in terms of providing a suitably-behaving net.Conn. +// +// The opts parameter is optional. If nil, default values are used. +func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { + baseCtx, cancel := http2serverConnBaseContext(c, opts) + defer cancel() + + sc := &http2serverConn{ + srv: s, + hs: opts.baseConfig(), + conn: c, + baseCtx: baseCtx, + remoteAddrStr: c.RemoteAddr().String(), + bw: http2newBufferedWriter(c), + handler: opts.handler(), + streams: make(map[uint32]*http2stream), + readFrameCh: make(chan http2readFrameResult), + wantWriteFrameCh: make(chan http2FrameWriteRequest, 8), + serveMsgCh: make(chan interface{}, 8), + wroteFrameCh: make(chan http2frameWriteResult, 1), // buffered; one send in writeFrameAsync + bodyReadCh: make(chan http2bodyReadMsg), // buffering doesn't matter either way + doneServing: make(chan struct{}), + clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value" + advMaxStreams: s.maxConcurrentStreams(), + initialStreamSendWindowSize: http2initialWindowSize, + maxFrameSize: http2initialMaxFrameSize, + headerTableSize: http2initialHeaderTableSize, + serveG: http2newGoroutineLock(), + pushEnabled: true, + } + + s.state.registerConn(sc) + defer s.state.unregisterConn(sc) + + // The net/http package sets the write deadline from the + // http.Server.WriteTimeout during the TLS handshake, but then + // passes the connection off to us with the deadline already set. + // Write deadlines are set per stream in serverConn.newStream. + // Disarm the net.Conn write deadline here. + if sc.hs.WriteTimeout != 0 { + sc.conn.SetWriteDeadline(time.Time{}) + } + + if s.NewWriteScheduler != nil { + sc.writeSched = s.NewWriteScheduler() + } else { + sc.writeSched = http2NewRandomWriteScheduler() + } + + // These start at the RFC-specified defaults. If there is a higher + // configured value for inflow, that will be updated when we send a + // WINDOW_UPDATE shortly after sending SETTINGS. + sc.flow.add(http2initialWindowSize) + sc.inflow.add(http2initialWindowSize) + sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) + + fr := http2NewFramer(sc.bw, c) + fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + fr.MaxHeaderListSize = sc.maxHeaderListSize() + fr.SetMaxReadFrameSize(s.maxReadFrameSize()) + sc.framer = fr + + if tc, ok := c.(http2connectionStater); ok { + sc.tlsState = new(tls.ConnectionState) + *sc.tlsState = tc.ConnectionState() + // 9.2 Use of TLS Features + // An implementation of HTTP/2 over TLS MUST use TLS + // 1.2 or higher with the restrictions on feature set + // and cipher suite described in this section. Due to + // implementation limitations, it might not be + // possible to fail TLS negotiation. An endpoint MUST + // immediately terminate an HTTP/2 connection that + // does not meet the TLS requirements described in + // this section with a connection error (Section + // 5.4.1) of type INADEQUATE_SECURITY. + if sc.tlsState.Version < tls.VersionTLS12 { + sc.rejectConn(http2ErrCodeInadequateSecurity, "TLS version too low") + return + } + + if sc.tlsState.ServerName == "" { + // Client must use SNI, but we don't enforce that anymore, + // since it was causing problems when connecting to bare IP + // addresses during development. + // + // TODO: optionally enforce? Or enforce at the time we receive + // a new request, and verify the ServerName matches the :authority? + // But that precludes proxy situations, perhaps. + // + // So for now, do nothing here again. + } + + if !s.PermitProhibitedCipherSuites && http2isBadCipher(sc.tlsState.CipherSuite) { + // "Endpoints MAY choose to generate a connection error + // (Section 5.4.1) of type INADEQUATE_SECURITY if one of + // the prohibited cipher suites are negotiated." + // + // We choose that. In my opinion, the spec is weak + // here. It also says both parties must support at least + // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 so there's no + // excuses here. If we really must, we could allow an + // "AllowInsecureWeakCiphers" option on the server later. + // Let's see how it plays out first. + sc.rejectConn(http2ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite)) + return + } + } + + if hook := http2testHookGetServerConn; hook != nil { + hook(sc) + } + sc.serve() +} + +func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx context.Context, cancel func()) { + ctx, cancel = context.WithCancel(opts.context()) + ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr()) + if hs := opts.baseConfig(); hs != nil { + ctx = context.WithValue(ctx, ServerContextKey, hs) + } + return +} + +func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) { + sc.vlogf("http2: server rejecting conn: %v, %s", err, debug) + // ignoring errors. hanging up anyway. + sc.framer.WriteGoAway(0, err, []byte(debug)) + sc.bw.Flush() + sc.conn.Close() +} + +type http2serverConn struct { + // Immutable: + srv *http2Server + hs *Server + conn net.Conn + bw *http2bufferedWriter // writing to conn + handler Handler + baseCtx context.Context + framer *http2Framer + doneServing chan struct{} // closed when serverConn.serve ends + readFrameCh chan http2readFrameResult // written by serverConn.readFrames + wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve + wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes + bodyReadCh chan http2bodyReadMsg // from handlers -> serve + serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop + flow http2flow // conn-wide (not stream-specific) outbound flow control + inflow http2flow // conn-wide inbound flow control + tlsState *tls.ConnectionState // shared by all handlers, like net/http + remoteAddrStr string + writeSched http2WriteScheduler + + // Everything following is owned by the serve loop; use serveG.check(): + serveG http2goroutineLock // used to verify funcs are on serve() + pushEnabled bool + sawFirstSettings bool // got the initial SETTINGS frame after the preface + needToSendSettingsAck bool + unackedSettings int // how many SETTINGS have we sent without ACKs? + queuedControlFrames int // control frames in the writeSched queue + clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) + advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client + curClientStreams uint32 // number of open streams initiated by the client + curPushedStreams uint32 // number of open streams initiated by server push + maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests + maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes + streams map[uint32]*http2stream + initialStreamSendWindowSize int32 + maxFrameSize int32 + headerTableSize uint32 + peerMaxHeaderListSize uint32 // zero means unknown (default) + canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case + writingFrame bool // started writing a frame (on serve goroutine or separate) + writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh + needsFrameFlush bool // last frame write wasn't a flush + inGoAway bool // we've started to or sent GOAWAY + inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop + needToSendGoAway bool // we need to schedule a GOAWAY frame write + goAwayCode http2ErrCode + shutdownTimer *time.Timer // nil until used + idleTimer *time.Timer // nil if unused + + // Owned by the writeFrameAsync goroutine: + headerWriteBuf bytes.Buffer + hpackEncoder *hpack.Encoder + + // Used by startGracefulShutdown. + shutdownOnce sync.Once +} + +func (sc *http2serverConn) maxHeaderListSize() uint32 { + n := sc.hs.MaxHeaderBytes + if n <= 0 { + n = DefaultMaxHeaderBytes + } + // http2's count is in a slightly different unit and includes 32 bytes per pair. + // So, take the net/http.Server value and pad it up a bit, assuming 10 headers. + const perFieldOverhead = 32 // per http2 spec + const typicalHeaders = 10 // conservative + return uint32(n + typicalHeaders*perFieldOverhead) +} + +func (sc *http2serverConn) curOpenStreams() uint32 { + sc.serveG.check() + return sc.curClientStreams + sc.curPushedStreams +} + +// stream represents a stream. This is the minimal metadata needed by +// the serve goroutine. Most of the actual stream state is owned by +// the http.Handler's goroutine in the responseWriter. Because the +// responseWriter's responseWriterState is recycled at the end of a +// handler, this struct intentionally has no pointer to the +// *responseWriter{,State} itself, as the Handler ending nils out the +// responseWriter's state field. +type http2stream struct { + // immutable: + sc *http2serverConn + id uint32 + body *http2pipe // non-nil if expecting DATA frames + cw http2closeWaiter // closed wait stream transitions to closed state + ctx context.Context + cancelCtx func() + + // owned by serverConn's serve loop: + bodyBytes int64 // body bytes seen so far + declBodyBytes int64 // or -1 if undeclared + flow http2flow // limits writing from Handler to client + inflow http2flow // what the client is allowed to POST/etc to us + state http2streamState + resetQueued bool // RST_STREAM queued for write; set by sc.resetStream + gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + writeDeadline *time.Timer // nil if unused + + trailer Header // accumulated trailers + reqTrailer Header // handler's Request.Trailer +} + +func (sc *http2serverConn) Framer() *http2Framer { return sc.framer } + +func (sc *http2serverConn) CloseConn() error { return sc.conn.Close() } + +func (sc *http2serverConn) Flush() error { return sc.bw.Flush() } + +func (sc *http2serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) { + return sc.hpackEncoder, &sc.headerWriteBuf +} + +func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2stream) { + sc.serveG.check() + // http://tools.ietf.org/html/rfc7540#section-5.1 + if st, ok := sc.streams[streamID]; ok { + return st.state, st + } + // "The first use of a new stream identifier implicitly closes all + // streams in the "idle" state that might have been initiated by + // that peer with a lower-valued stream identifier. For example, if + // a client sends a HEADERS frame on stream 7 without ever sending a + // frame on stream 5, then stream 5 transitions to the "closed" + // state when the first frame for stream 7 is sent or received." + if streamID%2 == 1 { + if streamID <= sc.maxClientStreamID { + return http2stateClosed, nil + } + } else { + if streamID <= sc.maxPushPromiseID { + return http2stateClosed, nil + } + } + return http2stateIdle, nil +} + +// setConnState calls the net/http ConnState hook for this connection, if configured. +// Note that the net/http package does StateNew and StateClosed for us. +// There is currently no plan for StateHijacked or hijacking HTTP/2 connections. +func (sc *http2serverConn) setConnState(state ConnState) { + if sc.hs.ConnState != nil { + sc.hs.ConnState(sc.conn, state) + } +} + +func (sc *http2serverConn) vlogf(format string, args ...interface{}) { + if http2VerboseLogs { + sc.logf(format, args...) + } +} + +func (sc *http2serverConn) logf(format string, args ...interface{}) { + if lg := sc.hs.ErrorLog; lg != nil { + lg.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +// errno returns v's underlying uintptr, else 0. +// +// TODO: remove this helper function once http2 can use build +// tags. See comment in isClosedConnError. +func http2errno(v error) uintptr { + if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr { + return uintptr(rv.Uint()) + } + return 0 +} + +// isClosedConnError reports whether err is an error from use of a closed +// network connection. +func http2isClosedConnError(err error) bool { + if err == nil { + return false + } + + // TODO: remove this string search and be more like the Windows + // case below. That might involve modifying the standard library + // to return better error types. + str := err.Error() + if strings.Contains(str, "use of closed network connection") { + return true + } + + // TODO(bradfitz): x/tools/cmd/bundle doesn't really support + // build tags, so I can't make an http2_windows.go file with + // Windows-specific stuff. Fix that and move this, once we + // have a way to bundle this into std's net/http somehow. + if runtime.GOOS == "windows" { + if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { + if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" { + const WSAECONNABORTED = 10053 + const WSAECONNRESET = 10054 + if n := http2errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED { + return true + } + } + } + } + return false +} + +func (sc *http2serverConn) condlogf(err error, format string, args ...interface{}) { + if err == nil { + return + } + if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) || err == http2errPrefaceTimeout { + // Boring, expected errors. + sc.vlogf(format, args...) + } else { + sc.logf(format, args...) + } +} + +func (sc *http2serverConn) canonicalHeader(v string) string { + sc.serveG.check() + http2buildCommonHeaderMapsOnce() + cv, ok := http2commonCanonHeader[v] + if ok { + return cv + } + cv, ok = sc.canonHeader[v] + if ok { + return cv + } + if sc.canonHeader == nil { + sc.canonHeader = make(map[string]string) + } + cv = CanonicalHeaderKey(v) + sc.canonHeader[v] = cv + return cv +} + +type http2readFrameResult struct { + f http2Frame // valid until readMore is called + err error + + // readMore should be called once the consumer no longer needs or + // retains f. After readMore, f is invalid and more frames can be + // read. + readMore func() +} + +// readFrames is the loop that reads incoming frames. +// It takes care to only read one frame at a time, blocking until the +// consumer is done with the frame. +// It's run on its own goroutine. +func (sc *http2serverConn) readFrames() { + gate := make(http2gate) + gateDone := gate.Done + for { + f, err := sc.framer.ReadFrame() + select { + case sc.readFrameCh <- http2readFrameResult{f, err, gateDone}: + case <-sc.doneServing: + return + } + select { + case <-gate: + case <-sc.doneServing: + return + } + if http2terminalReadFrameError(err) { + return + } + } +} + +// frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. +type http2frameWriteResult struct { + _ http2incomparable + wr http2FrameWriteRequest // what was written (or attempted) + err error // result of the writeFrame call +} + +// writeFrameAsync runs in its own goroutine and writes a single frame +// and then reports when it's done. +// At most one goroutine can be running writeFrameAsync at a time per +// serverConn. +func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) { + err := wr.write.writeFrame(sc) + sc.wroteFrameCh <- http2frameWriteResult{wr: wr, err: err} +} + +func (sc *http2serverConn) closeAllStreamsOnConnClose() { + sc.serveG.check() + for _, st := range sc.streams { + sc.closeStream(st, http2errClientDisconnected) + } +} + +func (sc *http2serverConn) stopShutdownTimer() { + sc.serveG.check() + if t := sc.shutdownTimer; t != nil { + t.Stop() + } +} + +func (sc *http2serverConn) notePanic() { + // Note: this is for serverConn.serve panicking, not http.Handler code. + if http2testHookOnPanicMu != nil { + http2testHookOnPanicMu.Lock() + defer http2testHookOnPanicMu.Unlock() + } + if http2testHookOnPanic != nil { + if e := recover(); e != nil { + if http2testHookOnPanic(sc, e) { + panic(e) + } + } + } +} + +func (sc *http2serverConn) serve() { + sc.serveG.check() + defer sc.notePanic() + defer sc.conn.Close() + defer sc.closeAllStreamsOnConnClose() + defer sc.stopShutdownTimer() + defer close(sc.doneServing) // unblocks handlers trying to send + + if http2VerboseLogs { + sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) + } + + sc.writeFrame(http2FrameWriteRequest{ + write: http2writeSettings{ + {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, + {http2SettingMaxConcurrentStreams, sc.advMaxStreams}, + {http2SettingMaxHeaderListSize, sc.maxHeaderListSize()}, + {http2SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, + }, + }) + sc.unackedSettings++ + + // Each connection starts with intialWindowSize inflow tokens. + // If a higher value is configured, we add more tokens. + if diff := sc.srv.initialConnRecvWindowSize() - http2initialWindowSize; diff > 0 { + sc.sendWindowUpdate(nil, int(diff)) + } + + if err := sc.readPreface(); err != nil { + sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err) + return + } + // Now that we've got the preface, get us out of the + // "StateNew" state. We can't go directly to idle, though. + // Active means we read some data and anticipate a request. We'll + // do another Active when we get a HEADERS frame. + sc.setConnState(StateActive) + sc.setConnState(StateIdle) + + if sc.srv.IdleTimeout != 0 { + sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) + defer sc.idleTimer.Stop() + } + + go sc.readFrames() // closed by defer sc.conn.Close above + + settingsTimer := time.AfterFunc(http2firstSettingsTimeout, sc.onSettingsTimer) + defer settingsTimer.Stop() + + loopNum := 0 + for { + loopNum++ + select { + case wr := <-sc.wantWriteFrameCh: + if se, ok := wr.write.(http2StreamError); ok { + sc.resetStream(se) + break + } + sc.writeFrame(wr) + case res := <-sc.wroteFrameCh: + sc.wroteFrame(res) + case res := <-sc.readFrameCh: + if !sc.processFrameFromReader(res) { + return + } + res.readMore() + if settingsTimer != nil { + settingsTimer.Stop() + settingsTimer = nil + } + case m := <-sc.bodyReadCh: + sc.noteBodyRead(m.st, m.n) + case msg := <-sc.serveMsgCh: + switch v := msg.(type) { + case func(int): + v(loopNum) // for testing + case *http2serverMessage: + switch v { + case http2settingsTimerMsg: + sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) + return + case http2idleTimerMsg: + sc.vlogf("connection is idle") + sc.goAway(http2ErrCodeNo) + case http2shutdownTimerMsg: + sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) + return + case http2gracefulShutdownMsg: + sc.startGracefulShutdownInternal() + default: + panic("unknown timer") + } + case *http2startPushRequest: + sc.startPush(v) + default: + panic(fmt.Sprintf("unexpected type %T", v)) + } + } + + // If the peer is causing us to generate a lot of control frames, + // but not reading them from us, assume they are trying to make us + // run out of memory. + if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() { + sc.vlogf("http2: too many control frames in send queue, closing connection") + return + } + + // Start the shutdown timer after sending a GOAWAY. When sending GOAWAY + // with no error code (graceful shutdown), don't start the timer until + // all open streams have been completed. + sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame + gracefulShutdownComplete := sc.goAwayCode == http2ErrCodeNo && sc.curOpenStreams() == 0 + if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != http2ErrCodeNo || gracefulShutdownComplete) { + sc.shutDownIn(http2goAwayTimeout) + } + } +} + +func (sc *http2serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh chan struct{}) { + select { + case <-sc.doneServing: + case <-sharedCh: + close(privateCh) + } +} + +type http2serverMessage int + +// Message values sent to serveMsgCh. +var ( + http2settingsTimerMsg = new(http2serverMessage) + http2idleTimerMsg = new(http2serverMessage) + http2shutdownTimerMsg = new(http2serverMessage) + http2gracefulShutdownMsg = new(http2serverMessage) +) + +func (sc *http2serverConn) onSettingsTimer() { sc.sendServeMsg(http2settingsTimerMsg) } + +func (sc *http2serverConn) onIdleTimer() { sc.sendServeMsg(http2idleTimerMsg) } + +func (sc *http2serverConn) onShutdownTimer() { sc.sendServeMsg(http2shutdownTimerMsg) } + +func (sc *http2serverConn) sendServeMsg(msg interface{}) { + sc.serveG.checkNotOn() // NOT + select { + case sc.serveMsgCh <- msg: + case <-sc.doneServing: + } +} + +var http2errPrefaceTimeout = errors.New("timeout waiting for client preface") + +// readPreface reads the ClientPreface greeting from the peer or +// returns errPrefaceTimeout on timeout, or an error if the greeting +// is invalid. +func (sc *http2serverConn) readPreface() error { + errc := make(chan error, 1) + go func() { + // Read the client preface + buf := make([]byte, len(http2ClientPreface)) + if _, err := io.ReadFull(sc.conn, buf); err != nil { + errc <- err + } else if !bytes.Equal(buf, http2clientPreface) { + errc <- fmt.Errorf("bogus greeting %q", buf) + } else { + errc <- nil + } + }() + timer := time.NewTimer(http2prefaceTimeout) // TODO: configurable on *Server? + defer timer.Stop() + select { + case <-timer.C: + return http2errPrefaceTimeout + case err := <-errc: + if err == nil { + if http2VerboseLogs { + sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr()) + } + } + return err + } +} + +var http2errChanPool = sync.Pool{ + New: func() interface{} { return make(chan error, 1) }, +} + +var http2writeDataPool = sync.Pool{ + New: func() interface{} { return new(http2writeData) }, +} + +// writeDataFromHandler writes DATA response frames from a handler on +// the given stream. +func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte, endStream bool) error { + ch := http2errChanPool.Get().(chan error) + writeArg := http2writeDataPool.Get().(*http2writeData) + *writeArg = http2writeData{stream.id, data, endStream} + err := sc.writeFrameFromHandler(http2FrameWriteRequest{ + write: writeArg, + stream: stream, + done: ch, + }) + if err != nil { + return err + } + var frameWriteDone bool // the frame write is done (successfully or not) + select { + case err = <-ch: + frameWriteDone = true + case <-sc.doneServing: + return http2errClientDisconnected + case <-stream.cw: + // If both ch and stream.cw were ready (as might + // happen on the final Write after an http.Handler + // ends), prefer the write result. Otherwise this + // might just be us successfully closing the stream. + // The writeFrameAsync and serve goroutines guarantee + // that the ch send will happen before the stream.cw + // close. + select { + case err = <-ch: + frameWriteDone = true + default: + return http2errStreamClosed + } + } + http2errChanPool.Put(ch) + if frameWriteDone { + http2writeDataPool.Put(writeArg) + } + return err +} + +// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts +// if the connection has gone away. +// +// This must not be run from the serve goroutine itself, else it might +// deadlock writing to sc.wantWriteFrameCh (which is only mildly +// buffered and is read by serve itself). If you're on the serve +// goroutine, call writeFrame instead. +func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error { + sc.serveG.checkNotOn() // NOT + select { + case sc.wantWriteFrameCh <- wr: + return nil + case <-sc.doneServing: + // Serve loop is gone. + // Client has closed their connection to the server. + return http2errClientDisconnected + } +} + +// writeFrame schedules a frame to write and sends it if there's nothing +// already being written. +// +// There is no pushback here (the serve goroutine never blocks). It's +// the http.Handlers that block, waiting for their previous frames to +// make it onto the wire +// +// If you're not on the serve goroutine, use writeFrameFromHandler instead. +func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { + sc.serveG.check() + + // If true, wr will not be written and wr.done will not be signaled. + var ignoreWrite bool + + // We are not allowed to write frames on closed streams. RFC 7540 Section + // 5.1.1 says: "An endpoint MUST NOT send frames other than PRIORITY on + // a closed stream." Our server never sends PRIORITY, so that exception + // does not apply. + // + // The serverConn might close an open stream while the stream's handler + // is still running. For example, the server might close a stream when it + // receives bad data from the client. If this happens, the handler might + // attempt to write a frame after the stream has been closed (since the + // handler hasn't yet been notified of the close). In this case, we simply + // ignore the frame. The handler will notice that the stream is closed when + // it waits for the frame to be written. + // + // As an exception to this rule, we allow sending RST_STREAM after close. + // This allows us to immediately reject new streams without tracking any + // state for those streams (except for the queued RST_STREAM frame). This + // may result in duplicate RST_STREAMs in some cases, but the client should + // ignore those. + if wr.StreamID() != 0 { + _, isReset := wr.write.(http2StreamError) + if state, _ := sc.state(wr.StreamID()); state == http2stateClosed && !isReset { + ignoreWrite = true + } + } + + // Don't send a 100-continue response if we've already sent headers. + // See golang.org/issue/14030. + switch wr.write.(type) { + case *http2writeResHeaders: + wr.stream.wroteHeaders = true + case http2write100ContinueHeadersFrame: + if wr.stream.wroteHeaders { + // We do not need to notify wr.done because this frame is + // never written with wr.done != nil. + if wr.done != nil { + panic("wr.done != nil for write100ContinueHeadersFrame") + } + ignoreWrite = true + } + } + + if !ignoreWrite { + if wr.isControl() { + sc.queuedControlFrames++ + // For extra safety, detect wraparounds, which should not happen, + // and pull the plug. + if sc.queuedControlFrames < 0 { + sc.conn.Close() + } + } + sc.writeSched.Push(wr) + } + sc.scheduleFrameWrite() +} + +// startFrameWrite starts a goroutine to write wr (in a separate +// goroutine since that might block on the network), and updates the +// serve goroutine's state about the world, updated from info in wr. +func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { + sc.serveG.check() + if sc.writingFrame { + panic("internal error: can only be writing one frame at a time") + } + + st := wr.stream + if st != nil { + switch st.state { + case http2stateHalfClosedLocal: + switch wr.write.(type) { + case http2StreamError, http2handlerPanicRST, http2writeWindowUpdate: + // RFC 7540 Section 5.1 allows sending RST_STREAM, PRIORITY, and WINDOW_UPDATE + // in this state. (We never send PRIORITY from the server, so that is not checked.) + default: + panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr)) + } + case http2stateClosed: + panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr)) + } + } + if wpp, ok := wr.write.(*http2writePushPromise); ok { + var err error + wpp.promisedID, err = wpp.allocatePromisedID() + if err != nil { + sc.writingFrameAsync = false + wr.replyToWriter(err) + return + } + } + + sc.writingFrame = true + sc.needsFrameFlush = true + if wr.write.staysWithinBuffer(sc.bw.Available()) { + sc.writingFrameAsync = false + err := wr.write.writeFrame(sc) + sc.wroteFrame(http2frameWriteResult{wr: wr, err: err}) + } else { + sc.writingFrameAsync = true + go sc.writeFrameAsync(wr) + } +} + +// errHandlerPanicked is the error given to any callers blocked in a read from +// Request.Body when the main goroutine panics. Since most handlers read in the +// main ServeHTTP goroutine, this will show up rarely. +var http2errHandlerPanicked = errors.New("http2: handler panicked") + +// wroteFrame is called on the serve goroutine with the result of +// whatever happened on writeFrameAsync. +func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { + sc.serveG.check() + if !sc.writingFrame { + panic("internal error: expected to be already writing a frame") + } + sc.writingFrame = false + sc.writingFrameAsync = false + + wr := res.wr + + if http2writeEndsStream(wr.write) { + st := wr.stream + if st == nil { + panic("internal error: expecting non-nil stream") + } + switch st.state { + case http2stateOpen: + // Here we would go to stateHalfClosedLocal in + // theory, but since our handler is done and + // the net/http package provides no mechanism + // for closing a ResponseWriter while still + // reading data (see possible TODO at top of + // this file), we go into closed state here + // anyway, after telling the peer we're + // hanging up on them. We'll transition to + // stateClosed after the RST_STREAM frame is + // written. + st.state = http2stateHalfClosedLocal + // Section 8.1: a server MAY request that the client abort + // transmission of a request without error by sending a + // RST_STREAM with an error code of NO_ERROR after sending + // a complete response. + sc.resetStream(http2streamError(st.id, http2ErrCodeNo)) + case http2stateHalfClosedRemote: + sc.closeStream(st, http2errHandlerComplete) + } + } else { + switch v := wr.write.(type) { + case http2StreamError: + // st may be unknown if the RST_STREAM was generated to reject bad input. + if st, ok := sc.streams[v.StreamID]; ok { + sc.closeStream(st, v) + } + case http2handlerPanicRST: + sc.closeStream(wr.stream, http2errHandlerPanicked) + } + } + + // Reply (if requested) to unblock the ServeHTTP goroutine. + wr.replyToWriter(res.err) + + sc.scheduleFrameWrite() +} + +// scheduleFrameWrite tickles the frame writing scheduler. +// +// If a frame is already being written, nothing happens. This will be called again +// when the frame is done being written. +// +// If a frame isn't being written and we need to send one, the best frame +// to send is selected by writeSched. +// +// If a frame isn't being written and there's nothing else to send, we +// flush the write buffer. +func (sc *http2serverConn) scheduleFrameWrite() { + sc.serveG.check() + if sc.writingFrame || sc.inFrameScheduleLoop { + return + } + sc.inFrameScheduleLoop = true + for !sc.writingFrameAsync { + if sc.needToSendGoAway { + sc.needToSendGoAway = false + sc.startFrameWrite(http2FrameWriteRequest{ + write: &http2writeGoAway{ + maxStreamID: sc.maxClientStreamID, + code: sc.goAwayCode, + }, + }) + continue + } + if sc.needToSendSettingsAck { + sc.needToSendSettingsAck = false + sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}}) + continue + } + if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo { + if wr, ok := sc.writeSched.Pop(); ok { + if wr.isControl() { + sc.queuedControlFrames-- + } + sc.startFrameWrite(wr) + continue + } + } + if sc.needsFrameFlush { + sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}}) + sc.needsFrameFlush = false // after startFrameWrite, since it sets this true + continue + } + break + } + sc.inFrameScheduleLoop = false +} + +// startGracefulShutdown gracefully shuts down a connection. This +// sends GOAWAY with ErrCodeNo to tell the client we're gracefully +// shutting down. The connection isn't closed until all current +// streams are done. +// +// startGracefulShutdown returns immediately; it does not wait until +// the connection has shut down. +func (sc *http2serverConn) startGracefulShutdown() { + sc.serveG.checkNotOn() // NOT + sc.shutdownOnce.Do(func() { sc.sendServeMsg(http2gracefulShutdownMsg) }) +} + +// After sending GOAWAY, the connection will close after goAwayTimeout. +// If we close the connection immediately after sending GOAWAY, there may +// be unsent data in our kernel receive buffer, which will cause the kernel +// to send a TCP RST on close() instead of a FIN. This RST will abort the +// connection immediately, whether or not the client had received the GOAWAY. +// +// Ideally we should delay for at least 1 RTT + epsilon so the client has +// a chance to read the GOAWAY and stop sending messages. Measuring RTT +// is hard, so we approximate with 1 second. See golang.org/issue/18701. +// +// This is a var so it can be shorter in tests, where all requests uses the +// loopback interface making the expected RTT very small. +// +// TODO: configurable? +var http2goAwayTimeout = 1 * time.Second + +func (sc *http2serverConn) startGracefulShutdownInternal() { + sc.goAway(http2ErrCodeNo) +} + +func (sc *http2serverConn) goAway(code http2ErrCode) { + sc.serveG.check() + if sc.inGoAway { + return + } + sc.inGoAway = true + sc.needToSendGoAway = true + sc.goAwayCode = code + sc.scheduleFrameWrite() +} + +func (sc *http2serverConn) shutDownIn(d time.Duration) { + sc.serveG.check() + sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer) +} + +func (sc *http2serverConn) resetStream(se http2StreamError) { + sc.serveG.check() + sc.writeFrame(http2FrameWriteRequest{write: se}) + if st, ok := sc.streams[se.StreamID]; ok { + st.resetQueued = true + } +} + +// processFrameFromReader processes the serve loop's read from readFrameCh from the +// frame-reading goroutine. +// processFrameFromReader returns whether the connection should be kept open. +func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool { + sc.serveG.check() + err := res.err + if err != nil { + if err == http2ErrFrameTooLarge { + sc.goAway(http2ErrCodeFrameSize) + return true // goAway will close the loop + } + clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) + if clientGone { + // TODO: could we also get into this state if + // the peer does a half close + // (e.g. CloseWrite) because they're done + // sending frames but they're still wanting + // our open replies? Investigate. + // TODO: add CloseWrite to crypto/tls.Conn first + // so we have a way to test this? I suppose + // just for testing we could have a non-TLS mode. + return false + } + } else { + f := res.f + if http2VerboseLogs { + sc.vlogf("http2: server read frame %v", http2summarizeFrame(f)) + } + err = sc.processFrame(f) + if err == nil { + return true + } + } + + switch ev := err.(type) { + case http2StreamError: + sc.resetStream(ev) + return true + case http2goAwayFlowError: + sc.goAway(http2ErrCodeFlowControl) + return true + case http2ConnectionError: + sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev) + sc.goAway(http2ErrCode(ev)) + return true // goAway will handle shutdown + default: + if res.err != nil { + sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err) + } else { + sc.logf("http2: server closing client connection: %v", err) + } + return false + } +} + +func (sc *http2serverConn) processFrame(f http2Frame) error { + sc.serveG.check() + + // First frame received must be SETTINGS. + if !sc.sawFirstSettings { + if _, ok := f.(*http2SettingsFrame); !ok { + return http2ConnectionError(http2ErrCodeProtocol) + } + sc.sawFirstSettings = true + } + + switch f := f.(type) { + case *http2SettingsFrame: + return sc.processSettings(f) + case *http2MetaHeadersFrame: + return sc.processHeaders(f) + case *http2WindowUpdateFrame: + return sc.processWindowUpdate(f) + case *http2PingFrame: + return sc.processPing(f) + case *http2DataFrame: + return sc.processData(f) + case *http2RSTStreamFrame: + return sc.processResetStream(f) + case *http2PriorityFrame: + return sc.processPriority(f) + case *http2GoAwayFrame: + return sc.processGoAway(f) + case *http2PushPromiseFrame: + // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE + // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. + return http2ConnectionError(http2ErrCodeProtocol) + default: + sc.vlogf("http2: server ignoring frame: %v", f.Header()) + return nil + } +} + +func (sc *http2serverConn) processPing(f *http2PingFrame) error { + sc.serveG.check() + if f.IsAck() { + // 6.7 PING: " An endpoint MUST NOT respond to PING frames + // containing this flag." + return nil + } + if f.StreamID != 0 { + // "PING frames are not associated with any individual + // stream. If a PING frame is received with a stream + // identifier field value other than 0x0, the recipient MUST + // respond with a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR." + return http2ConnectionError(http2ErrCodeProtocol) + } + if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { + return nil + } + sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}}) + return nil +} + +func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error { + sc.serveG.check() + switch { + case f.StreamID != 0: // stream-level flow control + state, st := sc.state(f.StreamID) + if state == http2stateIdle { + // Section 5.1: "Receiving any frame other than HEADERS + // or PRIORITY on a stream in this state MUST be + // treated as a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR." + return http2ConnectionError(http2ErrCodeProtocol) + } + if st == nil { + // "WINDOW_UPDATE can be sent by a peer that has sent a + // frame bearing the END_STREAM flag. This means that a + // receiver could receive a WINDOW_UPDATE frame on a "half + // closed (remote)" or "closed" stream. A receiver MUST + // NOT treat this as an error, see Section 5.1." + return nil + } + if !st.flow.add(int32(f.Increment)) { + return http2streamError(f.StreamID, http2ErrCodeFlowControl) + } + default: // connection-level flow control + if !sc.flow.add(int32(f.Increment)) { + return http2goAwayFlowError{} + } + } + sc.scheduleFrameWrite() + return nil +} + +func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { + sc.serveG.check() + + state, st := sc.state(f.StreamID) + if state == http2stateIdle { + // 6.4 "RST_STREAM frames MUST NOT be sent for a + // stream in the "idle" state. If a RST_STREAM frame + // identifying an idle stream is received, the + // recipient MUST treat this as a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + return http2ConnectionError(http2ErrCodeProtocol) + } + if st != nil { + st.cancelCtx() + sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode)) + } + return nil +} + +func (sc *http2serverConn) closeStream(st *http2stream, err error) { + sc.serveG.check() + if st.state == http2stateIdle || st.state == http2stateClosed { + panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) + } + st.state = http2stateClosed + if st.writeDeadline != nil { + st.writeDeadline.Stop() + } + if st.isPushed() { + sc.curPushedStreams-- + } else { + sc.curClientStreams-- + } + delete(sc.streams, st.id) + if len(sc.streams) == 0 { + sc.setConnState(StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer.Reset(sc.srv.IdleTimeout) + } + if http2h1ServerKeepAlivesDisabled(sc.hs) { + sc.startGracefulShutdownInternal() + } + } + if p := st.body; p != nil { + // Return any buffered unread bytes worth of conn-level flow control. + // See golang.org/issue/16481 + sc.sendWindowUpdate(nil, p.Len()) + + p.CloseWithError(err) + } + st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc + sc.writeSched.CloseStream(st.id) +} + +func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { + sc.serveG.check() + if f.IsAck() { + sc.unackedSettings-- + if sc.unackedSettings < 0 { + // Why is the peer ACKing settings we never sent? + // The spec doesn't mention this case, but + // hang up on them anyway. + return http2ConnectionError(http2ErrCodeProtocol) + } + return nil + } + if f.NumSettings() > 100 || f.HasDuplicates() { + // This isn't actually in the spec, but hang up on + // suspiciously large settings frames or those with + // duplicate entries. + return http2ConnectionError(http2ErrCodeProtocol) + } + if err := f.ForeachSetting(sc.processSetting); err != nil { + return err + } + // TODO: judging by RFC 7540, Section 6.5.3 each SETTINGS frame should be + // acknowledged individually, even if multiple are received before the ACK. + sc.needToSendSettingsAck = true + sc.scheduleFrameWrite() + return nil +} + +func (sc *http2serverConn) processSetting(s http2Setting) error { + sc.serveG.check() + if err := s.Valid(); err != nil { + return err + } + if http2VerboseLogs { + sc.vlogf("http2: server processing setting %v", s) + } + switch s.ID { + case http2SettingHeaderTableSize: + sc.headerTableSize = s.Val + sc.hpackEncoder.SetMaxDynamicTableSize(s.Val) + case http2SettingEnablePush: + sc.pushEnabled = s.Val != 0 + case http2SettingMaxConcurrentStreams: + sc.clientMaxStreams = s.Val + case http2SettingInitialWindowSize: + return sc.processSettingInitialWindowSize(s.Val) + case http2SettingMaxFrameSize: + sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31 + case http2SettingMaxHeaderListSize: + sc.peerMaxHeaderListSize = s.Val + default: + // Unknown setting: "An endpoint that receives a SETTINGS + // frame with any unknown or unsupported identifier MUST + // ignore that setting." + if http2VerboseLogs { + sc.vlogf("http2: server ignoring unknown setting %v", s) + } + } + return nil +} + +func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { + sc.serveG.check() + // Note: val already validated to be within range by + // processSetting's Valid call. + + // "A SETTINGS frame can alter the initial flow control window + // size for all current streams. When the value of + // SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST + // adjust the size of all stream flow control windows that it + // maintains by the difference between the new value and the + // old value." + old := sc.initialStreamSendWindowSize + sc.initialStreamSendWindowSize = int32(val) + growth := int32(val) - old // may be negative + for _, st := range sc.streams { + if !st.flow.add(growth) { + // 6.9.2 Initial Flow Control Window Size + // "An endpoint MUST treat a change to + // SETTINGS_INITIAL_WINDOW_SIZE that causes any flow + // control window to exceed the maximum size as a + // connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR." + return http2ConnectionError(http2ErrCodeFlowControl) + } + } + return nil +} + +func (sc *http2serverConn) processData(f *http2DataFrame) error { + sc.serveG.check() + if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { + return nil + } + data := f.Data() + + // "If a DATA frame is received whose stream is not in "open" + // or "half closed (local)" state, the recipient MUST respond + // with a stream error (Section 5.4.2) of type STREAM_CLOSED." + id := f.Header().StreamID + state, st := sc.state(id) + if id == 0 || state == http2stateIdle { + // Section 5.1: "Receiving any frame other than HEADERS + // or PRIORITY on a stream in this state MUST be + // treated as a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR." + return http2ConnectionError(http2ErrCodeProtocol) + } + if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued { + // This includes sending a RST_STREAM if the stream is + // in stateHalfClosedLocal (which currently means that + // the http.Handler returned, so it's done reading & + // done writing). Try to stop the client from sending + // more DATA. + + // But still enforce their connection-level flow control, + // and return any flow control bytes since we're not going + // to consume them. + if sc.inflow.available() < int32(f.Length) { + return http2streamError(id, http2ErrCodeFlowControl) + } + // Deduct the flow control from inflow, since we're + // going to immediately add it back in + // sendWindowUpdate, which also schedules sending the + // frames. + sc.inflow.take(int32(f.Length)) + sc.sendWindowUpdate(nil, int(f.Length)) // conn-level + + if st != nil && st.resetQueued { + // Already have a stream error in flight. Don't send another. + return nil + } + return http2streamError(id, http2ErrCodeStreamClosed) + } + if st.body == nil { + panic("internal error: should have a body in this state") + } + + // Sender sending more than they'd declared? + if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { + st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) + // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the + // value of a content-length header field does not equal the sum of the + // DATA frame payload lengths that form the body. + return http2streamError(id, http2ErrCodeProtocol) + } + if f.Length > 0 { + // Check whether the client has flow control quota. + if st.inflow.available() < int32(f.Length) { + return http2streamError(id, http2ErrCodeFlowControl) + } + st.inflow.take(int32(f.Length)) + + if len(data) > 0 { + wrote, err := st.body.Write(data) + if err != nil { + sc.sendWindowUpdate(nil, int(f.Length)-wrote) + return http2streamError(id, http2ErrCodeStreamClosed) + } + if wrote != len(data) { + panic("internal error: bad Writer") + } + st.bodyBytes += int64(len(data)) + } + + // Return any padded flow control now, since we won't + // refund it later on body reads. + if pad := int32(f.Length) - int32(len(data)); pad > 0 { + sc.sendWindowUpdate32(nil, pad) + sc.sendWindowUpdate32(st, pad) + } + } + if f.StreamEnded() { + st.endStream() + } + return nil +} + +func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error { + sc.serveG.check() + if f.ErrCode != http2ErrCodeNo { + sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } else { + sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } + sc.startGracefulShutdownInternal() + // http://tools.ietf.org/html/rfc7540#section-6.8 + // We should not create any new streams, which means we should disable push. + sc.pushEnabled = false + return nil +} + +// isPushed reports whether the stream is server-initiated. +func (st *http2stream) isPushed() bool { + return st.id%2 == 0 +} + +// endStream closes a Request.Body's pipe. It is called when a DATA +// frame says a request body is over (or after trailers). +func (st *http2stream) endStream() { + sc := st.sc + sc.serveG.check() + + if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes { + st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes", + st.declBodyBytes, st.bodyBytes)) + } else { + st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest) + st.body.CloseWithError(io.EOF) + } + st.state = http2stateHalfClosedRemote +} + +// copyTrailersToHandlerRequest is run in the Handler's goroutine in +// its Request.Body.Read just before it gets io.EOF. +func (st *http2stream) copyTrailersToHandlerRequest() { + for k, vv := range st.trailer { + if _, ok := st.reqTrailer[k]; ok { + // Only copy it over it was pre-declared. + st.reqTrailer[k] = vv + } + } +} + +// onWriteTimeout is run on its own goroutine (from time.AfterFunc) +// when the stream's WriteTimeout has fired. +func (st *http2stream) onWriteTimeout() { + st.sc.writeFrameFromHandler(http2FrameWriteRequest{write: http2streamError(st.id, http2ErrCodeInternal)}) +} + +func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { + sc.serveG.check() + id := f.StreamID + if sc.inGoAway { + // Ignore. + return nil + } + // http://tools.ietf.org/html/rfc7540#section-5.1.1 + // Streams initiated by a client MUST use odd-numbered stream + // identifiers. [...] An endpoint that receives an unexpected + // stream identifier MUST respond with a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + if id%2 != 1 { + return http2ConnectionError(http2ErrCodeProtocol) + } + // A HEADERS frame can be used to create a new stream or + // send a trailer for an open one. If we already have a stream + // open, let it process its own HEADERS frame (trailers at this + // point, if it's valid). + if st := sc.streams[f.StreamID]; st != nil { + if st.resetQueued { + // We're sending RST_STREAM to close the stream, so don't bother + // processing this frame. + return nil + } + // RFC 7540, sec 5.1: If an endpoint receives additional frames, other than + // WINDOW_UPDATE, PRIORITY, or RST_STREAM, for a stream that is in + // this state, it MUST respond with a stream error (Section 5.4.2) of + // type STREAM_CLOSED. + if st.state == http2stateHalfClosedRemote { + return http2streamError(id, http2ErrCodeStreamClosed) + } + return st.processTrailerHeaders(f) + } + + // [...] The identifier of a newly established stream MUST be + // numerically greater than all streams that the initiating + // endpoint has opened or reserved. [...] An endpoint that + // receives an unexpected stream identifier MUST respond with + // a connection error (Section 5.4.1) of type PROTOCOL_ERROR. + if id <= sc.maxClientStreamID { + return http2ConnectionError(http2ErrCodeProtocol) + } + sc.maxClientStreamID = id + + if sc.idleTimer != nil { + sc.idleTimer.Stop() + } + + // http://tools.ietf.org/html/rfc7540#section-5.1.2 + // [...] Endpoints MUST NOT exceed the limit set by their peer. An + // endpoint that receives a HEADERS frame that causes their + // advertised concurrent stream limit to be exceeded MUST treat + // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR + // or REFUSED_STREAM. + if sc.curClientStreams+1 > sc.advMaxStreams { + if sc.unackedSettings == 0 { + // They should know better. + return http2streamError(id, http2ErrCodeProtocol) + } + // Assume it's a network race, where they just haven't + // received our last SETTINGS update. But actually + // this can't happen yet, because we don't yet provide + // a way for users to adjust server parameters at + // runtime. + return http2streamError(id, http2ErrCodeRefusedStream) + } + + initialState := http2stateOpen + if f.StreamEnded() { + initialState = http2stateHalfClosedRemote + } + st := sc.newStream(id, 0, initialState) + + if f.HasPriority() { + if err := http2checkPriority(f.StreamID, f.Priority); err != nil { + return err + } + sc.writeSched.AdjustStream(st.id, f.Priority) + } + + rw, req, err := sc.newWriterAndRequest(st, f) + if err != nil { + return err + } + st.reqTrailer = req.Trailer + if st.reqTrailer != nil { + st.trailer = make(Header) + } + st.body = req.Body.(*http2requestBody).pipe // may be nil + st.declBodyBytes = req.ContentLength + + handler := sc.handler.ServeHTTP + if f.Truncated { + // Their header list was too long. Send a 431 error. + handler = http2handleHeaderListTooLong + } else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil { + handler = http2new400Handler(err) + } + + // The net/http package sets the read deadline from the + // http.Server.ReadTimeout during the TLS handshake, but then + // passes the connection off to us with the deadline already + // set. Disarm it here after the request headers are read, + // similar to how the http1 server works. Here it's + // technically more like the http1 Server's ReadHeaderTimeout + // (in Go 1.8), though. That's a more sane option anyway. + if sc.hs.ReadTimeout != 0 { + sc.conn.SetReadDeadline(time.Time{}) + } + + go sc.runHandler(rw, req, handler) + return nil +} + +func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { + sc := st.sc + sc.serveG.check() + if st.gotTrailerHeader { + return http2ConnectionError(http2ErrCodeProtocol) + } + st.gotTrailerHeader = true + if !f.StreamEnded() { + return http2streamError(st.id, http2ErrCodeProtocol) + } + + if len(f.PseudoFields()) > 0 { + return http2streamError(st.id, http2ErrCodeProtocol) + } + if st.trailer != nil { + for _, hf := range f.RegularFields() { + key := sc.canonicalHeader(hf.Name) + if !httpguts.ValidTrailerHeader(key) { + // TODO: send more details to the peer somehow. But http2 has + // no way to send debug data at a stream level. Discuss with + // HTTP folk. + return http2streamError(st.id, http2ErrCodeProtocol) + } + st.trailer[key] = append(st.trailer[key], hf.Value) + } + } + st.endStream() + return nil +} + +func http2checkPriority(streamID uint32, p http2PriorityParam) error { + if streamID == p.StreamDep { + // Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat + // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR." + // Section 5.3.3 says that a stream can depend on one of its dependencies, + // so it's only self-dependencies that are forbidden. + return http2streamError(streamID, http2ErrCodeProtocol) + } + return nil +} + +func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { + if sc.inGoAway { + return nil + } + if err := http2checkPriority(f.StreamID, f.http2PriorityParam); err != nil { + return err + } + sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam) + return nil +} + +func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState) *http2stream { + sc.serveG.check() + if id == 0 { + panic("internal error: cannot create stream with id 0") + } + + ctx, cancelCtx := context.WithCancel(sc.baseCtx) + st := &http2stream{ + sc: sc, + id: id, + state: state, + ctx: ctx, + cancelCtx: cancelCtx, + } + st.cw.Init() + st.flow.conn = &sc.flow // link to conn-level counter + st.flow.add(sc.initialStreamSendWindowSize) + st.inflow.conn = &sc.inflow // link to conn-level counter + st.inflow.add(sc.srv.initialStreamRecvWindowSize()) + if sc.hs.WriteTimeout != 0 { + st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) + } + + sc.streams[id] = st + sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID}) + if st.isPushed() { + sc.curPushedStreams++ + } else { + sc.curClientStreams++ + } + if sc.curOpenStreams() == 1 { + sc.setConnState(StateActive) + } + + return st +} + +func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *Request, error) { + sc.serveG.check() + + rp := http2requestParam{ + method: f.PseudoValue("method"), + scheme: f.PseudoValue("scheme"), + authority: f.PseudoValue("authority"), + path: f.PseudoValue("path"), + } + + isConnect := rp.method == "CONNECT" + if isConnect { + if rp.path != "" || rp.scheme != "" || rp.authority == "" { + return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + } + } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { + // See 8.1.2.6 Malformed Requests and Responses: + // + // Malformed requests or responses that are detected + // MUST be treated as a stream error (Section 5.4.2) + // of type PROTOCOL_ERROR." + // + // 8.1.2.3 Request Pseudo-Header Fields + // "All HTTP/2 requests MUST include exactly one valid + // value for the :method, :scheme, and :path + // pseudo-header fields" + return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + } + + bodyOpen := !f.StreamEnded() + if rp.method == "HEAD" && bodyOpen { + // HEAD requests can't have bodies + return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + } + + rp.header = make(Header) + for _, hf := range f.RegularFields() { + rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) + } + if rp.authority == "" { + rp.authority = rp.header.Get("Host") + } + + rw, req, err := sc.newWriterAndRequestNoBody(st, rp) + if err != nil { + return nil, nil, err + } + if bodyOpen { + if vv, ok := rp.header["Content-Length"]; ok { + if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil { + req.ContentLength = int64(cl) + } else { + req.ContentLength = 0 + } + } else { + req.ContentLength = -1 + } + req.Body.(*http2requestBody).pipe = &http2pipe{ + b: &http2dataBuffer{expected: req.ContentLength}, + } + } + return rw, req, nil +} + +type http2requestParam struct { + method string + scheme, authority, path string + header Header +} + +func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2requestParam) (*http2responseWriter, *Request, error) { + sc.serveG.check() + + var tlsState *tls.ConnectionState // nil if not scheme https + if rp.scheme == "https" { + tlsState = sc.tlsState + } + + needsContinue := rp.header.Get("Expect") == "100-continue" + if needsContinue { + rp.header.Del("Expect") + } + // Merge Cookie headers into one "; "-delimited value. + if cookies := rp.header["Cookie"]; len(cookies) > 1 { + rp.header.Set("Cookie", strings.Join(cookies, "; ")) + } + + // Setup Trailers + var trailer Header + for _, v := range rp.header["Trailer"] { + for _, key := range strings.Split(v, ",") { + key = CanonicalHeaderKey(textproto.TrimString(key)) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + // Bogus. (copy of http1 rules) + // Ignore. + default: + if trailer == nil { + trailer = make(Header) + } + trailer[key] = nil + } + } + } + delete(rp.header, "Trailer") + + var url_ *url.URL + var requestURI string + if rp.method == "CONNECT" { + url_ = &url.URL{Host: rp.authority} + requestURI = rp.authority // mimic HTTP/1 server behavior + } else { + var err error + url_, err = url.ParseRequestURI(rp.path) + if err != nil { + return nil, nil, http2streamError(st.id, http2ErrCodeProtocol) + } + requestURI = rp.path + } + + body := &http2requestBody{ + conn: sc, + stream: st, + needsContinue: needsContinue, + } + req := &Request{ + Method: rp.method, + URL: url_, + RemoteAddr: sc.remoteAddrStr, + Header: rp.header, + RequestURI: requestURI, + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + TLS: tlsState, + Host: rp.authority, + Body: body, + Trailer: trailer, + } + req = req.WithContext(st.ctx) + + rws := http2responseWriterStatePool.Get().(*http2responseWriterState) + bwSave := rws.bw + *rws = http2responseWriterState{} // zero all the fields + rws.conn = sc + rws.bw = bwSave + rws.bw.Reset(http2chunkWriter{rws}) + rws.stream = st + rws.req = req + rws.body = body + + rw := &http2responseWriter{rws: rws} + return rw, req, nil +} + +// Run on its own goroutine. +func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) { + didPanic := true + defer func() { + rw.rws.stream.cancelCtx() + if didPanic { + e := recover() + sc.writeFrameFromHandler(http2FrameWriteRequest{ + write: http2handlerPanicRST{rw.rws.stream.id}, + stream: rw.rws.stream, + }) + // Same as net/http: + if e != nil && e != ErrAbortHandler { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) + } + return + } + rw.handlerDone() + }() + handler(rw, req) + didPanic = false +} + +func http2handleHeaderListTooLong(w ResponseWriter, r *Request) { + // 10.5.1 Limits on Header Block Size: + // .. "A server that receives a larger header block than it is + // willing to handle can send an HTTP 431 (Request Header Fields Too + // Large) status code" + const statusRequestHeaderFieldsTooLarge = 431 // only in Go 1.6+ + w.WriteHeader(statusRequestHeaderFieldsTooLarge) + io.WriteString(w, "

HTTP Error 431

Request Header Field(s) Too Large

") +} + +// called from handler goroutines. +// h may be nil. +func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeResHeaders) error { + sc.serveG.checkNotOn() // NOT on + var errc chan error + if headerData.h != nil { + // If there's a header map (which we don't own), so we have to block on + // waiting for this frame to be written, so an http.Flush mid-handler + // writes out the correct value of keys, before a handler later potentially + // mutates it. + errc = http2errChanPool.Get().(chan error) + } + if err := sc.writeFrameFromHandler(http2FrameWriteRequest{ + write: headerData, + stream: st, + done: errc, + }); err != nil { + return err + } + if errc != nil { + select { + case err := <-errc: + http2errChanPool.Put(errc) + return err + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + } + } + return nil +} + +// called from handler goroutines. +func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) { + sc.writeFrameFromHandler(http2FrameWriteRequest{ + write: http2write100ContinueHeadersFrame{st.id}, + stream: st, + }) +} + +// A bodyReadMsg tells the server loop that the http.Handler read n +// bytes of the DATA from the client on the given stream. +type http2bodyReadMsg struct { + st *http2stream + n int +} + +// called from handler goroutines. +// Notes that the handler for the given stream ID read n bytes of its body +// and schedules flow control tokens to be sent. +func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int, err error) { + sc.serveG.checkNotOn() // NOT on + if n > 0 { + select { + case sc.bodyReadCh <- http2bodyReadMsg{st, n}: + case <-sc.doneServing: + } + } +} + +func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) { + sc.serveG.check() + sc.sendWindowUpdate(nil, n) // conn-level + if st.state != http2stateHalfClosedRemote && st.state != http2stateClosed { + // Don't send this WINDOW_UPDATE if the stream is closed + // remotely. + sc.sendWindowUpdate(st, n) + } +} + +// st may be nil for conn-level +func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) { + sc.serveG.check() + // "The legal range for the increment to the flow control + // window is 1 to 2^31-1 (2,147,483,647) octets." + // A Go Read call on 64-bit machines could in theory read + // a larger Read than this. Very unlikely, but we handle it here + // rather than elsewhere for now. + const maxUint31 = 1<<31 - 1 + for n >= maxUint31 { + sc.sendWindowUpdate32(st, maxUint31) + n -= maxUint31 + } + sc.sendWindowUpdate32(st, int32(n)) +} + +// st may be nil for conn-level +func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { + sc.serveG.check() + if n == 0 { + return + } + if n < 0 { + panic("negative update") + } + var streamID uint32 + if st != nil { + streamID = st.id + } + sc.writeFrame(http2FrameWriteRequest{ + write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)}, + stream: st, + }) + var ok bool + if st == nil { + ok = sc.inflow.add(n) + } else { + ok = st.inflow.add(n) + } + if !ok { + panic("internal error; sent too many window updates without decrements?") + } +} + +// requestBody is the Handler's Request.Body type. +// Read and Close may be called concurrently. +type http2requestBody struct { + _ http2incomparable + stream *http2stream + conn *http2serverConn + closed bool // for use by Close only + sawEOF bool // for use by Read only + pipe *http2pipe // non-nil if we have a HTTP entity message body + needsContinue bool // need to send a 100-continue +} + +func (b *http2requestBody) Close() error { + if b.pipe != nil && !b.closed { + b.pipe.BreakWithError(http2errClosedBody) + } + b.closed = true + return nil +} + +func (b *http2requestBody) Read(p []byte) (n int, err error) { + if b.needsContinue { + b.needsContinue = false + b.conn.write100ContinueHeaders(b.stream) + } + if b.pipe == nil || b.sawEOF { + return 0, io.EOF + } + n, err = b.pipe.Read(p) + if err == io.EOF { + b.sawEOF = true + } + if b.conn == nil && http2inTests { + return + } + b.conn.noteBodyReadFromHandler(b.stream, n, err) + return +} + +// responseWriter is the http.ResponseWriter implementation. It's +// intentionally small (1 pointer wide) to minimize garbage. The +// responseWriterState pointer inside is zeroed at the end of a +// request (in handlerDone) and calls on the responseWriter thereafter +// simply crash (caller's mistake), but the much larger responseWriterState +// and buffers are reused between multiple requests. +type http2responseWriter struct { + rws *http2responseWriterState +} + +// Optional http.ResponseWriter interfaces implemented. +var ( + _ CloseNotifier = (*http2responseWriter)(nil) + _ Flusher = (*http2responseWriter)(nil) + _ http2stringWriter = (*http2responseWriter)(nil) +) + +type http2responseWriterState struct { + // immutable within a request: + stream *http2stream + req *Request + body *http2requestBody // to close at end of request, if DATA frames didn't + conn *http2serverConn + + // TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc + bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState} + + // mutated by http.Handler goroutine: + handlerHeader Header // nil until called + snapHeader Header // snapshot of handlerHeader at WriteHeader time + trailers []string // set in writeChunk + status int // status code passed to WriteHeader + wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. + sentHeader bool // have we sent the header frame? + handlerDone bool // handler has finished + dirty bool // a Write failed; don't reuse this responseWriterState + + sentContentLen int64 // non-zero if handler set a Content-Length header + wroteBytes int64 + + closeNotifierMu sync.Mutex // guards closeNotifierCh + closeNotifierCh chan bool // nil until first used +} + +type http2chunkWriter struct{ rws *http2responseWriterState } + +func (cw http2chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } + +func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 } + +func (rws *http2responseWriterState) hasNonemptyTrailers() bool { + for _, trailer := range rws.trailers { + if _, ok := rws.handlerHeader[trailer]; ok { + return true + } + } + return false +} + +// declareTrailer is called for each Trailer header when the +// response header is written. It notes that a header will need to be +// written in the trailers at the end of the response. +func (rws *http2responseWriterState) declareTrailer(k string) { + k = CanonicalHeaderKey(k) + if !httpguts.ValidTrailerHeader(k) { + // Forbidden by RFC 7230, section 4.1.2. + rws.conn.logf("ignoring invalid trailer %q", k) + return + } + if !http2strSliceContains(rws.trailers, k) { + rws.trailers = append(rws.trailers, k) + } +} + +// writeChunk writes chunks from the bufio.Writer. But because +// bufio.Writer may bypass its chunking, sometimes p may be +// arbitrarily large. +// +// writeChunk is also responsible (on the first chunk) for sending the +// HEADER response. +func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { + if !rws.wroteHeader { + rws.writeHeader(200) + } + + isHeadResp := rws.req.Method == "HEAD" + if !rws.sentHeader { + rws.sentHeader = true + var ctype, clen string + if clen = rws.snapHeader.Get("Content-Length"); clen != "" { + rws.snapHeader.Del("Content-Length") + if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { + rws.sentContentLen = int64(cl) + } else { + clen = "" + } + } + if clen == "" && rws.handlerDone && http2bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) { + clen = strconv.Itoa(len(p)) + } + _, hasContentType := rws.snapHeader["Content-Type"] + // If the Content-Encoding is non-blank, we shouldn't + // sniff the body. See Issue golang.org/issue/31753. + ce := rws.snapHeader.Get("Content-Encoding") + hasCE := len(ce) > 0 + if !hasCE && !hasContentType && http2bodyAllowedForStatus(rws.status) && len(p) > 0 { + ctype = DetectContentType(p) + } + var date string + if _, ok := rws.snapHeader["Date"]; !ok { + // TODO(bradfitz): be faster here, like net/http? measure. + date = time.Now().UTC().Format(TimeFormat) + } + + for _, v := range rws.snapHeader["Trailer"] { + http2foreachHeaderElement(v, rws.declareTrailer) + } + + // "Connection" headers aren't allowed in HTTP/2 (RFC 7540, 8.1.2.2), + // but respect "Connection" == "close" to mean sending a GOAWAY and tearing + // down the TCP connection when idle, like we do for HTTP/1. + // TODO: remove more Connection-specific header fields here, in addition + // to "Connection". + if _, ok := rws.snapHeader["Connection"]; ok { + v := rws.snapHeader.Get("Connection") + delete(rws.snapHeader, "Connection") + if v == "close" { + rws.conn.startGracefulShutdown() + } + } + + endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp + err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ + streamID: rws.stream.id, + httpResCode: rws.status, + h: rws.snapHeader, + endStream: endStream, + contentType: ctype, + contentLength: clen, + date: date, + }) + if err != nil { + rws.dirty = true + return 0, err + } + if endStream { + return 0, nil + } + } + if isHeadResp { + return len(p), nil + } + if len(p) == 0 && !rws.handlerDone { + return 0, nil + } + + if rws.handlerDone { + rws.promoteUndeclaredTrailers() + } + + // only send trailers if they have actually been defined by the + // server handler. + hasNonemptyTrailers := rws.hasNonemptyTrailers() + endStream := rws.handlerDone && !hasNonemptyTrailers + if len(p) > 0 || endStream { + // only send a 0 byte DATA frame if we're ending the stream. + if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil { + rws.dirty = true + return 0, err + } + } + + if rws.handlerDone && hasNonemptyTrailers { + err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ + streamID: rws.stream.id, + h: rws.handlerHeader, + trailers: rws.trailers, + endStream: true, + }) + if err != nil { + rws.dirty = true + } + return len(p), err + } + return len(p), nil +} + +// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys +// that, if present, signals that the map entry is actually for +// the response trailers, and not the response headers. The prefix +// is stripped after the ServeHTTP call finishes and the values are +// sent in the trailers. +// +// This mechanism is intended only for trailers that are not known +// prior to the headers being written. If the set of trailers is fixed +// or known before the header is written, the normal Go trailers mechanism +// is preferred: +// https://golang.org/pkg/net/http/#ResponseWriter +// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers +const http2TrailerPrefix = "Trailer:" + +// promoteUndeclaredTrailers permits http.Handlers to set trailers +// after the header has already been flushed. Because the Go +// ResponseWriter interface has no way to set Trailers (only the +// Header), and because we didn't want to expand the ResponseWriter +// interface, and because nobody used trailers, and because RFC 7230 +// says you SHOULD (but not must) predeclare any trailers in the +// header, the official ResponseWriter rules said trailers in Go must +// be predeclared, and then we reuse the same ResponseWriter.Header() +// map to mean both Headers and Trailers. When it's time to write the +// Trailers, we pick out the fields of Headers that were declared as +// trailers. That worked for a while, until we found the first major +// user of Trailers in the wild: gRPC (using them only over http2), +// and gRPC libraries permit setting trailers mid-stream without +// predeclaring them. So: change of plans. We still permit the old +// way, but we also permit this hack: if a Header() key begins with +// "Trailer:", the suffix of that key is a Trailer. Because ':' is an +// invalid token byte anyway, there is no ambiguity. (And it's already +// filtered out) It's mildly hacky, but not terrible. +// +// This method runs after the Handler is done and promotes any Header +// fields to be trailers. +func (rws *http2responseWriterState) promoteUndeclaredTrailers() { + for k, vv := range rws.handlerHeader { + if !strings.HasPrefix(k, http2TrailerPrefix) { + continue + } + trailerKey := strings.TrimPrefix(k, http2TrailerPrefix) + rws.declareTrailer(trailerKey) + rws.handlerHeader[CanonicalHeaderKey(trailerKey)] = vv + } + + if len(rws.trailers) > 1 { + sorter := http2sorterPool.Get().(*http2sorter) + sorter.SortStrings(rws.trailers) + http2sorterPool.Put(sorter) + } +} + +func (w *http2responseWriter) Flush() { + rws := w.rws + if rws == nil { + panic("Header called after Handler finished") + } + if rws.bw.Buffered() > 0 { + if err := rws.bw.Flush(); err != nil { + // Ignore the error. The frame writer already knows. + return + } + } else { + // The bufio.Writer won't call chunkWriter.Write + // (writeChunk with zero bytes, so we have to do it + // ourselves to force the HTTP response header and/or + // final DATA frame (with END_STREAM) to be sent. + rws.writeChunk(nil) + } +} + +func (w *http2responseWriter) CloseNotify() <-chan bool { + rws := w.rws + if rws == nil { + panic("CloseNotify called after Handler finished") + } + rws.closeNotifierMu.Lock() + ch := rws.closeNotifierCh + if ch == nil { + ch = make(chan bool, 1) + rws.closeNotifierCh = ch + cw := rws.stream.cw + go func() { + cw.Wait() // wait for close + ch <- true + }() + } + rws.closeNotifierMu.Unlock() + return ch +} + +func (w *http2responseWriter) Header() Header { + rws := w.rws + if rws == nil { + panic("Header called after Handler finished") + } + if rws.handlerHeader == nil { + rws.handlerHeader = make(Header) + } + return rws.handlerHeader +} + +// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode. +func http2checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at http://httpwg.org/specs/rfc7231.html#status.codes) + // and we might block under 200 (once we have more mature 1xx support). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/2, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + +func (w *http2responseWriter) WriteHeader(code int) { + rws := w.rws + if rws == nil { + panic("WriteHeader called after Handler finished") + } + rws.writeHeader(code) +} + +func (rws *http2responseWriterState) writeHeader(code int) { + if !rws.wroteHeader { + http2checkWriteHeaderCode(code) + rws.wroteHeader = true + rws.status = code + if len(rws.handlerHeader) > 0 { + rws.snapHeader = http2cloneHeader(rws.handlerHeader) + } + } +} + +func http2cloneHeader(h Header) Header { + h2 := make(Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +// The Life Of A Write is like this: +// +// * Handler calls w.Write or w.WriteString -> +// * -> rws.bw (*bufio.Writer) -> +// * (Handler might call Flush) +// * -> chunkWriter{rws} +// * -> responseWriterState.writeChunk(p []byte) +// * -> responseWriterState.writeChunk (most of the magic; see comment there) +func (w *http2responseWriter) Write(p []byte) (n int, err error) { + return w.write(len(p), p, "") +} + +func (w *http2responseWriter) WriteString(s string) (n int, err error) { + return w.write(len(s), nil, s) +} + +// either dataB or dataS is non-zero. +func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) { + rws := w.rws + if rws == nil { + panic("Write called after Handler finished") + } + if !rws.wroteHeader { + w.WriteHeader(200) + } + if !http2bodyAllowedForStatus(rws.status) { + return 0, ErrBodyNotAllowed + } + rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) // only one can be set + if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen { + // TODO: send a RST_STREAM + return 0, errors.New("http2: handler wrote more than declared Content-Length") + } + + if dataB != nil { + return rws.bw.Write(dataB) + } else { + return rws.bw.WriteString(dataS) + } +} + +func (w *http2responseWriter) handlerDone() { + rws := w.rws + dirty := rws.dirty + rws.handlerDone = true + w.Flush() + w.rws = nil + if !dirty { + // Only recycle the pool if all prior Write calls to + // the serverConn goroutine completed successfully. If + // they returned earlier due to resets from the peer + // there might still be write goroutines outstanding + // from the serverConn referencing the rws memory. See + // issue 20704. + http2responseWriterStatePool.Put(rws) + } +} + +// Push errors. +var ( + http2ErrRecursivePush = errors.New("http2: recursive push not allowed") + http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") +) + +var _ Pusher = (*http2responseWriter)(nil) + +func (w *http2responseWriter) Push(target string, opts *PushOptions) error { + st := w.rws.stream + sc := st.sc + sc.serveG.checkNotOn() + + // No recursive pushes: "PUSH_PROMISE frames MUST only be sent on a peer-initiated stream." + // http://tools.ietf.org/html/rfc7540#section-6.6 + if st.isPushed() { + return http2ErrRecursivePush + } + + if opts == nil { + opts = new(PushOptions) + } + + // Default options. + if opts.Method == "" { + opts.Method = "GET" + } + if opts.Header == nil { + opts.Header = Header{} + } + wantScheme := "http" + if w.rws.req.TLS != nil { + wantScheme = "https" + } + + // Validate the request. + u, err := url.Parse(target) + if err != nil { + return err + } + if u.Scheme == "" { + if !strings.HasPrefix(target, "/") { + return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target) + } + u.Scheme = wantScheme + u.Host = w.rws.req.Host + } else { + if u.Scheme != wantScheme { + return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme) + } + if u.Host == "" { + return errors.New("URL must have a host") + } + } + for k := range opts.Header { + if strings.HasPrefix(k, ":") { + return fmt.Errorf("promised request headers cannot include pseudo header %q", k) + } + // These headers are meaningful only if the request has a body, + // but PUSH_PROMISE requests cannot have a body. + // http://tools.ietf.org/html/rfc7540#section-8.2 + // Also disallow Host, since the promised URL must be absolute. + switch strings.ToLower(k) { + case "content-length", "content-encoding", "trailer", "te", "expect", "host": + return fmt.Errorf("promised request headers cannot include %q", k) + } + } + if err := http2checkValidHTTP2RequestHeaders(opts.Header); err != nil { + return err + } + + // The RFC effectively limits promised requests to GET and HEAD: + // "Promised requests MUST be cacheable [GET, HEAD, or POST], and MUST be safe [GET or HEAD]" + // http://tools.ietf.org/html/rfc7540#section-8.2 + if opts.Method != "GET" && opts.Method != "HEAD" { + return fmt.Errorf("method %q must be GET or HEAD", opts.Method) + } + + msg := &http2startPushRequest{ + parent: st, + method: opts.Method, + url: u, + header: http2cloneHeader(opts.Header), + done: http2errChanPool.Get().(chan error), + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case sc.serveMsgCh <- msg: + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case err := <-msg.done: + http2errChanPool.Put(msg.done) + return err + } +} + +type http2startPushRequest struct { + parent *http2stream + method string + url *url.URL + header Header + done chan error +} + +func (sc *http2serverConn) startPush(msg *http2startPushRequest) { + sc.serveG.check() + + // http://tools.ietf.org/html/rfc7540#section-6.6. + // PUSH_PROMISE frames MUST only be sent on a peer-initiated stream that + // is in either the "open" or "half-closed (remote)" state. + if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote { + // responseWriter.Push checks that the stream is peer-initiated. + msg.done <- http2errStreamClosed + return + } + + // http://tools.ietf.org/html/rfc7540#section-6.6. + if !sc.pushEnabled { + msg.done <- ErrNotSupported + return + } + + // PUSH_PROMISE frames must be sent in increasing order by stream ID, so + // we allocate an ID for the promised stream lazily, when the PUSH_PROMISE + // is written. Once the ID is allocated, we start the request handler. + allocatePromisedID := func() (uint32, error) { + sc.serveG.check() + + // Check this again, just in case. Technically, we might have received + // an updated SETTINGS by the time we got around to writing this frame. + if !sc.pushEnabled { + return 0, ErrNotSupported + } + // http://tools.ietf.org/html/rfc7540#section-6.5.2. + if sc.curPushedStreams+1 > sc.clientMaxStreams { + return 0, http2ErrPushLimitReached + } + + // http://tools.ietf.org/html/rfc7540#section-5.1.1. + // Streams initiated by the server MUST use even-numbered identifiers. + // A server that is unable to establish a new stream identifier can send a GOAWAY + // frame so that the client is forced to open a new connection for new streams. + if sc.maxPushPromiseID+2 >= 1<<31 { + sc.startGracefulShutdownInternal() + return 0, http2ErrPushLimitReached + } + sc.maxPushPromiseID += 2 + promisedID := sc.maxPushPromiseID + + // http://tools.ietf.org/html/rfc7540#section-8.2. + // Strictly speaking, the new stream should start in "reserved (local)", then + // transition to "half closed (remote)" after sending the initial HEADERS, but + // we start in "half closed (remote)" for simplicity. + // See further comments at the definition of stateHalfClosedRemote. + promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote) + rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{ + method: msg.method, + scheme: msg.url.Scheme, + authority: msg.url.Host, + path: msg.url.RequestURI(), + header: http2cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE + }) + if err != nil { + // Should not happen, since we've already validated msg.url. + panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err)) + } + + go sc.runHandler(rw, req, sc.handler.ServeHTTP) + return promisedID, nil + } + + sc.writeFrame(http2FrameWriteRequest{ + write: &http2writePushPromise{ + streamID: msg.parent.id, + method: msg.method, + url: msg.url, + h: msg.header, + allocatePromisedID: allocatePromisedID, + }, + stream: msg.parent, + done: msg.done, + }) +} + +// foreachHeaderElement splits v according to the "#rule" construction +// in RFC 7230 section 7 and calls fn for each non-empty element. +func http2foreachHeaderElement(v string, fn func(string)) { + v = textproto.TrimString(v) + if v == "" { + return + } + if !strings.Contains(v, ",") { + fn(v) + return + } + for _, f := range strings.Split(v, ",") { + if f = textproto.TrimString(f); f != "" { + fn(f) + } + } +} + +// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2 +var http2connHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Connection", + "Transfer-Encoding", + "Upgrade", +} + +// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, +// per RFC 7540 Section 8.1.2.2. +// The returned error is reported to users. +func http2checkValidHTTP2RequestHeaders(h Header) error { + for _, k := range http2connHeaders { + if _, ok := h[k]; ok { + return fmt.Errorf("request header %q is not valid in HTTP/2", k) + } + } + te := h["Te"] + if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { + return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) + } + return nil +} + +func http2new400Handler(err error) HandlerFunc { + return func(w ResponseWriter, r *Request) { + Error(w, err.Error(), StatusBadRequest) + } +} + +// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives +// disabled. See comments on h1ServerShutdownChan above for why +// the code is written this way. +func http2h1ServerKeepAlivesDisabled(hs *Server) bool { + var x interface{} = hs + type I interface { + doKeepAlives() bool + } + if hs, ok := x.(I); ok { + return !hs.doKeepAlives() + } + return false +} + +const ( + // transportDefaultConnFlow is how many connection-level flow control + // tokens we give the server at start-up, past the default 64k. + http2transportDefaultConnFlow = 1 << 30 + + // transportDefaultStreamFlow is how many stream-level flow + // control tokens we announce to the peer, and how many bytes + // we buffer per stream. + http2transportDefaultStreamFlow = 4 << 20 + + // transportDefaultStreamMinRefresh is the minimum number of bytes we'll send + // a stream-level WINDOW_UPDATE for at a time. + http2transportDefaultStreamMinRefresh = 4 << 10 + + http2defaultUserAgent = "Go-http-client/2.0" +) + +// Transport is an HTTP/2 Transport. +// +// A Transport internally caches connections to servers. It is safe +// for concurrent use by multiple goroutines. +type http2Transport struct { + // DialTLS specifies an optional dial function for creating + // TLS connections for requests. + // + // If DialTLS is nil, tls.Dial is used. + // + // If the returned net.Conn has a ConnectionState method like tls.Conn, + // it will be used to set http.Response.TLS. + DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error) + + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + + // ConnPool optionally specifies an alternate connection pool to use. + // If nil, the default is used. + ConnPool http2ClientConnPool + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool + + // AllowHTTP, if true, permits HTTP/2 requests using the insecure, + // plain-text "http" scheme. Note that this does not enable h2c support. + AllowHTTP bool + + // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to + // send in the initial settings frame. It is how many bytes + // of response headers are allowed. Unlike the http2 spec, zero here + // means to use a default limit (currently 10MB). If you actually + // want to advertise an unlimited value to the peer, Transport + // interprets the highest possible value here (0xffffffff or 1<<32-1) + // to mean no limit. + MaxHeaderListSize uint32 + + // StrictMaxConcurrentStreams controls whether the server's + // SETTINGS_MAX_CONCURRENT_STREAMS should be respected + // globally. If false, new TCP connections are created to the + // server as needed to keep each under the per-connection + // SETTINGS_MAX_CONCURRENT_STREAMS limit. If true, the + // server's SETTINGS_MAX_CONCURRENT_STREAMS is interpreted as + // a global limit and callers of RoundTrip block when needed, + // waiting for their turn. + StrictMaxConcurrentStreams bool + + // ReadIdleTimeout is the timeout after which a health check using ping + // frame will be carried out if no frame is received on the connection. + // Note that a ping response will is considered a received frame, so if + // there is no other traffic on the connection, the health check will + // be performed every ReadIdleTimeout interval. + // If zero, no health check is performed. + ReadIdleTimeout time.Duration + + // PingTimeout is the timeout after which the connection will be closed + // if a response to Ping is not received. + // Defaults to 15s. + PingTimeout time.Duration + + // t1, if non-nil, is the standard library Transport using + // this transport. Its settings are used (but not its + // RoundTrip method, etc). + t1 *Transport + + connPoolOnce sync.Once + connPoolOrDef http2ClientConnPool // non-nil version of ConnPool +} + +func (t *http2Transport) maxHeaderListSize() uint32 { + if t.MaxHeaderListSize == 0 { + return 10 << 20 + } + if t.MaxHeaderListSize == 0xffffffff { + return 0 + } + return t.MaxHeaderListSize +} + +func (t *http2Transport) disableCompression() bool { + return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) +} + +func (t *http2Transport) pingTimeout() time.Duration { + if t.PingTimeout == 0 { + return 15 * time.Second + } + return t.PingTimeout + +} + +// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. +// It returns an error if t1 has already been HTTP/2-enabled. +// +// Use ConfigureTransports instead to configure the HTTP/2 Transport. +func http2ConfigureTransport(t1 *Transport) error { + _, err := http2ConfigureTransports(t1) + return err +} + +// ConfigureTransports configures a net/http HTTP/1 Transport to use HTTP/2. +// It returns a new HTTP/2 Transport for further configuration. +// It returns an error if t1 has already been HTTP/2-enabled. +func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { + return http2configureTransports(t1) +} + +func http2configureTransports(t1 *Transport) (*http2Transport, error) { + connPool := new(http2clientConnPool) + t2 := &http2Transport{ + ConnPool: http2noDialClientConnPool{connPool}, + t1: t1, + } + connPool.t = t2 + if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil { + return nil, err + } + if t1.TLSClientConfig == nil { + t1.TLSClientConfig = new(tls.Config) + } + if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { + t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) + } + if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { + t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") + } + upgradeFn := func(authority string, c *tls.Conn) RoundTripper { + addr := http2authorityAddr("https", authority) + if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { + go c.Close() + return http2erringRoundTripper{err} + } else if !used { + // Turns out we don't need this c. + // For example, two goroutines made requests to the same host + // at the same time, both kicking off TCP dials. (since protocol + // was unknown) + go c.Close() + } + return t2 + } + if m := t1.TLSNextProto; len(m) == 0 { + t1.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{ + "h2": upgradeFn, + } + } else { + m["h2"] = upgradeFn + } + return t2, nil +} + +func (t *http2Transport) connPool() http2ClientConnPool { + t.connPoolOnce.Do(t.initConnPool) + return t.connPoolOrDef +} + +func (t *http2Transport) initConnPool() { + if t.ConnPool != nil { + t.connPoolOrDef = t.ConnPool + } else { + t.connPoolOrDef = &http2clientConnPool{t: t} + } +} + +// ClientConn is the state of a single HTTP/2 client connection to an +// HTTP/2 server. +type http2ClientConn struct { + t *http2Transport + tconn net.Conn // usually *tls.Conn, except specialized impls + tlsState *tls.ConnectionState // nil only for specialized impls + reused uint32 // whether conn is being reused; atomic + singleUse bool // whether being used for a single http.Request + + // readLoop goroutine fields: + readerDone chan struct{} // closed on error + readerErr error // set before readerDone is closed + + idleTimeout time.Duration // or 0 for never + idleTimer *time.Timer + + mu sync.Mutex // guards following + cond *sync.Cond // hold mu; broadcast on flow/closed changes + flow http2flow // our conn-level flow control quota (cs.flow is per stream) + inflow http2flow // peer's conn-level flow control + closing bool + closed bool + wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back + goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received + goAwayDebug string // goAway frame's debug data, retained as a string + streams map[uint32]*http2clientStream // client-initiated + nextStreamID uint32 + pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams + pings map[[8]byte]chan struct{} // in flight ping data to notification channel + bw *bufio.Writer + br *bufio.Reader + fr *http2Framer + lastActive time.Time + lastIdle time.Time // time last idle + // Settings from peer: (also guarded by mu) + maxFrameSize uint32 + maxConcurrentStreams uint32 + peerMaxHeaderListSize uint64 + initialWindowSize uint32 + + hbuf bytes.Buffer // HPACK encoder writes into this + henc *hpack.Encoder + freeBuf [][]byte + + wmu sync.Mutex // held while writing; acquire AFTER mu if holding both + werr error // first write error that has occurred +} + +// clientStream is the state for a single HTTP/2 stream. One of these +// is created for each Transport.RoundTrip call. +type http2clientStream struct { + cc *http2ClientConn + req *Request + trace *httptrace.ClientTrace // or nil + ID uint32 + resc chan http2resAndError + bufPipe http2pipe // buffered pipe with the flow-controlled response payload + startedWrite bool // started request body write; guarded by cc.mu + requestedGzip bool + on100 func() // optional code to run if get a 100 continue response + + flow http2flow // guarded by cc.mu + inflow http2flow // guarded by cc.mu + bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read + readErr error // sticky read error; owned by transportResponseBody.Read + stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu + didReset bool // whether we sent a RST_STREAM to the server; guarded by cc.mu + + peerReset chan struct{} // closed on peer reset + resetErr error // populated before peerReset is closed + + done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu + + // owned by clientConnReadLoop: + firstByte bool // got the first response byte + pastHeaders bool // got first MetaHeadersFrame (actual headers) + pastTrailers bool // got optional second MetaHeadersFrame (trailers) + num1xx uint8 // number of 1xx responses seen + + trailer Header // accumulated trailers + resTrailer *Header // client's Response.Trailer +} + +// awaitRequestCancel waits for the user to cancel a request or for the done +// channel to be signaled. A non-nil error is returned only if the request was +// canceled. +func http2awaitRequestCancel(req *Request, done <-chan struct{}) error { + ctx := req.Context() + if req.Cancel == nil && ctx.Done() == nil { + return nil + } + select { + case <-req.Cancel: + return http2errRequestCanceled + case <-ctx.Done(): + return ctx.Err() + case <-done: + return nil + } +} + +var http2got1xxFuncForTests func(int, textproto.MIMEHeader) error + +// get1xxTraceFunc returns the value of request's httptrace.ClientTrace.Got1xxResponse func, +// if any. It returns nil if not set or if the Go version is too old. +func (cs *http2clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error { + if fn := http2got1xxFuncForTests; fn != nil { + return fn + } + return http2traceGot1xxResponseFunc(cs.trace) +} + +// awaitRequestCancel waits for the user to cancel a request, its context to +// expire, or for the request to be done (any way it might be removed from the +// cc.streams map: peer reset, successful completion, TCP connection breakage, +// etc). If the request is canceled, then cs will be canceled and closed. +func (cs *http2clientStream) awaitRequestCancel(req *Request) { + if err := http2awaitRequestCancel(req, cs.done); err != nil { + cs.cancelStream() + cs.bufPipe.CloseWithError(err) + } +} + +func (cs *http2clientStream) cancelStream() { + cc := cs.cc + cc.mu.Lock() + didReset := cs.didReset + cs.didReset = true + cc.mu.Unlock() + + if !didReset { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + cc.forgetStreamID(cs.ID) + } +} + +// checkResetOrDone reports any error sent in a RST_STREAM frame by the +// server, or errStreamClosed if the stream is complete. +func (cs *http2clientStream) checkResetOrDone() error { + select { + case <-cs.peerReset: + return cs.resetErr + case <-cs.done: + return http2errStreamClosed + default: + return nil + } +} + +func (cs *http2clientStream) getStartedWrite() bool { + cc := cs.cc + cc.mu.Lock() + defer cc.mu.Unlock() + return cs.startedWrite +} + +func (cs *http2clientStream) abortRequestBodyWrite(err error) { + if err == nil { + panic("nil error") + } + cc := cs.cc + cc.mu.Lock() + cs.stopReqBody = err + cc.cond.Broadcast() + cc.mu.Unlock() +} + +type http2stickyErrWriter struct { + w io.Writer + err *error +} + +func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) { + if *sew.err != nil { + return 0, *sew.err + } + n, err = sew.w.Write(p) + *sew.err = err + return +} + +// noCachedConnError is the concrete type of ErrNoCachedConn, which +// needs to be detected by net/http regardless of whether it's its +// bundled version (in h2_bundle.go with a rewritten type name) or +// from a user's x/net/http2. As such, as it has a unique method name +// (IsHTTP2NoCachedConnError) that net/http sniffs for via func +// isNoCachedConnError. +type http2noCachedConnError struct{} + +func (http2noCachedConnError) IsHTTP2NoCachedConnError() {} + +func (http2noCachedConnError) Error() string { return "http2: no cached connection was available" } + +// isNoCachedConnError reports whether err is of type noCachedConnError +// or its equivalent renamed type in net/http2's h2_bundle.go. Both types +// may coexist in the same running program. +func http2isNoCachedConnError(err error) bool { + _, ok := err.(interface{ IsHTTP2NoCachedConnError() }) + return ok +} + +var http2ErrNoCachedConn error = http2noCachedConnError{} + +// RoundTripOpt are options for the Transport.RoundTripOpt method. +type http2RoundTripOpt struct { + // OnlyCachedConn controls whether RoundTripOpt may + // create a new TCP connection. If set true and + // no cached connection is available, RoundTripOpt + // will return ErrNoCachedConn. + OnlyCachedConn bool +} + +func (t *http2Transport) RoundTrip(req *Request) (*Response, error) { + return t.RoundTripOpt(req, http2RoundTripOpt{}) +} + +// authorityAddr returns a given authority (a host/IP, or host:port / ip:port) +// and returns a host:port. The port 443 is added if needed. +func http2authorityAddr(scheme string, authority string) (addr string) { + host, port, err := net.SplitHostPort(authority) + if err != nil { // authority didn't have a port + port = "443" + if scheme == "http" { + port = "80" + } + host = authority + } + if a, err := idna.ToASCII(host); err == nil { + host = a + } + // IPv6 address literal, without a port: + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return host + ":" + port + } + return net.JoinHostPort(host, port) +} + +// RoundTripOpt is like RoundTrip, but takes options. +func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) { + if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { + return nil, errors.New("http2: unsupported scheme") + } + + addr := http2authorityAddr(req.URL.Scheme, req.URL.Host) + for retry := 0; ; retry++ { + cc, err := t.connPool().GetClientConn(req, addr) + if err != nil { + t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err) + return nil, err + } + reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1) + http2traceGotConn(req, cc, reused) + res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req) + if err != nil && retry <= 6 { + if req, err = http2shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil { + // After the first retry, do exponential backoff with 10% jitter. + if retry == 0 { + continue + } + backoff := float64(uint(1) << (uint(retry) - 1)) + backoff += backoff * (0.1 * mathrand.Float64()) + select { + case <-time.After(time.Second * time.Duration(backoff)): + continue + case <-req.Context().Done(): + return nil, req.Context().Err() + } + } + } + if err != nil { + t.vlogf("RoundTrip failure: %v", err) + return nil, err + } + return res, nil + } +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle. +// It does not interrupt any connections currently in use. +func (t *http2Transport) CloseIdleConnections() { + if cp, ok := t.connPool().(http2clientConnPoolIdleCloser); ok { + cp.closeIdleConnections() + } +} + +var ( + http2errClientConnClosed = errors.New("http2: client conn is closed") + http2errClientConnUnusable = errors.New("http2: client conn not usable") + http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") +) + +// shouldRetryRequest is called by RoundTrip when a request fails to get +// response headers. It is always called with a non-nil error. +// It returns either a request to retry (either the same request, or a +// modified clone), or an error if the request can't be replayed. +func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Request, error) { + if !http2canRetryError(err) { + return nil, err + } + // If the Body is nil (or http.NoBody), it's safe to reuse + // this request and its Body. + if req.Body == nil || req.Body == NoBody { + return req, nil + } + + // If the request body can be reset back to its original + // state via the optional req.GetBody, do that. + if req.GetBody != nil { + // TODO: consider a req.Body.Close here? or audit that all caller paths do? + body, err := req.GetBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = body + return &newReq, nil + } + + // The Request.Body can't reset back to the beginning, but we + // don't seem to have started to read from it yet, so reuse + // the request directly. The "afterBodyWrite" means the + // bodyWrite process has started, which becomes true before + // the first Read. + if !afterBodyWrite { + return req, nil + } + + return nil, fmt.Errorf("http2: Transport: cannot retry err [%v] after Request.Body was written; define Request.GetBody to avoid this error", err) +} + +func http2canRetryError(err error) bool { + if err == http2errClientConnUnusable || err == http2errClientConnGotGoAway { + return true + } + if se, ok := err.(http2StreamError); ok { + return se.Code == http2ErrCodeRefusedStream + } + return false +} + +func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2ClientConn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host)) + if err != nil { + return nil, err + } + return t.newClientConn(tconn, singleUse) +} + +func (t *http2Transport) newTLSConfig(host string) *tls.Config { + cfg := new(tls.Config) + if t.TLSClientConfig != nil { + *cfg = *t.TLSClientConfig.Clone() + } + if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { + cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...) + } + if cfg.ServerName == "" { + cfg.ServerName = host + } + return cfg +} + +func (t *http2Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) { + if t.DialTLS != nil { + return t.DialTLS + } + return t.dialTLSDefault +} + +func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) { + cn, err := tls.Dial(network, addr, cfg) + if err != nil { + return nil, err + } + if err := cn.Handshake(); err != nil { + return nil, err + } + if !cfg.InsecureSkipVerify { + if err := cn.VerifyHostname(cfg.ServerName); err != nil { + return nil, err + } + } + state := cn.ConnectionState() + if p := state.NegotiatedProtocol; p != http2NextProtoTLS { + return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2NextProtoTLS) + } + if !state.NegotiatedProtocolIsMutual { + return nil, errors.New("http2: could not negotiate protocol mutually") + } + return cn, nil +} + +// disableKeepAlives reports whether connections should be closed as +// soon as possible after handling the first request. +func (t *http2Transport) disableKeepAlives() bool { + return t.t1 != nil && t.t1.DisableKeepAlives +} + +func (t *http2Transport) expectContinueTimeout() time.Duration { + if t.t1 == nil { + return 0 + } + return t.t1.ExpectContinueTimeout +} + +func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { + return t.newClientConn(c, t.disableKeepAlives()) +} + +func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) { + cc := &http2ClientConn{ + t: t, + tconn: c, + readerDone: make(chan struct{}), + nextStreamID: 1, + maxFrameSize: 16 << 10, // spec default + initialWindowSize: 65535, // spec default + maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough. + peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. + streams: make(map[uint32]*http2clientStream), + singleUse: singleUse, + wantSettingsAck: true, + pings: make(map[[8]byte]chan struct{}), + } + if d := t.idleConnTimeout(); d != 0 { + cc.idleTimeout = d + cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) + } + if http2VerboseLogs { + t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) + } + + cc.cond = sync.NewCond(&cc.mu) + cc.flow.add(int32(http2initialWindowSize)) + + // TODO: adjust this writer size to account for frame size + + // MTU + crypto/tls record padding. + cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr}) + cc.br = bufio.NewReader(c) + cc.fr = http2NewFramer(cc.bw, cc.br) + cc.fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + cc.fr.MaxHeaderListSize = t.maxHeaderListSize() + + // TODO: SetMaxDynamicTableSize, SetMaxDynamicTableSizeLimit on + // henc in response to SETTINGS frames? + cc.henc = hpack.NewEncoder(&cc.hbuf) + + if t.AllowHTTP { + cc.nextStreamID = 3 + } + + if cs, ok := c.(http2connectionStater); ok { + state := cs.ConnectionState() + cc.tlsState = &state + } + + initialSettings := []http2Setting{ + {ID: http2SettingEnablePush, Val: 0}, + {ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, + } + if max := t.maxHeaderListSize(); max != 0 { + initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max}) + } + + cc.bw.Write(http2clientPreface) + cc.fr.WriteSettings(initialSettings...) + cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow) + cc.inflow.add(http2transportDefaultConnFlow + http2initialWindowSize) + cc.bw.Flush() + if cc.werr != nil { + cc.Close() + return nil, cc.werr + } + + go cc.readLoop() + return cc, nil +} + +func (cc *http2ClientConn) healthCheck() { + pingTimeout := cc.t.pingTimeout() + // We don't need to periodically ping in the health check, because the readLoop of ClientConn will + // trigger the healthCheck again if there is no frame received. + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + err := cc.Ping(ctx) + if err != nil { + cc.closeForLostPing() + cc.t.connPool().MarkDead(cc) + return + } +} + +func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { + cc.mu.Lock() + defer cc.mu.Unlock() + + old := cc.goAway + cc.goAway = f + + // Merge the previous and current GoAway error frames. + if cc.goAwayDebug == "" { + cc.goAwayDebug = string(f.DebugData()) + } + if old != nil && old.ErrCode != http2ErrCodeNo { + cc.goAway.ErrCode = old.ErrCode + } + last := f.LastStreamID + for streamID, cs := range cc.streams { + if streamID > last { + select { + case cs.resc <- http2resAndError{err: http2errClientConnGotGoAway}: + default: + } + } + } +} + +// CanTakeNewRequest reports whether the connection can take a new request, +// meaning it has not been closed or received or sent a GOAWAY. +func (cc *http2ClientConn) CanTakeNewRequest() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.canTakeNewRequestLocked() +} + +// clientConnIdleState describes the suitability of a client +// connection to initiate a new RoundTrip request. +type http2clientConnIdleState struct { + canTakeNewRequest bool + freshConn bool // whether it's unused by any previous request +} + +func (cc *http2ClientConn) idleState() http2clientConnIdleState { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.idleStateLocked() +} + +func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) { + if cc.singleUse && cc.nextStreamID > 1 { + return + } + var maxConcurrentOkay bool + if cc.t.StrictMaxConcurrentStreams { + // We'll tell the caller we can take a new request to + // prevent the caller from dialing a new TCP + // connection, but then we'll block later before + // writing it. + maxConcurrentOkay = true + } else { + maxConcurrentOkay = int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) + } + + st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay && + int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 && + !cc.tooIdleLocked() + st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest + return +} + +func (cc *http2ClientConn) canTakeNewRequestLocked() bool { + st := cc.idleStateLocked() + return st.canTakeNewRequest +} + +// tooIdleLocked reports whether this connection has been been sitting idle +// for too much wall time. +func (cc *http2ClientConn) tooIdleLocked() bool { + // The Round(0) strips the monontonic clock reading so the + // times are compared based on their wall time. We don't want + // to reuse a connection that's been sitting idle during + // VM/laptop suspend if monotonic time was also frozen. + return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && time.Since(cc.lastIdle.Round(0)) > cc.idleTimeout +} + +// onIdleTimeout is called from a time.AfterFunc goroutine. It will +// only be called when we're idle, but because we're coming from a new +// goroutine, there could be a new request coming in at the same time, +// so this simply calls the synchronized closeIfIdle to shut down this +// connection. The timer could just call closeIfIdle, but this is more +// clear. +func (cc *http2ClientConn) onIdleTimeout() { + cc.closeIfIdle() +} + +func (cc *http2ClientConn) closeIfIdle() { + cc.mu.Lock() + if len(cc.streams) > 0 { + cc.mu.Unlock() + return + } + cc.closed = true + nextID := cc.nextStreamID + // TODO: do clients send GOAWAY too? maybe? Just Close: + cc.mu.Unlock() + + if http2VerboseLogs { + cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, nextID-2) + } + cc.tconn.Close() +} + +var http2shutdownEnterWaitStateHook = func() {} + +// Shutdown gracefully close the client connection, waiting for running streams to complete. +func (cc *http2ClientConn) Shutdown(ctx context.Context) error { + if err := cc.sendGoAway(); err != nil { + return err + } + // Wait for all in-flight streams to complete or connection to close + done := make(chan error, 1) + cancelled := false // guarded by cc.mu + go func() { + cc.mu.Lock() + defer cc.mu.Unlock() + for { + if len(cc.streams) == 0 || cc.closed { + cc.closed = true + done <- cc.tconn.Close() + break + } + if cancelled { + break + } + cc.cond.Wait() + } + }() + http2shutdownEnterWaitStateHook() + select { + case err := <-done: + return err + case <-ctx.Done(): + cc.mu.Lock() + // Free the goroutine above + cancelled = true + cc.cond.Broadcast() + cc.mu.Unlock() + return ctx.Err() + } +} + +func (cc *http2ClientConn) sendGoAway() error { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.wmu.Lock() + defer cc.wmu.Unlock() + if cc.closing { + // GOAWAY sent already + return nil + } + // Send a graceful shutdown frame to server + maxStreamID := cc.nextStreamID + if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil { + return err + } + if err := cc.bw.Flush(); err != nil { + return err + } + // Prevent new requests + cc.closing = true + return nil +} + +// closes the client connection immediately. In-flight requests are interrupted. +// err is sent to streams. +func (cc *http2ClientConn) closeForError(err error) error { + cc.mu.Lock() + defer cc.cond.Broadcast() + defer cc.mu.Unlock() + for id, cs := range cc.streams { + select { + case cs.resc <- http2resAndError{err: err}: + default: + } + cs.bufPipe.CloseWithError(err) + delete(cc.streams, id) + } + cc.closed = true + return cc.tconn.Close() +} + +// Close closes the client connection immediately. +// +// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. +func (cc *http2ClientConn) Close() error { + err := errors.New("http2: client connection force closed via ClientConn.Close") + return cc.closeForError(err) +} + +// closes the client connection immediately. In-flight requests are interrupted. +func (cc *http2ClientConn) closeForLostPing() error { + err := errors.New("http2: client connection lost") + return cc.closeForError(err) +} + +const http2maxAllocFrameSize = 512 << 10 + +// frameBuffer returns a scratch buffer suitable for writing DATA frames. +// They're capped at the min of the peer's max frame size or 512KB +// (kinda arbitrarily), but definitely capped so we don't allocate 4GB +// bufers. +func (cc *http2ClientConn) frameScratchBuffer() []byte { + cc.mu.Lock() + size := cc.maxFrameSize + if size > http2maxAllocFrameSize { + size = http2maxAllocFrameSize + } + for i, buf := range cc.freeBuf { + if len(buf) >= int(size) { + cc.freeBuf[i] = nil + cc.mu.Unlock() + return buf[:size] + } + } + cc.mu.Unlock() + return make([]byte, size) +} + +func (cc *http2ClientConn) putFrameScratchBuffer(buf []byte) { + cc.mu.Lock() + defer cc.mu.Unlock() + const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate. + if len(cc.freeBuf) < maxBufs { + cc.freeBuf = append(cc.freeBuf, buf) + return + } + for i, old := range cc.freeBuf { + if old == nil { + cc.freeBuf[i] = buf + return + } + } + // forget about it. +} + +// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not +// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. +var http2errRequestCanceled = errors.New("net/http: request canceled") + +func http2commaSeparatedTrailers(req *Request) (string, error) { + keys := make([]string, 0, len(req.Trailer)) + for k := range req.Trailer { + k = CanonicalHeaderKey(k) + switch k { + case "Transfer-Encoding", "Trailer", "Content-Length": + return "", fmt.Errorf("invalid Trailer key %q", k) + } + keys = append(keys, k) + } + if len(keys) > 0 { + sort.Strings(keys) + return strings.Join(keys, ","), nil + } + return "", nil +} + +func (cc *http2ClientConn) responseHeaderTimeout() time.Duration { + if cc.t.t1 != nil { + return cc.t.t1.ResponseHeaderTimeout + } + // No way to do this (yet?) with just an http2.Transport. Probably + // no need. Request.Cancel this is the new way. We only need to support + // this for compatibility with the old http.Transport fields when + // we're doing transparent http2. + return 0 +} + +// checkConnHeaders checks whether req has any invalid connection-level headers. +// per RFC 7540 section 8.1.2.2: Connection-Specific Header Fields. +// Certain headers are special-cased as okay but not transmitted later. +func http2checkConnHeaders(req *Request) error { + if v := req.Header.Get("Upgrade"); v != "" { + return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"]) + } + if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { + return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv) + } + if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !strings.EqualFold(vv[0], "close") && !strings.EqualFold(vv[0], "keep-alive")) { + return fmt.Errorf("http2: invalid Connection request header: %q", vv) + } + return nil +} + +// actualContentLength returns a sanitized version of +// req.ContentLength, where 0 actually means zero (not unknown) and -1 +// means unknown. +func http2actualContentLength(req *Request) int64 { + if req.Body == nil || req.Body == NoBody { + return 0 + } + if req.ContentLength != 0 { + return req.ContentLength + } + return -1 +} + +func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { + resp, _, err := cc.roundTrip(req) + return resp, err +} + +func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterReqBodyWrite bool, err error) { + if err := http2checkConnHeaders(req); err != nil { + return nil, false, err + } + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } + + trailers, err := http2commaSeparatedTrailers(req) + if err != nil { + return nil, false, err + } + hasTrailers := trailers != "" + + cc.mu.Lock() + if err := cc.awaitOpenSlotForRequest(req); err != nil { + cc.mu.Unlock() + return nil, false, err + } + + body := req.Body + contentLen := http2actualContentLength(req) + hasBody := contentLen != 0 + + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? + var requestedGzip bool + if !cc.t.disableCompression() && + req.Header.Get("Accept-Encoding") == "" && + req.Header.Get("Range") == "" && + req.Method != "HEAD" { + // Request gzip only, not deflate. Deflate is ambiguous and + // not as universally supported anyway. + // See: https://zlib.net/zlib_faq.html#faq39 + // + // Note that we don't request this for HEAD requests, + // due to a bug in nginx: + // http://trac.nginx.org/nginx/ticket/358 + // https://golang.org/issue/5522 + // + // We don't request gzip if the request is for a range, since + // auto-decoding a portion of a gzipped document will just fail + // anyway. See https://golang.org/issue/8923 + requestedGzip = true + } + + // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is + // sent by writeRequestBody below, along with any Trailers, + // again in form HEADERS{1}, CONTINUATION{0,}) + hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen) + if err != nil { + cc.mu.Unlock() + return nil, false, err + } + + cs := cc.newStream() + cs.req = req + cs.trace = httptrace.ContextClientTrace(req.Context()) + cs.requestedGzip = requestedGzip + bodyWriter := cc.t.getBodyWriterState(cs, body) + cs.on100 = bodyWriter.on100 + + defer func() { + cc.wmu.Lock() + werr := cc.werr + cc.wmu.Unlock() + if werr != nil { + cc.Close() + } + }() + + cc.wmu.Lock() + endStream := !hasBody && !hasTrailers + werr := cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) + cc.wmu.Unlock() + http2traceWroteHeaders(cs.trace) + cc.mu.Unlock() + + if werr != nil { + if hasBody { + req.Body.Close() // per RoundTripper contract + bodyWriter.cancel() + } + cc.forgetStreamID(cs.ID) + // Don't bother sending a RST_STREAM (our write already failed; + // no need to keep writing) + http2traceWroteRequest(cs.trace, werr) + return nil, false, werr + } + + var respHeaderTimer <-chan time.Time + if hasBody { + bodyWriter.scheduleBodyWrite() + } else { + http2traceWroteRequest(cs.trace, nil) + if d := cc.responseHeaderTimeout(); d != 0 { + timer := time.NewTimer(d) + defer timer.Stop() + respHeaderTimer = timer.C + } + } + + readLoopResCh := cs.resc + bodyWritten := false + ctx := req.Context() + + handleReadLoopResponse := func(re http2resAndError) (*Response, bool, error) { + res := re.res + if re.err != nil || res.StatusCode > 299 { + // On error or status code 3xx, 4xx, 5xx, etc abort any + // ongoing write, assuming that the server doesn't care + // about our request body. If the server replied with 1xx or + // 2xx, however, then assume the server DOES potentially + // want our body (e.g. full-duplex streaming: + // golang.org/issue/13444). If it turns out the server + // doesn't, they'll RST_STREAM us soon enough. This is a + // heuristic to avoid adding knobs to Transport. Hopefully + // we can keep it. + bodyWriter.cancel() + cs.abortRequestBodyWrite(http2errStopReqBodyWrite) + if hasBody && !bodyWritten { + <-bodyWriter.resc + } + } + if re.err != nil { + cc.forgetStreamID(cs.ID) + return nil, cs.getStartedWrite(), re.err + } + res.Request = req + res.TLS = cc.tlsState + return res, false, nil + } + + for { + select { + case re := <-readLoopResCh: + return handleReadLoopResponse(re) + case <-respHeaderTimer: + if !hasBody || bodyWritten { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + } else { + bodyWriter.cancel() + cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) + <-bodyWriter.resc + } + cc.forgetStreamID(cs.ID) + return nil, cs.getStartedWrite(), http2errTimeout + case <-ctx.Done(): + if !hasBody || bodyWritten { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + } else { + bodyWriter.cancel() + cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) + <-bodyWriter.resc + } + cc.forgetStreamID(cs.ID) + return nil, cs.getStartedWrite(), ctx.Err() + case <-req.Cancel: + if !hasBody || bodyWritten { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + } else { + bodyWriter.cancel() + cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) + <-bodyWriter.resc + } + cc.forgetStreamID(cs.ID) + return nil, cs.getStartedWrite(), http2errRequestCanceled + case <-cs.peerReset: + // processResetStream already removed the + // stream from the streams map; no need for + // forgetStreamID. + return nil, cs.getStartedWrite(), cs.resetErr + case err := <-bodyWriter.resc: + bodyWritten = true + // Prefer the read loop's response, if available. Issue 16102. + select { + case re := <-readLoopResCh: + return handleReadLoopResponse(re) + default: + } + if err != nil { + cc.forgetStreamID(cs.ID) + return nil, cs.getStartedWrite(), err + } + if d := cc.responseHeaderTimeout(); d != 0 { + timer := time.NewTimer(d) + defer timer.Stop() + respHeaderTimer = timer.C + } + } + } +} + +// awaitOpenSlotForRequest waits until len(streams) < maxConcurrentStreams. +// Must hold cc.mu. +func (cc *http2ClientConn) awaitOpenSlotForRequest(req *Request) error { + var waitingForConn chan struct{} + var waitingForConnErr error // guarded by cc.mu + for { + cc.lastActive = time.Now() + if cc.closed || !cc.canTakeNewRequestLocked() { + if waitingForConn != nil { + close(waitingForConn) + } + return http2errClientConnUnusable + } + cc.lastIdle = time.Time{} + if int64(len(cc.streams))+1 <= int64(cc.maxConcurrentStreams) { + if waitingForConn != nil { + close(waitingForConn) + } + return nil + } + // Unfortunately, we cannot wait on a condition variable and channel at + // the same time, so instead, we spin up a goroutine to check if the + // request is canceled while we wait for a slot to open in the connection. + if waitingForConn == nil { + waitingForConn = make(chan struct{}) + go func() { + if err := http2awaitRequestCancel(req, waitingForConn); err != nil { + cc.mu.Lock() + waitingForConnErr = err + cc.cond.Broadcast() + cc.mu.Unlock() + } + }() + } + cc.pendingRequests++ + cc.cond.Wait() + cc.pendingRequests-- + if waitingForConnErr != nil { + return waitingForConnErr + } + } +} + +// requires cc.wmu be held +func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize int, hdrs []byte) error { + first := true // first frame written (HEADERS is first, then CONTINUATION) + for len(hdrs) > 0 && cc.werr == nil { + chunk := hdrs + if len(chunk) > maxFrameSize { + chunk = chunk[:maxFrameSize] + } + hdrs = hdrs[len(chunk):] + endHeaders := len(hdrs) == 0 + if first { + cc.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: streamID, + BlockFragment: chunk, + EndStream: endStream, + EndHeaders: endHeaders, + }) + first = false + } else { + cc.fr.WriteContinuation(streamID, endHeaders, chunk) + } + } + // TODO(bradfitz): this Flush could potentially block (as + // could the WriteHeaders call(s) above), which means they + // wouldn't respond to Request.Cancel being readable. That's + // rare, but this should probably be in a goroutine. + cc.bw.Flush() + return cc.werr +} + +// internal error values; they don't escape to callers +var ( + // abort request body write; don't send cancel + http2errStopReqBodyWrite = errors.New("http2: aborting request body write") + + // abort request body write, but send stream reset of cancel. + http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request") + + http2errReqBodyTooLong = errors.New("http2: request body larger than specified content length") +) + +func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) { + cc := cs.cc + sentEnd := false // whether we sent the final DATA frame w/ END_STREAM + buf := cc.frameScratchBuffer() + defer cc.putFrameScratchBuffer(buf) + + defer func() { + http2traceWroteRequest(cs.trace, err) + // TODO: write h12Compare test showing whether + // Request.Body is closed by the Transport, + // and in multiple cases: server replies <=299 and >299 + // while still writing request body + cerr := bodyCloser.Close() + if err == nil { + err = cerr + } + }() + + req := cs.req + hasTrailers := req.Trailer != nil + remainLen := http2actualContentLength(req) + hasContentLen := remainLen != -1 + + var sawEOF bool + for !sawEOF { + n, err := body.Read(buf[:len(buf)-1]) + if hasContentLen { + remainLen -= int64(n) + if remainLen == 0 && err == nil { + // The request body's Content-Length was predeclared and + // we just finished reading it all, but the underlying io.Reader + // returned the final chunk with a nil error (which is one of + // the two valid things a Reader can do at EOF). Because we'd prefer + // to send the END_STREAM bit early, double-check that we're actually + // at EOF. Subsequent reads should return (0, EOF) at this point. + // If either value is different, we return an error in one of two ways below. + var n1 int + n1, err = body.Read(buf[n:]) + remainLen -= int64(n1) + } + if remainLen < 0 { + err = http2errReqBodyTooLong + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) + return err + } + } + if err == io.EOF { + sawEOF = true + err = nil + } else if err != nil { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) + return err + } + + remain := buf[:n] + for len(remain) > 0 && err == nil { + var allowed int32 + allowed, err = cs.awaitFlowControl(len(remain)) + switch { + case err == http2errStopReqBodyWrite: + return err + case err == http2errStopReqBodyWriteAndCancel: + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + return err + case err != nil: + return err + } + cc.wmu.Lock() + data := remain[:allowed] + remain = remain[allowed:] + sentEnd = sawEOF && len(remain) == 0 && !hasTrailers + err = cc.fr.WriteData(cs.ID, sentEnd, data) + if err == nil { + // TODO(bradfitz): this flush is for latency, not bandwidth. + // Most requests won't need this. Make this opt-in or + // opt-out? Use some heuristic on the body type? Nagel-like + // timers? Based on 'n'? Only last chunk of this for loop, + // unless flow control tokens are low? For now, always. + // If we change this, see comment below. + err = cc.bw.Flush() + } + cc.wmu.Unlock() + } + if err != nil { + return err + } + } + + if sentEnd { + // Already sent END_STREAM (which implies we have no + // trailers) and flushed, because currently all + // WriteData frames above get a flush. So we're done. + return nil + } + + var trls []byte + if hasTrailers { + cc.mu.Lock() + trls, err = cc.encodeTrailers(req) + cc.mu.Unlock() + if err != nil { + cc.writeStreamReset(cs.ID, http2ErrCodeInternal, err) + cc.forgetStreamID(cs.ID) + return err + } + } + + cc.mu.Lock() + maxFrameSize := int(cc.maxFrameSize) + cc.mu.Unlock() + + cc.wmu.Lock() + defer cc.wmu.Unlock() + + // Two ways to send END_STREAM: either with trailers, or + // with an empty DATA frame. + if len(trls) > 0 { + err = cc.writeHeaders(cs.ID, true, maxFrameSize, trls) + } else { + err = cc.fr.WriteData(cs.ID, true, nil) + } + if ferr := cc.bw.Flush(); ferr != nil && err == nil { + err = ferr + } + return err +} + +// awaitFlowControl waits for [1, min(maxBytes, cc.cs.maxFrameSize)] flow +// control tokens from the server. +// It returns either the non-zero number of tokens taken or an error +// if the stream is dead. +func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) { + cc := cs.cc + cc.mu.Lock() + defer cc.mu.Unlock() + for { + if cc.closed { + return 0, http2errClientConnClosed + } + if cs.stopReqBody != nil { + return 0, cs.stopReqBody + } + if err := cs.checkResetOrDone(); err != nil { + return 0, err + } + if a := cs.flow.available(); a > 0 { + take := a + if int(take) > maxBytes { + + take = int32(maxBytes) // can't truncate int; take is int32 + } + if take > int32(cc.maxFrameSize) { + take = int32(cc.maxFrameSize) + } + cs.flow.take(take) + return take, nil + } + cc.cond.Wait() + } +} + +// requires cc.mu be held. +func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { + cc.hbuf.Reset() + + host := req.Host + if host == "" { + host = req.URL.Host + } + host, err := httpguts.PunycodeHostPort(host) + if err != nil { + return nil, err + } + + var path string + if req.Method != "CONNECT" { + path = req.URL.RequestURI() + if !http2validPseudoPath(path) { + orig := path + path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) + if !http2validPseudoPath(path) { + if req.URL.Opaque != "" { + return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) + } else { + return nil, fmt.Errorf("invalid request :path %q", orig) + } + } + } + } + + // Check for any invalid headers and return an error before we + // potentially pollute our hpack state. (We want to be able to + // continue to reuse the hpack encoder for future requests) + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("invalid HTTP header name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) + } + } + } + + enumerateHeaders := func(f func(name, value string)) { + // 8.1.2.3 Request Pseudo-Header Fields + // The :path pseudo-header field includes the path and query parts of the + // target URI (the path-absolute production and optionally a '?' character + // followed by the query production (see Sections 3.3 and 3.4 of + // [RFC3986]). + f(":authority", host) + m := req.Method + if m == "" { + m = MethodGet + } + f(":method", m) + if req.Method != "CONNECT" { + f(":path", path) + f(":scheme", req.URL.Scheme) + } + if trailers != "" { + f("trailer", trailers) + } + + var didUA bool + for k, vv := range req.Header { + if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") { + // Host is :authority, already sent. + // Content-Length is automatic, set below. + continue + } else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") || + strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") || + strings.EqualFold(k, "keep-alive") { + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. + continue + } else if strings.EqualFold(k, "user-agent") { + // Match Go's http1 behavior: at most one + // User-Agent. If set to nil or empty string, + // then omit it. Otherwise if not mentioned, + // include the default (below). + didUA = true + if len(vv) < 1 { + continue + } + vv = vv[:1] + if vv[0] == "" { + continue + } + } else if strings.EqualFold(k, "cookie") { + // Per 8.1.2.5 To allow for better compression efficiency, the + // Cookie header field MAY be split into separate header fields, + // each with one or more cookie-pairs. + for _, v := range vv { + for { + p := strings.IndexByte(v, ';') + if p < 0 { + break + } + f("cookie", v[:p]) + p++ + // strip space after semicolon if any. + for p+1 <= len(v) && v[p] == ' ' { + p++ + } + v = v[p:] + } + if len(v) > 0 { + f("cookie", v) + } + } + continue + } + + for _, v := range vv { + f(k, v) + } + } + if http2shouldSendReqContentLength(req.Method, contentLength) { + f("content-length", strconv.FormatInt(contentLength, 10)) + } + if addGzipHeader { + f("accept-encoding", "gzip") + } + if !didUA { + f("user-agent", http2defaultUserAgent) + } + } + + // Do a first pass over the headers counting bytes to ensure + // we don't exceed cc.peerMaxHeaderListSize. This is done as a + // separate pass before encoding the headers to prevent + // modifying the hpack state. + hlSize := uint64(0) + enumerateHeaders(func(name, value string) { + hf := hpack.HeaderField{Name: name, Value: value} + hlSize += uint64(hf.Size()) + }) + + if hlSize > cc.peerMaxHeaderListSize { + return nil, http2errRequestHeaderListSize + } + + trace := httptrace.ContextClientTrace(req.Context()) + traceHeaders := http2traceHasWroteHeaderField(trace) + + // Header list size is ok. Write the headers. + enumerateHeaders(func(name, value string) { + name = strings.ToLower(name) + cc.writeHeader(name, value) + if traceHeaders { + http2traceWroteHeaderField(trace, name, value) + } + }) + + return cc.hbuf.Bytes(), nil +} + +// shouldSendReqContentLength reports whether the http2.Transport should send +// a "content-length" request header. This logic is basically a copy of the net/http +// transferWriter.shouldSendContentLength. +// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). +// -1 means unknown. +func http2shouldSendReqContentLength(method string, contentLength int64) bool { + if contentLength > 0 { + return true + } + if contentLength < 0 { + return false + } + // For zero bodies, whether we send a content-length depends on the method. + // It also kinda doesn't matter for http2 either way, with END_STREAM. + switch method { + case "POST", "PUT", "PATCH": + return true + default: + return false + } +} + +// requires cc.mu be held. +func (cc *http2ClientConn) encodeTrailers(req *Request) ([]byte, error) { + cc.hbuf.Reset() + + hlSize := uint64(0) + for k, vv := range req.Trailer { + for _, v := range vv { + hf := hpack.HeaderField{Name: k, Value: v} + hlSize += uint64(hf.Size()) + } + } + if hlSize > cc.peerMaxHeaderListSize { + return nil, http2errRequestHeaderListSize + } + + for k, vv := range req.Trailer { + // Transfer-Encoding, etc.. have already been filtered at the + // start of RoundTrip + lowKey := strings.ToLower(k) + for _, v := range vv { + cc.writeHeader(lowKey, v) + } + } + return cc.hbuf.Bytes(), nil +} + +func (cc *http2ClientConn) writeHeader(name, value string) { + if http2VerboseLogs { + log.Printf("http2: Transport encoding header %q = %q", name, value) + } + cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) +} + +type http2resAndError struct { + _ http2incomparable + res *Response + err error +} + +// requires cc.mu be held. +func (cc *http2ClientConn) newStream() *http2clientStream { + cs := &http2clientStream{ + cc: cc, + ID: cc.nextStreamID, + resc: make(chan http2resAndError, 1), + peerReset: make(chan struct{}), + done: make(chan struct{}), + } + cs.flow.add(int32(cc.initialWindowSize)) + cs.flow.setConnFlow(&cc.flow) + cs.inflow.add(http2transportDefaultStreamFlow) + cs.inflow.setConnFlow(&cc.inflow) + cc.nextStreamID += 2 + cc.streams[cs.ID] = cs + return cs +} + +func (cc *http2ClientConn) forgetStreamID(id uint32) { + cc.streamByID(id, true) +} + +func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStream { + cc.mu.Lock() + defer cc.mu.Unlock() + cs := cc.streams[id] + if andRemove && cs != nil && !cc.closed { + cc.lastActive = time.Now() + delete(cc.streams, id) + if len(cc.streams) == 0 && cc.idleTimer != nil { + cc.idleTimer.Reset(cc.idleTimeout) + cc.lastIdle = time.Now() + } + close(cs.done) + // Wake up checkResetOrDone via clientStream.awaitFlowControl and + // wake up RoundTrip if there is a pending request. + cc.cond.Broadcast() + } + return cs +} + +// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. +type http2clientConnReadLoop struct { + _ http2incomparable + cc *http2ClientConn + closeWhenIdle bool +} + +// readLoop runs in its own goroutine and reads and dispatches frames. +func (cc *http2ClientConn) readLoop() { + rl := &http2clientConnReadLoop{cc: cc} + defer rl.cleanup() + cc.readerErr = rl.run() + if ce, ok := cc.readerErr.(http2ConnectionError); ok { + cc.wmu.Lock() + cc.fr.WriteGoAway(0, http2ErrCode(ce), nil) + cc.wmu.Unlock() + } +} + +// GoAwayError is returned by the Transport when the server closes the +// TCP connection after sending a GOAWAY frame. +type http2GoAwayError struct { + LastStreamID uint32 + ErrCode http2ErrCode + DebugData string +} + +func (e http2GoAwayError) Error() string { + return fmt.Sprintf("http2: server sent GOAWAY and closed the connection; LastStreamID=%v, ErrCode=%v, debug=%q", + e.LastStreamID, e.ErrCode, e.DebugData) +} + +func http2isEOFOrNetReadError(err error) bool { + if err == io.EOF { + return true + } + ne, ok := err.(*net.OpError) + return ok && ne.Op == "read" +} + +func (rl *http2clientConnReadLoop) cleanup() { + cc := rl.cc + defer cc.tconn.Close() + defer cc.t.connPool().MarkDead(cc) + defer close(cc.readerDone) + + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } + + // Close any response bodies if the server closes prematurely. + // TODO: also do this if we've written the headers but not + // gotten a response yet. + err := cc.readerErr + cc.mu.Lock() + if cc.goAway != nil && http2isEOFOrNetReadError(err) { + err = http2GoAwayError{ + LastStreamID: cc.goAway.LastStreamID, + ErrCode: cc.goAway.ErrCode, + DebugData: cc.goAwayDebug, + } + } else if err == io.EOF { + err = io.ErrUnexpectedEOF + } + for _, cs := range cc.streams { + cs.bufPipe.CloseWithError(err) // no-op if already closed + select { + case cs.resc <- http2resAndError{err: err}: + default: + } + close(cs.done) + } + cc.closed = true + cc.cond.Broadcast() + cc.mu.Unlock() +} + +func (rl *http2clientConnReadLoop) run() error { + cc := rl.cc + rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse + gotReply := false // ever saw a HEADERS reply + gotSettings := false + readIdleTimeout := cc.t.ReadIdleTimeout + var t *time.Timer + if readIdleTimeout != 0 { + t = time.AfterFunc(readIdleTimeout, cc.healthCheck) + defer t.Stop() + } + for { + f, err := cc.fr.ReadFrame() + if t != nil { + t.Reset(readIdleTimeout) + } + if err != nil { + cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) + } + if se, ok := err.(http2StreamError); ok { + if cs := cc.streamByID(se.StreamID, false); cs != nil { + cs.cc.writeStreamReset(cs.ID, se.Code, err) + cs.cc.forgetStreamID(cs.ID) + if se.Cause == nil { + se.Cause = cc.fr.errDetail + } + rl.endStreamError(cs, se) + } + continue + } else if err != nil { + return err + } + if http2VerboseLogs { + cc.vlogf("http2: Transport received %s", http2summarizeFrame(f)) + } + if !gotSettings { + if _, ok := f.(*http2SettingsFrame); !ok { + cc.logf("protocol error: received %T before a SETTINGS frame", f) + return http2ConnectionError(http2ErrCodeProtocol) + } + gotSettings = true + } + maybeIdle := false // whether frame might transition us to idle + + switch f := f.(type) { + case *http2MetaHeadersFrame: + err = rl.processHeaders(f) + maybeIdle = true + gotReply = true + case *http2DataFrame: + err = rl.processData(f) + maybeIdle = true + case *http2GoAwayFrame: + err = rl.processGoAway(f) + maybeIdle = true + case *http2RSTStreamFrame: + err = rl.processResetStream(f) + maybeIdle = true + case *http2SettingsFrame: + err = rl.processSettings(f) + case *http2PushPromiseFrame: + err = rl.processPushPromise(f) + case *http2WindowUpdateFrame: + err = rl.processWindowUpdate(f) + case *http2PingFrame: + err = rl.processPing(f) + default: + cc.logf("Transport: unhandled response frame type %T", f) + } + if err != nil { + if http2VerboseLogs { + cc.vlogf("http2: Transport conn %p received error from processing frame %v: %v", cc, http2summarizeFrame(f), err) + } + return err + } + if rl.closeWhenIdle && gotReply && maybeIdle { + cc.closeIfIdle() + } + } +} + +func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error { + cc := rl.cc + cs := cc.streamByID(f.StreamID, false) + if cs == nil { + // We'd get here if we canceled a request while the + // server had its response still in flight. So if this + // was just something we canceled, ignore it. + return nil + } + if f.StreamEnded() { + // Issue 20521: If the stream has ended, streamByID() causes + // clientStream.done to be closed, which causes the request's bodyWriter + // to be closed with an errStreamClosed, which may be received by + // clientConn.RoundTrip before the result of processing these headers. + // Deferring stream closure allows the header processing to occur first. + // clientConn.RoundTrip may still receive the bodyWriter error first, but + // the fix for issue 16102 prioritises any response. + // + // Issue 22413: If there is no request body, we should close the + // stream before writing to cs.resc so that the stream is closed + // immediately once RoundTrip returns. + if cs.req.Body != nil { + defer cc.forgetStreamID(f.StreamID) + } else { + cc.forgetStreamID(f.StreamID) + } + } + if !cs.firstByte { + if cs.trace != nil { + // TODO(bradfitz): move first response byte earlier, + // when we first read the 9 byte header, not waiting + // until all the HEADERS+CONTINUATION frames have been + // merged. This works for now. + http2traceFirstResponseByte(cs.trace) + } + cs.firstByte = true + } + if !cs.pastHeaders { + cs.pastHeaders = true + } else { + return rl.processTrailers(cs, f) + } + + res, err := rl.handleResponse(cs, f) + if err != nil { + if _, ok := err.(http2ConnectionError); ok { + return err + } + // Any other error type is a stream error. + cs.cc.writeStreamReset(f.StreamID, http2ErrCodeProtocol, err) + cc.forgetStreamID(cs.ID) + cs.resc <- http2resAndError{err: err} + return nil // return nil from process* funcs to keep conn alive + } + if res == nil { + // (nil, nil) special case. See handleResponse docs. + return nil + } + cs.resTrailer = &res.Trailer + cs.resc <- http2resAndError{res: res} + return nil +} + +// may return error types nil, or ConnectionError. Any other error value +// is a StreamError of type ErrCodeProtocol. The returned error in that case +// is the detail. +// +// As a special case, handleResponse may return (nil, nil) to skip the +// frame (currently only used for 1xx responses). +func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http2MetaHeadersFrame) (*Response, error) { + if f.Truncated { + return nil, http2errResponseHeaderListSize + } + + status := f.PseudoValue("status") + if status == "" { + return nil, errors.New("malformed response from server: missing status pseudo header") + } + statusCode, err := strconv.Atoi(status) + if err != nil { + return nil, errors.New("malformed response from server: malformed non-numeric status pseudo header") + } + + regularFields := f.RegularFields() + strs := make([]string, len(regularFields)) + header := make(Header, len(regularFields)) + res := &Response{ + Proto: "HTTP/2.0", + ProtoMajor: 2, + Header: header, + StatusCode: statusCode, + Status: status + " " + StatusText(statusCode), + } + for _, hf := range regularFields { + key := CanonicalHeaderKey(hf.Name) + if key == "Trailer" { + t := res.Trailer + if t == nil { + t = make(Header) + res.Trailer = t + } + http2foreachHeaderElement(hf.Value, func(v string) { + t[CanonicalHeaderKey(v)] = nil + }) + } else { + vv := header[key] + if vv == nil && len(strs) > 0 { + // More than likely this will be a single-element key. + // Most headers aren't multi-valued. + // Set the capacity on strs[0] to 1, so any future append + // won't extend the slice into the other strings. + vv, strs = strs[:1:1], strs[1:] + vv[0] = hf.Value + header[key] = vv + } else { + header[key] = append(vv, hf.Value) + } + } + } + + if statusCode >= 100 && statusCode <= 199 { + cs.num1xx++ + const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http + if cs.num1xx > max1xxResponses { + return nil, errors.New("http2: too many 1xx informational responses") + } + if fn := cs.get1xxTraceFunc(); fn != nil { + if err := fn(statusCode, textproto.MIMEHeader(header)); err != nil { + return nil, err + } + } + if statusCode == 100 { + http2traceGot100Continue(cs.trace) + if cs.on100 != nil { + cs.on100() // forces any write delay timer to fire + } + } + cs.pastHeaders = false // do it all again + return nil, nil + } + + streamEnded := f.StreamEnded() + isHead := cs.req.Method == "HEAD" + if !streamEnded || isHead { + res.ContentLength = -1 + if clens := res.Header["Content-Length"]; len(clens) == 1 { + if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil { + res.ContentLength = int64(cl) + } else { + // TODO: care? unlike http/1, it won't mess up our framing, so it's + // more safe smuggling-wise to ignore. + } + } else if len(clens) > 1 { + // TODO: care? unlike http/1, it won't mess up our framing, so it's + // more safe smuggling-wise to ignore. + } + } + + if streamEnded || isHead { + res.Body = http2noBody + return res, nil + } + + cs.bufPipe = http2pipe{b: &http2dataBuffer{expected: res.ContentLength}} + cs.bytesRemain = res.ContentLength + res.Body = http2transportResponseBody{cs} + go cs.awaitRequestCancel(cs.req) + + if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = &http2gzipReader{body: res.Body} + res.Uncompressed = true + } + return res, nil +} + +func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *http2MetaHeadersFrame) error { + if cs.pastTrailers { + // Too many HEADERS frames for this stream. + return http2ConnectionError(http2ErrCodeProtocol) + } + cs.pastTrailers = true + if !f.StreamEnded() { + // We expect that any headers for trailers also + // has END_STREAM. + return http2ConnectionError(http2ErrCodeProtocol) + } + if len(f.PseudoFields()) > 0 { + // No pseudo header fields are defined for trailers. + // TODO: ConnectionError might be overly harsh? Check. + return http2ConnectionError(http2ErrCodeProtocol) + } + + trailer := make(Header) + for _, hf := range f.RegularFields() { + key := CanonicalHeaderKey(hf.Name) + trailer[key] = append(trailer[key], hf.Value) + } + cs.trailer = trailer + + rl.endStream(cs) + return nil +} + +// transportResponseBody is the concrete type of Transport.RoundTrip's +// Response.Body. It is an io.ReadCloser. On Read, it reads from cs.body. +// On Close it sends RST_STREAM if EOF wasn't already seen. +type http2transportResponseBody struct { + cs *http2clientStream +} + +func (b http2transportResponseBody) Read(p []byte) (n int, err error) { + cs := b.cs + cc := cs.cc + + if cs.readErr != nil { + return 0, cs.readErr + } + n, err = b.cs.bufPipe.Read(p) + if cs.bytesRemain != -1 { + if int64(n) > cs.bytesRemain { + n = int(cs.bytesRemain) + if err == nil { + err = errors.New("net/http: server replied with more than declared Content-Length; truncated") + cc.writeStreamReset(cs.ID, http2ErrCodeProtocol, err) + } + cs.readErr = err + return int(cs.bytesRemain), err + } + cs.bytesRemain -= int64(n) + if err == io.EOF && cs.bytesRemain > 0 { + err = io.ErrUnexpectedEOF + cs.readErr = err + return n, err + } + } + if n == 0 { + // No flow control tokens to send back. + return + } + + cc.mu.Lock() + defer cc.mu.Unlock() + + var connAdd, streamAdd int32 + // Check the conn-level first, before the stream-level. + if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 { + connAdd = http2transportDefaultConnFlow - v + cc.inflow.add(connAdd) + } + if err == nil { // No need to refresh if the stream is over or failed. + // Consider any buffered body data (read from the conn but not + // consumed by the client) when computing flow control for this + // stream. + v := int(cs.inflow.available()) + cs.bufPipe.Len() + if v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { + streamAdd = int32(http2transportDefaultStreamFlow - v) + cs.inflow.add(streamAdd) + } + } + if connAdd != 0 || streamAdd != 0 { + cc.wmu.Lock() + defer cc.wmu.Unlock() + if connAdd != 0 { + cc.fr.WriteWindowUpdate(0, http2mustUint31(connAdd)) + } + if streamAdd != 0 { + cc.fr.WriteWindowUpdate(cs.ID, http2mustUint31(streamAdd)) + } + cc.bw.Flush() + } + return +} + +var http2errClosedResponseBody = errors.New("http2: response body closed") + +func (b http2transportResponseBody) Close() error { + cs := b.cs + cc := cs.cc + + serverSentStreamEnd := cs.bufPipe.Err() == io.EOF + unread := cs.bufPipe.Len() + + if unread > 0 || !serverSentStreamEnd { + cc.mu.Lock() + cc.wmu.Lock() + if !serverSentStreamEnd { + cc.fr.WriteRSTStream(cs.ID, http2ErrCodeCancel) + cs.didReset = true + } + // Return connection-level flow control. + if unread > 0 { + cc.inflow.add(int32(unread)) + cc.fr.WriteWindowUpdate(0, uint32(unread)) + } + cc.bw.Flush() + cc.wmu.Unlock() + cc.mu.Unlock() + } + + cs.bufPipe.BreakWithError(http2errClosedResponseBody) + cc.forgetStreamID(cs.ID) + return nil +} + +func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { + cc := rl.cc + cs := cc.streamByID(f.StreamID, f.StreamEnded()) + data := f.Data() + if cs == nil { + cc.mu.Lock() + neverSent := cc.nextStreamID + cc.mu.Unlock() + if f.StreamID >= neverSent { + // We never asked for this. + cc.logf("http2: Transport received unsolicited DATA frame; closing connection") + return http2ConnectionError(http2ErrCodeProtocol) + } + // We probably did ask for this, but canceled. Just ignore it. + // TODO: be stricter here? only silently ignore things which + // we canceled, but not things which were closed normally + // by the peer? Tough without accumulating too much state. + + // But at least return their flow control: + if f.Length > 0 { + cc.mu.Lock() + cc.inflow.add(int32(f.Length)) + cc.mu.Unlock() + + cc.wmu.Lock() + cc.fr.WriteWindowUpdate(0, uint32(f.Length)) + cc.bw.Flush() + cc.wmu.Unlock() + } + return nil + } + if !cs.firstByte { + cc.logf("protocol error: received DATA before a HEADERS frame") + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + }) + return nil + } + if f.Length > 0 { + if cs.req.Method == "HEAD" && len(data) > 0 { + cc.logf("protocol error: received DATA on a HEAD request") + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + }) + return nil + } + // Check connection-level flow control. + cc.mu.Lock() + if cs.inflow.available() >= int32(f.Length) { + cs.inflow.take(int32(f.Length)) + } else { + cc.mu.Unlock() + return http2ConnectionError(http2ErrCodeFlowControl) + } + // Return any padded flow control now, since we won't + // refund it later on body reads. + var refund int + if pad := int(f.Length) - len(data); pad > 0 { + refund += pad + } + // Return len(data) now if the stream is already closed, + // since data will never be read. + didReset := cs.didReset + if didReset { + refund += len(data) + } + if refund > 0 { + cc.inflow.add(int32(refund)) + cc.wmu.Lock() + cc.fr.WriteWindowUpdate(0, uint32(refund)) + if !didReset { + cs.inflow.add(int32(refund)) + cc.fr.WriteWindowUpdate(cs.ID, uint32(refund)) + } + cc.bw.Flush() + cc.wmu.Unlock() + } + cc.mu.Unlock() + + if len(data) > 0 && !didReset { + if _, err := cs.bufPipe.Write(data); err != nil { + rl.endStreamError(cs, err) + return err + } + } + } + + if f.StreamEnded() { + rl.endStream(cs) + } + return nil +} + +func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { + // TODO: check that any declared content-length matches, like + // server.go's (*stream).endStream method. + rl.endStreamError(cs, nil) +} + +func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err error) { + var code func() + if err == nil { + err = io.EOF + code = cs.copyTrailers + } + if http2isConnectionCloseRequest(cs.req) { + rl.closeWhenIdle = true + } + cs.bufPipe.closeWithErrorAndCode(err, code) + + select { + case cs.resc <- http2resAndError{err: err}: + default: + } +} + +func (cs *http2clientStream) copyTrailers() { + for k, vv := range cs.trailer { + t := cs.resTrailer + if *t == nil { + *t = make(Header) + } + (*t)[k] = vv + } +} + +func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { + cc := rl.cc + cc.t.connPool().MarkDead(cc) + if f.ErrCode != 0 { + // TODO: deal with GOAWAY more. particularly the error code + cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode) + } + cc.setGoAway(f) + return nil +} + +func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error { + cc := rl.cc + cc.mu.Lock() + defer cc.mu.Unlock() + + if f.IsAck() { + if cc.wantSettingsAck { + cc.wantSettingsAck = false + return nil + } + return http2ConnectionError(http2ErrCodeProtocol) + } + + err := f.ForeachSetting(func(s http2Setting) error { + switch s.ID { + case http2SettingMaxFrameSize: + cc.maxFrameSize = s.Val + case http2SettingMaxConcurrentStreams: + cc.maxConcurrentStreams = s.Val + case http2SettingMaxHeaderListSize: + cc.peerMaxHeaderListSize = uint64(s.Val) + case http2SettingInitialWindowSize: + // Values above the maximum flow-control + // window size of 2^31-1 MUST be treated as a + // connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR. + if s.Val > math.MaxInt32 { + return http2ConnectionError(http2ErrCodeFlowControl) + } + + // Adjust flow control of currently-open + // frames by the difference of the old initial + // window size and this one. + delta := int32(s.Val) - int32(cc.initialWindowSize) + for _, cs := range cc.streams { + cs.flow.add(delta) + } + cc.cond.Broadcast() + + cc.initialWindowSize = s.Val + default: + // TODO(bradfitz): handle more settings? SETTINGS_HEADER_TABLE_SIZE probably. + cc.vlogf("Unhandled Setting: %v", s) + } + return nil + }) + if err != nil { + return err + } + + cc.wmu.Lock() + defer cc.wmu.Unlock() + + cc.fr.WriteSettingsAck() + cc.bw.Flush() + return cc.werr +} + +func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error { + cc := rl.cc + cs := cc.streamByID(f.StreamID, false) + if f.StreamID != 0 && cs == nil { + return nil + } + + cc.mu.Lock() + defer cc.mu.Unlock() + + fl := &cc.flow + if cs != nil { + fl = &cs.flow + } + if !fl.add(int32(f.Increment)) { + return http2ConnectionError(http2ErrCodeFlowControl) + } + cc.cond.Broadcast() + return nil +} + +func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error { + cs := rl.cc.streamByID(f.StreamID, true) + if cs == nil { + // TODO: return error if server tries to RST_STEAM an idle stream + return nil + } + select { + case <-cs.peerReset: + // Already reset. + // This is the only goroutine + // which closes this, so there + // isn't a race. + default: + err := http2streamError(cs.ID, f.ErrCode) + cs.resetErr = err + close(cs.peerReset) + cs.bufPipe.CloseWithError(err) + cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl + } + return nil +} + +// Ping sends a PING frame to the server and waits for the ack. +func (cc *http2ClientConn) Ping(ctx context.Context) error { + c := make(chan struct{}) + // Generate a random payload + var p [8]byte + for { + if _, err := rand.Read(p[:]); err != nil { + return err + } + cc.mu.Lock() + // check for dup before insert + if _, found := cc.pings[p]; !found { + cc.pings[p] = c + cc.mu.Unlock() + break + } + cc.mu.Unlock() + } + cc.wmu.Lock() + if err := cc.fr.WritePing(false, p); err != nil { + cc.wmu.Unlock() + return err + } + if err := cc.bw.Flush(); err != nil { + cc.wmu.Unlock() + return err + } + cc.wmu.Unlock() + select { + case <-c: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-cc.readerDone: + // connection closed + return cc.readerErr + } +} + +func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error { + if f.IsAck() { + cc := rl.cc + cc.mu.Lock() + defer cc.mu.Unlock() + // If ack, notify listener if any + if c, ok := cc.pings[f.Data]; ok { + close(c) + delete(cc.pings, f.Data) + } + return nil + } + cc := rl.cc + cc.wmu.Lock() + defer cc.wmu.Unlock() + if err := cc.fr.WritePing(true, f.Data); err != nil { + return err + } + return cc.bw.Flush() +} + +func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) error { + // We told the peer we don't want them. + // Spec says: + // "PUSH_PROMISE MUST NOT be sent if the SETTINGS_ENABLE_PUSH + // setting of the peer endpoint is set to 0. An endpoint that + // has set this setting and has received acknowledgement MUST + // treat the receipt of a PUSH_PROMISE frame as a connection + // error (Section 5.4.1) of type PROTOCOL_ERROR." + return http2ConnectionError(http2ErrCodeProtocol) +} + +func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) { + // TODO: map err to more interesting error codes, once the + // HTTP community comes up with some. But currently for + // RST_STREAM there's no equivalent to GOAWAY frame's debug + // data, and the error codes are all pretty vague ("cancel"). + cc.wmu.Lock() + cc.fr.WriteRSTStream(streamID, code) + cc.bw.Flush() + cc.wmu.Unlock() +} + +var ( + http2errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") + http2errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit") +) + +func (cc *http2ClientConn) logf(format string, args ...interface{}) { + cc.t.logf(format, args...) +} + +func (cc *http2ClientConn) vlogf(format string, args ...interface{}) { + cc.t.vlogf(format, args...) +} + +func (t *http2Transport) vlogf(format string, args ...interface{}) { + if http2VerboseLogs { + t.logf(format, args...) + } +} + +func (t *http2Transport) logf(format string, args ...interface{}) { + log.Printf(format, args...) +} + +var http2noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) + +func http2strSliceContains(ss []string, s string) bool { + for _, v := range ss { + if v == s { + return true + } + } + return false +} + +type http2erringRoundTripper struct{ err error } + +func (rt http2erringRoundTripper) RoundTripErr() error { return rt.err } + +func (rt http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err } + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type http2gzipReader struct { + _ http2incomparable + body io.ReadCloser // underlying Response.Body + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // sticky error +} + +func (gz *http2gzipReader) Read(p []byte) (n int, err error) { + if gz.zerr != nil { + return 0, gz.zerr + } + if gz.zr == nil { + gz.zr, err = gzip.NewReader(gz.body) + if err != nil { + gz.zerr = err + return 0, err + } + } + return gz.zr.Read(p) +} + +func (gz *http2gzipReader) Close() error { + return gz.body.Close() +} + +type http2errorReader struct{ err error } + +func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err } + +// bodyWriterState encapsulates various state around the Transport's writing +// of the request body, particularly regarding doing delayed writes of the body +// when the request contains "Expect: 100-continue". +type http2bodyWriterState struct { + cs *http2clientStream + timer *time.Timer // if non-nil, we're doing a delayed write + fnonce *sync.Once // to call fn with + fn func() // the code to run in the goroutine, writing the body + resc chan error // result of fn's execution + delay time.Duration // how long we should delay a delayed write for +} + +func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reader) (s http2bodyWriterState) { + s.cs = cs + if body == nil { + return + } + resc := make(chan error, 1) + s.resc = resc + s.fn = func() { + cs.cc.mu.Lock() + cs.startedWrite = true + cs.cc.mu.Unlock() + resc <- cs.writeRequestBody(body, cs.req.Body) + } + s.delay = t.expectContinueTimeout() + if s.delay == 0 || + !httpguts.HeaderValuesContainsToken( + cs.req.Header["Expect"], + "100-continue") { + return + } + s.fnonce = new(sync.Once) + + // Arm the timer with a very large duration, which we'll + // intentionally lower later. It has to be large now because + // we need a handle to it before writing the headers, but the + // s.delay value is defined to not start until after the + // request headers were written. + const hugeDuration = 365 * 24 * time.Hour + s.timer = time.AfterFunc(hugeDuration, func() { + s.fnonce.Do(s.fn) + }) + return +} + +func (s http2bodyWriterState) cancel() { + if s.timer != nil { + if s.timer.Stop() { + s.resc <- nil + } + } +} + +func (s http2bodyWriterState) on100() { + if s.timer == nil { + // If we didn't do a delayed write, ignore the server's + // bogus 100 continue response. + return + } + s.timer.Stop() + go func() { s.fnonce.Do(s.fn) }() +} + +// scheduleBodyWrite starts writing the body, either immediately (in +// the common case) or after the delay timeout. It should not be +// called until after the headers have been written. +func (s http2bodyWriterState) scheduleBodyWrite() { + if s.timer == nil { + // We're not doing a delayed write (see + // getBodyWriterState), so just start the writing + // goroutine immediately. + go s.fn() + return + } + http2traceWait100Continue(s.cs.trace) + if s.timer.Stop() { + s.timer.Reset(s.delay) + } +} + +// isConnectionCloseRequest reports whether req should use its own +// connection for a single request and then close the connection. +func http2isConnectionCloseRequest(req *Request) bool { + return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close") +} + +// registerHTTPSProtocol calls Transport.RegisterProtocol but +// converting panics into errors. +func http2registerHTTPSProtocol(t *Transport, rt http2noDialH2RoundTripper) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("%v", e) + } + }() + t.RegisterProtocol("https", rt) + return nil +} + +// noDialH2RoundTripper is a RoundTripper which only tries to complete the request +// if there's already has a cached connection to the host. +// (The field is exported so it can be accessed via reflect from net/http; tested +// by TestNoDialH2RoundTripperType) +type http2noDialH2RoundTripper struct{ *http2Transport } + +func (rt http2noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) { + res, err := rt.http2Transport.RoundTrip(req) + if http2isNoCachedConnError(err) { + return nil, ErrSkipAltProtocol + } + return res, err +} + +func (t *http2Transport) idleConnTimeout() time.Duration { + if t.t1 != nil { + return t.t1.IdleConnTimeout + } + return 0 +} + +func http2traceGetConn(req *Request, hostPort string) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.GetConn == nil { + return + } + trace.GetConn(hostPort) +} + +func http2traceGotConn(req *Request, cc *http2ClientConn, reused bool) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.GotConn == nil { + return + } + ci := httptrace.GotConnInfo{Conn: cc.tconn} + ci.Reused = reused + cc.mu.Lock() + ci.WasIdle = len(cc.streams) == 0 && reused + if ci.WasIdle && !cc.lastActive.IsZero() { + ci.IdleTime = time.Now().Sub(cc.lastActive) + } + cc.mu.Unlock() + + trace.GotConn(ci) +} + +func http2traceWroteHeaders(trace *httptrace.ClientTrace) { + if trace != nil && trace.WroteHeaders != nil { + trace.WroteHeaders() + } +} + +func http2traceGot100Continue(trace *httptrace.ClientTrace) { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() + } +} + +func http2traceWait100Continue(trace *httptrace.ClientTrace) { + if trace != nil && trace.Wait100Continue != nil { + trace.Wait100Continue() + } +} + +func http2traceWroteRequest(trace *httptrace.ClientTrace, err error) { + if trace != nil && trace.WroteRequest != nil { + trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) + } +} + +func http2traceFirstResponseByte(trace *httptrace.ClientTrace) { + if trace != nil && trace.GotFirstResponseByte != nil { + trace.GotFirstResponseByte() + } +} + +// writeFramer is implemented by any type that is used to write frames. +type http2writeFramer interface { + writeFrame(http2writeContext) error + + // staysWithinBuffer reports whether this writer promises that + // it will only write less than or equal to size bytes, and it + // won't Flush the write context. + staysWithinBuffer(size int) bool +} + +// writeContext is the interface needed by the various frame writer +// types below. All the writeFrame methods below are scheduled via the +// frame writing scheduler (see writeScheduler in writesched.go). +// +// This interface is implemented by *serverConn. +// +// TODO: decide whether to a) use this in the client code (which didn't +// end up using this yet, because it has a simpler design, not +// currently implementing priorities), or b) delete this and +// make the server code a bit more concrete. +type http2writeContext interface { + Framer() *http2Framer + Flush() error + CloseConn() error + // HeaderEncoder returns an HPACK encoder that writes to the + // returned buffer. + HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) +} + +// writeEndsStream reports whether w writes a frame that will transition +// the stream to a half-closed local state. This returns false for RST_STREAM, +// which closes the entire stream (not just the local half). +func http2writeEndsStream(w http2writeFramer) bool { + switch v := w.(type) { + case *http2writeData: + return v.endStream + case *http2writeResHeaders: + return v.endStream + case nil: + // This can only happen if the caller reuses w after it's + // been intentionally nil'ed out to prevent use. Keep this + // here to catch future refactoring breaking it. + panic("writeEndsStream called on nil writeFramer") + } + return false +} + +type http2flushFrameWriter struct{} + +func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error { + return ctx.Flush() +} + +func (http2flushFrameWriter) staysWithinBuffer(max int) bool { return false } + +type http2writeSettings []http2Setting + +func (s http2writeSettings) staysWithinBuffer(max int) bool { + const settingSize = 6 // uint16 + uint32 + return http2frameHeaderLen+settingSize*len(s) <= max + +} + +func (s http2writeSettings) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteSettings([]http2Setting(s)...) +} + +type http2writeGoAway struct { + maxStreamID uint32 + code http2ErrCode +} + +func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error { + err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil) + ctx.Flush() // ignore error: we're hanging up on them anyway + return err +} + +func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } // flushes + +type http2writeData struct { + streamID uint32 + p []byte + endStream bool +} + +func (w *http2writeData) String() string { + return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream) +} + +func (w *http2writeData) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) +} + +func (w *http2writeData) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.p) <= max +} + +// handlerPanicRST is the message sent from handler goroutines when +// the handler panics. +type http2handlerPanicRST struct { + StreamID uint32 +} + +func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal) +} + +func (hp http2handlerPanicRST) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + +func (se http2StreamError) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) +} + +func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + +type http2writePingAck struct{ pf *http2PingFrame } + +func (w http2writePingAck) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WritePing(true, w.pf.Data) +} + +func (w http2writePingAck) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.pf.Data) <= max +} + +type http2writeSettingsAck struct{} + +func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteSettingsAck() +} + +func (http2writeSettingsAck) staysWithinBuffer(max int) bool { return http2frameHeaderLen <= max } + +// splitHeaderBlock splits headerBlock into fragments so that each fragment fits +// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true +// for the first/last fragment, respectively. +func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error) error { + // For now we're lazy and just pick the minimum MAX_FRAME_SIZE + // that all peers must support (16KB). Later we could care + // more and send larger frames if the peer advertised it, but + // there's little point. Most headers are small anyway (so we + // generally won't have CONTINUATION frames), and extra frames + // only waste 9 bytes anyway. + const maxFrameSize = 16384 + + first := true + for len(headerBlock) > 0 { + frag := headerBlock + if len(frag) > maxFrameSize { + frag = frag[:maxFrameSize] + } + headerBlock = headerBlock[len(frag):] + if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil { + return err + } + first = false + } + return nil +} + +// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames +// for HTTP response headers or trailers from a server handler. +type http2writeResHeaders struct { + streamID uint32 + httpResCode int // 0 means no ":status" line + h Header // may be nil + trailers []string // if non-nil, which keys of h to write. nil means all. + endStream bool + + date string + contentType string + contentLength string +} + +func http2encKV(enc *hpack.Encoder, k, v string) { + if http2VerboseLogs { + log.Printf("http2: server encoding header %q = %q", k, v) + } + enc.WriteField(hpack.HeaderField{Name: k, Value: v}) +} + +func (w *http2writeResHeaders) staysWithinBuffer(max int) bool { + // TODO: this is a common one. It'd be nice to return true + // here and get into the fast path if we could be clever and + // calculate the size fast enough, or at least a conservative + // upper bound that usually fires. (Maybe if w.h and + // w.trailers are nil, so we don't need to enumerate it.) + // Otherwise I'm afraid that just calculating the length to + // answer this question would be slower than the ~2µs benefit. + return false +} + +func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + + if w.httpResCode != 0 { + http2encKV(enc, ":status", http2httpCodeString(w.httpResCode)) + } + + http2encodeHeaders(enc, w.h, w.trailers) + + if w.contentType != "" { + http2encKV(enc, "content-type", w.contentType) + } + if w.contentLength != "" { + http2encKV(enc, "content-length", w.contentLength) + } + if w.date != "" { + http2encKV(enc, "date", w.date) + } + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 && w.trailers == nil { + panic("unexpected empty hpack") + } + + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} + +func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: frag, + EndStream: w.endStream, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + } +} + +// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames. +type http2writePushPromise struct { + streamID uint32 // pusher stream + method string // for :method + url *url.URL // for :scheme, :authority, :path + h Header + + // Creates an ID for a pushed stream. This runs on serveG just before + // the frame is written. The returned ID is copied to promisedID. + allocatePromisedID func() (uint32, error) + promisedID uint32 +} + +func (w *http2writePushPromise) staysWithinBuffer(max int) bool { + // TODO: see writeResHeaders.staysWithinBuffer + return false +} + +func (w *http2writePushPromise) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + + http2encKV(enc, ":method", w.method) + http2encKV(enc, ":scheme", w.url.Scheme) + http2encKV(enc, ":authority", w.url.Host) + http2encKV(enc, ":path", w.url.RequestURI()) + http2encodeHeaders(enc, w.h, nil) + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 { + panic("unexpected empty hpack") + } + + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} + +func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WritePushPromise(http2PushPromiseParam{ + StreamID: w.streamID, + PromiseID: w.promisedID, + BlockFragment: frag, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + } +} + +type http2write100ContinueHeadersFrame struct { + streamID uint32 +} + +func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + http2encKV(enc, ":status", "100") + return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: buf.Bytes(), + EndStream: false, + EndHeaders: true, + }) +} + +func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { + // Sloppy but conservative: + return 9+2*(len(":status")+len("100")) <= max +} + +type http2writeWindowUpdate struct { + streamID uint32 // or 0 for conn-level + n uint32 +} + +func (wu http2writeWindowUpdate) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + +func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) +} + +// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) +// is encoded only if k is in keys. +func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { + if keys == nil { + sorter := http2sorterPool.Get().(*http2sorter) + // Using defer here, since the returned keys from the + // sorter.Keys method is only valid until the sorter + // is returned: + defer http2sorterPool.Put(sorter) + keys = sorter.Keys(h) + } + for _, k := range keys { + vv := h[k] + k = http2lowerHeader(k) + if !http2validWireHeaderFieldName(k) { + // Skip it as backup paranoia. Per + // golang.org/issue/14048, these should + // already be rejected at a higher level. + continue + } + isTE := k == "transfer-encoding" + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // TODO: return an error? golang.org/issue/14048 + // For now just omit it. + continue + } + // TODO: more of "8.1.2.2 Connection-Specific Header Fields" + if isTE && v != "trailers" { + continue + } + http2encKV(enc, k, v) + } + } +} + +// WriteScheduler is the interface implemented by HTTP/2 write schedulers. +// Methods are never called concurrently. +type http2WriteScheduler interface { + // OpenStream opens a new stream in the write scheduler. + // It is illegal to call this with streamID=0 or with a streamID that is + // already open -- the call may panic. + OpenStream(streamID uint32, options http2OpenStreamOptions) + + // CloseStream closes a stream in the write scheduler. Any frames queued on + // this stream should be discarded. It is illegal to call this on a stream + // that is not open -- the call may panic. + CloseStream(streamID uint32) + + // AdjustStream adjusts the priority of the given stream. This may be called + // on a stream that has not yet been opened or has been closed. Note that + // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See: + // https://tools.ietf.org/html/rfc7540#section-5.1 + AdjustStream(streamID uint32, priority http2PriorityParam) + + // Push queues a frame in the scheduler. In most cases, this will not be + // called with wr.StreamID()!=0 unless that stream is currently open. The one + // exception is RST_STREAM frames, which may be sent on idle or closed streams. + Push(wr http2FrameWriteRequest) + + // Pop dequeues the next frame to write. Returns false if no frames can + // be written. Frames with a given wr.StreamID() are Pop'd in the same + // order they are Push'd. No frames should be discarded except by CloseStream. + Pop() (wr http2FrameWriteRequest, ok bool) +} + +// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream. +type http2OpenStreamOptions struct { + // PusherID is zero if the stream was initiated by the client. Otherwise, + // PusherID names the stream that pushed the newly opened stream. + PusherID uint32 +} + +// FrameWriteRequest is a request to write a frame. +type http2FrameWriteRequest struct { + // write is the interface value that does the writing, once the + // WriteScheduler has selected this frame to write. The write + // functions are all defined in write.go. + write http2writeFramer + + // stream is the stream on which this frame will be written. + // nil for non-stream frames like PING and SETTINGS. + stream *http2stream + + // done, if non-nil, must be a buffered channel with space for + // 1 message and is sent the return value from write (or an + // earlier error) when the frame has been written. + done chan error +} + +// StreamID returns the id of the stream this frame will be written to. +// 0 is used for non-stream frames such as PING and SETTINGS. +func (wr http2FrameWriteRequest) StreamID() uint32 { + if wr.stream == nil { + if se, ok := wr.write.(http2StreamError); ok { + // (*serverConn).resetStream doesn't set + // stream because it doesn't necessarily have + // one. So special case this type of write + // message. + return se.StreamID + } + return 0 + } + return wr.stream.id +} + +// isControl reports whether wr is a control frame for MaxQueuedControlFrames +// purposes. That includes non-stream frames and RST_STREAM frames. +func (wr http2FrameWriteRequest) isControl() bool { + return wr.stream == nil +} + +// DataSize returns the number of flow control bytes that must be consumed +// to write this entire frame. This is 0 for non-DATA frames. +func (wr http2FrameWriteRequest) DataSize() int { + if wd, ok := wr.write.(*http2writeData); ok { + return len(wd.p) + } + return 0 +} + +// Consume consumes min(n, available) bytes from this frame, where available +// is the number of flow control bytes available on the stream. Consume returns +// 0, 1, or 2 frames, where the integer return value gives the number of frames +// returned. +// +// If flow control prevents consuming any bytes, this returns (_, _, 0). If +// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this +// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and +// 'rest' contains the remaining bytes. The consumed bytes are deducted from the +// underlying stream's flow control budget. +func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) { + var empty http2FrameWriteRequest + + // Non-DATA frames are always consumed whole. + wd, ok := wr.write.(*http2writeData) + if !ok || len(wd.p) == 0 { + return wr, empty, 1 + } + + // Might need to split after applying limits. + allowed := wr.stream.flow.available() + if n < allowed { + allowed = n + } + if wr.stream.sc.maxFrameSize < allowed { + allowed = wr.stream.sc.maxFrameSize + } + if allowed <= 0 { + return empty, empty, 0 + } + if len(wd.p) > int(allowed) { + wr.stream.flow.take(allowed) + consumed := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[:allowed], + // Even if the original had endStream set, there + // are bytes remaining because len(wd.p) > allowed, + // so we know endStream is false. + endStream: false, + }, + // Our caller is blocking on the final DATA frame, not + // this intermediate frame, so no need to wait. + done: nil, + } + rest := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[allowed:], + endStream: wd.endStream, + }, + done: wr.done, + } + return consumed, rest, 2 + } + + // The frame is consumed whole. + // NB: This cast cannot overflow because allowed is <= math.MaxInt32. + wr.stream.flow.take(int32(len(wd.p))) + return wr, empty, 1 +} + +// String is for debugging only. +func (wr http2FrameWriteRequest) String() string { + var des string + if s, ok := wr.write.(fmt.Stringer); ok { + des = s.String() + } else { + des = fmt.Sprintf("%T", wr.write) + } + return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des) +} + +// replyToWriter sends err to wr.done and panics if the send must block +// This does nothing if wr.done is nil. +func (wr *http2FrameWriteRequest) replyToWriter(err error) { + if wr.done == nil { + return + } + select { + case wr.done <- err: + default: + panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write)) + } + wr.write = nil // prevent use (assume it's tainted after wr.done send) +} + +// writeQueue is used by implementations of WriteScheduler. +type http2writeQueue struct { + s []http2FrameWriteRequest +} + +func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } + +func (q *http2writeQueue) push(wr http2FrameWriteRequest) { + q.s = append(q.s, wr) +} + +func (q *http2writeQueue) shift() http2FrameWriteRequest { + if len(q.s) == 0 { + panic("invalid use of queue") + } + wr := q.s[0] + // TODO: less copy-happy queue. + copy(q.s, q.s[1:]) + q.s[len(q.s)-1] = http2FrameWriteRequest{} + q.s = q.s[:len(q.s)-1] + return wr +} + +// consume consumes up to n bytes from q.s[0]. If the frame is +// entirely consumed, it is removed from the queue. If the frame +// is partially consumed, the frame is kept with the consumed +// bytes removed. Returns true iff any bytes were consumed. +func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) { + if len(q.s) == 0 { + return http2FrameWriteRequest{}, false + } + consumed, rest, numresult := q.s[0].Consume(n) + switch numresult { + case 0: + return http2FrameWriteRequest{}, false + case 1: + q.shift() + case 2: + q.s[0] = rest + } + return consumed, true +} + +type http2writeQueuePool []*http2writeQueue + +// put inserts an unused writeQueue into the pool. + +// put inserts an unused writeQueue into the pool. +func (p *http2writeQueuePool) put(q *http2writeQueue) { + for i := range q.s { + q.s[i] = http2FrameWriteRequest{} + } + q.s = q.s[:0] + *p = append(*p, q) +} + +// get returns an empty writeQueue. +func (p *http2writeQueuePool) get() *http2writeQueue { + ln := len(*p) + if ln == 0 { + return new(http2writeQueue) + } + x := ln - 1 + q := (*p)[x] + (*p)[x] = nil + *p = (*p)[:x] + return q +} + +// RFC 7540, Section 5.3.5: the default weight is 16. +const http2priorityDefaultWeight = 15 // 16 = 15 + 1 + +// PriorityWriteSchedulerConfig configures a priorityWriteScheduler. +type http2PriorityWriteSchedulerConfig struct { + // MaxClosedNodesInTree controls the maximum number of closed streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // "It is possible for a stream to become closed while prioritization + // information ... is in transit. ... This potentially creates suboptimal + // prioritization, since the stream could be given a priority that is + // different from what is intended. To avoid these problems, an endpoint + // SHOULD retain stream prioritization state for a period after streams + // become closed. The longer state is retained, the lower the chance that + // streams are assigned incorrect or default priority values." + MaxClosedNodesInTree int + + // MaxIdleNodesInTree controls the maximum number of idle streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // Similarly, streams that are in the "idle" state can be assigned + // priority or become a parent of other streams. This allows for the + // creation of a grouping node in the dependency tree, which enables + // more flexible expressions of priority. Idle streams begin with a + // default priority (Section 5.3.5). + MaxIdleNodesInTree int + + // ThrottleOutOfOrderWrites enables write throttling to help ensure that + // data is delivered in priority order. This works around a race where + // stream B depends on stream A and both streams are about to call Write + // to queue DATA frames. If B wins the race, a naive scheduler would eagerly + // write as much data from B as possible, but this is suboptimal because A + // is a higher-priority stream. With throttling enabled, we write a small + // amount of data from B to minimize the amount of bandwidth that B can + // steal from A. + ThrottleOutOfOrderWrites bool +} + +// NewPriorityWriteScheduler constructs a WriteScheduler that schedules +// frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3. +// If cfg is nil, default options are used. +func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler { + if cfg == nil { + // For justification of these defaults, see: + // https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY + cfg = &http2PriorityWriteSchedulerConfig{ + MaxClosedNodesInTree: 10, + MaxIdleNodesInTree: 10, + ThrottleOutOfOrderWrites: false, + } + } + + ws := &http2priorityWriteScheduler{ + nodes: make(map[uint32]*http2priorityNode), + maxClosedNodesInTree: cfg.MaxClosedNodesInTree, + maxIdleNodesInTree: cfg.MaxIdleNodesInTree, + enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, + } + ws.nodes[0] = &ws.root + if cfg.ThrottleOutOfOrderWrites { + ws.writeThrottleLimit = 1024 + } else { + ws.writeThrottleLimit = math.MaxInt32 + } + return ws +} + +type http2priorityNodeState int + +const ( + http2priorityNodeOpen http2priorityNodeState = iota + http2priorityNodeClosed + http2priorityNodeIdle +) + +// priorityNode is a node in an HTTP/2 priority tree. +// Each node is associated with a single stream ID. +// See RFC 7540, Section 5.3. +type http2priorityNode struct { + q http2writeQueue // queue of pending frames to write + id uint32 // id of the stream, or 0 for the root of the tree + weight uint8 // the actual weight is weight+1, so the value is in [1,256] + state http2priorityNodeState // open | closed | idle + bytes int64 // number of bytes written by this node, or 0 if closed + subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree + + // These links form the priority tree. + parent *http2priorityNode + kids *http2priorityNode // start of the kids list + prev, next *http2priorityNode // doubly-linked list of siblings +} + +func (n *http2priorityNode) setParent(parent *http2priorityNode) { + if n == parent { + panic("setParent to self") + } + if n.parent == parent { + return + } + // Unlink from current parent. + if parent := n.parent; parent != nil { + if n.prev == nil { + parent.kids = n.next + } else { + n.prev.next = n.next + } + if n.next != nil { + n.next.prev = n.prev + } + } + // Link to new parent. + // If parent=nil, remove n from the tree. + // Always insert at the head of parent.kids (this is assumed by walkReadyInOrder). + n.parent = parent + if parent == nil { + n.next = nil + n.prev = nil + } else { + n.next = parent.kids + n.prev = nil + if n.next != nil { + n.next.prev = n + } + parent.kids = n + } +} + +func (n *http2priorityNode) addBytes(b int64) { + n.bytes += b + for ; n != nil; n = n.parent { + n.subtreeBytes += b + } +} + +// walkReadyInOrder iterates over the tree in priority order, calling f for each node +// with a non-empty write queue. When f returns true, this function returns true and the +// walk halts. tmp is used as scratch space for sorting. +// +// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true +// if any ancestor p of n is still open (ignoring the root node). +func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2priorityNode, f func(*http2priorityNode, bool) bool) bool { + if !n.q.empty() && f(n, openParent) { + return true + } + if n.kids == nil { + return false + } + + // Don't consider the root "open" when updating openParent since + // we can't send data frames on the root stream (only control frames). + if n.id != 0 { + openParent = openParent || (n.state == http2priorityNodeOpen) + } + + // Common case: only one kid or all kids have the same weight. + // Some clients don't use weights; other clients (like web browsers) + // use mostly-linear priority trees. + w := n.kids.weight + needSort := false + for k := n.kids.next; k != nil; k = k.next { + if k.weight != w { + needSort = true + break + } + } + if !needSort { + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true + } + } + return false + } + + // Uncommon case: sort the child nodes. We remove the kids from the parent, + // then re-insert after sorting so we can reuse tmp for future sort calls. + *tmp = (*tmp)[:0] + for n.kids != nil { + *tmp = append(*tmp, n.kids) + n.kids.setParent(nil) + } + sort.Sort(http2sortPriorityNodeSiblings(*tmp)) + for i := len(*tmp) - 1; i >= 0; i-- { + (*tmp)[i].setParent(n) // setParent inserts at the head of n.kids + } + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true + } + } + return false +} + +type http2sortPriorityNodeSiblings []*http2priorityNode + +func (z http2sortPriorityNodeSiblings) Len() int { return len(z) } + +func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } + +func (z http2sortPriorityNodeSiblings) Less(i, k int) bool { + // Prefer the subtree that has sent fewer bytes relative to its weight. + // See sections 5.3.2 and 5.3.4. + wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) + wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) + if bi == 0 && bk == 0 { + return wi >= wk + } + if bk == 0 { + return false + } + return bi/bk <= wi/wk +} + +type http2priorityWriteScheduler struct { + // root is the root of the priority tree, where root.id = 0. + // The root queues control frames that are not associated with any stream. + root http2priorityNode + + // nodes maps stream ids to priority tree nodes. + nodes map[uint32]*http2priorityNode + + // maxID is the maximum stream id in nodes. + maxID uint32 + + // lists of nodes that have been closed or are idle, but are kept in + // the tree for improved prioritization. When the lengths exceed either + // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. + closedNodes, idleNodes []*http2priorityNode + + // From the config. + maxClosedNodesInTree int + maxIdleNodesInTree int + writeThrottleLimit int32 + enableWriteThrottle bool + + // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. + tmp []*http2priorityNode + + // pool of empty queues for reuse. + queuePool http2writeQueuePool +} + +func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + // The stream may be currently idle but cannot be opened or closed. + if curr := ws.nodes[streamID]; curr != nil { + if curr.state != http2priorityNodeIdle { + panic(fmt.Sprintf("stream %d already opened", streamID)) + } + curr.state = http2priorityNodeOpen + return + } + + // RFC 7540, Section 5.3.5: + // "All streams are initially assigned a non-exclusive dependency on stream 0x0. + // Pushed streams initially depend on their associated stream. In both cases, + // streams are assigned a default weight of 16." + parent := ws.nodes[options.PusherID] + if parent == nil { + parent = &ws.root + } + n := &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeOpen, + } + n.setParent(parent) + ws.nodes[streamID] = n + if streamID > ws.maxID { + ws.maxID = streamID + } +} + +func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) { + if streamID == 0 { + panic("violation of WriteScheduler interface: cannot close stream 0") + } + if ws.nodes[streamID] == nil { + panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) + } + if ws.nodes[streamID].state != http2priorityNodeOpen { + panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) + } + + n := ws.nodes[streamID] + n.state = http2priorityNodeClosed + n.addBytes(-n.bytes) + + q := n.q + ws.queuePool.put(&q) + n.q.s = nil + if ws.maxClosedNodesInTree > 0 { + ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n) + } else { + ws.removeNode(n) + } +} + +func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { + if streamID == 0 { + panic("adjustPriority on root") + } + + // If streamID does not exist, there are two cases: + // - A closed stream that has been removed (this will have ID <= maxID) + // - An idle stream that is being used for "grouping" (this will have ID > maxID) + n := ws.nodes[streamID] + if n == nil { + if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 { + return + } + ws.maxID = streamID + n = &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeIdle, + } + n.setParent(&ws.root) + ws.nodes[streamID] = n + ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n) + } + + // Section 5.3.1: A dependency on a stream that is not currently in the tree + // results in that stream being given a default priority (Section 5.3.5). + parent := ws.nodes[priority.StreamDep] + if parent == nil { + n.setParent(&ws.root) + n.weight = http2priorityDefaultWeight + return + } + + // Ignore if the client tries to make a node its own parent. + if n == parent { + return + } + + // Section 5.3.3: + // "If a stream is made dependent on one of its own dependencies, the + // formerly dependent stream is first moved to be dependent on the + // reprioritized stream's previous parent. The moved dependency retains + // its weight." + // + // That is: if parent depends on n, move parent to depend on n.parent. + for x := parent.parent; x != nil; x = x.parent { + if x == n { + parent.setParent(n.parent) + break + } + } + + // Section 5.3.3: The exclusive flag causes the stream to become the sole + // dependency of its parent stream, causing other dependencies to become + // dependent on the exclusive stream. + if priority.Exclusive { + k := parent.kids + for k != nil { + next := k.next + if k != n { + k.setParent(n) + } + k = next + } + } + + n.setParent(parent) + n.weight = priority.Weight +} + +func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) { + var n *http2priorityNode + if id := wr.StreamID(); id == 0 { + n = &ws.root + } else { + n = ws.nodes[id] + if n == nil { + // id is an idle or closed stream. wr should not be a HEADERS or + // DATA frame. However, wr can be a RST_STREAM. In this case, we + // push wr onto the root, rather than creating a new priorityNode, + // since RST_STREAM is tiny and the stream's priority is unknown + // anyway. See issue #17919. + if wr.DataSize() > 0 { + panic("add DATA on non-open stream") + } + n = &ws.root + } + } + n.q.push(wr) +} + +func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool) { + ws.root.walkReadyInOrder(false, &ws.tmp, func(n *http2priorityNode, openParent bool) bool { + limit := int32(math.MaxInt32) + if openParent { + limit = ws.writeThrottleLimit + } + wr, ok = n.q.consume(limit) + if !ok { + return false + } + n.addBytes(int64(wr.DataSize())) + // If B depends on A and B continuously has data available but A + // does not, gradually increase the throttling limit to allow B to + // steal more and more bandwidth from A. + if openParent { + ws.writeThrottleLimit += 1024 + if ws.writeThrottleLimit < 0 { + ws.writeThrottleLimit = math.MaxInt32 + } + } else if ws.enableWriteThrottle { + ws.writeThrottleLimit = 1024 + } + return true + }) + return wr, ok +} + +func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorityNode, maxSize int, n *http2priorityNode) { + if maxSize == 0 { + return + } + if len(*list) == maxSize { + // Remove the oldest node, then shift left. + ws.removeNode((*list)[0]) + x := (*list)[1:] + copy(*list, x) + *list = (*list)[:len(x)] + } + *list = append(*list, n) +} + +func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) { + for k := n.kids; k != nil; k = k.next { + k.setParent(n.parent) + } + n.setParent(nil) + delete(ws.nodes, n.id) +} + +// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2 +// priorities. Control frames like SETTINGS and PING are written before DATA +// frames, but if no control frames are queued and multiple streams have queued +// HEADERS or DATA frames, Pop selects a ready stream arbitrarily. +func http2NewRandomWriteScheduler() http2WriteScheduler { + return &http2randomWriteScheduler{sq: make(map[uint32]*http2writeQueue)} +} + +type http2randomWriteScheduler struct { + // zero are frames not associated with a specific stream. + zero http2writeQueue + + // sq contains the stream-specific queues, keyed by stream ID. + // When a stream is idle, closed, or emptied, it's deleted + // from the map. + sq map[uint32]*http2writeQueue + + // pool of empty queues for reuse. + queuePool http2writeQueuePool +} + +func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + // no-op: idle streams are not tracked +} + +func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) { + q, ok := ws.sq[streamID] + if !ok { + return + } + delete(ws.sq, streamID) + ws.queuePool.put(q) +} + +func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { + // no-op: priorities are ignored +} + +func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { + id := wr.StreamID() + if id == 0 { + ws.zero.push(wr) + return + } + q, ok := ws.sq[id] + if !ok { + q = ws.queuePool.get() + ws.sq[id] = q + } + q.push(wr) +} + +func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { + // Control frames first. + if !ws.zero.empty() { + return ws.zero.shift(), true + } + // Iterate over all non-idle streams until finding one that can be consumed. + for streamID, q := range ws.sq { + if wr, ok := q.consume(math.MaxInt32); ok { + if q.empty() { + delete(ws.sq, streamID) + ws.queuePool.put(q) + } + return wr, true + } + } + return http2FrameWriteRequest{}, false +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/header.go b/vendor/github.com/lesismal/llib/std/net/http/header.go new file mode 100644 index 0000000..a888141 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/header.go @@ -0,0 +1,263 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "github.com/lesismal/llib/std/net/http/httptrace" + "io" + "net/textproto" + "sort" + "strings" + "sync" + "time" +) + +// A Header represents the key-value pairs in an HTTP header. +// +// The keys should be in canonical form, as returned by +// CanonicalHeaderKey. +type Header map[string][]string + +// Add adds the key, value pair to the header. +// It appends to any existing values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (h Header) Add(key, value string) { + textproto.MIMEHeader(h).Add(key, value) +} + +// Set sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +// To use non-canonical keys, assign to the map directly. +func (h Header) Set(key, value string) { + textproto.MIMEHeader(h).Set(key, value) +} + +// Get gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. To use non-canonical keys, +// access the map directly. +func (h Header) Get(key string) string { + return textproto.MIMEHeader(h).Get(key) +} + +// Values returns all values associated with the given key. +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. To use non-canonical +// keys, access the map directly. +// The returned slice is not a copy. +func (h Header) Values(key string) []string { + return textproto.MIMEHeader(h).Values(key) +} + +// get is like Get, but key must already be in CanonicalHeaderKey form. +func (h Header) get(key string) string { + if v := h[key]; len(v) > 0 { + return v[0] + } + return "" +} + +// has reports whether h has the provided key defined, even if it's +// set to 0-length slice. +func (h Header) has(key string) bool { + _, ok := h[key] + return ok +} + +// Del deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (h Header) Del(key string) { + textproto.MIMEHeader(h).Del(key) +} + +// Write writes a header in wire format. +func (h Header) Write(w io.Writer) error { + return h.write(w, nil) +} + +func (h Header) write(w io.Writer, trace *httptrace.ClientTrace) error { + return h.writeSubset(w, nil, trace) +} + +// Clone returns a copy of h or nil if h is nil. +func (h Header) Clone() Header { + if h == nil { + return nil + } + + // Find total number of values. + nv := 0 + for _, vv := range h { + nv += len(vv) + } + sv := make([]string, nv) // shared backing array for headers' values + h2 := make(Header, len(h)) + for k, vv := range h { + n := copy(sv, vv) + h2[k] = sv[:n:n] + sv = sv[n:] + } + return h2 +} + +var timeFormats = []string{ + TimeFormat, + time.RFC850, + time.ANSIC, +} + +// ParseTime parses a time header (such as the Date: header), +// trying each of the three formats allowed by HTTP/1.1: +// TimeFormat, time.RFC850, and time.ANSIC. +func ParseTime(text string) (t time.Time, err error) { + for _, layout := range timeFormats { + t, err = time.Parse(layout, text) + if err == nil { + return + } + } + return +} + +var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") + +// stringWriter implements WriteString on a Writer. +type stringWriter struct { + w io.Writer +} + +func (w stringWriter) WriteString(s string) (n int, err error) { + return w.w.Write([]byte(s)) +} + +type keyValues struct { + key string + values []string +} + +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +var headerSorterPool = sync.Pool{ + New: func() interface{} { return new(headerSorter) }, +} + +// sortedKeyValues returns h's keys sorted in the returned kvs +// slice. The headerSorter used to sort is also returned, for possible +// return to headerSorterCache. +func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { + hs = headerSorterPool.Get().(*headerSorter) + if cap(hs.kvs) < len(h) { + hs.kvs = make([]keyValues, 0, len(h)) + } + kvs = hs.kvs[:0] + for k, vv := range h { + if !exclude[k] { + kvs = append(kvs, keyValues{k, vv}) + } + } + hs.kvs = kvs + sort.Sort(hs) + return kvs, hs +} + +// WriteSubset writes a header in wire format. +// If exclude is not nil, keys where exclude[key] == true are not written. +// Keys are not canonicalized before checking the exclude map. +func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { + return h.writeSubset(w, exclude, nil) +} + +func (h Header) writeSubset(w io.Writer, exclude map[string]bool, trace *httptrace.ClientTrace) error { + ws, ok := w.(io.StringWriter) + if !ok { + ws = stringWriter{w} + } + kvs, sorter := h.sortedKeyValues(exclude) + var formattedVals []string + for _, kv := range kvs { + for _, v := range kv.values { + v = headerNewlineToSpace.Replace(v) + v = textproto.TrimString(v) + for _, s := range []string{kv.key, ": ", v, "\r\n"} { + if _, err := ws.WriteString(s); err != nil { + headerSorterPool.Put(sorter) + return err + } + } + if trace != nil && trace.WroteHeaderField != nil { + formattedVals = append(formattedVals, v) + } + } + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField(kv.key, formattedVals) + formattedVals = nil + } + } + headerSorterPool.Put(sorter) + return nil +} + +// CanonicalHeaderKey returns the canonical format of the +// header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +// If s contains a space or invalid header field bytes, it is +// returned without modifications. +func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } + +// hasToken reports whether token appears with v, ASCII +// case-insensitive, with space or comma boundaries. +// token must be all lowercase. +// v may contain mixed cased. +func hasToken(v, token string) bool { + if len(token) > len(v) || token == "" { + return false + } + if v == token { + return true + } + for sp := 0; sp <= len(v)-len(token); sp++ { + // Check that first character is good. + // The token is ASCII, so checking only a single byte + // is sufficient. We skip this potential starting + // position if both the first byte and its potential + // ASCII uppercase equivalent (b|0x20) don't match. + // False positives ('^' => '~') are caught by EqualFold. + if b := v[sp]; b != token[0] && b|0x20 != token[0] { + continue + } + // Check that start pos is on a valid token boundary. + if sp > 0 && !isTokenBoundary(v[sp-1]) { + continue + } + // Check that end pos is on a valid token boundary. + if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) { + continue + } + if strings.EqualFold(v[sp:sp+len(token)], token) { + return true + } + } + return false +} + +func isTokenBoundary(b byte) bool { + return b == ' ' || b == ',' || b == '\t' +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/header_test.go b/vendor/github.com/lesismal/llib/std/net/http/header_test.go new file mode 100644 index 0000000..4789362 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/header_test.go @@ -0,0 +1,253 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "internal/race" + "reflect" + "runtime" + "testing" + "time" +) + +var headerWriteTests = []struct { + h Header + exclude map[string]bool + expected string +}{ + {Header{}, nil, ""}, + { + Header{ + "Content-Type": {"text/html; charset=UTF-8"}, + "Content-Length": {"0"}, + }, + nil, + "Content-Length: 0\r\nContent-Type: text/html; charset=UTF-8\r\n", + }, + { + Header{ + "Content-Length": {"0", "1", "2"}, + }, + nil, + "Content-Length: 0\r\nContent-Length: 1\r\nContent-Length: 2\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true}, + "Content-Encoding: gzip\r\nExpires: -1\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0", "1", "2"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true}, + "Content-Encoding: gzip\r\nExpires: -1\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true, "Expires": true, "Content-Encoding": true}, + "", + }, + { + Header{ + "Nil": nil, + "Empty": {}, + "Blank": {""}, + "Double-Blank": {"", ""}, + }, + nil, + "Blank: \r\nDouble-Blank: \r\nDouble-Blank: \r\n", + }, + // Tests header sorting when over the insertion sort threshold side: + { + Header{ + "k1": {"1a", "1b"}, + "k2": {"2a", "2b"}, + "k3": {"3a", "3b"}, + "k4": {"4a", "4b"}, + "k5": {"5a", "5b"}, + "k6": {"6a", "6b"}, + "k7": {"7a", "7b"}, + "k8": {"8a", "8b"}, + "k9": {"9a", "9b"}, + }, + map[string]bool{"k5": true}, + "k1: 1a\r\nk1: 1b\r\nk2: 2a\r\nk2: 2b\r\nk3: 3a\r\nk3: 3b\r\n" + + "k4: 4a\r\nk4: 4b\r\nk6: 6a\r\nk6: 6b\r\n" + + "k7: 7a\r\nk7: 7b\r\nk8: 8a\r\nk8: 8b\r\nk9: 9a\r\nk9: 9b\r\n", + }, +} + +func TestHeaderWrite(t *testing.T) { + var buf bytes.Buffer + for i, test := range headerWriteTests { + test.h.WriteSubset(&buf, test.exclude) + if buf.String() != test.expected { + t.Errorf("#%d:\n got: %q\nwant: %q", i, buf.String(), test.expected) + } + buf.Reset() + } +} + +var parseTimeTests = []struct { + h Header + err bool +}{ + {Header{"Date": {""}}, true}, + {Header{"Date": {"invalid"}}, true}, + {Header{"Date": {"1994-11-06T08:49:37Z00:00"}}, true}, + {Header{"Date": {"Sun, 06 Nov 1994 08:49:37 GMT"}}, false}, + {Header{"Date": {"Sunday, 06-Nov-94 08:49:37 GMT"}}, false}, + {Header{"Date": {"Sun Nov 6 08:49:37 1994"}}, false}, +} + +func TestParseTime(t *testing.T) { + expect := time.Date(1994, 11, 6, 8, 49, 37, 0, time.UTC) + for i, test := range parseTimeTests { + d, err := ParseTime(test.h.Get("Date")) + if err != nil { + if !test.err { + t.Errorf("#%d:\n got err: %v", i, err) + } + continue + } + if test.err { + t.Errorf("#%d:\n should err", i) + continue + } + if !expect.Equal(d) { + t.Errorf("#%d:\n got: %v\nwant: %v", i, d, expect) + } + } +} + +type hasTokenTest struct { + header string + token string + want bool +} + +var hasTokenTests = []hasTokenTest{ + {"", "", false}, + {"", "foo", false}, + {"foo", "foo", true}, + {"foo ", "foo", true}, + {" foo", "foo", true}, + {" foo ", "foo", true}, + {"foo,bar", "foo", true}, + {"bar,foo", "foo", true}, + {"bar, foo", "foo", true}, + {"bar,foo, baz", "foo", true}, + {"bar, foo,baz", "foo", true}, + {"bar,foo, baz", "foo", true}, + {"bar, foo, baz", "foo", true}, + {"FOO", "foo", true}, + {"FOO ", "foo", true}, + {" FOO", "foo", true}, + {" FOO ", "foo", true}, + {"FOO,BAR", "foo", true}, + {"BAR,FOO", "foo", true}, + {"BAR, FOO", "foo", true}, + {"BAR,FOO, baz", "foo", true}, + {"BAR, FOO,BAZ", "foo", true}, + {"BAR,FOO, BAZ", "foo", true}, + {"BAR, FOO, BAZ", "foo", true}, + {"foobar", "foo", false}, + {"barfoo ", "foo", false}, +} + +func TestHasToken(t *testing.T) { + for _, tt := range hasTokenTests { + if hasToken(tt.header, tt.token) != tt.want { + t.Errorf("hasToken(%q, %q) = %v; want %v", tt.header, tt.token, !tt.want, tt.want) + } + } +} + +func TestNilHeaderClone(t *testing.T) { + t1 := Header(nil) + t2 := t1.Clone() + if t2 != nil { + t.Errorf("cloned header does not match original: got: %+v; want: %+v", t2, nil) + } +} + +var testHeader = Header{ + "Content-Length": {"123"}, + "Content-Type": {"text/plain"}, + "Date": {"some date at some time Z"}, + "Server": {DefaultUserAgent}, +} + +var buf bytes.Buffer + +func BenchmarkHeaderWriteSubset(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf.Reset() + testHeader.WriteSubset(&buf, nil) + } +} + +func TestHeaderWriteSubsetAllocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping alloc test in short mode") + } + if race.Enabled { + t.Skip("skipping test under race detector") + } + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } + n := testing.AllocsPerRun(100, func() { + buf.Reset() + testHeader.WriteSubset(&buf, nil) + }) + if n > 0 { + t.Errorf("allocs = %g; want 0", n) + } +} + +// Issue 34878: test that every call to +// cloneOrMakeHeader never returns a nil Header. +func TestCloneOrMakeHeader(t *testing.T) { + tests := []struct { + name string + in, want Header + }{ + {"nil", nil, Header{}}, + {"empty", Header{}, Header{}}, + { + name: "non-empty", + in: Header{"foo": {"bar"}}, + want: Header{"foo": {"bar"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := cloneOrMakeHeader(tt.in) + if got == nil { + t.Fatal("unexpected nil Header") + } + if !reflect.DeepEqual(got, tt.want) { + t.Fatalf("Got: %#v\nWant: %#v", got, tt.want) + } + got.Add("A", "B") + got.Get("A") + }) + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/http.go b/vendor/github.com/lesismal/llib/std/net/http/http.go new file mode 100644 index 0000000..4c5054b --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/http.go @@ -0,0 +1,168 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:generate bundle -o=h2_bundle.go -prefix=http2 -tags=!nethttpomithttp2 golang.org/x/net/http2 + +package http + +import ( + "io" + "strconv" + "strings" + "time" + "unicode/utf8" + + "golang.org/x/net/http/httpguts" +) + +// incomparable is a zero-width, non-comparable type. Adding it to a struct +// makes that struct also non-comparable, and generally doesn't add +// any size (as long as it's first). +type incomparable [0]func() + +// maxInt64 is the effective "infinite" value for the Server and +// Transport's byte-limiting readers. +const maxInt64 = 1<<63 - 1 + +// aLongTimeAgo is a non-zero time, far in the past, used for +// immediate cancellation of network operations. +var aLongTimeAgo = time.Unix(1, 0) + +// omitBundledHTTP2 is set by omithttp2.go when the nethttpomithttp2 +// build tag is set. That means h2_bundle.go isn't compiled in and we +// shouldn't try to use it. +var omitBundledHTTP2 bool + +// TODO(bradfitz): move common stuff here. The other files have accumulated +// generic http stuff in random places. + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { return "net/http context value " + k.name } + +// Given a string of the form "host", "host:port", or "[ipv6::address]:port", +// return true if the string includes a port. +func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } + +// removeEmptyPort strips the empty port in ":port" to "" +// as mandated by RFC 3986 Section 6.2.3. +func removeEmptyPort(host string) string { + if hasPort(host) { + return strings.TrimSuffix(host, ":") + } + return host +} + +func isNotToken(r rune) bool { + return !httpguts.IsTokenRune(r) +} + +func isASCII(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + return false + } + } + return true +} + +// stringContainsCTLByte reports whether s contains any ASCII control character. +func stringContainsCTLByte(s string) bool { + for i := 0; i < len(s); i++ { + b := s[i] + if b < ' ' || b == 0x7f { + return true + } + } + return false +} + +func hexEscapeNonASCII(s string) string { + newLen := 0 + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + newLen += 3 + } else { + newLen++ + } + } + if newLen == len(s) { + return s + } + b := make([]byte, 0, newLen) + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + b = append(b, '%') + b = strconv.AppendInt(b, int64(s[i]), 16) + } else { + b = append(b, s[i]) + } + } + return string(b) +} + +// NoBody is an io.ReadCloser with no bytes. Read always returns EOF +// and Close always returns nil. It can be used in an outgoing client +// request to explicitly signal that a request has zero bytes. +// An alternative, however, is to simply set Request.Body to nil. +var NoBody = noBody{} + +type noBody struct{} + +func (noBody) Read([]byte) (int, error) { return 0, io.EOF } +func (noBody) Close() error { return nil } +func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil } + +var ( + // verify that an io.Copy from NoBody won't require a buffer: + _ io.WriterTo = NoBody + _ io.ReadCloser = NoBody +) + +// PushOptions describes options for Pusher.Push. +type PushOptions struct { + // Method specifies the HTTP method for the promised request. + // If set, it must be "GET" or "HEAD". Empty means "GET". + Method string + + // Header specifies additional promised request headers. This cannot + // include HTTP/2 pseudo header fields like ":path" and ":scheme", + // which will be added automatically. + Header Header +} + +// Pusher is the interface implemented by ResponseWriters that support +// HTTP/2 server push. For more background, see +// https://tools.ietf.org/html/rfc7540#section-8.2. +type Pusher interface { + // Push initiates an HTTP/2 server push. This constructs a synthetic + // request using the given target and options, serializes that request + // into a PUSH_PROMISE frame, then dispatches that request using the + // server's request handler. If opts is nil, default options are used. + // + // The target must either be an absolute path (like "/path") or an absolute + // URL that contains a valid host and the same scheme as the parent request. + // If the target is a path, it will inherit the scheme and host of the + // parent request. + // + // The HTTP/2 spec disallows recursive pushes and cross-authority pushes. + // Push may or may not detect these invalid pushes; however, invalid + // pushes will be detected and canceled by conforming clients. + // + // Handlers that wish to push URL X should call Push before sending any + // data that may trigger a request for URL X. This avoids a race where the + // client issues requests for X before receiving the PUSH_PROMISE for X. + // + // Push will run in a separate goroutine making the order of arrival + // non-deterministic. Any required synchronization needs to be implemented + // by the caller. + // + // Push returns ErrNotSupported if the client has disabled push or if push + // is not supported on the underlying connection. + Push(target string, opts *PushOptions) error +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/http_test.go b/vendor/github.com/lesismal/llib/std/net/http/http_test.go new file mode 100644 index 0000000..3f1d7ce --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/http_test.go @@ -0,0 +1,158 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests of internal functions and things with no better homes. + +package http + +import ( + "bytes" + "internal/testenv" + "net/url" + "os/exec" + "reflect" + "testing" +) + +func TestForeachHeaderElement(t *testing.T) { + tests := []struct { + in string + want []string + }{ + {"Foo", []string{"Foo"}}, + {" Foo", []string{"Foo"}}, + {"Foo ", []string{"Foo"}}, + {" Foo ", []string{"Foo"}}, + + {"foo", []string{"foo"}}, + {"anY-cAsE", []string{"anY-cAsE"}}, + + {"", nil}, + {",,,, , ,, ,,, ,", nil}, + + {" Foo,Bar, Baz,lower,,Quux ", []string{"Foo", "Bar", "Baz", "lower", "Quux"}}, + } + for _, tt := range tests { + var got []string + foreachHeaderElement(tt.in, func(v string) { + got = append(got, v) + }) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("foreachHeaderElement(%q) = %q; want %q", tt.in, got, tt.want) + } + } +} + +func TestCleanHost(t *testing.T) { + tests := []struct { + in, want string + }{ + {"www.google.com", "www.google.com"}, + {"www.google.com foo", "www.google.com"}, + {"www.google.com/foo", "www.google.com"}, + {" first character is a space", ""}, + {"[1::6]:8080", "[1::6]:8080"}, + + // Punycode: + {"гофер.рф/foo", "xn--c1ae0ajs.xn--p1ai"}, + {"bücher.de", "xn--bcher-kva.de"}, + {"bücher.de:8080", "xn--bcher-kva.de:8080"}, + // Verify we convert to lowercase before punycode: + {"BÜCHER.de", "xn--bcher-kva.de"}, + {"BÜCHER.de:8080", "xn--bcher-kva.de:8080"}, + // Verify we normalize to NFC before punycode: + {"gophér.nfc", "xn--gophr-esa.nfc"}, // NFC input; no work needed + {"goph\u0065\u0301r.nfd", "xn--gophr-esa.nfd"}, // NFD input + } + for _, tt := range tests { + got := cleanHost(tt.in) + if tt.want != got { + t.Errorf("cleanHost(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + +// Test that cmd/go doesn't link in the HTTP server. +// +// This catches accidental dependencies between the HTTP transport and +// server code. +func TestCmdGoNoHTTPServer(t *testing.T) { + t.Parallel() + goBin := testenv.GoToolPath(t) + out, err := exec.Command(goBin, "tool", "nm", goBin).CombinedOutput() + if err != nil { + t.Fatalf("go tool nm: %v: %s", err, out) + } + wantSym := map[string]bool{ + // Verify these exist: (sanity checking this test) + "net/http.(*Client).do": true, + "net/http.(*Transport).RoundTrip": true, + + // Verify these don't exist: + "net/http.http2Server": false, + "net/http.(*Server).Serve": false, + "net/http.(*ServeMux).ServeHTTP": false, + "net/http.DefaultServeMux": false, + } + for sym, want := range wantSym { + got := bytes.Contains(out, []byte(sym)) + if !want && got { + t.Errorf("cmd/go unexpectedly links in HTTP server code; found symbol %q in cmd/go", sym) + } + if want && !got { + t.Errorf("expected to find symbol %q in cmd/go; not found", sym) + } + } +} + +// Tests that the nethttpomithttp2 build tag doesn't rot too much, +// even if there's not a regular builder on it. +func TestOmitHTTP2(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + t.Parallel() + goTool := testenv.GoToolPath(t) + out, err := exec.Command(goTool, "test", "-short", "-tags=nethttpomithttp2", "net/http").CombinedOutput() + if err != nil { + t.Fatalf("go test -short failed: %v, %s", err, out) + } +} + +// Tests that the nethttpomithttp2 build tag at least type checks +// in short mode. +// The TestOmitHTTP2 test above actually runs tests (in long mode). +func TestOmitHTTP2Vet(t *testing.T) { + t.Parallel() + goTool := testenv.GoToolPath(t) + out, err := exec.Command(goTool, "vet", "-tags=nethttpomithttp2", "net/http").CombinedOutput() + if err != nil { + t.Fatalf("go vet failed: %v, %s", err, out) + } +} + +var valuesCount int + +func BenchmarkCopyValues(b *testing.B) { + b.ReportAllocs() + src := url.Values{ + "a": {"1", "2", "3", "4", "5"}, + "b": {"2", "2", "3", "4", "5"}, + "c": {"3", "2", "3", "4", "5"}, + "d": {"4", "2", "3", "4", "5"}, + "e": {"1", "1", "2", "3", "4", "5", "6", "7", "abcdef", "l", "a", "b", "c", "d", "z"}, + "j": {"1", "2"}, + "m": nil, + } + for i := 0; i < b.N; i++ { + dst := url.Values{"a": {"b"}, "b": {"2"}, "c": {"3"}, "d": {"4"}, "j": nil, "m": {"x"}} + copyValues(dst, src) + if valuesCount = len(dst["a"]); valuesCount != 6 { + b.Fatalf(`%d items in dst["a"] but expected 6`, valuesCount) + } + } + if valuesCount == 0 { + b.Fatal("Benchmark wasn't run") + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httptest/example_test.go b/vendor/github.com/lesismal/llib/std/net/http/httptest/example_test.go new file mode 100644 index 0000000..aa720db --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httptest/example_test.go @@ -0,0 +1,99 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest_test + +import ( + "fmt" + "github.com/lesismal/llib/std/net/http/httptest" + "io" + "log" + "net/http" +) + +func ExampleResponseRecorder() { + handler := func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "Hello World!") + } + + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + w := httptest.NewRecorder() + handler(w, req) + + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + + fmt.Println(resp.StatusCode) + fmt.Println(resp.Header.Get("Content-Type")) + fmt.Println(string(body)) + + // Output: + // 200 + // text/html; charset=utf-8 + // Hello World! +} + +func ExampleServer() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + res, err := http.Get(ts.URL) + if err != nil { + log.Fatal(err) + } + greeting, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", greeting) + // Output: Hello, client +} + +func ExampleServer_hTTP2() { + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %s", r.Proto) + })) + ts.EnableHTTP2 = true + ts.StartTLS() + defer ts.Close() + + res, err := ts.Client().Get(ts.URL) + if err != nil { + log.Fatal(err) + } + greeting, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + fmt.Printf("%s", greeting) + + // Output: Hello, HTTP/2.0 +} + +func ExampleNewTLSServer() { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + client := ts.Client() + res, err := client.Get(ts.URL) + if err != nil { + log.Fatal(err) + } + + greeting, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", greeting) + // Output: Hello, client +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httptest/httptest.go b/vendor/github.com/lesismal/llib/std/net/http/httptest/httptest.go new file mode 100644 index 0000000..9bedefd --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httptest/httptest.go @@ -0,0 +1,90 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httptest provides utilities for HTTP testing. +package httptest + +import ( + "bufio" + "bytes" + "crypto/tls" + "io" + "net/http" + "strings" +) + +// NewRequest returns a new incoming server Request, suitable +// for passing to an http.Handler for testing. +// +// The target is the RFC 7230 "request-target": it may be either a +// path or an absolute URL. If target is an absolute URL, the host name +// from the URL is used. Otherwise, "example.com" is used. +// +// The TLS field is set to a non-nil dummy value if target has scheme +// "https". +// +// The Request.Proto is always HTTP/1.1. +// +// An empty method means "GET". +// +// The provided body may be nil. If the body is of type *bytes.Reader, +// *strings.Reader, or *bytes.Buffer, the Request.ContentLength is +// set. +// +// NewRequest panics on error for ease of use in testing, where a +// panic is acceptable. +// +// To generate a client HTTP request instead of a server request, see +// the NewRequest function in the net/http package. +func NewRequest(method, target string, body io.Reader) *http.Request { + if method == "" { + method = "GET" + } + req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(method + " " + target + " HTTP/1.0\r\n\r\n"))) + if err != nil { + panic("invalid NewRequest arguments; " + err.Error()) + } + + // HTTP/1.0 was used above to avoid needing a Host field. Change it to 1.1 here. + req.Proto = "HTTP/1.1" + req.ProtoMinor = 1 + req.Close = false + + if body != nil { + switch v := body.(type) { + case *bytes.Buffer: + req.ContentLength = int64(v.Len()) + case *bytes.Reader: + req.ContentLength = int64(v.Len()) + case *strings.Reader: + req.ContentLength = int64(v.Len()) + default: + req.ContentLength = -1 + } + if rc, ok := body.(io.ReadCloser); ok { + req.Body = rc + } else { + req.Body = io.NopCloser(body) + } + } + + // 192.0.2.0/24 is "TEST-NET" in RFC 5737 for use solely in + // documentation and example source code and should not be + // used publicly. + req.RemoteAddr = "192.0.2.1:1234" + + if req.Host == "" { + req.Host = "example.com" + } + + if strings.HasPrefix(target, "https://") { + req.TLS = &tls.ConnectionState{ + Version: tls.VersionTLS12, + HandshakeComplete: true, + ServerName: req.Host, + } + } + + return req +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httptest/httptest_test.go b/vendor/github.com/lesismal/llib/std/net/http/httptest/httptest_test.go new file mode 100644 index 0000000..071add6 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httptest/httptest_test.go @@ -0,0 +1,179 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "crypto/tls" + "io" + "net/http" + "net/url" + "reflect" + "strings" + "testing" +) + +func TestNewRequest(t *testing.T) { + for _, tt := range [...]struct { + name string + + method, uri string + body io.Reader + + want *http.Request + wantBody string + }{ + { + name: "Empty method means GET", + method: "", + uri: "/", + body: nil, + want: &http.Request{ + Method: "GET", + Host: "example.com", + URL: &url.URL{Path: "/"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "/", + }, + wantBody: "", + }, + + { + name: "GET with full URL", + method: "GET", + uri: "http://foo.com/path/%2f/bar/", + body: nil, + want: &http.Request{ + Method: "GET", + Host: "foo.com", + URL: &url.URL{ + Scheme: "http", + Path: "/path///bar/", + RawPath: "/path/%2f/bar/", + Host: "foo.com", + }, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "http://foo.com/path/%2f/bar/", + }, + wantBody: "", + }, + + { + name: "GET with full https URL", + method: "GET", + uri: "https://foo.com/path/", + body: nil, + want: &http.Request{ + Method: "GET", + Host: "foo.com", + URL: &url.URL{ + Scheme: "https", + Path: "/path/", + Host: "foo.com", + }, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "https://foo.com/path/", + TLS: &tls.ConnectionState{ + Version: tls.VersionTLS12, + HandshakeComplete: true, + ServerName: "foo.com", + }, + }, + wantBody: "", + }, + + { + name: "Post with known length", + method: "POST", + uri: "/", + body: strings.NewReader("foo"), + want: &http.Request{ + Method: "POST", + Host: "example.com", + URL: &url.URL{Path: "/"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ContentLength: 3, + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "/", + }, + wantBody: "foo", + }, + + { + name: "Post with unknown length", + method: "POST", + uri: "/", + body: struct{ io.Reader }{strings.NewReader("foo")}, + want: &http.Request{ + Method: "POST", + Host: "example.com", + URL: &url.URL{Path: "/"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ContentLength: -1, + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "/", + }, + wantBody: "foo", + }, + + { + name: "OPTIONS *", + method: "OPTIONS", + uri: "*", + want: &http.Request{ + Method: "OPTIONS", + Host: "example.com", + URL: &url.URL{Path: "*"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "*", + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + got := NewRequest(tt.method, tt.uri, tt.body) + slurp, err := io.ReadAll(got.Body) + if err != nil { + t.Errorf("ReadAll: %v", err) + } + if string(slurp) != tt.wantBody { + t.Errorf("Body = %q; want %q", slurp, tt.wantBody) + } + got.Body = nil // before DeepEqual + if !reflect.DeepEqual(got.URL, tt.want.URL) { + t.Errorf("Request.URL mismatch:\n got: %#v\nwant: %#v", got.URL, tt.want.URL) + } + if !reflect.DeepEqual(got.Header, tt.want.Header) { + t.Errorf("Request.Header mismatch:\n got: %#v\nwant: %#v", got.Header, tt.want.Header) + } + if !reflect.DeepEqual(got.TLS, tt.want.TLS) { + t.Errorf("Request.TLS mismatch:\n got: %#v\nwant: %#v", got.TLS, tt.want.TLS) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Request mismatch:\n got: %#v\nwant: %#v", got, tt.want) + } + }) + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httptest/recorder.go b/vendor/github.com/lesismal/llib/std/net/http/httptest/recorder.go new file mode 100644 index 0000000..2428482 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httptest/recorder.go @@ -0,0 +1,234 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/textproto" + "strconv" + "strings" + + "golang.org/x/net/http/httpguts" +) + +// ResponseRecorder is an implementation of http.ResponseWriter that +// records its mutations for later inspection in tests. +type ResponseRecorder struct { + // Code is the HTTP response code set by WriteHeader. + // + // Note that if a Handler never calls WriteHeader or Write, + // this might end up being 0, rather than the implicit + // http.StatusOK. To get the implicit value, use the Result + // method. + Code int + + // HeaderMap contains the headers explicitly set by the Handler. + // It is an internal detail. + // + // Deprecated: HeaderMap exists for historical compatibility + // and should not be used. To access the headers returned by a handler, + // use the Response.Header map as returned by the Result method. + HeaderMap http.Header + + // Body is the buffer to which the Handler's Write calls are sent. + // If nil, the Writes are silently discarded. + Body *bytes.Buffer + + // Flushed is whether the Handler called Flush. + Flushed bool + + result *http.Response // cache of Result's return value + snapHeader http.Header // snapshot of HeaderMap at first Write + wroteHeader bool +} + +// NewRecorder returns an initialized ResponseRecorder. +func NewRecorder() *ResponseRecorder { + return &ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + Code: 200, + } +} + +// DefaultRemoteAddr is the default remote address to return in RemoteAddr if +// an explicit DefaultRemoteAddr isn't set on ResponseRecorder. +const DefaultRemoteAddr = "1.2.3.4" + +// Header implements http.ResponseWriter. It returns the response +// headers to mutate within a handler. To test the headers that were +// written after a handler completes, use the Result method and see +// the returned Response value's Header. +func (rw *ResponseRecorder) Header() http.Header { + m := rw.HeaderMap + if m == nil { + m = make(http.Header) + rw.HeaderMap = m + } + return m +} + +// writeHeader writes a header if it was not written yet and +// detects Content-Type if needed. +// +// bytes or str are the beginning of the response body. +// We pass both to avoid unnecessarily generate garbage +// in rw.WriteString which was created for performance reasons. +// Non-nil bytes win. +func (rw *ResponseRecorder) writeHeader(b []byte, str string) { + if rw.wroteHeader { + return + } + if len(str) > 512 { + str = str[:512] + } + + m := rw.Header() + + _, hasType := m["Content-Type"] + hasTE := m.Get("Transfer-Encoding") != "" + if !hasType && !hasTE { + if b == nil { + b = []byte(str) + } + m.Set("Content-Type", http.DetectContentType(b)) + } + + rw.WriteHeader(200) +} + +// Write implements http.ResponseWriter. The data in buf is written to +// rw.Body, if not nil. +func (rw *ResponseRecorder) Write(buf []byte) (int, error) { + rw.writeHeader(buf, "") + if rw.Body != nil { + rw.Body.Write(buf) + } + return len(buf), nil +} + +// WriteString implements io.StringWriter. The data in str is written +// to rw.Body, if not nil. +func (rw *ResponseRecorder) WriteString(str string) (int, error) { + rw.writeHeader(nil, str) + if rw.Body != nil { + rw.Body.WriteString(str) + } + return len(str), nil +} + +// WriteHeader implements http.ResponseWriter. +func (rw *ResponseRecorder) WriteHeader(code int) { + if rw.wroteHeader { + return + } + rw.Code = code + rw.wroteHeader = true + if rw.HeaderMap == nil { + rw.HeaderMap = make(http.Header) + } + rw.snapHeader = rw.HeaderMap.Clone() +} + +// Flush implements http.Flusher. To test whether Flush was +// called, see rw.Flushed. +func (rw *ResponseRecorder) Flush() { + if !rw.wroteHeader { + rw.WriteHeader(200) + } + rw.Flushed = true +} + +// Result returns the response generated by the handler. +// +// The returned Response will have at least its StatusCode, +// Header, Body, and optionally Trailer populated. +// More fields may be populated in the future, so callers should +// not DeepEqual the result in tests. +// +// The Response.Header is a snapshot of the headers at the time of the +// first write call, or at the time of this call, if the handler never +// did a write. +// +// The Response.Body is guaranteed to be non-nil and Body.Read call is +// guaranteed to not return any error other than io.EOF. +// +// Result must only be called after the handler has finished running. +func (rw *ResponseRecorder) Result() *http.Response { + if rw.result != nil { + return rw.result + } + if rw.snapHeader == nil { + rw.snapHeader = rw.HeaderMap.Clone() + } + res := &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: rw.Code, + Header: rw.snapHeader, + } + rw.result = res + if res.StatusCode == 0 { + res.StatusCode = 200 + } + res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) + if rw.Body != nil { + res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes())) + } else { + res.Body = http.NoBody + } + res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) + + if trailers, ok := rw.snapHeader["Trailer"]; ok { + res.Trailer = make(http.Header, len(trailers)) + for _, k := range trailers { + k = http.CanonicalHeaderKey(k) + if !httpguts.ValidTrailerHeader(k) { + // Ignore since forbidden by RFC 7230, section 4.1.2. + continue + } + vv, ok := rw.HeaderMap[k] + if !ok { + continue + } + vv2 := make([]string, len(vv)) + copy(vv2, vv) + res.Trailer[k] = vv2 + } + } + for k, vv := range rw.HeaderMap { + if !strings.HasPrefix(k, http.TrailerPrefix) { + continue + } + if res.Trailer == nil { + res.Trailer = make(http.Header) + } + for _, v := range vv { + res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v) + } + } + return res +} + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +// +// This a modified version of same function found in net/http/transfer.go. This +// one just ignores an invalid header. +func parseContentLength(cl string) int64 { + cl = textproto.TrimString(cl) + if cl == "" { + return -1 + } + n, err := strconv.ParseUint(cl, 10, 63) + if err != nil { + return -1 + } + return int64(n) +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httptest/recorder_test.go b/vendor/github.com/lesismal/llib/std/net/http/httptest/recorder_test.go new file mode 100644 index 0000000..a865e87 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httptest/recorder_test.go @@ -0,0 +1,347 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "fmt" + "io" + "net/http" + "testing" +) + +func TestRecorder(t *testing.T) { + type checkFunc func(*ResponseRecorder) error + check := func(fns ...checkFunc) []checkFunc { return fns } + + hasStatus := func(wantCode int) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Code != wantCode { + return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode) + } + return nil + } + } + hasResultStatus := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Result().Status != want { + return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want) + } + return nil + } + } + hasResultStatusCode := func(wantCode int) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Result().StatusCode != wantCode { + return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode) + } + return nil + } + } + hasResultContents := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + contentBytes, err := io.ReadAll(rec.Result().Body) + if err != nil { + return err + } + contents := string(contentBytes) + if contents != want { + return fmt.Errorf("Result().Body = %s; want %s", contents, want) + } + return nil + } + } + hasContents := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Body.String() != want { + return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want) + } + return nil + } + } + hasFlush := func(want bool) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Flushed != want { + return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want) + } + return nil + } + } + hasOldHeader := func(key, want string) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.HeaderMap.Get(key); got != want { + return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want) + } + return nil + } + } + hasHeader := func(key, want string) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Result().Header.Get(key); got != want { + return fmt.Errorf("final header %s = %q; want %q", key, got, want) + } + return nil + } + } + hasNotHeaders := func(keys ...string) checkFunc { + return func(rec *ResponseRecorder) error { + for _, k := range keys { + v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)] + if ok { + return fmt.Errorf("unexpected header %s with value %q", k, v) + } + } + return nil + } + } + hasTrailer := func(key, want string) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Result().Trailer.Get(key); got != want { + return fmt.Errorf("trailer %s = %q; want %q", key, got, want) + } + return nil + } + } + hasNotTrailers := func(keys ...string) checkFunc { + return func(rec *ResponseRecorder) error { + trailers := rec.Result().Trailer + for _, k := range keys { + _, ok := trailers[http.CanonicalHeaderKey(k)] + if ok { + return fmt.Errorf("unexpected trailer %s", k) + } + } + return nil + } + } + hasContentLength := func(length int64) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Result().ContentLength; got != length { + return fmt.Errorf("ContentLength = %d; want %d", got, length) + } + return nil + } + } + + for _, tt := range [...]struct { + name string + h func(w http.ResponseWriter, r *http.Request) + checks []checkFunc + }{ + { + "200 default", + func(w http.ResponseWriter, r *http.Request) {}, + check(hasStatus(200), hasContents("")), + }, + { + "first code only", + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(201) + w.WriteHeader(202) + w.Write([]byte("hi")) + }, + check(hasStatus(201), hasContents("hi")), + }, + { + "write sends 200", + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi first")) + w.WriteHeader(201) + w.WriteHeader(202) + }, + check(hasStatus(200), hasContents("hi first"), hasFlush(false)), + }, + { + "write string", + func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "hi first") + }, + check( + hasStatus(200), + hasContents("hi first"), + hasFlush(false), + hasHeader("Content-Type", "text/plain; charset=utf-8"), + ), + }, + { + "flush", + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() // also sends a 200 + w.WriteHeader(201) + }, + check(hasStatus(200), hasFlush(true), hasContentLength(-1)), + }, + { + "Content-Type detection", + func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "") + }, + check(hasHeader("Content-Type", "text/html; charset=utf-8")), + }, + { + "no Content-Type detection with Transfer-Encoding", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Transfer-Encoding", "some encoding") + io.WriteString(w, "") + }, + check(hasHeader("Content-Type", "")), // no header + }, + { + "no Content-Type detection if set explicitly", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "some/type") + io.WriteString(w, "") + }, + check(hasHeader("Content-Type", "some/type")), + }, + { + "Content-Type detection doesn't crash if HeaderMap is nil", + func(w http.ResponseWriter, r *http.Request) { + // Act as if the user wrote new(httptest.ResponseRecorder) + // rather than using NewRecorder (which initializes + // HeaderMap) + w.(*ResponseRecorder).HeaderMap = nil + io.WriteString(w, "") + }, + check(hasHeader("Content-Type", "text/html; charset=utf-8")), + }, + { + "Header is not changed after write", + func(w http.ResponseWriter, r *http.Request) { + hdr := w.Header() + hdr.Set("Key", "correct") + w.WriteHeader(200) + hdr.Set("Key", "incorrect") + }, + check(hasHeader("Key", "correct")), + }, + { + "Trailer headers are correctly recorded", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Non-Trailer", "correct") + w.Header().Set("Trailer", "Trailer-A") + w.Header().Add("Trailer", "Trailer-B") + w.Header().Add("Trailer", "Trailer-C") + io.WriteString(w, "") + w.Header().Set("Non-Trailer", "incorrect") + w.Header().Set("Trailer-A", "valuea") + w.Header().Set("Trailer-C", "valuec") + w.Header().Set("Trailer-NotDeclared", "should be omitted") + w.Header().Set("Trailer:Trailer-D", "with prefix") + }, + check( + hasStatus(200), + hasHeader("Content-Type", "text/html; charset=utf-8"), + hasHeader("Non-Trailer", "correct"), + hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"), + hasTrailer("Trailer-A", "valuea"), + hasTrailer("Trailer-C", "valuec"), + hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"), + hasTrailer("Trailer-D", "with prefix"), + ), + }, + { + "Header set without any write", // Issue 15560 + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Foo", "1") + + // Simulate somebody using + // new(ResponseRecorder) instead of + // using the constructor which sets + // this to 200 + w.(*ResponseRecorder).Code = 0 + }, + check( + hasOldHeader("X-Foo", "1"), + hasStatus(0), + hasHeader("X-Foo", "1"), + hasResultStatus("200 OK"), + hasResultStatusCode(200), + ), + }, + { + "HeaderMap vs FinalHeaders", // more for Issue 15560 + func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Set("X-Foo", "1") + w.Write([]byte("hi")) + h.Set("X-Foo", "2") + h.Set("X-Bar", "2") + }, + check( + hasOldHeader("X-Foo", "2"), + hasOldHeader("X-Bar", "2"), + hasHeader("X-Foo", "1"), + hasNotHeaders("X-Bar"), + ), + }, + { + "setting Content-Length header", + func(w http.ResponseWriter, r *http.Request) { + body := "Some body" + contentLength := fmt.Sprintf("%d", len(body)) + w.Header().Set("Content-Length", contentLength) + io.WriteString(w, body) + }, + check(hasStatus(200), hasContents("Some body"), hasContentLength(9)), + }, + { + "nil ResponseRecorder.Body", // Issue 26642 + func(w http.ResponseWriter, r *http.Request) { + w.(*ResponseRecorder).Body = nil + io.WriteString(w, "hi") + }, + check(hasResultContents("")), // check we don't crash reading the body + + }, + } { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest("GET", "http://foo.com/", nil) + h := http.HandlerFunc(tt.h) + rec := NewRecorder() + h.ServeHTTP(rec, r) + for _, check := range tt.checks { + if err := check(rec); err != nil { + t.Error(err) + } + } + }) + } +} + +// issue 39017 - disallow Content-Length values such as "+3" +func TestParseContentLength(t *testing.T) { + tests := []struct { + cl string + want int64 + }{ + { + cl: "3", + want: 3, + }, + { + cl: "+3", + want: -1, + }, + { + cl: "-3", + want: -1, + }, + { + // max int64, for safe conversion before returning + cl: "9223372036854775807", + want: 9223372036854775807, + }, + { + cl: "9223372036854775808", + want: -1, + }, + } + + for _, tt := range tests { + if got := parseContentLength(tt.cl); got != tt.want { + t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want) + } + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httptest/server.go b/vendor/github.com/lesismal/llib/std/net/http/httptest/server.go new file mode 100644 index 0000000..52ba927 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httptest/server.go @@ -0,0 +1,383 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Implementation of Server + +package httptest + +import ( + "crypto/tls" + "crypto/x509" + "flag" + "fmt" + "github.com/lesismal/llib/std/net/http/internal" + "log" + "net" + "net/http" + "os" + "strings" + "sync" + "time" +) + +// A Server is an HTTP server listening on a system-chosen port on the +// local loopback interface, for use in end-to-end HTTP tests. +type Server struct { + URL string // base URL of form http://ipaddr:port with no trailing slash + Listener net.Listener + + // EnableHTTP2 controls whether HTTP/2 is enabled + // on the server. It must be set between calling + // NewUnstartedServer and calling Server.StartTLS. + EnableHTTP2 bool + + // TLS is the optional TLS configuration, populated with a new config + // after TLS is started. If set on an unstarted server before StartTLS + // is called, existing fields are copied into the new config. + TLS *tls.Config + + // Config may be changed after calling NewUnstartedServer and + // before Start or StartTLS. + Config *http.Server + + // certificate is a parsed version of the TLS config certificate, if present. + certificate *x509.Certificate + + // wg counts the number of outstanding HTTP requests on this server. + // Close blocks until all requests are finished. + wg sync.WaitGroup + + mu sync.Mutex // guards closed and conns + closed bool + conns map[net.Conn]http.ConnState // except terminal states + + // client is configured for use with the server. + // Its transport is automatically closed when Close is called. + client *http.Client +} + +func newLocalListener() net.Listener { + if serveFlag != "" { + l, err := net.Listen("tcp", serveFlag) + if err != nil { + panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err)) + } + return l + } + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) + } + } + return l +} + +// When debugging a particular http server-based test, +// this flag lets you run +// go test -run=BrokenTest -httptest.serve=127.0.0.1:8000 +// to start the broken server so you can interact with it manually. +// We only register this flag if it looks like the caller knows about it +// and is trying to use it as we don't want to pollute flags and this +// isn't really part of our API. Don't depend on this. +var serveFlag string + +func init() { + if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") { + flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.") + } +} + +func strSliceContainsPrefix(v []string, pre string) bool { + for _, s := range v { + if strings.HasPrefix(s, pre) { + return true + } + } + return false +} + +// NewServer starts and returns a new Server. +// The caller should call Close when finished, to shut it down. +func NewServer(handler http.Handler) *Server { + ts := NewUnstartedServer(handler) + ts.Start() + return ts +} + +// NewUnstartedServer returns a new Server but doesn't start it. +// +// After changing its configuration, the caller should call Start or +// StartTLS. +// +// The caller should call Close when finished, to shut it down. +func NewUnstartedServer(handler http.Handler) *Server { + return &Server{ + Listener: newLocalListener(), + Config: &http.Server{Handler: handler}, + } +} + +// Start starts a server from NewUnstartedServer. +func (s *Server) Start() { + if s.URL != "" { + panic("Server already started") + } + if s.client == nil { + s.client = &http.Client{Transport: &http.Transport{}} + } + s.URL = "http://" + s.Listener.Addr().String() + s.wrap() + s.goServe() + if serveFlag != "" { + fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL) + select {} + } +} + +// StartTLS starts TLS on a server from NewUnstartedServer. +func (s *Server) StartTLS() { + if s.URL != "" { + panic("Server already started") + } + if s.client == nil { + s.client = &http.Client{Transport: &http.Transport{}} + } + cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) + if err != nil { + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + + existingConfig := s.TLS + if existingConfig != nil { + s.TLS = existingConfig.Clone() + } else { + s.TLS = new(tls.Config) + } + if s.TLS.NextProtos == nil { + nextProtos := []string{"http/1.1"} + if s.EnableHTTP2 { + nextProtos = []string{"h2"} + } + s.TLS.NextProtos = nextProtos + } + if len(s.TLS.Certificates) == 0 { + s.TLS.Certificates = []tls.Certificate{cert} + } + s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0]) + if err != nil { + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + certpool := x509.NewCertPool() + certpool.AddCert(s.certificate) + s.client.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certpool, + }, + ForceAttemptHTTP2: s.EnableHTTP2, + } + s.Listener = tls.NewListener(s.Listener, s.TLS) + s.URL = "https://" + s.Listener.Addr().String() + s.wrap() + s.goServe() +} + +// NewTLSServer starts and returns a new Server using TLS. +// The caller should call Close when finished, to shut it down. +func NewTLSServer(handler http.Handler) *Server { + ts := NewUnstartedServer(handler) + ts.StartTLS() + return ts +} + +type closeIdleTransport interface { + CloseIdleConnections() +} + +// Close shuts down the server and blocks until all outstanding +// requests on this server have completed. +func (s *Server) Close() { + s.mu.Lock() + if !s.closed { + s.closed = true + s.Listener.Close() + s.Config.SetKeepAlivesEnabled(false) + for c, st := range s.conns { + // Force-close any idle connections (those between + // requests) and new connections (those which connected + // but never sent a request). StateNew connections are + // super rare and have only been seen (in + // previously-flaky tests) in the case of + // socket-late-binding races from the http Client + // dialing this server and then getting an idle + // connection before the dial completed. There is thus + // a connected connection in StateNew with no + // associated Request. We only close StateIdle and + // StateNew because they're not doing anything. It's + // possible StateNew is about to do something in a few + // milliseconds, but a previous CL to check again in a + // few milliseconds wasn't liked (early versions of + // https://golang.org/cl/15151) so now we just + // forcefully close StateNew. The docs for Server.Close say + // we wait for "outstanding requests", so we don't close things + // in StateActive. + if st == http.StateIdle || st == http.StateNew { + s.closeConn(c) + } + } + // If this server doesn't shut down in 5 seconds, tell the user why. + t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo) + defer t.Stop() + } + s.mu.Unlock() + + // Not part of httptest.Server's correctness, but assume most + // users of httptest.Server will be using the standard + // transport, so help them out and close any idle connections for them. + if t, ok := http.DefaultTransport.(closeIdleTransport); ok { + t.CloseIdleConnections() + } + + // Also close the client idle connections. + if s.client != nil { + if t, ok := s.client.Transport.(closeIdleTransport); ok { + t.CloseIdleConnections() + } + } + + s.wg.Wait() +} + +func (s *Server) logCloseHangDebugInfo() { + s.mu.Lock() + defer s.mu.Unlock() + var buf strings.Builder + buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n") + for c, st := range s.conns { + fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st) + } + log.Print(buf.String()) +} + +// CloseClientConnections closes any open HTTP connections to the test Server. +func (s *Server) CloseClientConnections() { + s.mu.Lock() + nconn := len(s.conns) + ch := make(chan struct{}, nconn) + for c := range s.conns { + go s.closeConnChan(c, ch) + } + s.mu.Unlock() + + // Wait for outstanding closes to finish. + // + // Out of paranoia for making a late change in Go 1.6, we + // bound how long this can wait, since golang.org/issue/14291 + // isn't fully understood yet. At least this should only be used + // in tests. + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + for i := 0; i < nconn; i++ { + select { + case <-ch: + case <-timer.C: + // Too slow. Give up. + return + } + } +} + +// Certificate returns the certificate used by the server, or nil if +// the server doesn't use TLS. +func (s *Server) Certificate() *x509.Certificate { + return s.certificate +} + +// Client returns an HTTP client configured for making requests to the server. +// It is configured to trust the server's TLS test certificate and will +// close its idle connections on Server.Close. +func (s *Server) Client() *http.Client { + return s.client +} + +func (s *Server) goServe() { + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.Config.Serve(s.Listener) + }() +} + +// wrap installs the connection state-tracking hook to know which +// connections are idle. +func (s *Server) wrap() { + oldHook := s.Config.ConnState + s.Config.ConnState = func(c net.Conn, cs http.ConnState) { + s.mu.Lock() + defer s.mu.Unlock() + switch cs { + case http.StateNew: + s.wg.Add(1) + if _, exists := s.conns[c]; exists { + panic("invalid state transition") + } + if s.conns == nil { + s.conns = make(map[net.Conn]http.ConnState) + } + s.conns[c] = cs + if s.closed { + // Probably just a socket-late-binding dial from + // the default transport that lost the race (and + // thus this connection is now idle and will + // never be used). + s.closeConn(c) + } + case http.StateActive: + if oldState, ok := s.conns[c]; ok { + if oldState != http.StateNew && oldState != http.StateIdle { + panic("invalid state transition") + } + s.conns[c] = cs + } + case http.StateIdle: + if oldState, ok := s.conns[c]; ok { + if oldState != http.StateActive { + panic("invalid state transition") + } + s.conns[c] = cs + } + if s.closed { + s.closeConn(c) + } + case http.StateHijacked, http.StateClosed: + s.forgetConn(c) + } + if oldHook != nil { + oldHook(c, cs) + } + } +} + +// closeConn closes c. +// s.mu must be held. +func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) } + +// closeConnChan is like closeConn, but takes an optional channel to receive a value +// when the goroutine closing c is done. +func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) { + c.Close() + if done != nil { + done <- struct{}{} + } +} + +// forgetConn removes c from the set of tracked conns and decrements it from the +// waitgroup, unless it was previously removed. +// s.mu must be held. +func (s *Server) forgetConn(c net.Conn) { + if _, ok := s.conns[c]; ok { + delete(s.conns, c) + s.wg.Done() + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httptest/server_test.go b/vendor/github.com/lesismal/llib/std/net/http/httptest/server_test.go new file mode 100644 index 0000000..39568b3 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httptest/server_test.go @@ -0,0 +1,240 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "bufio" + "io" + "net" + "net/http" + "testing" +) + +type newServerFunc func(http.Handler) *Server + +var newServers = map[string]newServerFunc{ + "NewServer": NewServer, + "NewTLSServer": NewTLSServer, + + // The manual variants of newServer create a Server manually by only filling + // in the exported fields of Server. + "NewServerManual": func(h http.Handler) *Server { + ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}} + ts.Start() + return ts + }, + "NewTLSServerManual": func(h http.Handler) *Server { + ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}} + ts.StartTLS() + return ts + }, +} + +func TestServer(t *testing.T) { + for _, name := range []string{"NewServer", "NewServerManual"} { + t.Run(name, func(t *testing.T) { + newServer := newServers[name] + t.Run("Server", func(t *testing.T) { testServer(t, newServer) }) + t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) }) + t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) }) + t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) }) + t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) }) + }) + } + for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} { + t.Run(name, func(t *testing.T) { + newServer := newServers[name] + t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) }) + t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) }) + }) + } +} + +func testServer(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + defer ts.Close() + res, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + got, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Errorf("got %q, want hello", string(got)) + } +} + +// Issue 12781 +func testGetAfterClose(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + + res, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + got, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Fatalf("got %q, want hello", string(got)) + } + + ts.Close() + + res, err = http.Get(ts.URL) + if err == nil { + body, _ := io.ReadAll(res.Body) + t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body) + } +} + +func testServerCloseBlocking(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + dial := func() net.Conn { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + return c + } + + // Keep one connection in StateNew (connected, but not sending anything) + cnew := dial() + defer cnew.Close() + + // Keep one connection in StateIdle (idle after a request) + cidle := dial() + defer cidle.Close() + cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n")) + _, err := http.ReadResponse(bufio.NewReader(cidle), nil) + if err != nil { + t.Fatal(err) + } + + ts.Close() // test we don't hang here forever. +} + +// Issue 14290 +func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) { + var s *Server + s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.CloseClientConnections() + })) + defer s.Close() + res, err := http.Get(s.URL) + if err == nil { + res.Body.Close() + t.Fatalf("Unexpected response: %#v", res) + } +} + +// Tests that the Server.Client method works and returns an http.Client that can hit +// NewTLSServer without cert warnings. +func testServerClient(t *testing.T, newTLSServer newServerFunc) { + ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + defer ts.Close() + client := ts.Client() + res, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + got, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Errorf("got %q, want hello", string(got)) + } +} + +// Tests that the Server.Client.Transport interface is implemented +// by a *http.Transport. +func testServerClientTransportType(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer ts.Close() + client := ts.Client() + if _, ok := client.Transport.(*http.Transport); !ok { + t.Errorf("got %T, want *http.Transport", client.Transport) + } +} + +// Tests that the TLS Server.Client.Transport interface is implemented +// by a *http.Transport. +func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) { + ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer ts.Close() + client := ts.Client() + if _, ok := client.Transport.(*http.Transport); !ok { + t.Errorf("got %T, want *http.Transport", client.Transport) + } +} + +type onlyCloseListener struct { + net.Listener +} + +func (onlyCloseListener) Close() error { return nil } + +// Issue 19729: panic in Server.Close for values created directly +// without a constructor (so the unexported client field is nil). +func TestServerZeroValueClose(t *testing.T) { + ts := &Server{ + Listener: onlyCloseListener{}, + Config: &http.Server{}, + } + + ts.Close() // tests that it doesn't panic +} + +func TestTLSServerWithHTTP2(t *testing.T) { + modes := []struct { + name string + wantProto string + }{ + {"http1", "HTTP/1.1"}, + {"http2", "HTTP/2.0"}, + } + + for _, tt := range modes { + t.Run(tt.name, func(t *testing.T) { + cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Proto", r.Proto) + })) + + switch tt.name { + case "http2": + cst.EnableHTTP2 = true + cst.StartTLS() + default: + cst.Start() + } + + defer cst.Close() + + res, err := cst.Client().Get(cst.URL) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w { + t.Fatalf("X-Proto header mismatch:\n\tgot: %q\n\twant: %q", g, w) + } + }) + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httptrace/example_test.go b/vendor/github.com/lesismal/llib/std/net/http/httptrace/example_test.go new file mode 100644 index 0000000..836a50d --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httptrace/example_test.go @@ -0,0 +1,29 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptrace_test + +import ( + "fmt" + "github.com/lesismal/llib/std/net/http/httptrace" + "log" + "net/http" +) + +func Example() { + req, _ := http.NewRequest("GET", "http://example.com", nil) + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + fmt.Printf("Got Conn: %+v\n", connInfo) + }, + DNSDone: func(dnsInfo httptrace.DNSDoneInfo) { + fmt.Printf("DNS Info: %+v\n", dnsInfo) + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + _, err := http.DefaultTransport.RoundTrip(req) + if err != nil { + log.Fatal(err) + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httptrace/trace.go b/vendor/github.com/lesismal/llib/std/net/http/httptrace/trace.go new file mode 100644 index 0000000..5e34065 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httptrace/trace.go @@ -0,0 +1,256 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httptrace provides mechanisms to trace the events within +// HTTP client requests. +package httptrace + +import ( + "context" + "crypto/tls" + "net" + "net/textproto" + "reflect" + "time" + + "github.com/lesismal/llib/std/internal/nettrace" +) + +// unique type to prevent assignment. +type clientEventContextKey struct{} + +// ContextClientTrace returns the ClientTrace associated with the +// provided context. If none, it returns nil. +func ContextClientTrace(ctx context.Context) *ClientTrace { + trace, _ := ctx.Value(clientEventContextKey{}).(*ClientTrace) + return trace +} + +// WithClientTrace returns a new context based on the provided parent +// ctx. HTTP client requests made with the returned context will use +// the provided trace hooks, in addition to any previous hooks +// registered with ctx. Any hooks defined in the provided trace will +// be called first. +func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context { + if trace == nil { + panic("nil trace") + } + old := ContextClientTrace(ctx) + trace.compose(old) + + ctx = context.WithValue(ctx, clientEventContextKey{}, trace) + if trace.hasNetHooks() { + nt := &nettrace.Trace{ + ConnectStart: trace.ConnectStart, + ConnectDone: trace.ConnectDone, + } + if trace.DNSStart != nil { + nt.DNSStart = func(name string) { + trace.DNSStart(DNSStartInfo{Host: name}) + } + } + if trace.DNSDone != nil { + nt.DNSDone = func(netIPs []interface{}, coalesced bool, err error) { + addrs := make([]net.IPAddr, len(netIPs)) + for i, ip := range netIPs { + addrs[i] = ip.(net.IPAddr) + } + trace.DNSDone(DNSDoneInfo{ + Addrs: addrs, + Coalesced: coalesced, + Err: err, + }) + } + } + ctx = context.WithValue(ctx, nettrace.TraceKey{}, nt) + } + return ctx +} + +// ClientTrace is a set of hooks to run at various stages of an outgoing +// HTTP request. Any particular hook may be nil. Functions may be +// called concurrently from different goroutines and some may be called +// after the request has completed or failed. +// +// ClientTrace currently traces a single HTTP request & response +// during a single round trip and has no hooks that span a series +// of redirected requests. +// +// See https://blog.golang.org/http-tracing for more. +type ClientTrace struct { + // GetConn is called before a connection is created or + // retrieved from an idle pool. The hostPort is the + // "host:port" of the target or proxy. GetConn is called even + // if there's already an idle cached connection available. + GetConn func(hostPort string) + + // GotConn is called after a successful connection is + // obtained. There is no hook for failure to obtain a + // connection; instead, use the error from + // Transport.RoundTrip. + GotConn func(GotConnInfo) + + // PutIdleConn is called when the connection is returned to + // the idle pool. If err is nil, the connection was + // successfully returned to the idle pool. If err is non-nil, + // it describes why not. PutIdleConn is not called if + // connection reuse is disabled via Transport.DisableKeepAlives. + // PutIdleConn is called before the caller's Response.Body.Close + // call returns. + // For HTTP/2, this hook is not currently used. + PutIdleConn func(err error) + + // GotFirstResponseByte is called when the first byte of the response + // headers is available. + GotFirstResponseByte func() + + // Got100Continue is called if the server replies with a "100 + // Continue" response. + Got100Continue func() + + // Got1xxResponse is called for each 1xx informational response header + // returned before the final non-1xx response. Got1xxResponse is called + // for "100 Continue" responses, even if Got100Continue is also defined. + // If it returns an error, the client request is aborted with that error value. + Got1xxResponse func(code int, header textproto.MIMEHeader) error + + // DNSStart is called when a DNS lookup begins. + DNSStart func(DNSStartInfo) + + // DNSDone is called when a DNS lookup ends. + DNSDone func(DNSDoneInfo) + + // ConnectStart is called when a new connection's Dial begins. + // If net.Dialer.DualStack (IPv6 "Happy Eyeballs") support is + // enabled, this may be called multiple times. + ConnectStart func(network, addr string) + + // ConnectDone is called when a new connection's Dial + // completes. The provided err indicates whether the + // connection completedly successfully. + // If net.Dialer.DualStack ("Happy Eyeballs") support is + // enabled, this may be called multiple times. + ConnectDone func(network, addr string, err error) + + // TLSHandshakeStart is called when the TLS handshake is started. When + // connecting to an HTTPS site via an HTTP proxy, the handshake happens + // after the CONNECT request is processed by the proxy. + TLSHandshakeStart func() + + // TLSHandshakeDone is called after the TLS handshake with either the + // successful handshake's connection state, or a non-nil error on handshake + // failure. + TLSHandshakeDone func(tls.ConnectionState, error) + + // WroteHeaderField is called after the Transport has written + // each request header. At the time of this call the values + // might be buffered and not yet written to the network. + WroteHeaderField func(key string, value []string) + + // WroteHeaders is called after the Transport has written + // all request headers. + WroteHeaders func() + + // Wait100Continue is called if the Request specified + // "Expect: 100-continue" and the Transport has written the + // request headers but is waiting for "100 Continue" from the + // server before writing the request body. + Wait100Continue func() + + // WroteRequest is called with the result of writing the + // request and any body. It may be called multiple times + // in the case of retried requests. + WroteRequest func(WroteRequestInfo) +} + +// WroteRequestInfo contains information provided to the WroteRequest +// hook. +type WroteRequestInfo struct { + // Err is any error encountered while writing the Request. + Err error +} + +// compose modifies t such that it respects the previously-registered hooks in old, +// subject to the composition policy requested in t.Compose. +func (t *ClientTrace) compose(old *ClientTrace) { + if old == nil { + return + } + tv := reflect.ValueOf(t).Elem() + ov := reflect.ValueOf(old).Elem() + structType := tv.Type() + for i := 0; i < structType.NumField(); i++ { + tf := tv.Field(i) + hookType := tf.Type() + if hookType.Kind() != reflect.Func { + continue + } + of := ov.Field(i) + if of.IsNil() { + continue + } + if tf.IsNil() { + tf.Set(of) + continue + } + + // Make a copy of tf for tf to call. (Otherwise it + // creates a recursive call cycle and stack overflows) + tfCopy := reflect.ValueOf(tf.Interface()) + + // We need to call both tf and of in some order. + newFunc := reflect.MakeFunc(hookType, func(args []reflect.Value) []reflect.Value { + tfCopy.Call(args) + return of.Call(args) + }) + tv.Field(i).Set(newFunc) + } +} + +// DNSStartInfo contains information about a DNS request. +type DNSStartInfo struct { + Host string +} + +// DNSDoneInfo contains information about the results of a DNS lookup. +type DNSDoneInfo struct { + // Addrs are the IPv4 and/or IPv6 addresses found in the DNS + // lookup. The contents of the slice should not be mutated. + Addrs []net.IPAddr + + // Err is any error that occurred during the DNS lookup. + Err error + + // Coalesced is whether the Addrs were shared with another + // caller who was doing the same DNS lookup concurrently. + Coalesced bool +} + +func (t *ClientTrace) hasNetHooks() bool { + if t == nil { + return false + } + return t.DNSStart != nil || t.DNSDone != nil || t.ConnectStart != nil || t.ConnectDone != nil +} + +// GotConnInfo is the argument to the ClientTrace.GotConn function and +// contains information about the obtained connection. +type GotConnInfo struct { + // Conn is the connection that was obtained. It is owned by + // the http.Transport and should not be read, written or + // closed by users of ClientTrace. + Conn net.Conn + + // Reused is whether this connection has been previously + // used for another HTTP request. + Reused bool + + // WasIdle is whether this connection was obtained from an + // idle pool. + WasIdle bool + + // IdleTime reports how long the connection was previously + // idle, if WasIdle is true. + IdleTime time.Duration +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httptrace/trace_test.go b/vendor/github.com/lesismal/llib/std/net/http/httptrace/trace_test.go new file mode 100644 index 0000000..bb57ada --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httptrace/trace_test.go @@ -0,0 +1,89 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptrace + +import ( + "bytes" + "context" + "testing" +) + +func TestWithClientTrace(t *testing.T) { + var buf bytes.Buffer + connectStart := func(b byte) func(network, addr string) { + return func(network, addr string) { + buf.WriteByte(b) + } + } + + ctx := context.Background() + oldtrace := &ClientTrace{ + ConnectStart: connectStart('O'), + } + ctx = WithClientTrace(ctx, oldtrace) + newtrace := &ClientTrace{ + ConnectStart: connectStart('N'), + } + ctx = WithClientTrace(ctx, newtrace) + trace := ContextClientTrace(ctx) + + buf.Reset() + trace.ConnectStart("net", "addr") + if got, want := buf.String(), "NO"; got != want { + t.Errorf("got %q; want %q", got, want) + } +} + +func TestCompose(t *testing.T) { + var buf bytes.Buffer + var testNum int + + connectStart := func(b byte) func(network, addr string) { + return func(network, addr string) { + if addr != "addr" { + t.Errorf(`%d. args for %q case = %q, %q; want addr of "addr"`, testNum, b, network, addr) + } + buf.WriteByte(b) + } + } + + tests := [...]struct { + trace, old *ClientTrace + want string + }{ + 0: { + want: "T", + trace: &ClientTrace{ + ConnectStart: connectStart('T'), + }, + }, + 1: { + want: "TO", + trace: &ClientTrace{ + ConnectStart: connectStart('T'), + }, + old: &ClientTrace{ConnectStart: connectStart('O')}, + }, + 2: { + want: "O", + trace: &ClientTrace{}, + old: &ClientTrace{ConnectStart: connectStart('O')}, + }, + } + for i, tt := range tests { + testNum = i + buf.Reset() + + tr := *tt.trace + tr.compose(tt.old) + if tr.ConnectStart != nil { + tr.ConnectStart("net", "addr") + } + if got := buf.String(); got != tt.want { + t.Errorf("%d. got = %q; want %q", i, got, tt.want) + } + } + +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httputil/dump.go b/vendor/github.com/lesismal/llib/std/net/http/httputil/dump.go new file mode 100644 index 0000000..2948f27 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httputil/dump.go @@ -0,0 +1,340 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +// drainBody reads all of b to memory and then returns two equivalent +// ReadClosers yielding the same bytes. +// +// It returns an error if the initial slurp of all bytes fails. It does not attempt +// to make the returned ReadClosers have identical error-matching behavior. +func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { + if b == nil || b == http.NoBody { + // No copying needed. Preserve the magic sentinel meaning of NoBody. + return http.NoBody, http.NoBody, nil + } + var buf bytes.Buffer + if _, err = buf.ReadFrom(b); err != nil { + return nil, b, err + } + if err = b.Close(); err != nil { + return nil, b, err + } + return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil +} + +// dumpConn is a net.Conn which writes to Writer and reads from Reader +type dumpConn struct { + io.Writer + io.Reader +} + +func (c *dumpConn) Close() error { return nil } +func (c *dumpConn) LocalAddr() net.Addr { return nil } +func (c *dumpConn) RemoteAddr() net.Addr { return nil } +func (c *dumpConn) SetDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +// outGoingLength is a copy of the unexported +// (*http.Request).outgoingLength method. +func outgoingLength(req *http.Request) int64 { + if req.Body == nil || req.Body == http.NoBody { + return 0 + } + if req.ContentLength != 0 { + return req.ContentLength + } + return -1 +} + +// DumpRequestOut is like DumpRequest but for outgoing client requests. It +// includes any headers that the standard http.Transport adds, such as +// User-Agent. +func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { + save := req.Body + dummyBody := false + if !body { + contentLength := outgoingLength(req) + if contentLength != 0 { + req.Body = io.NopCloser(io.LimitReader(neverEnding('x'), contentLength)) + dummyBody = true + } + } else { + var err error + save, req.Body, err = drainBody(req.Body) + if err != nil { + return nil, err + } + } + + // Since we're using the actual Transport code to write the request, + // switch to http so the Transport doesn't try to do an SSL + // negotiation with our dumpConn and its bytes.Buffer & pipe. + // The wire format for https and http are the same, anyway. + reqSend := req + if req.URL.Scheme == "https" { + reqSend = new(http.Request) + *reqSend = *req + reqSend.URL = new(url.URL) + *reqSend.URL = *req.URL + reqSend.URL.Scheme = "http" + } + + // Use the actual Transport code to record what we would send + // on the wire, but not using TCP. Use a Transport with a + // custom dialer that returns a fake net.Conn that waits + // for the full input (and recording it), and then responds + // with a dummy response. + var buf bytes.Buffer // records the output + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + dr := &delegateReader{c: make(chan io.Reader)} + + t := &http.Transport{ + Dial: func(net, addr string) (net.Conn, error) { + return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil + }, + } + defer t.CloseIdleConnections() + + // We need this channel to ensure that the reader + // goroutine exits if t.RoundTrip returns an error. + // See golang.org/issue/32571. + quitReadCh := make(chan struct{}) + // Wait for the request before replying with a dummy response: + go func() { + req, err := http.ReadRequest(bufio.NewReader(pr)) + if err == nil { + // Ensure all the body is read; otherwise + // we'll get a partial dump. + io.Copy(io.Discard, req.Body) + req.Body.Close() + } + select { + case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"): + case <-quitReadCh: + // Ensure delegateReader.Read doesn't block forever if we get an error. + close(dr.c) + } + }() + + _, err := t.RoundTrip(reqSend) + + req.Body = save + if err != nil { + pw.Close() + dr.err = err + close(quitReadCh) + return nil, err + } + dump := buf.Bytes() + + // If we used a dummy body above, remove it now. + // TODO: if the req.ContentLength is large, we allocate memory + // unnecessarily just to slice it off here. But this is just + // a debug function, so this is acceptable for now. We could + // discard the body earlier if this matters. + if dummyBody { + if i := bytes.Index(dump, []byte("\r\n\r\n")); i >= 0 { + dump = dump[:i+4] + } + } + return dump, nil +} + +// delegateReader is a reader that delegates to another reader, +// once it arrives on a channel. +type delegateReader struct { + c chan io.Reader + err error // only used if r is nil and c is closed. + r io.Reader // nil until received from c +} + +func (r *delegateReader) Read(p []byte) (int, error) { + if r.r == nil { + var ok bool + if r.r, ok = <-r.c; !ok { + return 0, r.err + } + } + return r.r.Read(p) +} + +// Return value if nonempty, def otherwise. +func valueOrDefault(value, def string) string { + if value != "" { + return value + } + return def +} + +var reqWriteExcludeHeaderDump = map[string]bool{ + "Host": true, // not in Header map anyway + "Transfer-Encoding": true, + "Trailer": true, +} + +// DumpRequest returns the given request in its HTTP/1.x wire +// representation. It should only be used by servers to debug client +// requests. The returned representation is an approximation only; +// some details of the initial request are lost while parsing it into +// an http.Request. In particular, the order and case of header field +// names are lost. The order of values in multi-valued headers is kept +// intact. HTTP/2 requests are dumped in HTTP/1.x form, not in their +// original binary representations. +// +// If body is true, DumpRequest also returns the body. To do so, it +// consumes req.Body and then replaces it with a new io.ReadCloser +// that yields the same bytes. If DumpRequest returns an error, +// the state of req is undefined. +// +// The documentation for http.Request.Write details which fields +// of req are included in the dump. +func DumpRequest(req *http.Request, body bool) ([]byte, error) { + var err error + save := req.Body + if !body || req.Body == nil { + req.Body = nil + } else { + save, req.Body, err = drainBody(req.Body) + if err != nil { + return nil, err + } + } + + var b bytes.Buffer + + // By default, print out the unmodified req.RequestURI, which + // is always set for incoming server requests. But because we + // previously used req.URL.RequestURI and the docs weren't + // always so clear about when to use DumpRequest vs + // DumpRequestOut, fall back to the old way if the caller + // provides a non-server Request. + reqURI := req.RequestURI + if reqURI == "" { + reqURI = req.URL.RequestURI() + } + + fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"), + reqURI, req.ProtoMajor, req.ProtoMinor) + + absRequestURI := strings.HasPrefix(req.RequestURI, "http://") || strings.HasPrefix(req.RequestURI, "https://") + if !absRequestURI { + host := req.Host + if host == "" && req.URL != nil { + host = req.URL.Host + } + if host != "" { + fmt.Fprintf(&b, "Host: %s\r\n", host) + } + } + + chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" + if len(req.TransferEncoding) > 0 { + fmt.Fprintf(&b, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ",")) + } + if req.Close { + fmt.Fprintf(&b, "Connection: close\r\n") + } + + err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump) + if err != nil { + return nil, err + } + + io.WriteString(&b, "\r\n") + + if req.Body != nil { + var dest io.Writer = &b + if chunked { + dest = NewChunkedWriter(dest) + } + _, err = io.Copy(dest, req.Body) + if chunked { + dest.(io.Closer).Close() + io.WriteString(&b, "\r\n") + } + } + + req.Body = save + if err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// errNoBody is a sentinel error value used by failureToReadBody so we +// can detect that the lack of body was intentional. +var errNoBody = errors.New("sentinel error value") + +// failureToReadBody is a io.ReadCloser that just returns errNoBody on +// Read. It's swapped in when we don't actually want to consume +// the body, but need a non-nil one, and want to distinguish the +// error from reading the dummy body. +type failureToReadBody struct{} + +func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody } +func (failureToReadBody) Close() error { return nil } + +// emptyBody is an instance of empty reader. +var emptyBody = io.NopCloser(strings.NewReader("")) + +// DumpResponse is like DumpRequest but dumps a response. +func DumpResponse(resp *http.Response, body bool) ([]byte, error) { + var b bytes.Buffer + var err error + save := resp.Body + savecl := resp.ContentLength + + if !body { + // For content length of zero. Make sure the body is an empty + // reader, instead of returning error through failureToReadBody{}. + if resp.ContentLength == 0 { + resp.Body = emptyBody + } else { + resp.Body = failureToReadBody{} + } + } else if resp.Body == nil { + resp.Body = emptyBody + } else { + save, resp.Body, err = drainBody(resp.Body) + if err != nil { + return nil, err + } + } + err = resp.Write(&b) + if err == errNoBody { + err = nil + } + resp.Body = save + resp.ContentLength = savecl + if err != nil { + return nil, err + } + return b.Bytes(), nil +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httputil/dump_test.go b/vendor/github.com/lesismal/llib/std/net/http/httputil/dump_test.go new file mode 100644 index 0000000..8168b2e --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httputil/dump_test.go @@ -0,0 +1,519 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "math/rand" + "net/http" + "net/url" + "runtime" + "runtime/pprof" + "strings" + "testing" + "time" +) + +type eofReader struct{} + +func (n eofReader) Close() error { return nil } + +func (n eofReader) Read([]byte) (int, error) { return 0, io.EOF } + +type dumpTest struct { + // Either Req or GetReq can be set/nil but not both. + Req *http.Request + GetReq func() *http.Request + + Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body + + WantDump string + WantDumpOut string + MustError bool // if true, the test is expected to throw an error + NoBody bool // if true, set DumpRequest{,Out} body to false +} + +var dumpTests = []dumpTest{ + // HTTP/1.1 => chunked coding; body; empty trailer + { + Req: &http.Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + TransferEncoding: []string{"chunked"}, + }, + + Body: []byte("abcdef"), + + WantDump: "GET /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + }, + + // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, + // and doesn't add a User-Agent. + { + Req: &http.Request{ + Method: "GET", + URL: mustParseURL("/foo"), + ProtoMajor: 1, + ProtoMinor: 0, + Header: http.Header{ + "X-Foo": []string{"X-Bar"}, + }, + }, + + WantDump: "GET /foo HTTP/1.0\r\n" + + "X-Foo: X-Bar\r\n\r\n", + }, + + { + Req: mustNewRequest("GET", "http://example.com/foo", nil), + + WantDumpOut: "GET /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, + + // Test that an https URL doesn't try to do an SSL negotiation + // with a bytes.Buffer and hang with all goroutines not + // runnable. + { + Req: mustNewRequest("GET", "https://example.com/foo", nil), + WantDumpOut: "GET /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, + + // Request with Body, but Dump requested without it. + { + Req: &http.Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: "post.tld", + Path: "/", + }, + ContentLength: 6, + ProtoMajor: 1, + ProtoMinor: 1, + }, + + Body: []byte("abcdef"), + + WantDumpOut: "POST / HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Content-Length: 6\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + + NoBody: true, + }, + + // Request with Body > 8196 (default buffer size) + { + Req: &http.Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: "post.tld", + Path: "/", + }, + Header: http.Header{ + "Content-Length": []string{"8193"}, + }, + + ContentLength: 8193, + ProtoMajor: 1, + ProtoMinor: 1, + }, + + Body: bytes.Repeat([]byte("a"), 8193), + + WantDumpOut: "POST / HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Content-Length: 8193\r\n" + + "Accept-Encoding: gzip\r\n\r\n" + + strings.Repeat("a", 8193), + WantDump: "POST / HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "Content-Length: 8193\r\n\r\n" + + strings.Repeat("a", 8193), + }, + + { + GetReq: func() *http.Request { + return mustReadRequest("GET http://foo.com/ HTTP/1.1\r\n" + + "User-Agent: blah\r\n\r\n") + }, + NoBody: true, + WantDump: "GET http://foo.com/ HTTP/1.1\r\n" + + "User-Agent: blah\r\n\r\n", + }, + + // Issue #7215. DumpRequest should return the "Content-Length" when set + { + GetReq: func() *http.Request { + return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 3\r\n" + + "\r\nkey1=name1&key2=name2") + }, + WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 3\r\n" + + "\r\nkey", + }, + // Issue #7215. DumpRequest should return the "Content-Length" in ReadRequest + { + GetReq: func() *http.Request { + return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 0\r\n" + + "\r\nkey1=name1&key2=name2") + }, + WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 0\r\n\r\n", + }, + + // Issue #7215. DumpRequest should not return the "Content-Length" if unset + { + GetReq: func() *http.Request { + return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "\r\nkey1=name1&key2=name2") + }, + WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n\r\n", + }, + + // Issue 18506: make drainBody recognize NoBody. Otherwise + // this was turning into a chunked request. + { + Req: mustNewRequest("POST", "http://example.com/foo", http.NoBody), + WantDumpOut: "POST /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Content-Length: 0\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, + + // Issue 34504: a non-nil Body without ContentLength set should be chunked + { + Req: &http.Request{ + Method: "PUT", + URL: &url.URL{ + Scheme: "http", + Host: "post.tld", + Path: "/test", + }, + ContentLength: 0, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Body: &eofReader{}, + }, + NoBody: true, + WantDumpOut: "PUT /test HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Transfer-Encoding: chunked\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, +} + +func TestDumpRequest(t *testing.T) { + // Make a copy of dumpTests and add 10 new cases with an empty URL + // to test that no goroutines are leaked. See golang.org/issue/32571. + // 10 seems to be a decent number which always triggers the failure. + dumpTests := dumpTests[:] + for i := 0; i < 10; i++ { + dumpTests = append(dumpTests, dumpTest{ + Req: mustNewRequest("GET", "", nil), + MustError: true, + }) + } + numg0 := runtime.NumGoroutine() + for i, tt := range dumpTests { + if tt.Req != nil && tt.GetReq != nil || tt.Req == nil && tt.GetReq == nil { + t.Errorf("#%d: either .Req(%p) or .GetReq(%p) can be set/nil but not both", i, tt.Req, tt.GetReq) + continue + } + + freshReq := func(ti dumpTest) *http.Request { + req := ti.Req + if req == nil { + req = ti.GetReq() + } + + if req.Header == nil { + req.Header = make(http.Header) + } + + if ti.Body == nil { + return req + } + switch b := ti.Body.(type) { + case []byte: + req.Body = io.NopCloser(bytes.NewReader(b)) + case func() io.ReadCloser: + req.Body = b() + default: + t.Fatalf("Test %d: unsupported Body of %T", i, ti.Body) + } + return req + } + + if tt.WantDump != "" { + req := freshReq(tt) + dump, err := DumpRequest(req, !tt.NoBody) + if err != nil { + t.Errorf("DumpRequest #%d: %s\nWantDump:\n%s", i, err, tt.WantDump) + continue + } + if string(dump) != tt.WantDump { + t.Errorf("DumpRequest %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDump, string(dump)) + continue + } + } + + if tt.MustError { + req := freshReq(tt) + _, err := DumpRequestOut(req, !tt.NoBody) + if err == nil { + t.Errorf("DumpRequestOut #%d: expected an error, got nil", i) + } + continue + } + + if tt.WantDumpOut != "" { + req := freshReq(tt) + dump, err := DumpRequestOut(req, !tt.NoBody) + if err != nil { + t.Errorf("DumpRequestOut #%d: %s", i, err) + continue + } + if string(dump) != tt.WantDumpOut { + t.Errorf("DumpRequestOut %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDumpOut, string(dump)) + continue + } + } + } + + // Validate we haven't leaked any goroutines. + var dg int + dl := deadline(t, 5*time.Second, time.Second) + for time.Now().Before(dl) { + if dg = runtime.NumGoroutine() - numg0; dg <= 4 { + // No unexpected goroutines. + return + } + + // Allow goroutines to schedule and die off. + runtime.Gosched() + } + + buf := make([]byte, 4096) + buf = buf[:runtime.Stack(buf, true)] + t.Errorf("Unexpectedly large number of new goroutines: %d new: %s", dg, buf) +} + +// deadline returns the time which is needed before t.Deadline() +// if one is configured and it is s greater than needed in the future, +// otherwise defaultDelay from the current time. +func deadline(t *testing.T, defaultDelay, needed time.Duration) time.Time { + if dl, ok := t.Deadline(); ok { + if dl = dl.Add(-needed); dl.After(time.Now()) { + // Allow an arbitrarily long delay. + return dl + } + } + + // No deadline configured or its closer than needed from now + // so just use the default. + return time.Now().Add(defaultDelay) +} + +func chunk(s string) string { + return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) +} + +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(fmt.Sprintf("Error parsing URL %q: %v", s, err)) + } + return u +} + +func mustNewRequest(method, url string, body io.Reader) *http.Request { + req, err := http.NewRequest(method, url, body) + if err != nil { + panic(fmt.Sprintf("NewRequest(%q, %q, %p) err = %v", method, url, body, err)) + } + return req +} + +func mustReadRequest(s string) *http.Request { + req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(s))) + if err != nil { + panic(err) + } + return req +} + +var dumpResTests = []struct { + res *http.Response + body bool + want string +}{ + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 50, + Header: http.Header{ + "Foo": []string{"Bar"}, + }, + Body: io.NopCloser(strings.NewReader("foo")), // shouldn't be used + }, + body: false, // to verify we see 50, not empty or 3. + want: `HTTP/1.1 200 OK +Content-Length: 50 +Foo: Bar`, + }, + + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 3, + Body: io.NopCloser(strings.NewReader("foo")), + }, + body: true, + want: `HTTP/1.1 200 OK +Content-Length: 3 + +foo`, + }, + + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: -1, + Body: io.NopCloser(strings.NewReader("foo")), + TransferEncoding: []string{"chunked"}, + }, + body: true, + want: `HTTP/1.1 200 OK +Transfer-Encoding: chunked + +3 +foo +0`, + }, + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, + Header: http.Header{ + // To verify if headers are not filtered out. + "Foo1": []string{"Bar1"}, + "Foo2": []string{"Bar2"}, + }, + Body: nil, + }, + body: false, // to verify we see 0, not empty. + want: `HTTP/1.1 200 OK +Foo1: Bar1 +Foo2: Bar2 +Content-Length: 0`, + }, +} + +func TestDumpResponse(t *testing.T) { + for i, tt := range dumpResTests { + gotb, err := DumpResponse(tt.res, tt.body) + if err != nil { + t.Errorf("%d. DumpResponse = %v", i, err) + continue + } + got := string(gotb) + got = strings.TrimSpace(got) + got = strings.ReplaceAll(got, "\r", "") + + if got != tt.want { + t.Errorf("%d.\nDumpResponse got:\n%s\n\nWant:\n%s\n", i, got, tt.want) + } + } +} + +// Issue 38352: Check for deadlock on cancelled requests. +func TestDumpRequestOutIssue38352(t *testing.T) { + if testing.Short() { + return + } + t.Parallel() + + timeout := 10 * time.Second + if deadline, ok := t.Deadline(); ok { + timeout = time.Until(deadline) + timeout -= time.Second * 2 // Leave 2 seconds to report failures. + } + for i := 0; i < 1000; i++ { + delay := time.Duration(rand.Intn(5)) * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), delay) + defer cancel() + + r := bytes.NewBuffer(make([]byte, 10000)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://example.com", r) + if err != nil { + t.Fatal(err) + } + + out := make(chan error) + go func() { + _, err = DumpRequestOut(req, true) + out <- err + }() + + select { + case <-out: + case <-time.After(timeout): + b := &bytes.Buffer{} + fmt.Fprintf(b, "deadlock detected on iteration %d after %s with delay: %v\n", i, timeout, delay) + pprof.Lookup("goroutine").WriteTo(b, 1) + t.Fatal(b.String()) + } + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httputil/example_test.go b/vendor/github.com/lesismal/llib/std/net/http/httputil/example_test.go new file mode 100644 index 0000000..9ab8afd --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httputil/example_test.go @@ -0,0 +1,123 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil_test + +import ( + "fmt" + "github.com/lesismal/llib/std/net/http/httptest" + "github.com/lesismal/llib/std/net/http/httputil" + "io" + "log" + "net/http" + "net/url" + "strings" +) + +func ExampleDumpRequest() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dump, err := httputil.DumpRequest(r, true) + if err != nil { + http.Error(w, fmt.Sprint(err), http.StatusInternalServerError) + return + } + + fmt.Fprintf(w, "%q", dump) + })) + defer ts.Close() + + const body = "Go is a general-purpose language designed with systems programming in mind." + req, err := http.NewRequest("POST", ts.URL, strings.NewReader(body)) + if err != nil { + log.Fatal(err) + } + req.Host = "www.example.org" + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + b, err := io.ReadAll(resp.Body) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", b) + + // Output: + // "POST / HTTP/1.1\r\nHost: www.example.org\r\nAccept-Encoding: gzip\r\nContent-Length: 75\r\nUser-Agent: Go-http-client/1.1\r\n\r\nGo is a general-purpose language designed with systems programming in mind." +} + +func ExampleDumpRequestOut() { + const body = "Go is a general-purpose language designed with systems programming in mind." + req, err := http.NewRequest("PUT", "http://www.example.org", strings.NewReader(body)) + if err != nil { + log.Fatal(err) + } + + dump, err := httputil.DumpRequestOut(req, true) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%q", dump) + + // Output: + // "PUT / HTTP/1.1\r\nHost: www.example.org\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 75\r\nAccept-Encoding: gzip\r\n\r\nGo is a general-purpose language designed with systems programming in mind." +} + +func ExampleDumpResponse() { + const body = "Go is a general-purpose language designed with systems programming in mind." + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Date", "Wed, 19 Jul 1972 19:00:00 GMT") + fmt.Fprintln(w, body) + })) + defer ts.Close() + + resp, err := http.Get(ts.URL) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%q", dump) + + // Output: + // "HTTP/1.1 200 OK\r\nContent-Length: 76\r\nContent-Type: text/plain; charset=utf-8\r\nDate: Wed, 19 Jul 1972 19:00:00 GMT\r\n\r\nGo is a general-purpose language designed with systems programming in mind.\n" +} + +func ExampleReverseProxy() { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "this call was relayed by the reverse proxy") + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + log.Fatal(err) + } + frontendProxy := httptest.NewServer(httputil.NewSingleHostReverseProxy(rpURL)) + defer frontendProxy.Close() + + resp, err := http.Get(frontendProxy.URL) + if err != nil { + log.Fatal(err) + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", b) + + // Output: + // this call was relayed by the reverse proxy +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httputil/httputil.go b/vendor/github.com/lesismal/llib/std/net/http/httputil/httputil.go new file mode 100644 index 0000000..b05dc11 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httputil/httputil.go @@ -0,0 +1,41 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httputil provides HTTP utility functions, complementing the +// more common ones in the net/http package. +package httputil + +import ( + "github.com/lesismal/llib/std/net/http/internal" + "io" +) + +// NewChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r io.Reader) io.Reader { + return internal.NewChunkedReader(r) +} + +// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream but does +// not send the final CRLF that appears after trailers; trailers and the last +// CRLF must be written separately. +// +// NewChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using NewChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func NewChunkedWriter(w io.Writer) io.WriteCloser { + return internal.NewChunkedWriter(w) +} + +// ErrLineTooLong is returned when reading malformed chunked data +// with lines that are too long. +var ErrLineTooLong = internal.ErrLineTooLong diff --git a/vendor/github.com/lesismal/llib/std/net/http/httputil/persist.go b/vendor/github.com/lesismal/llib/std/net/http/httputil/persist.go new file mode 100644 index 0000000..84b116d --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httputil/persist.go @@ -0,0 +1,431 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bufio" + "errors" + "io" + "net" + "net/http" + "net/textproto" + "sync" +) + +var ( + // Deprecated: No longer used. + ErrPersistEOF = &http.ProtocolError{ErrorString: "persistent connection closed"} + + // Deprecated: No longer used. + ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"} + + // Deprecated: No longer used. + ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"} +) + +// This is an API usage error - the local side is closed. +// ErrPersistEOF (above) reports that the remote side is closed. +var errClosed = errors.New("i/o operation on closed connection") + +// ServerConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. +// +// Deprecated: Use the Server in package net/http instead. +type ServerConn struct { + mu sync.Mutex // read-write protects the following fields + c net.Conn + r *bufio.Reader + re, we error // read/write errors + lastbody io.ReadCloser + nread, nwritten int + pipereq map[*http.Request]uint + + pipe textproto.Pipeline +} + +// NewServerConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. +// +// Deprecated: Use the Server in package net/http instead. +func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { + if r == nil { + r = bufio.NewReader(c) + } + return &ServerConn{c: c, r: r, pipereq: make(map[*http.Request]uint)} +} + +// Hijack detaches the ServerConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be +// called before Read has signaled the end of the keep-alive logic. The user +// should not call Hijack while Read or Write is in progress. +func (sc *ServerConn) Hijack() (net.Conn, *bufio.Reader) { + sc.mu.Lock() + defer sc.mu.Unlock() + c := sc.c + r := sc.r + sc.c = nil + sc.r = nil + return c, r +} + +// Close calls Hijack and then also closes the underlying connection. +func (sc *ServerConn) Close() error { + c, _ := sc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + +// Read returns the next request on the wire. An ErrPersistEOF is returned if +// it is gracefully determined that there are no more requests (e.g. after the +// first request on an HTTP/1.0 connection, or after a Connection:close on a +// HTTP/1.1 connection). +func (sc *ServerConn) Read() (*http.Request, error) { + var req *http.Request + var err error + + // Ensure ordered execution of Reads and Writes + id := sc.pipe.Next() + sc.pipe.StartRequest(id) + defer func() { + sc.pipe.EndRequest(id) + if req == nil { + sc.pipe.StartResponse(id) + sc.pipe.EndResponse(id) + } else { + // Remember the pipeline id of this request + sc.mu.Lock() + sc.pipereq[req] = id + sc.mu.Unlock() + } + }() + + sc.mu.Lock() + if sc.we != nil { // no point receiving if write-side broken or closed + defer sc.mu.Unlock() + return nil, sc.we + } + if sc.re != nil { + defer sc.mu.Unlock() + return nil, sc.re + } + if sc.r == nil { // connection closed by user in the meantime + defer sc.mu.Unlock() + return nil, errClosed + } + r := sc.r + lastbody := sc.lastbody + sc.lastbody = nil + sc.mu.Unlock() + + // Make sure body is fully consumed, even if user does not call body.Close + if lastbody != nil { + // body.Close is assumed to be idempotent and multiple calls to + // it should return the error that its first invocation + // returned. + err = lastbody.Close() + if err != nil { + sc.mu.Lock() + defer sc.mu.Unlock() + sc.re = err + return nil, err + } + } + + req, err = http.ReadRequest(r) + sc.mu.Lock() + defer sc.mu.Unlock() + if err != nil { + if err == io.ErrUnexpectedEOF { + // A close from the opposing client is treated as a + // graceful close, even if there was some unparse-able + // data before the close. + sc.re = ErrPersistEOF + return nil, sc.re + } else { + sc.re = err + return req, err + } + } + sc.lastbody = req.Body + sc.nread++ + if req.Close { + sc.re = ErrPersistEOF + return req, sc.re + } + return req, err +} + +// Pending returns the number of unanswered requests +// that have been received on the connection. +func (sc *ServerConn) Pending() int { + sc.mu.Lock() + defer sc.mu.Unlock() + return sc.nread - sc.nwritten +} + +// Write writes resp in response to req. To close the connection gracefully, set the +// Response.Close field to true. Write should be considered operational until +// it returns an error, regardless of any errors returned on the Read side. +func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error { + + // Retrieve the pipeline ID of this request/response pair + sc.mu.Lock() + id, ok := sc.pipereq[req] + delete(sc.pipereq, req) + if !ok { + sc.mu.Unlock() + return ErrPipeline + } + sc.mu.Unlock() + + // Ensure pipeline order + sc.pipe.StartResponse(id) + defer sc.pipe.EndResponse(id) + + sc.mu.Lock() + if sc.we != nil { + defer sc.mu.Unlock() + return sc.we + } + if sc.c == nil { // connection closed by user in the meantime + defer sc.mu.Unlock() + return ErrClosed + } + c := sc.c + if sc.nread <= sc.nwritten { + defer sc.mu.Unlock() + return errors.New("persist server pipe count") + } + if resp.Close { + // After signaling a keep-alive close, any pipelined unread + // requests will be lost. It is up to the user to drain them + // before signaling. + sc.re = ErrPersistEOF + } + sc.mu.Unlock() + + err := resp.Write(c) + sc.mu.Lock() + defer sc.mu.Unlock() + if err != nil { + sc.we = err + return err + } + sc.nwritten++ + + return nil +} + +// ClientConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. +// +// Deprecated: Use Client or Transport in package net/http instead. +type ClientConn struct { + mu sync.Mutex // read-write protects the following fields + c net.Conn + r *bufio.Reader + re, we error // read/write errors + lastbody io.ReadCloser + nread, nwritten int + pipereq map[*http.Request]uint + + pipe textproto.Pipeline + writeReq func(*http.Request, io.Writer) error +} + +// NewClientConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. +// +// Deprecated: Use the Client or Transport in package net/http instead. +func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { + if r == nil { + r = bufio.NewReader(c) + } + return &ClientConn{ + c: c, + r: r, + pipereq: make(map[*http.Request]uint), + writeReq: (*http.Request).Write, + } +} + +// NewProxyClientConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. +// +// Deprecated: Use the Client or Transport in package net/http instead. +func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { + cc := NewClientConn(c, r) + cc.writeReq = (*http.Request).WriteProxy + return cc +} + +// Hijack detaches the ClientConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be +// called before the user or Read have signaled the end of the keep-alive +// logic. The user should not call Hijack while Read or Write is in progress. +func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { + cc.mu.Lock() + defer cc.mu.Unlock() + c = cc.c + r = cc.r + cc.c = nil + cc.r = nil + return +} + +// Close calls Hijack and then also closes the underlying connection. +func (cc *ClientConn) Close() error { + c, _ := cc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + +// Write writes a request. An ErrPersistEOF error is returned if the connection +// has been closed in an HTTP keep-alive sense. If req.Close equals true, the +// keep-alive connection is logically closed after this request and the opposing +// server is informed. An ErrUnexpectedEOF indicates the remote closed the +// underlying TCP connection, which is usually considered as graceful close. +func (cc *ClientConn) Write(req *http.Request) error { + var err error + + // Ensure ordered execution of Writes + id := cc.pipe.Next() + cc.pipe.StartRequest(id) + defer func() { + cc.pipe.EndRequest(id) + if err != nil { + cc.pipe.StartResponse(id) + cc.pipe.EndResponse(id) + } else { + // Remember the pipeline id of this request + cc.mu.Lock() + cc.pipereq[req] = id + cc.mu.Unlock() + } + }() + + cc.mu.Lock() + if cc.re != nil { // no point sending if read-side closed or broken + defer cc.mu.Unlock() + return cc.re + } + if cc.we != nil { + defer cc.mu.Unlock() + return cc.we + } + if cc.c == nil { // connection closed by user in the meantime + defer cc.mu.Unlock() + return errClosed + } + c := cc.c + if req.Close { + // We write the EOF to the write-side error, because there + // still might be some pipelined reads + cc.we = ErrPersistEOF + } + cc.mu.Unlock() + + err = cc.writeReq(req, c) + cc.mu.Lock() + defer cc.mu.Unlock() + if err != nil { + cc.we = err + return err + } + cc.nwritten++ + + return nil +} + +// Pending returns the number of unanswered requests +// that have been sent on the connection. +func (cc *ClientConn) Pending() int { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.nwritten - cc.nread +} + +// Read reads the next response from the wire. A valid response might be +// returned together with an ErrPersistEOF, which means that the remote +// requested that this be the last request serviced. Read can be called +// concurrently with Write, but not with another Read. +func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) { + // Retrieve the pipeline ID of this request/response pair + cc.mu.Lock() + id, ok := cc.pipereq[req] + delete(cc.pipereq, req) + if !ok { + cc.mu.Unlock() + return nil, ErrPipeline + } + cc.mu.Unlock() + + // Ensure pipeline order + cc.pipe.StartResponse(id) + defer cc.pipe.EndResponse(id) + + cc.mu.Lock() + if cc.re != nil { + defer cc.mu.Unlock() + return nil, cc.re + } + if cc.r == nil { // connection closed by user in the meantime + defer cc.mu.Unlock() + return nil, errClosed + } + r := cc.r + lastbody := cc.lastbody + cc.lastbody = nil + cc.mu.Unlock() + + // Make sure body is fully consumed, even if user does not call body.Close + if lastbody != nil { + // body.Close is assumed to be idempotent and multiple calls to + // it should return the error that its first invocation + // returned. + err = lastbody.Close() + if err != nil { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.re = err + return nil, err + } + } + + resp, err = http.ReadResponse(r, req) + cc.mu.Lock() + defer cc.mu.Unlock() + if err != nil { + cc.re = err + return resp, err + } + cc.lastbody = resp.Body + + cc.nread++ + + if resp.Close { + cc.re = ErrPersistEOF // don't send any more requests + return resp, cc.re + } + return resp, err +} + +// Do is convenience method that writes a request and reads a response. +func (cc *ClientConn) Do(req *http.Request) (*http.Response, error) { + err := cc.Write(req) + if err != nil { + return nil, err + } + return cc.Read(req) +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httputil/reverseproxy.go b/vendor/github.com/lesismal/llib/std/net/http/httputil/reverseproxy.go new file mode 100644 index 0000000..4e36958 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httputil/reverseproxy.go @@ -0,0 +1,617 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package httputil + +import ( + "context" + "fmt" + "io" + "log" + "net" + "net/http" + "net/textproto" + "net/url" + "strings" + "sync" + "time" + + "golang.org/x/net/http/httpguts" +) + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +// +// ReverseProxy by default sets the client IP as the value of the +// X-Forwarded-For header. +// +// If an X-Forwarded-For header already exists, the client IP is +// appended to the existing values. As a special case, if the header +// exists in the Request.Header map but has a nil value (such as when +// set by the Director func), the X-Forwarded-For header is +// not modified. +// +// To prevent IP spoofing, be sure to delete any pre-existing +// X-Forwarded-For header coming from the client or +// an untrusted proxy. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + // Director must not access the provided Request + // after returning. + Director func(*http.Request) + + // The transport used to perform proxy requests. + // If nil, http.DefaultTransport is used. + Transport http.RoundTripper + + // FlushInterval specifies the flush interval + // to flush to the client while copying the + // response body. + // If zero, no periodic flushing is done. + // A negative value means to flush immediately + // after each write to the client. + // The FlushInterval is ignored when ReverseProxy + // recognizes a response as a streaming response, or + // if its ContentLength is -1; for such responses, writes + // are flushed to the client immediately. + FlushInterval time.Duration + + // ErrorLog specifies an optional logger for errors + // that occur when attempting to proxy the request. + // If nil, logging is done via the log package's standard logger. + ErrorLog *log.Logger + + // BufferPool optionally specifies a buffer pool to + // get byte slices for use by io.CopyBuffer when + // copying HTTP response bodies. + BufferPool BufferPool + + // ModifyResponse is an optional function that modifies the + // Response from the backend. It is called if the backend + // returns a response at all, with any HTTP status code. + // If the backend is unreachable, the optional ErrorHandler is + // called without any call to ModifyResponse. + // + // If ModifyResponse returns an error, ErrorHandler is called + // with its error value. If ErrorHandler is nil, its default + // implementation is used. + ModifyResponse func(*http.Response) error + + // ErrorHandler is an optional function that handles errors + // reaching the backend or errors from ModifyResponse. + // + // If nil, the default is to log the provided error and return + // a 502 Status Bad Gateway response. + ErrorHandler func(http.ResponseWriter, *http.Request, error) +} + +// A BufferPool is an interface for getting and returning temporary +// byte slices for use by io.CopyBuffer. +type BufferPool interface { + Get() []byte + Put([]byte) +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +func joinURLPath(a, b *url.URL) (path, rawpath string) { + if a.RawPath == "" && b.RawPath == "" { + return singleJoiningSlash(a.Path, b.Path), "" + } + // Same as singleJoiningSlash, but uses EscapedPath to determine + // whether a slash should be added + apath := a.EscapedPath() + bpath := b.EscapedPath() + + aslash := strings.HasSuffix(apath, "/") + bslash := strings.HasPrefix(bpath, "/") + + switch { + case aslash && bslash: + return a.Path + b.Path[1:], apath + bpath[1:] + case !aslash && !bslash: + return a.Path + "/" + b.Path, apath + "/" + bpath + } + return a.Path + b.Path, apath + bpath +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that routes +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +// NewSingleHostReverseProxy does not rewrite the Host header. +// To rewrite Host headers, use ReverseProxy directly with a custom +// Director policy. +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + targetQuery := target.RawQuery + director := func(req *http.Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } + if _, ok := req.Header["User-Agent"]; !ok { + // explicitly disable User-Agent so it's not set to default value + req.Header.Set("User-Agent", "") + } + } + return &ReverseProxy{Director: director} +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusBadGateway) +} + +func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { + if p.ErrorHandler != nil { + return p.ErrorHandler + } + return p.defaultErrorHandler +} + +// modifyResponse conditionally runs the optional ModifyResponse hook +// and reports whether the request should proceed. +func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { + if p.ModifyResponse == nil { + return true + } + if err := p.ModifyResponse(res); err != nil { + res.Body.Close() + p.getErrorHandler()(rw, req, err) + return false + } + return true +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + transport := p.Transport + if transport == nil { + transport = http.DefaultTransport + } + + ctx := req.Context() + if cn, ok := rw.(http.CloseNotifier); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + } + + outreq := req.Clone(ctx) + if req.ContentLength == 0 { + outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + } + if outreq.Header == nil { + outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate + } + + p.Director(outreq) + outreq.Close = false + + reqUpType := upgradeType(outreq.Header) + removeConnectionHeaders(outreq.Header) + + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. + for _, h := range hopHeaders { + hv := outreq.Header.Get(h) + if hv == "" { + continue + } + if h == "Te" && hv == "trailers" { + // Issue 21096: tell backend applications that + // care about trailer support that we support + // trailers. (We do, but we don't go out of + // our way to advertise that unless the + // incoming client request thought it was + // worth mentioning) + continue + } + outreq.Header.Del(h) + } + + // After stripping all the hop-by-hop connection headers above, add back any + // necessary for protocol upgrades, such as for websockets. + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + prior, ok := outreq.Header["X-Forwarded-For"] + omit := ok && prior == nil // Issue 38079: nil now means don't populate the header + if len(prior) > 0 { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + if !omit { + outreq.Header.Set("X-Forwarded-For", clientIP) + } + } + + res, err := transport.RoundTrip(outreq) + if err != nil { + p.getErrorHandler()(rw, outreq, err) + return + } + + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode == http.StatusSwitchingProtocols { + if !p.modifyResponse(rw, res, outreq) { + return + } + p.handleUpgradeResponse(rw, outreq, res) + return + } + + removeConnectionHeaders(res.Header) + + for _, h := range hopHeaders { + res.Header.Del(h) + } + + if !p.modifyResponse(rw, res, outreq) { + return + } + + copyHeader(rw.Header(), res.Header) + + // The "Trailer" header isn't included in the Transport's response, + // at least for *http.Transport. Build it up from Trailer. + announcedTrailers := len(res.Trailer) + if announcedTrailers > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for k := range res.Trailer { + trailerKeys = append(trailerKeys, k) + } + rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + + rw.WriteHeader(res.StatusCode) + + err = p.copyResponse(rw, res.Body, p.flushInterval(res)) + if err != nil { + defer res.Body.Close() + // Since we're streaming the response, if we run into an error all we can do + // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler + // on read error while copying body. + if !shouldPanicOnCopyError(req) { + p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) + return + } + panic(http.ErrAbortHandler) + } + res.Body.Close() // close now, instead of defer, to populate res.Trailer + + if len(res.Trailer) > 0 { + // Force chunking if we saw a response trailer. + // This prevents net/http from calculating the length for short + // bodies and adding a Content-Length. + if fl, ok := rw.(http.Flusher); ok { + fl.Flush() + } + } + + if len(res.Trailer) == announcedTrailers { + copyHeader(rw.Header(), res.Trailer) + return + } + + for k, vv := range res.Trailer { + k = http.TrailerPrefix + k + for _, v := range vv { + rw.Header().Add(k, v) + } + } +} + +var inOurTests bool // whether we're in our own tests + +// shouldPanicOnCopyError reports whether the reverse proxy should +// panic with http.ErrAbortHandler. This is the right thing to do by +// default, but Go 1.10 and earlier did not, so existing unit tests +// weren't expecting panics. Only panic in our own tests, or when +// running under the HTTP server. +func shouldPanicOnCopyError(req *http.Request) bool { + if inOurTests { + // Our tests know to handle this panic. + return true + } + if req.Context().Value(http.ServerContextKey) != nil { + // We seem to be running under an HTTP server, so + // it'll recover the panic. + return true + } + // Otherwise act like Go 1.10 and earlier to not break + // existing tests. + return false +} + +// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. +// See RFC 7230, section 6.1 +func removeConnectionHeaders(h http.Header) { + for _, f := range h["Connection"] { + for _, sf := range strings.Split(f, ",") { + if sf = textproto.TrimString(sf); sf != "" { + h.Del(sf) + } + } + } +} + +// flushInterval returns the p.FlushInterval value, conditionally +// overriding its value for a specific request/response. +func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { + resCT := res.Header.Get("Content-Type") + + // For Server-Sent Events responses, flush immediately. + // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream + if resCT == "text/event-stream" { + return -1 // negative means immediately + } + + // We might have the case of streaming for which Content-Length might be unset. + if res.ContentLength == -1 { + return -1 + } + + return p.FlushInterval +} + +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { + if flushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: flushInterval, + } + defer mlw.stop() + + // set up initial timer so headers get flushed even if body writes are delayed + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + + dst = mlw + } + } + + var buf []byte + if p.BufferPool != nil { + buf = p.BufferPool.Get() + defer p.BufferPool.Put(buf) + } + _, err := p.copyBuffer(dst, src, buf) + return err +} + +// copyBuffer returns any write errors or non-EOF read errors, and the amount +// of bytes written. +func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF && rerr != context.Canceled { + p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + if rerr == io.EOF { + rerr = nil + } + return written, rerr + } + } +} + +func (p *ReverseProxy) logf(format string, args ...interface{}) { + if p.ErrorLog != nil { + p.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +type writeFlusher interface { + io.Writer + http.Flusher +} + +type maxLatencyWriter struct { + dst writeFlusher + latency time.Duration // non-zero; negative means to flush immediately + + mu sync.Mutex // protects t, flushPending, and dst.Flush + t *time.Timer + flushPending bool +} + +func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + n, err = m.dst.Write(p) + if m.latency < 0 { + m.dst.Flush() + return + } + if m.flushPending { + return + } + if m.t == nil { + m.t = time.AfterFunc(m.latency, m.delayedFlush) + } else { + m.t.Reset(m.latency) + } + m.flushPending = true + return +} + +func (m *maxLatencyWriter) delayedFlush() { + m.mu.Lock() + defer m.mu.Unlock() + if !m.flushPending { // if stop was called but AfterFunc already started this goroutine + return + } + m.dst.Flush() + m.flushPending = false +} + +func (m *maxLatencyWriter) stop() { + m.mu.Lock() + defer m.mu.Unlock() + m.flushPending = false + if m.t != nil { + m.t.Stop() + } +} + +func upgradeType(h http.Header) string { + if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { + return "" + } + return strings.ToLower(h.Get("Upgrade")) +} + +func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if reqUpType != resUpType { + p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + + hj, ok := rw.(http.Hijacker) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + + backConnCloseCh := make(chan bool) + go func() { + // Ensure that the cancelation of a request closes the backend. + // See issue https://golang.org/issue/35559. + select { + case <-req.Context().Done(): + case <-backConnCloseCh: + } + backConn.Close() + }() + + defer close(backConnCloseCh) + + conn, brw, err := hj.Hijack() + if err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) + return + } + defer conn.Close() + + copyHeader(rw.Header(), res.Header) + + res.Header = rw.Header() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) + return + } + if err := brw.Flush(); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc + return +} + +// switchProtocolCopier exists so goroutines proxying data back and +// forth have nice names in stacks. +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + _, err := io.Copy(c.user, c.backend) + errc <- err +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + _, err := io.Copy(c.backend, c.user) + errc <- err +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/httputil/reverseproxy_test.go b/vendor/github.com/lesismal/llib/std/net/http/httputil/reverseproxy_test.go new file mode 100644 index 0000000..5c78376 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/httputil/reverseproxy_test.go @@ -0,0 +1,1420 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Reverse proxy tests. + +package httputil + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "github.com/lesismal/llib/std/net/http/httptest" + "io" + "log" + "net/http" + "net/url" + "os" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +const fakeHopHeader = "X-Fake-Hop-Header-For-Test" + +func init() { + inOurTests = true + hopHeaders = append(hopHeaders, fakeHopHeader) +} + +func TestReverseProxy(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" && r.FormValue("mode") == "hangup" { + c, _, _ := w.(http.Hijacker).Hijack() + c.Close() + return + } + if len(r.TransferEncoding) > 0 { + t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) + } + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if c := r.Header.Get("Connection"); c != "" { + t.Errorf("handler got Connection header value %q", c) + } + if c := r.Header.Get("Te"); c != "trailers" { + t.Errorf("handler got Te header value %q; want 'trailers'", c) + } + if c := r.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got Upgrade header value %q", c) + } + if c := r.Header.Get("Proxy-Connection"); c != "" { + t.Errorf("handler got Proxy-Connection header value %q", c) + } + if g, e := r.Host, "some-name"; g != e { + t.Errorf("backend got Host header %q, want %q", g, e) + } + w.Header().Set("Trailers", "not a special header field name") + w.Header().Set("Trailer", "X-Trailer") + w.Header().Set("X-Foo", "bar") + w.Header().Set("Upgrade", "foo") + w.Header().Set(fakeHopHeader, "foo") + w.Header().Add("X-Multi-Value", "foo") + w.Header().Add("X-Multi-Value", "bar") + http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + w.Header().Set("X-Trailer", "trailer_value") + w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close") + getReq.Header.Set("Te", "trailers") + getReq.Header.Set("Proxy-Connection", "should be deleted") + getReq.Header.Set("Upgrade", "foo") + getReq.Close = true + res, err := frontendClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := res.Header.Get("X-Foo"), "bar"; g != e { + t.Errorf("got X-Foo %q; expected %q", g, e) + } + if c := res.Header.Get(fakeHopHeader); c != "" { + t.Errorf("got %s header value %q", fakeHopHeader, c) + } + if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { + t.Errorf("header Trailers = %q; want %q", g, e) + } + if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { + t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) + } + if g, e := len(res.Header["Set-Cookie"]), 1; g != e { + t.Fatalf("got %d SetCookies, want %d", g, e) + } + if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { + t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) + } + if cookie := res.Cookies()[0]; cookie.Name != "flavor" { + t.Errorf("unexpected cookie %q", cookie.Name) + } + bodyBytes, _ := io.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { + t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) + } + if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { + t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) + } + + // Test that a backend failing to be reached or one which doesn't return + // a response results in a StatusBadGateway. + getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) + getReq.Close = true + res, err = frontendClient.Do(getReq) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusBadGateway { + t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) + } + +} + +// Issue 16875: remove any proxied headers mentioned in the "Connection" +// header value. +func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { + const fakeConnectionToken = "X-Fake-Connection-Token" + const backendResponse = "I am the backend" + + // someConnHeader is some arbitrary header to be declared as a hop-by-hop header + // in the Request's Connection header. + const someConnHeader = "X-Some-Conn-Header" + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if c := r.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := r.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } + if c := r.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) + w.Header().Add("Connection", someConnHeader) + w.Header().Set(someConnHeader, "should be deleted") + w.Header().Set(fakeConnectionToken, "should be deleted") + io.WriteString(w, backendResponse) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxyHandler.ServeHTTP(w, r) + if c := r.Header.Get(someConnHeader); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") + } + if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") + } + c := r.Header["Connection"] + var cf []string + for _, f := range c { + for _, sf := range strings.Split(f, ",") { + if sf = strings.TrimSpace(sf); sf != "" { + cf = append(cf, sf) + } + } + } + sort.Strings(cf) + expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} + sort.Strings(expectedValues) + if !reflect.DeepEqual(cf, expectedValues) { + t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) + } + })) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) + getReq.Header.Add("Connection", someConnHeader) + getReq.Header.Set(someConnHeader, "should be deleted") + getReq.Header.Set(fakeConnectionToken, "should be deleted") + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("reading body: %v", err) + } + if got, want := string(bodyBytes), backendResponse; got != want { + t.Errorf("got body %q; want %q", got, want) + } + if c := res.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := res.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + if c := res.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } +} + +func TestXForwardedFor(t *testing.T) { + const prevForwardedFor = "client ip" + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { + t.Errorf("X-Forwarded-For didn't contain prior data") + } + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close") + getReq.Header.Set("X-Forwarded-For", prevForwardedFor) + getReq.Close = true + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := io.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +// Issue 38079: don't append to X-Forwarded-For if it's present but nil +func TestXForwardedFor_Omit(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if v := r.Header.Get("X-Forwarded-For"); v != "" { + t.Errorf("got X-Forwarded-For header: %q", v) + } + w.Write([]byte("hi")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + r.Header["X-Forwarded-For"] = nil + oldDirector(r) + } + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Close = true + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() +} + +var proxyQueryTests = []struct { + baseSuffix string // suffix to add to backend URL + reqSuffix string // suffix to add to frontend's request URL + want string // what backend should see for final request URL (without ?) +}{ + {"", "", ""}, + {"?sta=tic", "?us=er", "sta=tic&us=er"}, + {"", "?us=er", "us=er"}, + {"?sta=tic", "", "sta=tic"}, +} + +func TestReverseProxyQuery(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Got-Query", r.URL.RawQuery) + w.Write([]byte("hi")) + })) + defer backend.Close() + + for i, tt := range proxyQueryTests { + backendURL, err := url.Parse(backend.URL + tt.baseSuffix) + if err != nil { + t.Fatal(err) + } + frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) + req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("%d. Get: %v", i, err) + } + if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { + t.Errorf("%d. got query %q; expected %q", i, g, e) + } + res.Body.Close() + frontend.Close() + } +} + +func TestReverseProxyFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } +} + +func TestReverseProxyFlushIntervalHeaders(t *testing.T) { + const expected = "hi" + stopCh := make(chan struct{}) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("MyHeader", expected) + w.WriteHeader(200) + w.(http.Flusher).Flush() + <-stopCh + })) + defer backend.Close() + defer close(stopCh) + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + + ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) + defer cancel() + req = req.WithContext(ctx) + + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + + if res.Header.Get("MyHeader") != expected { + t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) + } +} + +func TestReverseProxyCancellation(t *testing.T) { + const backendResponse = "I am the backend" + + reqInFlight := make(chan struct{}) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(reqInFlight) // cause the client to cancel its request + + select { + case <-time.After(10 * time.Second): + // Note: this should only happen in broken implementations, and the + // closenotify case should be instantaneous. + t.Error("Handler never saw CloseNotify") + return + case <-w.(http.CloseNotifier).CloseNotify(): + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(backendResponse)) + })) + + defer backend.Close() + + backend.Config.ErrorLog = log.New(io.Discard, "", 0) + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + + // Discards errors of the form: + // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + go func() { + <-reqInFlight + frontendClient.Transport.(*http.Transport).CancelRequest(getReq) + }() + res, err := frontendClient.Do(getReq) + if res != nil { + t.Errorf("got response %v; want nil", res.Status) + } + if err == nil { + // This should be an error like: + // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079: + // use of closed network connection + t.Error("Server.Client().Do() returned nil error; want non-nil error") + } +} + +func req(t *testing.T, v string) *http.Request { + req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) + if err != nil { + t.Fatal(err) + } + return req +} + +// Issue 12344 +func TestNilBody(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi")) + })) + defer backend.Close() + + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backURL, _ := url.Parse(backend.URL) + rp := NewSingleHostReverseProxy(backURL) + r := req(t, "GET / HTTP/1.0\r\n\r\n") + r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working + rp.ServeHTTP(w, r) + })) + defer frontend.Close() + + res, err := http.Get(frontend.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != "hi" { + t.Errorf("Got %q; want %q", slurp, "hi") + } +} + +// Issue 15524 +func TestUserAgentHeader(t *testing.T) { + const explicitUA = "explicit UA" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/noua" { + if c := r.Header.Get("User-Agent"); c != "" { + t.Errorf("handler got non-empty User-Agent header %q", c) + } + return + } + if c := r.Header.Get("User-Agent"); c != explicitUA { + t.Errorf("handler got unexpected User-Agent header %q", c) + } + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Set("User-Agent", explicitUA) + getReq.Close = true + res, err := frontendClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() + + getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) + getReq.Header.Set("User-Agent", "") + getReq.Close = true + res, err = frontendClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() +} + +type bufferPool struct { + get func() []byte + put func([]byte) +} + +func (bp bufferPool) Get() []byte { return bp.get() } +func (bp bufferPool) Put(v []byte) { bp.put(v) } + +func TestReverseProxyGetPutBuffer(t *testing.T) { + const msg = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, msg) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + var ( + mu sync.Mutex + log []string + ) + addLog := func(event string) { + mu.Lock() + defer mu.Unlock() + log = append(log, event) + } + rp := NewSingleHostReverseProxy(backendURL) + const size = 1234 + rp.BufferPool = bufferPool{ + get: func() []byte { + addLog("getBuf") + return make([]byte, size) + }, + put: func(p []byte) { + addLog("putBuf-" + strconv.Itoa(len(p))) + }, + } + frontend := httptest.NewServer(rp) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + slurp, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatalf("reading body: %v", err) + } + if string(slurp) != msg { + t.Errorf("msg = %q; want %q", slurp, msg) + } + wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} + mu.Lock() + defer mu.Unlock() + if !reflect.DeepEqual(log, wantLog) { + t.Errorf("Log events = %q; want %q", log, wantLog) + } +} + +func TestReverseProxy_Post(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 200 + var requestBody = bytes.Repeat([]byte("a"), 1<<20) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + slurp, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Backend body read = %v", err) + } + if len(slurp) != len(requestBody) { + t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) + } + if !bytes.Equal(slurp, requestBody) { + t.Error("Backend read wrong request body.") // 1MB; omitting details + } + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) + res, err := frontend.Client().Do(postReq) + if err != nil { + t.Fatalf("Do: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := io.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +type RoundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +// Issue 16036: send a Request with a nil Body when possible +func TestReverseProxy_NilBody(t *testing.T) { + backendURL, _ := url.Parse("http://fake.tld/") + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Body != nil { + t.Error("Body != nil; want a nil Body") + } + return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") + }) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 502 { + t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) + } +} + +// Issue 33142: always allocate the request headers +func TestReverseProxy_AllocatedHeader(t *testing.T) { + proxyHandler := new(ReverseProxy) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Director = func(*http.Request) {} // noop + proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Header == nil { + t.Error("Header == nil; want a non-nil Header") + } + return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") + }) + + proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ + Method: "GET", + URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, + Proto: "HTTP/1.0", + ProtoMajor: 1, + }) +} + +// Issue 14237. Test ModifyResponse and that an error from it +// causes the proxy to return StatusBadGateway, or StatusOK otherwise. +func TestReverseProxyModifyResponse(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) + })) + defer backendServer.Close() + + rpURL, _ := url.Parse(backendServer.URL) + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(resp *http.Response) error { + if resp.Header.Get("X-Hit-Mod") != "true" { + return fmt.Errorf("tried to by-pass proxy") + } + return nil + } + + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + tests := []struct { + url string + wantCode int + }{ + {frontendProxy.URL + "/mod", http.StatusOK}, + {frontendProxy.URL + "/schedule", http.StatusBadGateway}, + } + + for i, tt := range tests { + resp, err := http.Get(tt.url) + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if g, e := resp.StatusCode, tt.wantCode; g != e { + t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) + } + resp.Body.Close() + } +} + +type failingRoundTripper struct{} + +func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return nil, errors.New("some error") +} + +type staticResponseRoundTripper struct{ res *http.Response } + +func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return rt.res, nil +} + +func TestReverseProxyErrorHandler(t *testing.T) { + tests := []struct { + name string + wantCode int + errorHandler func(http.ResponseWriter, *http.Request, error) + transport http.RoundTripper // defaults to failingRoundTripper + modifyResponse func(*http.Response) error + }{ + { + name: "default", + wantCode: http.StatusBadGateway, + }, + { + name: "errorhandler", + wantCode: http.StatusTeapot, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + }, + { + name: "modifyresponse_noerr", + transport: staticResponseRoundTripper{ + &http.Response{StatusCode: 345, Body: http.NoBody}, + }, + modifyResponse: func(res *http.Response) error { + res.StatusCode++ + return nil + }, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + wantCode: 346, + }, + { + name: "modifyresponse_err", + transport: staticResponseRoundTripper{ + &http.Response{StatusCode: 345, Body: http.NoBody}, + }, + modifyResponse: func(res *http.Response) error { + res.StatusCode++ + return errors.New("some error to trigger errorHandler") + }, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + wantCode: http.StatusTeapot, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + target := &url.URL{ + Scheme: "http", + Host: "dummy.tld", + Path: "/", + } + rproxy := NewSingleHostReverseProxy(target) + rproxy.Transport = tt.transport + rproxy.ModifyResponse = tt.modifyResponse + if rproxy.Transport == nil { + rproxy.Transport = failingRoundTripper{} + } + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + if tt.errorHandler != nil { + rproxy.ErrorHandler = tt.errorHandler + } + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + resp, err := http.Get(frontendProxy.URL + "/test") + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if g, e := resp.StatusCode, tt.wantCode; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + resp.Body.Close() + }) + } +} + +// Issue 16659: log errors from short read +func TestReverseProxy_CopyBuffer(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.UnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + t.Fatal(err) + } + + var proxyLog bytes.Buffer + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) + donec := make(chan bool, 1) + frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { donec <- true }() + rproxy.ServeHTTP(w, r) + })) + defer frontendProxy.Close() + + if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { + t.Fatalf("want non-nil error") + } + // The race detector complains about the proxyLog usage in logf in copyBuffer + // and our usage below with proxyLog.Bytes() so we're explicitly using a + // channel to ensure that the ReverseProxy's ServeHTTP is done before we + // continue after Get. + <-donec + + expected := []string{ + "EOF", + "read", + } + for _, phrase := range expected { + if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { + t.Errorf("expected log to contain phrase %q", phrase) + } + } +} + +type staticTransport struct { + res *http.Response +} + +func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { + return t.res, nil +} + +func BenchmarkServeHTTP(b *testing.B) { + res := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("")), + } + proxy := &ReverseProxy{ + Director: func(*http.Request) {}, + Transport: &staticTransport{res}, + } + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + proxy.ServeHTTP(w, r) + } +} + +func TestServeHTTPDeepCopy(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello Gopher!")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + type result struct { + before, after string + } + + resultChan := make(chan result, 1) + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + before := r.URL.String() + proxyHandler.ServeHTTP(w, r) + after := r.URL.String() + resultChan <- result{before: before, after: after} + })) + defer frontend.Close() + + want := result{before: "/", after: "/"} + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatalf("Do: %v", err) + } + res.Body.Close() + + got := <-resultChan + if got != want { + t.Errorf("got = %+v; want = %+v", got, want) + } +} + +// Issue 18327: verify we always do a deep copy of the Request.Header map +// before any mutations. +func TestClonesRequestHeaders(t *testing.T) { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + req.RemoteAddr = "1.2.3.4:56789" + rp := &ReverseProxy{ + Director: func(req *http.Request) { + req.Header.Set("From-Director", "1") + }, + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if v := req.Header.Get("From-Director"); v != "1" { + t.Errorf("From-Directory value = %q; want 1", v) + } + return nil, io.EOF + }), + } + rp.ServeHTTP(httptest.NewRecorder(), req) + + if req.Header.Get("From-Director") == "1" { + t.Error("Director header mutation modified caller's request") + } + if req.Header.Get("X-Forwarded-For") != "" { + t.Error("X-Forward-For header mutation modified caller's request") + } + +} + +type roundTripperFunc func(req *http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func TestModifyResponseClosesBody(t *testing.T) { + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + req.RemoteAddr = "1.2.3.4:56789" + closeCheck := new(checkCloser) + logBuf := new(bytes.Buffer) + outErr := errors.New("ModifyResponse error") + rp := &ReverseProxy{ + Director: func(req *http.Request) {}, + Transport: &staticTransport{&http.Response{ + StatusCode: 200, + Body: closeCheck, + }}, + ErrorLog: log.New(logBuf, "", 0), + ModifyResponse: func(*http.Response) error { + return outErr + }, + } + rec := httptest.NewRecorder() + rp.ServeHTTP(rec, req) + res := rec.Result() + if g, e := res.StatusCode, http.StatusBadGateway; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if !closeCheck.closed { + t.Errorf("body should have been closed") + } + if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { + t.Errorf("ErrorLog %q does not contain %q", g, e) + } +} + +type checkCloser struct { + closed bool +} + +func (cc *checkCloser) Close() error { + cc.closed = true + return nil +} + +func (cc *checkCloser) Read(b []byte) (int, error) { + return len(b), nil +} + +// Issue 23643: panic on body copy error +func TestReverseProxy_PanicBodyError(t *testing.T) { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.ErrUnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + t.Fatal(err) + } + + rproxy := NewSingleHostReverseProxy(rpURL) + + // Ensure that the handler panics when the body read encounters an + // io.ErrUnexpectedEOF + defer func() { + err := recover() + if err == nil { + t.Fatal("handler should have panicked") + } + if err != http.ErrAbortHandler { + t.Fatal("expected ErrAbortHandler, got", err) + } + }() + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + rproxy.ServeHTTP(httptest.NewRecorder(), req) +} + +func TestSelectFlushInterval(t *testing.T) { + tests := []struct { + name string + p *ReverseProxy + res *http.Response + want time.Duration + }{ + { + name: "default", + res: &http.Response{}, + p: &ReverseProxy{FlushInterval: 123}, + want: 123, + }, + { + name: "server-sent events overrides non-zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream"}, + }, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "server-sent events overrides zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream"}, + }, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + { + name: "Content-Length: -1, overrides non-zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "Content-Length: -1, overrides zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.p.flushInterval(tt.res) + if got != tt.want { + t.Errorf("flushLatency = %v; want %v", got, tt.want) + } + }) + } +} + +func TestReverseProxyWebSocket(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if upgradeType(r.Header) != "websocket" { + t.Error("unexpected backend request") + http.Error(w, "unexpected request", 400) + return + } + c, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer c.Close() + io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") + bs := bufio.NewScanner(c) + if !bs.Scan() { + t.Errorf("backend failed to read line from client: %v", bs.Err()) + return + } + fmt.Fprintf(c, "backend got %q\n", bs.Text()) + })) + defer backendServer.Close() + + backURL, _ := url.Parse(backendServer.URL) + rproxy := NewSingleHostReverseProxy(backURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(res *http.Response) error { + res.Header.Add("X-Modified", "true") + return nil + } + + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Header", "X-Value") + rproxy.ServeHTTP(rw, req) + if got, want := rw.Header().Get("X-Modified"), "true"; got != want { + t.Errorf("response writer X-Modified header = %q; want %q", got, want) + } + }) + + frontendProxy := httptest.NewServer(handler) + defer frontendProxy.Close() + + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + c := frontendProxy.Client() + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 101 { + t.Fatalf("status = %v; want 101", res.Status) + } + + got := res.Header.Get("X-Header") + want := "X-Value" + if got != want { + t.Errorf("Header(XHeader) = %q; want %q", got, want) + } + + if upgradeType(res.Header) != "websocket" { + t.Fatalf("not websocket upgrade; got %#v", res.Header) + } + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) + } + defer rwc.Close() + + if got, want := res.Header.Get("X-Modified"), "true"; got != want { + t.Errorf("response X-Modified header = %q; want %q", got, want) + } + + io.WriteString(rwc, "Hello\n") + bs := bufio.NewScanner(rwc) + if !bs.Scan() { + t.Fatalf("Scan: %v", bs.Err()) + } + got = bs.Text() + want = `backend got "Hello"` + if got != want { + t.Errorf("got %#q, want %#q", got, want) + } +} + +func TestReverseProxyWebSocketCancelation(t *testing.T) { + n := 5 + triggerCancelCh := make(chan bool, n) + nthResponse := func(i int) string { + return fmt.Sprintf("backend response #%d\n", i) + } + terminalMsg := "final message" + + cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if g, ws := upgradeType(r.Header), "websocket"; g != ws { + t.Errorf("Unexpected upgrade type %q, want %q", g, ws) + http.Error(w, "Unexpected request", 400) + return + } + conn, bufrw, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer conn.Close() + + upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" + if _, err := io.WriteString(conn, upgradeMsg); err != nil { + t.Error(err) + return + } + if _, _, err := bufrw.ReadLine(); err != nil { + t.Errorf("Failed to read line from client: %v", err) + return + } + + for i := 0; i < n; i++ { + if _, err := bufrw.WriteString(nthResponse(i)); err != nil { + select { + case <-triggerCancelCh: + default: + t.Errorf("Writing response #%d failed: %v", i, err) + } + return + } + bufrw.Flush() + time.Sleep(time.Second) + } + if _, err := bufrw.WriteString(terminalMsg); err != nil { + select { + case <-triggerCancelCh: + default: + t.Errorf("Failed to write terminal message: %v", err) + } + } + bufrw.Flush() + })) + defer cst.Close() + + backendURL, _ := url.Parse(cst.URL) + rproxy := NewSingleHostReverseProxy(backendURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(res *http.Response) error { + res.Header.Add("X-Modified", "true") + return nil + } + + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Header", "X-Value") + ctx, cancel := context.WithCancel(req.Context()) + go func() { + <-triggerCancelCh + cancel() + }() + rproxy.ServeHTTP(rw, req.WithContext(ctx)) + }) + + frontendProxy := httptest.NewServer(handler) + defer frontendProxy.Close() + + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + res, err := frontendProxy.Client().Do(req) + if err != nil { + t.Fatalf("Dialing to frontend proxy: %v", err) + } + defer res.Body.Close() + if g, w := res.StatusCode, 101; g != w { + t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) + } + + if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { + t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) + } + + if g, w := upgradeType(res.Header), "websocket"; g != w { + t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) + } + + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) + } + + if got, want := res.Header.Get("X-Modified"), "true"; got != want { + t.Errorf("response X-Modified header = %q; want %q", got, want) + } + + if _, err := io.WriteString(rwc, "Hello\n"); err != nil { + t.Fatalf("Failed to write first message: %v", err) + } + + // Read loop. + + br := bufio.NewReader(rwc) + for { + line, err := br.ReadString('\n') + switch { + case line == terminalMsg: // this case before "err == io.EOF" + t.Fatalf("The websocket request was not canceled, unfortunately!") + + case err == io.EOF: + return + + case err != nil: + t.Fatalf("Unexpected error: %v", err) + + case line == nthResponse(0): // We've gotten the first response back + // Let's trigger a cancel. + close(triggerCancelCh) + } + } +} + +func TestUnannouncedTrailer(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + res, err := frontendClient.Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + + io.ReadAll(res.Body) + + if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { + t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) + } + +} + +func TestSingleJoinSlash(t *testing.T) { + tests := []struct { + slasha string + slashb string + expected string + }{ + {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, + {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, + {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, + {"https://www.google.com", "", "https://www.google.com/"}, + {"", "favicon.ico", "/favicon.ico"}, + } + for _, tt := range tests { + if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { + t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", + tt.slasha, + tt.slashb, + tt.expected, + got) + } + } +} + +func TestJoinURLPath(t *testing.T) { + tests := []struct { + a *url.URL + b *url.URL + wantPath string + wantRaw string + }{ + {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, + {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, + {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, + {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, + {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, + {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, + } + + for _, tt := range tests { + p, rp := joinURLPath(tt.a, tt.b) + if p != tt.wantPath || rp != tt.wantRaw { + t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", + tt.a.Path, tt.a.RawPath, + tt.b.Path, tt.b.RawPath, + tt.wantPath, tt.wantRaw, + p, rp) + } + } +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/internal/chunked.go b/vendor/github.com/lesismal/llib/std/net/http/internal/chunked.go new file mode 100644 index 0000000..f06e572 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/internal/chunked.go @@ -0,0 +1,255 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The wire protocol for HTTP's "chunked" Transfer-Encoding. + +// Package internal contains HTTP internals shared by net/http and +// net/http/httputil. +package internal + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" +) + +const maxLineLength = 4096 // assumed <= bufio.defaultBufSize + +var ErrLineTooLong = errors.New("header line too long") + +// NewChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r io.Reader) io.Reader { + br, ok := r.(*bufio.Reader) + if !ok { + br = bufio.NewReader(r) + } + return &chunkedReader{r: br} +} + +type chunkedReader struct { + r *bufio.Reader + n uint64 // unread bytes in chunk + err error + buf [2]byte + checkEnd bool // whether need to check for \r\n chunk footer +} + +func (cr *chunkedReader) beginChunk() { + // chunk-size CRLF + var line []byte + line, cr.err = readChunkLine(cr.r) + if cr.err != nil { + return + } + cr.n, cr.err = parseHexUint(line) + if cr.err != nil { + return + } + if cr.n == 0 { + cr.err = io.EOF + } +} + +func (cr *chunkedReader) chunkHeaderAvailable() bool { + n := cr.r.Buffered() + if n > 0 { + peek, _ := cr.r.Peek(n) + return bytes.IndexByte(peek, '\n') >= 0 + } + return false +} + +func (cr *chunkedReader) Read(b []uint8) (n int, err error) { + for cr.err == nil { + if cr.checkEnd { + if n > 0 && cr.r.Buffered() < 2 { + // We have some data. Return early (per the io.Reader + // contract) instead of potentially blocking while + // reading more. + break + } + if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { + if string(cr.buf[:]) != "\r\n" { + cr.err = errors.New("malformed chunked encoding") + break + } + } + cr.checkEnd = false + } + if cr.n == 0 { + if n > 0 && !cr.chunkHeaderAvailable() { + // We've read enough. Don't potentially block + // reading a new chunk header. + break + } + cr.beginChunk() + continue + } + if len(b) == 0 { + break + } + rbuf := b + if uint64(len(rbuf)) > cr.n { + rbuf = rbuf[:cr.n] + } + var n0 int + n0, cr.err = cr.r.Read(rbuf) + n += n0 + b = b[n0:] + cr.n -= uint64(n0) + // If we're at the end of a chunk, read the next two + // bytes to verify they are "\r\n". + if cr.n == 0 && cr.err == nil { + cr.checkEnd = true + } + } + return n, cr.err +} + +// Read a line of bytes (up to \n) from b. +// Give up if the line exceeds maxLineLength. +// The returned bytes are owned by the bufio.Reader +// so they are only valid until the next bufio read. +func readChunkLine(b *bufio.Reader) ([]byte, error) { + p, err := b.ReadSlice('\n') + if err != nil { + // We always know when EOF is coming. + // If the caller asked for a line, there should be a line. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } else if err == bufio.ErrBufferFull { + err = ErrLineTooLong + } + return nil, err + } + if len(p) >= maxLineLength { + return nil, ErrLineTooLong + } + p = trimTrailingWhitespace(p) + p, err = removeChunkExtension(p) + if err != nil { + return nil, err + } + return p, nil +} + +func trimTrailingWhitespace(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] + } + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} + +// removeChunkExtension removes any chunk-extension from p. +// For example, +// "0" => "0" +// "0;token" => "0" +// "0;token=val" => "0" +// `0;token="quoted string"` => "0" +func removeChunkExtension(p []byte) ([]byte, error) { + semi := bytes.IndexByte(p, ';') + if semi == -1 { + return p, nil + } + // TODO: care about exact syntax of chunk extensions? We're + // ignoring and stripping them anyway. For now just never + // return an error. + return p[:semi], nil +} + +// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream but does +// not send the final CRLF that appears after trailers; trailers and the last +// CRLF must be written separately. +// +// NewChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using newChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func NewChunkedWriter(w io.Writer) io.WriteCloser { + return &chunkedWriter{w} +} + +// Writing to chunkedWriter translates to writing in HTTP chunked Transfer +// Encoding wire format to the underlying Wire chunkedWriter. +type chunkedWriter struct { + Wire io.Writer +} + +// Write the contents of data as one chunk to Wire. +// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has +// a bug since it does not check for success of io.WriteString +func (cw *chunkedWriter) Write(data []byte) (n int, err error) { + + // Don't send 0-length data. It looks like EOF for chunked encoding. + if len(data) == 0 { + return 0, nil + } + + if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil { + return 0, err + } + if n, err = cw.Wire.Write(data); err != nil { + return + } + if n != len(data) { + err = io.ErrShortWrite + return + } + if _, err = io.WriteString(cw.Wire, "\r\n"); err != nil { + return + } + if bw, ok := cw.Wire.(*FlushAfterChunkWriter); ok { + err = bw.Flush() + } + return +} + +func (cw *chunkedWriter) Close() error { + _, err := io.WriteString(cw.Wire, "0\r\n") + return err +} + +// FlushAfterChunkWriter signals from the caller of NewChunkedWriter +// that each chunk should be followed by a flush. It is used by the +// http.Transport code to keep the buffering behavior for headers and +// trailers, but flush out chunks aggressively in the middle for +// request bodies which may be generated slowly. See Issue 6574. +type FlushAfterChunkWriter struct { + *bufio.Writer +} + +func parseHexUint(v []byte) (n uint64, err error) { + for i, b := range v { + switch { + case '0' <= b && b <= '9': + b = b - '0' + case 'a' <= b && b <= 'f': + b = b - 'a' + 10 + case 'A' <= b && b <= 'F': + b = b - 'A' + 10 + default: + return 0, errors.New("invalid byte in chunk length") + } + if i == 16 { + return 0, errors.New("http chunk length too large") + } + n <<= 4 + n |= uint64(b) + } + return +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/internal/chunked_test.go b/vendor/github.com/lesismal/llib/std/net/http/internal/chunked_test.go new file mode 100644 index 0000000..08152ed --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/internal/chunked_test.go @@ -0,0 +1,213 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strings" + "testing" +) + +func TestChunk(t *testing.T) { + var b bytes.Buffer + + w := NewChunkedWriter(&b) + const chunk1 = "hello, " + const chunk2 = "world! 0123456789abcdef" + w.Write([]byte(chunk1)) + w.Write([]byte(chunk2)) + w.Close() + + if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e { + t.Fatalf("chunk writer wrote %q; want %q", g, e) + } + + r := NewChunkedReader(&b) + data, err := io.ReadAll(r) + if err != nil { + t.Logf(`data: "%s"`, data) + t.Fatalf("ReadAll from reader: %v", err) + } + if g, e := string(data), chunk1+chunk2; g != e { + t.Errorf("chunk reader read %q; want %q", g, e) + } +} + +func TestChunkReadMultiple(t *testing.T) { + // Bunch of small chunks, all read together. + { + var b bytes.Buffer + w := NewChunkedWriter(&b) + w.Write([]byte("foo")) + w.Write([]byte("bar")) + w.Close() + + r := NewChunkedReader(&b) + buf := make([]byte, 10) + n, err := r.Read(buf) + if n != 6 || err != io.EOF { + t.Errorf("Read = %d, %v; want 6, EOF", n, err) + } + buf = buf[:n] + if string(buf) != "foobar" { + t.Errorf("Read = %q; want %q", buf, "foobar") + } + } + + // One big chunk followed by a little chunk, but the small bufio.Reader size + // should prevent the second chunk header from being read. + { + var b bytes.Buffer + w := NewChunkedWriter(&b) + // fillBufChunk is 11 bytes + 3 bytes header + 2 bytes footer = 16 bytes, + // the same as the bufio ReaderSize below (the minimum), so even + // though we're going to try to Read with a buffer larger enough to also + // receive "foo", the second chunk header won't be read yet. + const fillBufChunk = "0123456789a" + const shortChunk = "foo" + w.Write([]byte(fillBufChunk)) + w.Write([]byte(shortChunk)) + w.Close() + + r := NewChunkedReader(bufio.NewReaderSize(&b, 16)) + buf := make([]byte, len(fillBufChunk)+len(shortChunk)) + n, err := r.Read(buf) + if n != len(fillBufChunk) || err != nil { + t.Errorf("Read = %d, %v; want %d, nil", n, err, len(fillBufChunk)) + } + buf = buf[:n] + if string(buf) != fillBufChunk { + t.Errorf("Read = %q; want %q", buf, fillBufChunk) + } + + n, err = r.Read(buf) + if n != len(shortChunk) || err != io.EOF { + t.Errorf("Read = %d, %v; want %d, EOF", n, err, len(shortChunk)) + } + } + + // And test that we see an EOF chunk, even though our buffer is already full: + { + r := NewChunkedReader(bufio.NewReader(strings.NewReader("3\r\nfoo\r\n0\r\n"))) + buf := make([]byte, 3) + n, err := r.Read(buf) + if n != 3 || err != io.EOF { + t.Errorf("Read = %d, %v; want 3, EOF", n, err) + } + if string(buf) != "foo" { + t.Errorf("buf = %q; want foo", buf) + } + } +} + +func TestChunkReaderAllocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + var buf bytes.Buffer + w := NewChunkedWriter(&buf) + a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") + w.Write(a) + w.Write(b) + w.Write(c) + w.Close() + + readBuf := make([]byte, len(a)+len(b)+len(c)+1) + byter := bytes.NewReader(buf.Bytes()) + bufr := bufio.NewReader(byter) + mallocs := testing.AllocsPerRun(100, func() { + byter.Seek(0, io.SeekStart) + bufr.Reset(byter) + r := NewChunkedReader(bufr) + n, err := io.ReadFull(r, readBuf) + if n != len(readBuf)-1 { + t.Fatalf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Fatalf("read error = %v; want ErrUnexpectedEOF", err) + } + }) + if mallocs > 1.5 { + t.Errorf("mallocs = %v; want 1", mallocs) + } +} + +func TestParseHexUint(t *testing.T) { + type testCase struct { + in string + want uint64 + wantErr string + } + tests := []testCase{ + {"x", 0, "invalid byte in chunk length"}, + {"0000000000000000", 0, ""}, + {"0000000000000001", 1, ""}, + {"ffffffffffffffff", 1<<64 - 1, ""}, + {"000000000000bogus", 0, "invalid byte in chunk length"}, + {"00000000000000000", 0, "http chunk length too large"}, // could accept if we wanted + {"10000000000000000", 0, "http chunk length too large"}, + {"00000000000000001", 0, "http chunk length too large"}, // could accept if we wanted + } + for i := uint64(0); i <= 1234; i++ { + tests = append(tests, testCase{in: fmt.Sprintf("%x", i), want: i}) + } + for _, tt := range tests { + got, err := parseHexUint([]byte(tt.in)) + if tt.wantErr != "" { + if !strings.Contains(fmt.Sprint(err), tt.wantErr) { + t.Errorf("parseHexUint(%q) = %v, %v; want error %q", tt.in, got, err, tt.wantErr) + } + } else { + if err != nil || got != tt.want { + t.Errorf("parseHexUint(%q) = %v, %v; want %v", tt.in, got, err, tt.want) + } + } + } +} + +func TestChunkReadingIgnoresExtensions(t *testing.T) { + in := "7;ext=\"some quoted string\"\r\n" + // token=quoted string + "hello, \r\n" + + "17;someext\r\n" + // token without value + "world! 0123456789abcdef\r\n" + + "0;someextension=sometoken\r\n" // token=token + data, err := io.ReadAll(NewChunkedReader(strings.NewReader(in))) + if err != nil { + t.Fatalf("ReadAll = %q, %v", data, err) + } + if g, e := string(data), "hello, world! 0123456789abcdef"; g != e { + t.Errorf("read %q; want %q", g, e) + } +} + +// Issue 17355: ChunkedReader shouldn't block waiting for more data +// if it can return something. +func TestChunkReadPartial(t *testing.T) { + pr, pw := io.Pipe() + go func() { + pw.Write([]byte("7\r\n1234567")) + }() + cr := NewChunkedReader(pr) + readBuf := make([]byte, 7) + n, err := cr.Read(readBuf) + if err != nil { + t.Fatal(err) + } + want := "1234567" + if n != 7 || string(readBuf) != want { + t.Fatalf("Read: %v %q; want %d, %q", n, readBuf[:n], len(want), want) + } + go func() { + pw.Write([]byte("xx")) + }() + _, err = cr.Read(readBuf) + if got := fmt.Sprint(err); !strings.Contains(got, "malformed") { + t.Fatalf("second read = %v; want malformed error", err) + } + +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/internal/testcert.go b/vendor/github.com/lesismal/llib/std/net/http/internal/testcert.go new file mode 100644 index 0000000..2284a83 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/internal/testcert.go @@ -0,0 +1,45 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import "strings" + +// LocalhostCert is a PEM-encoded TLS cert with SAN IPs +// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. +// generated from src/crypto/tls: +// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var LocalhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw +MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB +iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4 +iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul +rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO +BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw +AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA +AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9 +tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs +h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM +fblo6RBxUQ== +-----END CERTIFICATE-----`) + +// LocalhostKey is the private key for localhostCert. +var LocalhostKey = []byte(testingKey(`-----BEGIN RSA TESTING KEY----- +MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9 +SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB +l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB +AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet +3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb +uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H +qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp +jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY +fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U +fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU +y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX +qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo +f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA== +-----END RSA TESTING KEY-----`)) + +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } diff --git a/vendor/github.com/lesismal/llib/std/net/http/jar.go b/vendor/github.com/lesismal/llib/std/net/http/jar.go new file mode 100644 index 0000000..5c3de0d --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/jar.go @@ -0,0 +1,27 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "net/url" +) + +// A CookieJar manages storage and use of cookies in HTTP requests. +// +// Implementations of CookieJar must be safe for concurrent use by multiple +// goroutines. +// +// The net/http/cookiejar package provides a CookieJar implementation. +type CookieJar interface { + // SetCookies handles the receipt of the cookies in a reply for the + // given URL. It may or may not choose to save the cookies, depending + // on the jar's policy and implementation. + SetCookies(u *url.URL, cookies []*Cookie) + + // Cookies returns the cookies to send in a request for the given URL. + // It is up to the implementation to honor the standard cookie use + // restrictions such as in RFC 6265. + Cookies(u *url.URL) []*Cookie +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/main_test.go b/vendor/github.com/lesismal/llib/std/net/http/main_test.go new file mode 100644 index 0000000..6564627 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/main_test.go @@ -0,0 +1,171 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "fmt" + "io" + "log" + "net/http" + "os" + "runtime" + "sort" + "strings" + "testing" + "time" +) + +var quietLog = log.New(io.Discard, "", 0) + +func TestMain(m *testing.M) { + v := m.Run() + if v == 0 && goroutineLeaked() { + os.Exit(1) + } + os.Exit(v) +} + +func interestingGoroutines() (gs []string) { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + for _, g := range strings.Split(string(buf), "\n\n") { + sl := strings.SplitN(g, "\n", 2) + if len(sl) != 2 { + continue + } + stack := strings.TrimSpace(sl[1]) + if stack == "" || + strings.Contains(stack, "testing.(*M).before.func1") || + strings.Contains(stack, "os/signal.signal_recv") || + strings.Contains(stack, "created by net.startServer") || + strings.Contains(stack, "created by testing.RunTests") || + strings.Contains(stack, "closeWriteAndWait") || + strings.Contains(stack, "testing.Main(") || + // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) + strings.Contains(stack, "runtime.goexit") || + strings.Contains(stack, "created by runtime.gc") || + strings.Contains(stack, "net/http_test.interestingGoroutines") || + strings.Contains(stack, "runtime.MHeap_Scavenger") { + continue + } + gs = append(gs, stack) + } + sort.Strings(gs) + return +} + +// Verify the other tests didn't leave any goroutines running. +func goroutineLeaked() bool { + if testing.Short() || runningBenchmarks() { + // Don't worry about goroutine leaks in -short mode or in + // benchmark mode. Too distracting when there are false positives. + return false + } + + var stackCount map[string]int + for i := 0; i < 5; i++ { + n := 0 + stackCount = make(map[string]int) + gs := interestingGoroutines() + for _, g := range gs { + stackCount[g]++ + n++ + } + if n == 0 { + return false + } + // Wait for goroutines to schedule and die off: + time.Sleep(100 * time.Millisecond) + } + fmt.Fprintf(os.Stderr, "Too many goroutines running after net/http test(s).\n") + for stack, count := range stackCount { + fmt.Fprintf(os.Stderr, "%d instances of:\n%s\n", count, stack) + } + return true +} + +// setParallel marks t as a parallel test if we're in short mode +// (all.bash), but as a serial test otherwise. Using t.Parallel isn't +// compatible with the afterTest func in non-short mode. +func setParallel(t *testing.T) { + if strings.Contains(t.Name(), "HTTP2") { + http.CondSkipHTTP2(t) + } + if testing.Short() { + t.Parallel() + } +} + +func runningBenchmarks() bool { + for i, arg := range os.Args { + if strings.HasPrefix(arg, "-test.bench=") && !strings.HasSuffix(arg, "=") { + return true + } + if arg == "-test.bench" && i < len(os.Args)-1 && os.Args[i+1] != "" { + return true + } + } + return false +} + +func afterTest(t testing.TB) { + http.DefaultTransport.(*http.Transport).CloseIdleConnections() + if testing.Short() { + return + } + var bad string + badSubstring := map[string]string{ + ").readLoop(": "a Transport", + ").writeLoop(": "a Transport", + "created by net/http/httptest.(*Server).Start": "an httptest.Server", + "timeoutHandler": "a TimeoutHandler", + "net.(*netFD).connect(": "a timing out dial", + ").noteClientGone(": "a closenotifier sender", + } + var stacks string + for i := 0; i < 10; i++ { + bad = "" + stacks = strings.Join(interestingGoroutines(), "\n\n") + for substr, what := range badSubstring { + if strings.Contains(stacks, substr) { + bad = what + } + } + if bad == "" { + return + } + // Bad stuff found, but goroutines might just still be + // shutting down, so give it some time. + time.Sleep(250 * time.Millisecond) + } + t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks) +} + +// waitCondition reports whether fn eventually returned true, +// checking immediately and then every checkEvery amount, +// until waitFor has elapsed, at which point it returns false. +func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { + deadline := time.Now().Add(waitFor) + for time.Now().Before(deadline) { + if fn() { + return true + } + time.Sleep(checkEvery) + } + return false +} + +// waitErrCondition is like waitCondition but with errors instead of bools. +func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error { + deadline := time.Now().Add(waitFor) + var err error + for time.Now().Before(deadline) { + if err = fn(); err == nil { + return nil + } + time.Sleep(checkEvery) + } + return err +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/method.go b/vendor/github.com/lesismal/llib/std/net/http/method.go new file mode 100644 index 0000000..6f46155 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/method.go @@ -0,0 +1,20 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +// Common HTTP methods. +// +// Unless otherwise noted, these are defined in RFC 7231 section 4.3. +const ( + MethodGet = "GET" + MethodHead = "HEAD" + MethodPost = "POST" + MethodPut = "PUT" + MethodPatch = "PATCH" // RFC 5789 + MethodDelete = "DELETE" + MethodConnect = "CONNECT" + MethodOptions = "OPTIONS" + MethodTrace = "TRACE" +) diff --git a/vendor/github.com/lesismal/llib/std/net/http/omithttp2.go b/vendor/github.com/lesismal/llib/std/net/http/omithttp2.go new file mode 100644 index 0000000..30c6e48 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/omithttp2.go @@ -0,0 +1,71 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build nethttpomithttp2 + +package http + +import ( + "errors" + "sync" + "time" +) + +func init() { + omitBundledHTTP2 = true +} + +const noHTTP2 = "no bundled HTTP/2" // should never see this + +var http2errRequestCanceled = errors.New("net/http: request canceled") + +var http2goAwayTimeout = 1 * time.Second + +const http2NextProtoTLS = "h2" + +type http2Transport struct { + MaxHeaderListSize uint32 + ConnPool interface{} +} + +func (*http2Transport) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } +func (*http2Transport) CloseIdleConnections() {} + +type http2noDialH2RoundTripper struct{} + +func (http2noDialH2RoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } + +type http2noDialClientConnPool struct { + http2clientConnPool http2clientConnPool +} + +type http2clientConnPool struct { + mu *sync.Mutex + conns map[string][]struct{} +} + +func http2configureTransports(*Transport) (*http2Transport, error) { panic(noHTTP2) } + +func http2isNoCachedConnError(err error) bool { + _, ok := err.(interface{ IsHTTP2NoCachedConnError() }) + return ok +} + +type http2Server struct { + NewWriteScheduler func() http2WriteScheduler +} + +type http2WriteScheduler interface{} + +func http2NewPriorityWriteScheduler(interface{}) http2WriteScheduler { panic(noHTTP2) } + +func http2ConfigureServer(s *Server, conf *http2Server) error { panic(noHTTP2) } + +var http2ErrNoCachedConn = http2noCachedConnError{} + +type http2noCachedConnError struct{} + +func (http2noCachedConnError) IsHTTP2NoCachedConnError() {} + +func (http2noCachedConnError) Error() string { return "http2: no cached connection was available" } diff --git a/vendor/github.com/lesismal/llib/std/net/http/pprof/pprof.go b/vendor/github.com/lesismal/llib/std/net/http/pprof/pprof.go new file mode 100644 index 0000000..09731de --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/pprof/pprof.go @@ -0,0 +1,449 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package pprof serves via its HTTP server runtime profiling data +// in the format expected by the pprof visualization tool. +// +// The package is typically only imported for the side effect of +// registering its HTTP handlers. +// The handled paths all begin with /debug/pprof/. +// +// To use pprof, link this package into your program: +// import _ "github.com/lesismal/llib/std/net/http/pprof" +// +// If your application is not already running an http server, you +// need to start one. Add "net/http" and "log" to your imports and +// the following code to your main function: +// +// go func() { +// log.Println(http.ListenAndServe("localhost:6060", nil)) +// }() +// +// If you are not using DefaultServeMux, you will have to register handlers +// with the mux you are using. +// +// Then use the pprof tool to look at the heap profile: +// +// go tool pprof http://localhost:6060/debug/pprof/heap +// +// Or to look at a 30-second CPU profile: +// +// go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30 +// +// Or to look at the goroutine blocking profile, after calling +// runtime.SetBlockProfileRate in your program: +// +// go tool pprof http://localhost:6060/debug/pprof/block +// +// Or to look at the holders of contended mutexes, after calling +// runtime.SetMutexProfileFraction in your program: +// +// go tool pprof http://localhost:6060/debug/pprof/mutex +// +// The package also exports a handler that serves execution trace data +// for the "go tool trace" command. To collect a 5-second execution trace: +// +// wget -O trace.out http://localhost:6060/debug/pprof/trace?seconds=5 +// go tool trace trace.out +// +// To view all available profiles, open http://localhost:6060/debug/pprof/ +// in your browser. +// +// For a study of the facility in action, visit +// +// https://blog.golang.org/2011/06/profiling-go-programs.html +// +package pprof + +import ( + "bufio" + "bytes" + "context" + "fmt" + "html" + "internal/profile" + "io" + "log" + "net/http" + "net/url" + "os" + "runtime" + "runtime/pprof" + "runtime/trace" + "sort" + "strconv" + "strings" + "time" +) + +func init() { + http.HandleFunc("/debug/pprof/", Index) + http.HandleFunc("/debug/pprof/cmdline", Cmdline) + http.HandleFunc("/debug/pprof/profile", Profile) + http.HandleFunc("/debug/pprof/symbol", Symbol) + http.HandleFunc("/debug/pprof/trace", Trace) +} + +// Cmdline responds with the running program's +// command line, with arguments separated by NUL bytes. +// The package initialization registers it as /debug/pprof/cmdline. +func Cmdline(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprint(w, strings.Join(os.Args, "\x00")) +} + +func sleep(r *http.Request, d time.Duration) { + select { + case <-time.After(d): + case <-r.Context().Done(): + } +} + +func durationExceedsWriteTimeout(r *http.Request, seconds float64) bool { + srv, ok := r.Context().Value(http.ServerContextKey).(*http.Server) + return ok && srv.WriteTimeout != 0 && seconds >= srv.WriteTimeout.Seconds() +} + +func serveError(w http.ResponseWriter, status int, txt string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Go-Pprof", "1") + w.Header().Del("Content-Disposition") + w.WriteHeader(status) + fmt.Fprintln(w, txt) +} + +// Profile responds with the pprof-formatted cpu profile. +// Profiling lasts for duration specified in seconds GET parameter, or for 30 seconds if not specified. +// The package initialization registers it as /debug/pprof/profile. +func Profile(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + sec, err := strconv.ParseInt(r.FormValue("seconds"), 10, 64) + if sec <= 0 || err != nil { + sec = 30 + } + + if durationExceedsWriteTimeout(r, float64(sec)) { + serveError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout") + return + } + + // Set Content Type assuming StartCPUProfile will work, + // because if it does it starts writing. + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", `attachment; filename="profile"`) + if err := pprof.StartCPUProfile(w); err != nil { + // StartCPUProfile failed, so no writes yet. + serveError(w, http.StatusInternalServerError, + fmt.Sprintf("Could not enable CPU profiling: %s", err)) + return + } + sleep(r, time.Duration(sec)*time.Second) + pprof.StopCPUProfile() +} + +// Trace responds with the execution trace in binary form. +// Tracing lasts for duration specified in seconds GET parameter, or for 1 second if not specified. +// The package initialization registers it as /debug/pprof/trace. +func Trace(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + sec, err := strconv.ParseFloat(r.FormValue("seconds"), 64) + if sec <= 0 || err != nil { + sec = 1 + } + + if durationExceedsWriteTimeout(r, sec) { + serveError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout") + return + } + + // Set Content Type assuming trace.Start will work, + // because if it does it starts writing. + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", `attachment; filename="trace"`) + if err := trace.Start(w); err != nil { + // trace.Start failed, so no writes yet. + serveError(w, http.StatusInternalServerError, + fmt.Sprintf("Could not enable tracing: %s", err)) + return + } + sleep(r, time.Duration(sec*float64(time.Second))) + trace.Stop() +} + +// Symbol looks up the program counters listed in the request, +// responding with a table mapping program counters to function names. +// The package initialization registers it as /debug/pprof/symbol. +func Symbol(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + + // We have to read the whole POST body before + // writing any output. Buffer the output here. + var buf bytes.Buffer + + // We don't know how many symbols we have, but we + // do have symbol information. Pprof only cares whether + // this number is 0 (no symbols available) or > 0. + fmt.Fprintf(&buf, "num_symbols: 1\n") + + var b *bufio.Reader + if r.Method == "POST" { + b = bufio.NewReader(r.Body) + } else { + b = bufio.NewReader(strings.NewReader(r.URL.RawQuery)) + } + + for { + word, err := b.ReadSlice('+') + if err == nil { + word = word[0 : len(word)-1] // trim + + } + pc, _ := strconv.ParseUint(string(word), 0, 64) + if pc != 0 { + f := runtime.FuncForPC(uintptr(pc)) + if f != nil { + fmt.Fprintf(&buf, "%#x %s\n", pc, f.Name()) + } + } + + // Wait until here to check for err; the last + // symbol will have an err because it doesn't end in +. + if err != nil { + if err != io.EOF { + fmt.Fprintf(&buf, "reading request: %v\n", err) + } + break + } + } + + w.Write(buf.Bytes()) +} + +// Handler returns an HTTP handler that serves the named profile. +func Handler(name string) http.Handler { + return handler(name) +} + +type handler string + +func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + p := pprof.Lookup(string(name)) + if p == nil { + serveError(w, http.StatusNotFound, "Unknown profile") + return + } + if sec := r.FormValue("seconds"); sec != "" { + name.serveDeltaProfile(w, r, p, sec) + return + } + gc, _ := strconv.Atoi(r.FormValue("gc")) + if name == "heap" && gc > 0 { + runtime.GC() + } + debug, _ := strconv.Atoi(r.FormValue("debug")) + if debug != 0 { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + } else { + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, name)) + } + p.WriteTo(w, debug) +} + +func (name handler) serveDeltaProfile(w http.ResponseWriter, r *http.Request, p *pprof.Profile, secStr string) { + sec, err := strconv.ParseInt(secStr, 10, 64) + if err != nil || sec <= 0 { + serveError(w, http.StatusBadRequest, `invalid value for "seconds" - must be a positive integer`) + return + } + if !profileSupportsDelta[name] { + serveError(w, http.StatusBadRequest, `"seconds" parameter is not supported for this profile type`) + return + } + // 'name' should be a key in profileSupportsDelta. + if durationExceedsWriteTimeout(r, float64(sec)) { + serveError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout") + return + } + debug, _ := strconv.Atoi(r.FormValue("debug")) + if debug != 0 { + serveError(w, http.StatusBadRequest, "seconds and debug params are incompatible") + return + } + p0, err := collectProfile(p) + if err != nil { + serveError(w, http.StatusInternalServerError, "failed to collect profile") + return + } + + t := time.NewTimer(time.Duration(sec) * time.Second) + defer t.Stop() + + select { + case <-r.Context().Done(): + err := r.Context().Err() + if err == context.DeadlineExceeded { + serveError(w, http.StatusRequestTimeout, err.Error()) + } else { // TODO: what's a good status code for cancelled requests? 400? + serveError(w, http.StatusInternalServerError, err.Error()) + } + return + case <-t.C: + } + + p1, err := collectProfile(p) + if err != nil { + serveError(w, http.StatusInternalServerError, "failed to collect profile") + return + } + ts := p1.TimeNanos + dur := p1.TimeNanos - p0.TimeNanos + + p0.Scale(-1) + + p1, err = profile.Merge([]*profile.Profile{p0, p1}) + if err != nil { + serveError(w, http.StatusInternalServerError, "failed to compute delta") + return + } + + p1.TimeNanos = ts // set since we don't know what profile.Merge set for TimeNanos. + p1.DurationNanos = dur + + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s-delta"`, name)) + p1.Write(w) +} + +func collectProfile(p *pprof.Profile) (*profile.Profile, error) { + var buf bytes.Buffer + if err := p.WriteTo(&buf, 0); err != nil { + return nil, err + } + ts := time.Now().UnixNano() + p0, err := profile.Parse(&buf) + if err != nil { + return nil, err + } + p0.TimeNanos = ts + return p0, nil +} + +var profileSupportsDelta = map[handler]bool{ + "allocs": true, + "block": true, + "goroutine": true, + "heap": true, + "mutex": true, + "threadcreate": true, +} + +var profileDescriptions = map[string]string{ + "allocs": "A sampling of all past memory allocations", + "block": "Stack traces that led to blocking on synchronization primitives", + "cmdline": "The command line invocation of the current program", + "goroutine": "Stack traces of all current goroutines", + "heap": "A sampling of memory allocations of live objects. You can specify the gc GET parameter to run GC before taking the heap sample.", + "mutex": "Stack traces of holders of contended mutexes", + "profile": "CPU profile. You can specify the duration in the seconds GET parameter. After you get the profile file, use the go tool pprof command to investigate the profile.", + "threadcreate": "Stack traces that led to the creation of new OS threads", + "trace": "A trace of execution of the current program. You can specify the duration in the seconds GET parameter. After you get the trace file, use the go tool trace command to investigate the trace.", +} + +type profileEntry struct { + Name string + Href string + Desc string + Count int +} + +// Index responds with the pprof-formatted profile named by the request. +// For example, "/debug/pprof/heap" serves the "heap" profile. +// Index responds to a request for "/debug/pprof/" with an HTML page +// listing the available profiles. +func Index(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/debug/pprof/") { + name := strings.TrimPrefix(r.URL.Path, "/debug/pprof/") + if name != "" { + handler(name).ServeHTTP(w, r) + return + } + } + + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Content-Type", "text/html; charset=utf-8") + + var profiles []profileEntry + for _, p := range pprof.Profiles() { + profiles = append(profiles, profileEntry{ + Name: p.Name(), + Href: p.Name(), + Desc: profileDescriptions[p.Name()], + Count: p.Count(), + }) + } + + // Adding other profiles exposed from within this package + for _, p := range []string{"cmdline", "profile", "trace"} { + profiles = append(profiles, profileEntry{ + Name: p, + Href: p, + Desc: profileDescriptions[p], + }) + } + + sort.Slice(profiles, func(i, j int) bool { + return profiles[i].Name < profiles[j].Name + }) + + if err := indexTmplExecute(w, profiles); err != nil { + log.Print(err) + } +} + +func indexTmplExecute(w io.Writer, profiles []profileEntry) error { + var b bytes.Buffer + b.WriteString(` + +/debug/pprof/ + + + +/debug/pprof/
+
+Types of profiles available: + + +`) + + for _, profile := range profiles { + link := &url.URL{Path: profile.Href, RawQuery: "debug=1"} + fmt.Fprintf(&b, "\n", profile.Count, link, html.EscapeString(profile.Name)) + } + + b.WriteString(`
CountProfile
%d%s
+full goroutine stack dump +
+

+Profile Descriptions: +

    +`) + for _, profile := range profiles { + fmt.Fprintf(&b, "
  • %s:
    %s
  • \n", html.EscapeString(profile.Name), html.EscapeString(profile.Desc)) + } + b.WriteString(`
+

+ +`) + + _, err := w.Write(b.Bytes()) + return err +} diff --git a/vendor/github.com/lesismal/llib/std/net/http/pprof/pprof_test.go b/vendor/github.com/lesismal/llib/std/net/http/pprof/pprof_test.go new file mode 100644 index 0000000..ada46f3 --- /dev/null +++ b/vendor/github.com/lesismal/llib/std/net/http/pprof/pprof_test.go @@ -0,0 +1,258 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pprof + +import ( + "bytes" + "fmt" + "github.com/lesismal/llib/std/net/http/httptest" + "internal/profile" + "io" + "net/http" + "runtime" + "runtime/pprof" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestDescriptions checks that the profile names under runtime/pprof package +// have a key in the description map. +func TestDescriptions(t *testing.T) { + for _, p := range pprof.Profiles() { + _, ok := profileDescriptions[p.Name()] + if ok != true { + t.Errorf("%s does not exist in profileDescriptions map\n", p.Name()) + } + } +} + +func TestHandlers(t *testing.T) { + testCases := []struct { + path string + handler http.HandlerFunc + statusCode int + contentType string + contentDisposition string + resp []byte + }{ + {"/debug/pprof/