DERO-HE STARGATE Testnet Release33

This commit is contained in:
Captain 2021-12-01 15:43:13 +00:00
parent 2388bdd2f0
commit 9300ef9b14
No known key found for this signature in database
GPG Key ID: 18CDB3ED5E85D2D4
497 changed files with 118941 additions and 5537 deletions

View File

@ -1283,7 +1283,6 @@ func (chain *Blockchain) IS_TX_Valid(txhash crypto.Hash) (valid_blid crypto.Hash
for _, bltxhash := range bl.Tx_hashes {
if bltxhash == txhash {
exist_list = append(exist_list, blid)
//break , this is removed so as this case can be tested well
}
}
}

View File

@ -11,6 +11,7 @@
// Register a name, limit names of 5 or less length
Function Register(name String) Uint64
10 IF EXISTS(name) THEN GOTO 50 // if name is already used, it cannot reregistered
15 IF STRLEN(name) >= 64 THEN GOTO 50 // skip names misuse
20 IF STRLEN(name) >= 6 THEN GOTO 40
30 IF SIGNER() == address_raw("deto1qyvyeyzrcm2fzf6kyq7egkes2ufgny5xn77y6typhfx9s7w3mvyd5qqynr5hx") THEN GOTO 40
35 IF SIGNER() != address_raw("deto1qy0ehnqjpr0wxqnknyc66du2fsxyktppkr8m8e6jvplp954klfjz2qqdzcd8p") THEN GOTO 50

View File

@ -249,24 +249,24 @@ func (chain *Blockchain) Create_new_miner_block(miner_address rpc.Address) (cbl
}
if tx.IsProofRequired() && len(bl.Tips) == 2 {
if tx.BLID == bl.Tips[0] || tx.BLID == bl.Tips[1] {
if tx.BLID == bl.Tips[0] || tx.BLID == bl.Tips[1] { // delay txs by a block if they would collide
logger.V(8).Info("not selecting tx due to probable collision", "txid", tx_hash_list_sorted[i].Hash)
continue
}
} else {
version, err := chain.ReadBlockSnapshotVersion(tx.BLID)
if err != nil {
continue
}
hash, err := chain.Load_Merkle_Hash(version)
if err != nil {
continue
}
}
if hash != tx.Payloads[0].Statement.Roothash {
//return fmt.Errorf("Tx statement roothash mismatch expected %x actual %x", tx.Payloads[0].Statement.Roothash, hash[:])
continue
}
version, err := chain.ReadBlockSnapshotVersion(tx.BLID)
if err != nil {
continue
}
hash, err := chain.Load_Merkle_Hash(version)
if err != nil {
continue
}
if hash != tx.Payloads[0].Statement.Roothash {
//return fmt.Errorf("Tx statement roothash mismatch expected %x actual %x", tx.Payloads[0].Statement.Roothash, hash[:])
continue
}
if height-int64(tx.Height) < TX_VALIDITY_HEIGHT {

View File

@ -277,8 +277,22 @@ func (chain *Blockchain) Find_Blocks_Height_Range(startheight, stopheight int64)
}
_, topos_end := chain.Store.Topo_store.binarySearchHeight(stopheight)
lowest := topos_start[0]
for _, t := range topos_start {
if t < lowest {
lowest = t
}
}
highest := topos_end[0]
for _, t := range topos_end {
if t > highest {
highest = t
}
}
blid_map := map[crypto.Hash]bool{}
for i := topos_start[0]; i <= topos_end[0]; i++ {
for i := lowest; i <= highest; i++ {
if toporecord, err := chain.Store.Topo_store.Read(i); err != nil {
panic(err)
} else {

View File

@ -20,7 +20,9 @@ import "io"
import "os"
import "fmt"
import "time"
import "net/url"
import "crypto/rand"
import "crypto/tls"
import "sync"
import "runtime"
import "math/big"
@ -31,7 +33,6 @@ import "os/signal"
import "sync/atomic"
import "strings"
import "strconv"
import "context"
import "github.com/go-logr/logr"
@ -48,13 +49,10 @@ import "github.com/docopt/docopt-go"
import "github.com/deroproject/derohe/pow"
import "github.com/gorilla/websocket"
import "github.com/deroproject/derohe/glue/rwc"
import "github.com/creachadair/jrpc2"
import "github.com/creachadair/jrpc2/channel"
var mutex sync.RWMutex
var job rpc.GetBlockTemplate_Result
var job_counter int64
var maxdelay int = 10000
var threads int
var iterations int = 100
@ -67,8 +65,8 @@ var hash_rate uint64
var Difficulty uint64
var our_height int64
var block_counter int
var mini_block_counter int
var block_counter uint64
var mini_block_counter uint64
var logger logr.Logger
var command_line string = `dero-miner
@ -96,15 +94,6 @@ If daemon running on local machine no requirement of '--daemon-rpc-address' argu
`
var Exit_In_Progress = make(chan bool)
func Notify_broadcaster(req *jrpc2.Request) {
switch req.Method() {
case "Block", "MiniBlock", "Height":
go rpc_client.update_job()
default:
logger.V(1).Info("Notification received but not handled", "method", req.Method())
}
}
func main() {
var err error
@ -165,9 +154,9 @@ func main() {
}
if !globals.Arguments["--testnet"].(bool) {
daemon_rpc_address = "127.0.0.1:10102"
daemon_rpc_address = "127.0.0.1:10100"
} else {
daemon_rpc_address = "127.0.0.1:40402"
daemon_rpc_address = "127.0.0.1:10100"
}
if globals.Arguments["--daemon-rpc-address"] != nil {
@ -234,17 +223,12 @@ func main() {
go func() {
last_our_height := int64(0)
last_best_height := int64(0)
last_peer_count := uint64(0)
last_topo_height := int64(0)
last_mempool_tx_count := 0
last_counter := uint64(0)
last_counter_time := time.Now()
last_mining_state := false
_ = last_mining_state
_ = last_peer_count
_ = last_topo_height
_ = last_mempool_tx_count
mining := true
for {
@ -254,27 +238,12 @@ func main() {
default:
}
best_height, best_topo_height := int64(0), int64(0)
peer_count := uint64(0)
mempool_tx_count := 0
best_height := int64(0)
// only update prompt if needed
if last_our_height != our_height || last_best_height != best_height || last_counter != counter {
// choose color based on urgency
color := "\033[33m" // default is green color
/*if our_height < best_height {
color = "\033[33m" // make prompt yellow
} else if our_height > best_height {
color = "\033[31m" // make prompt red
}*/
color := "\033[33m" // default is green color
pcolor := "\033[32m" // default is green color
/*if peer_count < 1 {
pcolor = "\033[31m" // make prompt red
} else if peer_count <= 8 {
pcolor = "\033[33m" // make prompt yellow
}*/
mining_string := ""
@ -313,16 +282,10 @@ func main() {
testnet_string = "\033[31m TESTNET"
}
extra := fmt.Sprintf("%f", float32(mini_block_counter)/float32(block_counter))
l.SetPrompt(fmt.Sprintf("\033[1m\033[32mDERO Miner: \033[0m"+color+"Height %d "+pcolor+" BLOCKS %d MiniBlocks %d \033[32mNW %s %s>%s> avg %s >\033[0m ", our_height, block_counter, mini_block_counter, hash_rate_string, mining_string, testnet_string, extra))
l.SetPrompt(fmt.Sprintf("\033[1m\033[32mDERO Miner: \033[0m"+color+"Height %d "+pcolor+" BLOCKS %d MiniBlocks %d \033[32mNW %s %s>%s>>\033[0m ", our_height, block_counter, mini_block_counter, hash_rate_string, mining_string, testnet_string))
l.Refresh()
last_our_height = our_height
last_best_height = best_height
last_peer_count = peer_count
last_mempool_tx_count = mempool_tx_count
last_topo_height = best_topo_height
}
time.Sleep(1 * time.Second)
}
@ -348,11 +311,10 @@ func main() {
threads = 255
}
go increase_delay()
go getwork()
go getwork(wallet_address)
for i := 0; i < threads; i++ {
go rpc_client.mineblock(i)
go mineblock(i)
}
for {
@ -426,91 +388,61 @@ func random_execution(wg *sync.WaitGroup, iterations int) {
runtime.UnlockOSThread()
}
func increase_delay() {
for {
time.Sleep(time.Second)
maxdelay++
}
}
type Client struct {
WS *websocket.Conn
RPC *jrpc2.Client
Connected bool
}
var rpc_client = &Client{}
// continuously get work
func getwork() {
var connection *websocket.Conn
var connection_mutex sync.Mutex
func getwork(wallet_address string) {
var err error
for {
rpc_client.WS, _, err = websocket.DefaultDialer.Dial("ws://"+daemon_rpc_address+"/ws", nil)
u := url.URL{Scheme: "wss", Host: daemon_rpc_address, Path: "/ws/" + wallet_address}
logger.Info("connecting to ", "url", u.String())
dialer := websocket.DefaultDialer
dialer.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
connection, _, err = websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
logger.Error(err, "Error connecting to server", "server adress", daemon_rpc_address)
logger.Info("Will try in 10 secs", "server adress", daemon_rpc_address)
rpc_client.Connected = false
time.Sleep(10 * time.Second)
continue
}
input_output := rwc.New(rpc_client.WS)
rpc_client.RPC = jrpc2.NewClient(channel.RawJSON(input_output, input_output), &jrpc2.ClientOptions{OnNotify: Notify_broadcaster})
rpc_client.Connected = true
var result rpc.GetBlockTemplate_Result
wait_for_another_job:
for {
if err = rpc_client.update_job(); err != nil {
break
}
time.Sleep(100 * time.Millisecond)
if err = connection.ReadJSON(&result); err != nil {
logger.Error(err, "connection error")
continue
}
time.Sleep(4 * time.Second)
}
}
func (cli *Client) update_job() (err error) {
defer globals.Recover(1)
var result rpc.GetBlockTemplate_Result
if err = rpc_client.Call("DERO.GetBlockTemplate", rpc.GetBlockTemplate_Params{Wallet_Address: wallet_address}, &result); err == nil {
mutex.Lock()
job = result
maxdelay = 0
job_counter++
mutex.Unlock()
if job.LastError != "" {
logger.Error(nil, "received error", "err", job.LastError)
}
block_counter = job.Blocks
mini_block_counter = job.MiniBlocks
hash_rate = job.Difficultyuint64
our_height = int64(job.Height)
Difficulty = job.Difficultyuint64
} else {
rpc_client.WS.Close()
rpc_client.Connected = false
logger.Error(err, "Error receiving block template")
//fmt.Printf("recv: %s", result)
goto wait_for_another_job
}
return err
}
func (cli *Client) Call(method string, params interface{}, result interface{}) error {
return cli.RPC.CallResult(context.Background(), method, params, result)
}
// tests connectivity when connectivity to daemon
func (rpc_client *Client) test_connectivity() (err error) {
var info rpc.GetInfo_Result
if err = rpc_client.Call("DERO.GetInfo", nil, &info); err != nil {
logger.V(1).Error(err, "DERO.GetInfo Call failed:")
return
}
return nil
}
func (rpc_client *Client) mineblock(tid int) {
func mineblock(tid int) {
var diff big.Int
var work [block.MINIBLOCK_SIZE]byte
@ -518,19 +450,16 @@ func (rpc_client *Client) mineblock(tid int) {
runtime.LockOSThread()
threadaffinity()
var local_job_counter int64
i := uint32(0)
for {
mutex.RLock()
myjob := job
local_job_counter = job_counter
mutex.RUnlock()
if rpc_client.Connected == false {
time.Sleep(10 * time.Millisecond)
continue
}
n, err := hex.Decode(work[:], []byte(myjob.Blockhashing_blob))
if err != nil || n != block.MINIBLOCK_SIZE {
logger.Error(err, "Blockwork could not decoded successfully", "blockwork", myjob.Blockhashing_blob, "n", n, "job", myjob)
@ -543,40 +472,27 @@ func (rpc_client *Client) mineblock(tid int) {
diff.SetString(myjob.Difficulty, 10)
if work[0]&0xf != 1 { // check version
logger.Error(nil, "Unknown version", "version", work[0]&0x1f)
logger.Error(nil, "Unknown version, please check for updates", "version", work[0]&0x1f)
time.Sleep(time.Second)
continue
}
for {
for local_job_counter == job_counter { // update job when it comes, expected rate 1 per second
i++
binary.BigEndian.PutUint32(nonce_buf, i)
if i&0x3ff == 0x3ff { // get updated job every 250 millisecs
break
}
powhash := pow.Pow(work[:])
atomic.AddUint64(&counter, 1)
if CheckPowHashBig(powhash, &diff) == true {
logger.V(1).Info("Successfully found DERO miniblock", "difficulty", myjob.Difficulty, "height", myjob.Height)
maxdelay = 200
var result rpc.SubmitBlock_Result
if err = rpc_client.Call("DERO.SubmitBlock", rpc.SubmitBlock_Params{JobID: myjob.JobID, MiniBlockhashing_blob: fmt.Sprintf("%x", work[:])}, &result); err == nil {
func() {
defer globals.Recover(1)
connection_mutex.Lock()
defer connection_mutex.Unlock()
connection.WriteJSON(rpc.SubmitBlock_Params{JobID: myjob.JobID, MiniBlockhashing_blob: fmt.Sprintf("%x", work[:])})
}()
if result.MiniBlock {
mini_block_counter++
} else {
block_counter++
}
logger.V(2).Info("submitting block", "result", result)
go rpc_client.update_job()
} else {
logger.Error(err, "error submitting block")
rpc_client.update_job()
break
}
}
}
}
@ -584,7 +500,6 @@ func (rpc_client *Client) mineblock(tid int) {
func usage(w io.Writer) {
io.WriteString(w, "commands:\n")
//io.WriteString(w, completer.Tree(" "))
io.WriteString(w, "\t\033[1mhelp\033[0m\t\tthis help\n")
io.WriteString(w, "\t\033[1mstatus\033[0m\t\tShow general information\n")
io.WriteString(w, "\t\033[1mbye\033[0m\t\tQuit the miner\n")

View File

@ -890,7 +890,7 @@ func valid_registration_or_display_error(l *readline.Instance, wallet *walletapi
// show the transfers to the user originating from this account
func show_transfers(l *readline.Instance, wallet *walletapi.Wallet_Disk, scid crypto.Hash, limit uint64) {
if wallet.GetMode() { // if wallet is in offline mode , we cannot do anything
if wallet.GetMode() && walletapi.IsDaemonOnline() { // if wallet is in offline mode , we cannot do anything
if err := wallet.Sync_Wallet_Memory_With_Daemon_internal(scid); err != nil {
logger.Error(err, "Error syncing wallet", "scid", scid.String())
return

View File

@ -59,7 +59,7 @@ var command_line string = `derod
DERO : A secure, private blockchain with smart-contracts
Usage:
derod [--help] [--version] [--testnet] [--debug] [--sync-node] [--timeisinsync] [--fastsync] [--socks-proxy=<socks_ip:port>] [--data-dir=<directory>] [--p2p-bind=<0.0.0.0:18089>] [--add-exclusive-node=<ip:port>]... [--add-priority-node=<ip:port>]... [--min-peers=<11>] [--rpc-bind=<127.0.0.1:9999>] [--node-tag=<unique name>] [--prune-history=<50>] [--integrator-address=<address>] [--clog-level=1] [--flog-level=1]
derod [--help] [--version] [--testnet] [--debug] [--sync-node] [--timeisinsync] [--fastsync] [--socks-proxy=<socks_ip:port>] [--data-dir=<directory>] [--p2p-bind=<0.0.0.0:18089>] [--add-exclusive-node=<ip:port>]... [--add-priority-node=<ip:port>]... [--min-peers=<11>] [--rpc-bind=<127.0.0.1:9999>] [--getwork-bind=<0.0.0.0:18089>] [--node-tag=<unique name>] [--prune-history=<50>] [--integrator-address=<address>] [--clog-level=1] [--flog-level=1]
derod -h | --help
derod --version
@ -76,6 +76,7 @@ Options:
--data-dir=<directory> Store blockchain data at this location
--rpc-bind=<127.0.0.1:9999> RPC listens on this ip:port
--p2p-bind=<0.0.0.0:18089> p2p server listens on this ip:port, specify port 0 to disable listening server
--getwork-bind=<0.0.0.0:10100> getwork server listens on this ip:port, specify port 0 to disable listening server
--add-exclusive-node=<ip:port> Connect to specific peer only
--add-priority-node=<ip:port> Maintain persistant connection to specified peer
--sync-node Sync node automatically with the seeds nodes. This option is for rare use.
@ -211,6 +212,8 @@ func main() {
p2p.P2P_Init(params)
rpcserver, _ := derodrpc.RPCServer_Start(params)
go derodrpc.Getwork_server()
// setup function pointers
chain.P2P_Block_Relayer = func(cbl *block.Complete_Block, peerid uint64) {
p2p.Broadcast_Block(cbl, peerid)
@ -284,7 +287,7 @@ func main() {
testnet_string += " " + strconv.Itoa(chain.MiniBlocks.Count()) + " " + globals.GetOffset().Round(time.Millisecond).String() + "|" + globals.GetOffsetNTP().Round(time.Millisecond).String() + "|" + globals.GetOffsetP2P().Round(time.Millisecond).String()
l.SetPrompt(fmt.Sprintf("\033[1m\033[32mDERO HE: \033[0m"+color+"%d/%d [%d/%d] "+pcolor+"P %d TXp %d:%d \033[32mNW %s >%s>>\033[0m ", our_height, topo_height, best_height, best_topo_height, peer_count, mempool_tx_count, regpool_tx_count, hash_rate_string, testnet_string))
l.SetPrompt(fmt.Sprintf("\033[1m\033[32mDERO HE: \033[0m"+color+"%d/%d [%d/%d] "+pcolor+"P %d TXp %d:%d \033[32mNW %s >Miners %d %s>>\033[0m ", our_height, topo_height, best_height, best_topo_height, peer_count, mempool_tx_count, regpool_tx_count, hash_rate_string, derodrpc.CountMiners(), testnet_string))
l.Refresh()
last_second = time.Now().Unix()
last_our_height = our_height
@ -491,6 +494,19 @@ func readline_loop(l *readline.Instance, chain *blockchain.Blockchain, logger lo
logger.Error(fmt.Errorf("regpool_delete_tx needs a single transaction id as argument"), "")
}
case command == "mempool_dump": // dump mempool to directory
tx_hash_list_sorted := chain.Mempool.Mempool_List_TX_SortedInfo() // hash of all tx expected to be included within this block , sorted by fees
os.Mkdir(filepath.Join(globals.GetDataDirectory(), "mempool"), 0755)
count := 0
for _, txi := range tx_hash_list_sorted {
if tx := chain.Mempool.Mempool_Get_TX(txi.Hash); tx != nil {
os.WriteFile(filepath.Join(globals.GetDataDirectory(), "mempool", txi.Hash.String()), tx.Serialize(), 0755)
count++
}
}
logger.Info("flushed mempool to driectory", "count", count, "dir", filepath.Join(globals.GetDataDirectory(), "mempool"))
case command == "mempool_print":
chain.Mempool.Mempool_Print()
@ -981,6 +997,7 @@ var completer = readline.NewPrefixCompleter(
readline.PcItem("help"),
readline.PcItem("diff"),
readline.PcItem("gc"),
readline.PcItem("mempool_dump"),
readline.PcItem("mempool_flush"),
readline.PcItem("mempool_delete_tx"),
readline.PcItem("mempool_print"),

View File

@ -0,0 +1,341 @@
package rpc
import (
"flag"
"fmt"
"net/http"
"time"
"github.com/lesismal/llib/std/crypto/tls"
"github.com/lesismal/nbio/nbhttp"
"github.com/lesismal/nbio/nbhttp/websocket"
)
import "github.com/lesismal/nbio"
import "github.com/lesismal/nbio/logging"
import "net"
import "bytes"
import "encoding/hex"
import "encoding/json"
import "runtime"
import "strings"
import "math/big"
import "crypto/ecdsa"
import "crypto/elliptic"
//import "crypto/tls"
import "crypto/rand"
import "crypto/x509"
import "encoding/pem"
import "github.com/deroproject/derohe/globals"
import "github.com/deroproject/derohe/rpc"
import "github.com/deroproject/graviton"
import "github.com/go-logr/logr"
// this file implements the non-blocking job streamer
// only job is to stream jobs to thousands of workers, if any is successful,accept and report back
import "sync"
var memPool = sync.Pool{
New: func() interface{} {
return make([]byte, 16*1024)
},
}
var logger_getwork logr.Logger
var (
svr *nbhttp.Server
print = flag.Bool("print", false, "stdout output of echoed data")
)
type user_session struct {
blocks uint64
miniblocks uint64
lasterr string
address rpc.Address
valid_address bool
address_sum [32]byte
}
var client_list_mutex sync.Mutex
var client_list = map[*websocket.Conn]*user_session{}
func CountMiners() int {
client_list_mutex.Lock()
defer client_list_mutex.Unlock()
return len(client_list)
}
func SendJob() {
var params rpc.GetBlockTemplate_Result
var buf bytes.Buffer
encoder := json.NewEncoder(&buf)
// get a block template, and then we will fill the address here as optimization
bl, mbl, _, _, err := chain.Create_new_block_template_mining(chain.IntegratorAddress())
if err != nil {
return
}
prev_hash := ""
for i := range bl.Tips {
prev_hash = prev_hash + bl.Tips[i].String()
}
params.JobID = fmt.Sprintf("%d.%d.%s", bl.Timestamp, 0, "notified")
diff := chain.Get_Difficulty_At_Tips(bl.Tips)
params.Height = bl.Height
params.Prev_Hash = prev_hash
params.Difficultyuint64 = diff.Uint64()
params.Difficulty = diff.String()
client_list_mutex.Lock()
defer client_list_mutex.Unlock()
for k, v := range client_list {
if !mbl.Final { //write miners address only if possible
copy(mbl.KeyHash[:], v.address_sum[:])
}
for i := range mbl.Nonce { // give each user different work
mbl.Nonce[i] = globals.Global_Random.Uint32() // fill with randomness
}
if v.lasterr != "" {
params.LastError = v.lasterr
v.lasterr = ""
}
if !v.valid_address && !chain.IsAddressHashValid(false, v.address_sum) {
params.LastError = "unregistered miner or you need to wait 15 mins"
} else {
v.valid_address = true
}
params.Blockhashing_blob = fmt.Sprintf("%x", mbl.Serialize())
params.Blocks = v.blocks
params.MiniBlocks = v.miniblocks
encoder.Encode(params)
k.WriteMessage(websocket.TextMessage, buf.Bytes())
buf.Reset()
}
}
func newUpgrader() *websocket.Upgrader {
u := websocket.NewUpgrader()
u.OnMessage(func(c *websocket.Conn, messageType websocket.MessageType, data []byte) {
// echo
c.WriteMessage(messageType, data)
if messageType != websocket.TextMessage {
return
}
sess := c.Session().(*user_session)
client_list_mutex.Lock()
client_list_mutex.Unlock()
var p rpc.SubmitBlock_Params
if err := json.Unmarshal(data, &p); err != nil {
}
mbl_block_data_bytes, err := hex.DecodeString(p.MiniBlockhashing_blob)
if err != nil {
//logger.Info("Submitting block could not be decoded")
sess.lasterr = fmt.Sprintf("Submitted block could not be decoded. err: %s", err)
return
}
var tstamp, extra uint64
fmt.Sscanf(p.JobID, "%d.%d", &tstamp, &extra)
_, blid, sresult, err := chain.Accept_new_block(tstamp, mbl_block_data_bytes)
if sresult {
//logger.Infof("Submitted block %s accepted", blid)
if blid.IsZero() {
sess.miniblocks++
} else {
sess.blocks++
}
}
})
u.OnClose(func(c *websocket.Conn, err error) {
client_list_mutex.Lock()
delete(client_list, c)
client_list_mutex.Unlock()
})
return u
}
func onWebsocket(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/ws/") {
http.NotFound(w, r)
return
}
address := strings.TrimPrefix(r.URL.Path, "/ws/")
addr, err := globals.ParseValidateAddress(address)
if err != nil {
fmt.Fprintf(w, "err: %s\n", err)
return
}
addr_raw := addr.PublicKey.EncodeCompressed()
upgrader := newUpgrader()
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
//panic(err)
return
}
wsConn := conn.(*websocket.Conn)
session := user_session{address: *addr, address_sum: graviton.Sum(addr_raw)}
wsConn.SetSession(&session)
client_list_mutex.Lock()
client_list[wsConn] = &session
client_list_mutex.Unlock()
}
func Getwork_server() {
var err error
logger_getwork = globals.Logger.WithName("GETWORK")
logging.SetLevel(logging.LevelNone) //LevelDebug)//LevelNone)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{generate_random_tls_cert()},
InsecureSkipVerify: true,
}
mux := &http.ServeMux{}
mux.HandleFunc("/", onWebsocket) // handle everything
default_address := fmt.Sprintf("0.0.0.0:%d", globals.Config.GETWORK_Default_Port)
if _, ok := globals.Arguments["--getwork-bind"]; ok && globals.Arguments["--getwork-bind"] != nil {
addr, err := net.ResolveTCPAddr("tcp", globals.Arguments["--getwork-bind"].(string))
if err != nil {
logger_getwork.Error(err, "--getwork-bind address is invalid")
return
} else {
if addr.Port == 0 {
logger_getwork.Info("GETWORK server is disabled, No ports will be opened for miners to get work")
return
} else {
default_address = addr.String()
}
}
}
logger_getwork.Info("GETWORK will listen", "address", default_address)
svr = nbhttp.NewServer(nbhttp.Config{
Name: "GETWORK",
Network: "tcp",
AddrsTLS: []string{default_address},
TLSConfig: tlsConfig,
Handler: mux,
MaxLoad: 10 * 1024,
MaxWriteBufferSize: 32 * 1024,
ReleaseWebsocketPayload: true,
KeepaliveTime: 240 * time.Hour, // we expects all miners to find a block every 10 days,
NPoller: runtime.NumCPU(),
})
svr.OnReadBufferAlloc(func(c *nbio.Conn) []byte {
return memPool.Get().([]byte)
})
svr.OnReadBufferFree(func(c *nbio.Conn, b []byte) {
memPool.Put(b)
})
globals.Cron.AddFunc("@every 2s", SendJob) // if daemon restart automaticaly send job
if err = svr.Start(); err != nil {
logger_getwork.Error(err, "nbio.Start failed.")
return
}
logger.Info("GETWORK/Websocket server started")
svr.Wait()
defer svr.Stop()
}
// generate default tls cert to encrypt everything
// NOTE: this does NOT protect from individual active man-in-the-middle attacks
func generate_random_tls_cert() tls.Certificate {
/* RSA can do only 500 exchange per second, we need to be faster
* reference https://github.com/golang/go/issues/20058
key, err := rsa.GenerateKey(rand.Reader, 512) // current using minimum size
if err != nil {
log.Fatal("Private key cannot be created.", err.Error())
}
// Generate a pem block with the private key
keyPem := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
*/
// EC256 does roughly 20000 exchanges per second
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
b, err := x509.MarshalECPrivateKey(key)
if err != nil {
logger.Error(err, "Unable to marshal ECDSA private key")
panic(err)
}
// Generate a pem block with the private key
keyPem := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
tml := x509.Certificate{
SerialNumber: big.NewInt(int64(time.Now().UnixNano())),
// TODO do we need to add more parameters to make our certificate more authentic
// and thwart traffic identification as a mass scale
// you can add any attr that you need
NotBefore: time.Now().AddDate(0, -1, 0),
NotAfter: time.Now().AddDate(1, 0, 0),
// you have to generate a different serial number each execution
/*
Subject: pkix.Name{
CommonName: "New Name",
Organization: []string{"New Org."},
},
BasicConstraintsValid: true, // even basic constraints are not required
*/
}
cert, err := x509.CreateCertificate(rand.Reader, &tml, &tml, &key.PublicKey, key)
if err != nil {
logger.Error(err, "Certificate cannot be created.")
panic(err)
}
// Generate a pem block with the certificate
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert})
tlsCert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
logger.Error(err, "Certificate cannot be loaded.")
panic(err)
}
return tlsCert
}

View File

@ -117,10 +117,7 @@ func Notify_MiniBlock_Addition() {
chain.RPC_NotifyNewMiniBlock.L.Unlock()
go func() {
defer globals.Recover(2)
client_connections.Range(func(key, value interface{}) bool {
key.(*jrpc2.Server).Notify(context.Background(), "MiniBlock", nil)
return true
})
SendJob()
}()
}
}

View File

@ -76,6 +76,7 @@ type CHAIN_CONFIG struct {
Name string
Network_ID uuid.UUID // network ID
GETWORK_Default_Port int // used for miner getwork as effeciently as poosible
P2P_Default_Port int
RPC_Default_Port int
Wallet_RPC_Default_Port int
@ -87,6 +88,7 @@ type CHAIN_CONFIG struct {
var Mainnet = CHAIN_CONFIG{Name: "mainnet",
Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x9a, 0x44, 0x45, 0x0}),
GETWORK_Default_Port: 10100,
P2P_Default_Port: 10101,
RPC_Default_Port: 10102,
Wallet_RPC_Default_Port: 10103,
@ -103,7 +105,8 @@ var Mainnet = CHAIN_CONFIG{Name: "mainnet",
}
var Testnet = CHAIN_CONFIG{Name: "testnet", // testnet will always have last 3 bytes 0
Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x73, 0x00, 0x00, 0x00}),
Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x74, 0x00, 0x00, 0x00}),
GETWORK_Default_Port: 10100,
P2P_Default_Port: 40401,
RPC_Default_Port: 40402,
Wallet_RPC_Default_Port: 40403,

View File

@ -20,4 +20,4 @@ import "github.com/blang/semver/v4"
// right now it has to be manually changed
// do we need to include git commitsha??
var Version = semver.MustParse("3.4.93-1.DEROHE.STARGATE+25112021")
var Version = semver.MustParse("3.4.94-1.DEROHE.STARGATE+25112021")

View File

@ -114,6 +114,9 @@ type (
Height uint64 `json:"height"`
Prev_Hash string `json:"prev_hash"`
EpochMilli uint64 `json:"epochmilli"`
Blocks uint64 `json:"blocks"` // number of blocks found
MiniBlocks uint64 `json:"miniblocks"` // number of miniblocks found
LastError string `json:"lasterror"` // last error
Status string `json:"status"`
}
)
@ -197,14 +200,14 @@ type (
} // no params
GetTransaction_Result struct {
Txs_as_hex []string `json:"txs_as_hex"`
Txs_as_json []string `json:"txs_as_json"`
Txs_as_json []string `json:"txs_as_json,omitempty"`
Txs []Tx_Related_Info `json:"txs"`
Status string `json:"status"`
}
Tx_Related_Info struct {
As_Hex string `json:"as_hex"`
As_Json string `json:"as_json"`
As_Json string `json:"as_json,omitempty"`
Block_Height int64 `json:"block_height"`
Reward uint64 `json:"reward"` // miner tx rewards are decided by the protocol during execution
Ignored bool `json:"ignored"` // tell whether this tx is okau as per client protocol or bein ignored
@ -261,17 +264,8 @@ type (
Tx_as_hex string `json:"tx_as_hex"`
}
SendRawTransaction_Result struct {
Status string `json:"status"`
DoubleSpend bool `json:"double_spend"`
FeeTooLow bool `json:"fee_too_low"`
InvalidInput bool `json:"invalid_input"`
InvalidOutput bool `json:"invalid_output"`
Low_Mixin bool `json:"low_mixin"`
Non_rct bool `json:"not_rct"`
NotRelayed bool `json:"not_relayed"`
Overspend bool `json:"overspend"`
TooBig bool `json:"too_big"`
Reason string `json:"string"`
Status string `json:"status"`
Reason string `json:"string"`
}
)

27
vendor/github.com/lesismal/llib/LICENSE generated vendored Normal file
View File

@ -0,0 +1,27 @@
Copyright (c) 2021 lesismal. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

12
vendor/github.com/lesismal/llib/README.md generated vendored Normal file
View File

@ -0,0 +1,12 @@
# llib - [lesismal](https://github.com/lesismal)'s lib
[![GoDoc][1]][2] [![MIT licensed][3]][4] [![Go Version][5]][6]
[1]: https://godoc.org/github.com/lesismal/llib?status.svg
[2]: https://godoc.org/github.com/lesismal/llib
[3]: https://img.shields.io/badge/license-BSD-blue.svg
[4]: LICENSE
[5]: https://img.shields.io/badge/go-%3E%3D1.16-30dff3?style=flat-square&logo=go
[6]: https://github.com/lesismal/llib
Less Is More :smile:

225
vendor/github.com/lesismal/llib/bytes/buffer.go generated vendored Normal file
View File

@ -0,0 +1,225 @@
package bytes
import (
"errors"
)
var (
ErrInvalidLength = errors.New("invalid length")
ErrInvalidPosition = errors.New("invalid position")
ErrNotEnougth = errors.New("bytes not enougth")
)
// Buffer .
type Buffer struct {
total int
buffers [][]byte
onRelease func(b []byte)
}
// Len .
func (bb *Buffer) Len() int {
return bb.total
}
// Push .
func (bb *Buffer) Push(b []byte) {
if len(b) == 0 {
return
}
bb.buffers = append(bb.buffers, b)
bb.total += len(b)
}
// Pop .
func (bb *Buffer) Pop(n int) ([]byte, error) {
if n < 0 {
return nil, ErrInvalidLength
}
if bb.total < n {
return nil, ErrNotEnougth
}
bb.total -= n
var buf = bb.buffers[0]
if len(buf) >= n {
ret := buf[:n]
bb.buffers[0] = bb.buffers[0][n:]
if len(bb.buffers[0]) == 0 {
bb.releaseHead()
}
return ret, nil
}
var ret = make([]byte, n)[0:0]
for n > 0 {
if len(buf) >= n {
ret = append(ret, buf[:n]...)
bb.buffers[0] = bb.buffers[0][n:]
if len(bb.buffers[0]) == 0 {
bb.releaseHead()
}
return ret, nil
}
ret = append(ret, buf...)
bb.releaseHead()
n -= len(buf)
buf = bb.buffers[0]
}
return ret, nil
}
// Append .
func (bb *Buffer) Append(b []byte) {
if len(b) == 0 {
return
}
n := len(bb.buffers)
if n == 0 {
bb.buffers = append(bb.buffers, b)
return
}
bb.buffers[n-1] = append(bb.buffers[n-1], b...)
bb.total += len(b)
}
// Head .
func (bb *Buffer) Head(n int) ([]byte, error) {
if n < 0 {
return nil, ErrInvalidLength
}
if bb.total < n {
return nil, ErrNotEnougth
}
if len(bb.buffers[0]) >= n {
return bb.buffers[0][:n], nil
}
ret := make([]byte, n)
copied := 0
for i := 0; n > 0; i++ {
buf := bb.buffers[i]
if len(buf) >= n {
copy(ret[copied:], buf[:n])
return ret, nil
} else {
copy(ret[copied:], buf)
n -= len(buf)
copied += len(buf)
}
}
return ret, nil
}
// Sub .
func (bb *Buffer) Sub(from, to int) ([]byte, error) {
if from < 0 || to < 0 || to < from {
return nil, ErrInvalidPosition
}
if bb.total < to {
return nil, ErrNotEnougth
}
if len(bb.buffers[0]) >= to {
return bb.buffers[0][from:to], nil
}
n := to - from
ret := make([]byte, n)
copied := 0
for i := 0; n > 0; i++ {
buf := bb.buffers[i]
if len(buf) >= from+n {
copy(ret[copied:], buf[from:from+n])
return ret, nil
} else {
if len(buf) > from {
if from > 0 {
buf = buf[from:]
from = 0
}
copy(ret[copied:], buf)
copied += len(buf)
n -= len(buf)
} else {
from -= len(buf)
}
}
}
return ret, nil
}
// Write .
func (bb *Buffer) Write(b []byte) {
bb.Push(b)
}
// Read .
func (bb *Buffer) Read(n int) ([]byte, error) {
return bb.Pop(n)
}
// ReadAll .
func (bb *Buffer) ReadAll() ([]byte, error) {
if len(bb.buffers) == 0 {
return nil, nil
}
ret := append([]byte{}, bb.buffers[0]...)
if bb.onRelease != nil {
bb.onRelease(bb.buffers[0])
for i := 1; i < len(bb.buffers); i++ {
ret = append(ret, bb.buffers[i]...)
bb.onRelease(bb.buffers[i])
}
} else {
for i := 1; i < len(bb.buffers); i++ {
ret = append(ret, bb.buffers[i]...)
}
}
bb.buffers = nil
bb.total = 0
return ret, nil
}
// Reset .
func (bb *Buffer) Reset() {
if bb.onRelease != nil {
for i := 0; i < len(bb.buffers); i++ {
bb.onRelease(bb.buffers[i])
}
}
bb.buffers = nil
bb.total = 0
}
func (bb *Buffer) OnRelease(onRelease func(b []byte)) {
bb.onRelease = onRelease
}
func (bb *Buffer) releaseHead() {
if bb.onRelease != nil {
bb.onRelease(bb.buffers[0])
}
switch len(bb.buffers) {
case 1:
bb.buffers = nil
default:
bb.buffers = bb.buffers[1:]
}
}
// NewBuffer .
func NewBuffer() *Buffer {
return &Buffer{}
}

108
vendor/github.com/lesismal/llib/bytes/buffer_test.go generated vendored Normal file
View File

@ -0,0 +1,108 @@
package bytes
import (
"testing"
)
func TestBuffer(t *testing.T) {
str := "hello world"
buffer := NewBuffer()
buffer.Write([]byte("hel"))
buffer.Write([]byte("lo world"))
b, err := buffer.ReadAll()
if err != nil {
t.Fatal(err)
}
if string(b) != str {
t.Fatal(string(b))
}
buffer.Write([]byte("hel"))
buffer.Write([]byte("lo "))
buffer.Write([]byte("wor"))
buffer.Write([]byte("ld"))
for i := 0; i < len(str); i++ {
for j := i; j < len(str); j++ {
sub, err := buffer.Sub(i, j)
if err != nil {
t.Fatal(err)
}
if string(sub) != string([]byte(str)[i:j]) {
t.Fatalf("[%v:%v] %v != %v", i, j, string(sub), string([]byte(str)[i:j]))
}
}
}
for i := 0; i < len(str); i++ {
for j := i; j < len(str); j++ {
buffer.Write([]byte("hel"))
buffer.Write([]byte("lo "))
buffer.Write([]byte("wor"))
buffer.Write([]byte("ld"))
b, err = buffer.Read(j)
if err != nil {
t.Fatal(err)
}
if string(b) != string([]byte(str)[:j]) {
t.Fatalf("[%v:%v] %v != %v", i, j, string(b), string([]byte(str)[:j]))
}
buffer.Reset()
}
}
for i := 0; i < len(str); i++ {
for j := i; j < len(str); j++ {
buffer.Write([]byte("hel"))
buffer.Write([]byte("lo "))
buffer.Write([]byte("wor"))
buffer.Write([]byte("ld"))
buffer.Read(i)
b, err = buffer.Read(j - i)
if err != nil {
t.Fatal(err)
}
if string(b) != string([]byte(str)[i:j]) {
t.Fatalf("[%v:%v] %v != %v", i, j, string(b), string([]byte(str)[i:j]))
}
buffer.Reset()
}
}
buffer.Append([]byte("hello"))
buffer.Append([]byte(" world"))
if string(buffer.buffers[0]) != "hello world" {
t.Fatal(string(buffer.buffers[0]))
}
b, err = buffer.ReadAll()
if err != nil {
t.Fatal(err)
}
if string(b) != "hello world" {
t.Fatal(string(b))
}
buffer.Reset()
buffer.Push([]byte("hello "))
buffer.Push([]byte("world"))
if string(buffer.buffers[0]) != "hello " {
t.Fatal(string(buffer.buffers[0]))
}
buffer.Pop(1)
if string(buffer.buffers[0]) != "ello " {
t.Fatal(string(buffer.buffers[0]))
}
buffer.Pop(5)
if string(buffer.buffers[0]) != "world" {
t.Fatal(string(buffer.buffers[0]))
}
buffer.ReadAll()
if len(buffer.buffers) != 0 {
t.Fatal(string(buffer.buffers[0]))
}
}

84
vendor/github.com/lesismal/llib/bytes/pool.go generated vendored Normal file
View File

@ -0,0 +1,84 @@
// Copyright 2020 lesismal. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package bytes
import (
"sync"
)
// maxAppendSize represents the max size to append to a slice.
const maxAppendSize = 1024 * 1024 * 4
// Pool is the default instance of []byte pool.
// User can customize a Pool implementation and reset this instance if needed.
var Pool interface {
Get() []byte
GetN(size int) []byte
Put(b []byte)
} = NewPool(64)
// bufferPool is a default implementatiion of []byte Pool.
type bufferPool struct {
sync.Pool
MinSize int
}
// NewPool creates and returns a bufferPool instance.
// All slice created by this instance has an initial cap of minSize.
func NewPool(minSize int) *bufferPool {
if minSize <= 0 {
minSize = 64
}
bp := &bufferPool{
MinSize: minSize,
}
bp.Pool.New = func() interface{} {
buf := make([]byte, bp.MinSize)
return &buf
}
return bp
}
// Get gets a slice from the pool and returns it with length 0.
// User can append the slice and should Put it back to the pool after being used over.
func (bp *bufferPool) Get() []byte {
pbuf := bp.Pool.Get().(*[]byte)
return (*pbuf)[0:0]
}
// GetN returns a slice with length size.
// To reuse slices as possible,
// if the cap of the slice got from the pool is not enough,
// It will append the slice,
// or put the slice back to the pool and create a new slice with cap of size.
//
// User can use the slice both by the size or append it,
// and should Put it back to the pool after being used over.
func (bp *bufferPool) GetN(size int) []byte {
pbuf := bp.Pool.Get().(*[]byte)
need := size - cap(*pbuf)
if need > 0 {
if need <= maxAppendSize {
*pbuf = (*pbuf)[:cap(*pbuf)]
*pbuf = append(*pbuf, make([]byte, need)...)
} else {
bp.Pool.Put(pbuf)
newBuf := make([]byte, size)
pbuf = &newBuf
}
}
return (*pbuf)[:size]
}
// Put puts a slice back to the pool.
// If the slice's cap is smaller than MinSize,
// it will not be put back to the pool but dropped.
func (bp *bufferPool) Put(b []byte) {
if cap(b) < bp.MinSize {
return
}
bp.Pool.Put(&b)
}

22
vendor/github.com/lesismal/llib/bytes/pool_test.go generated vendored Normal file
View File

@ -0,0 +1,22 @@
package bytes
import "testing"
func TestMemPool(t *testing.T) {
const minMemSize = 64
pool := NewPool(minMemSize)
for i := 0; i < 1024*1024; i++ {
buf := pool.GetN(i)
if len(buf) != i {
t.Fatalf("invalid length: %v != %v", len(buf), i)
}
pool.Put(buf)
}
for i := 1024 * 1024; i < 1024*1024*1024; i += 1024 * 1024 {
buf := pool.GetN(i)
if len(buf) != i {
t.Fatalf("invalid length: %v != %v", len(buf), i)
}
pool.Put(buf)
}
}

60
vendor/github.com/lesismal/llib/concurrent/batch.go generated vendored Normal file
View File

@ -0,0 +1,60 @@
// Copyright 2020 lesismal. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package concurrent
import (
"sync"
)
var (
_defaultBatch = NewBatch()
)
type call struct {
mux sync.RWMutex
ret interface{}
err error
}
// Batch .
type Batch struct {
_mux sync.Mutex
_callings map[interface{}]*call
}
// Do .
func (o *Batch) Do(key interface{}, f func() (interface{}, error)) (interface{}, error) {
o._mux.Lock()
c, ok := o._callings[key]
if ok {
o._mux.Unlock()
c.mux.RLock()
c.mux.RUnlock()
return c.ret, c.err
}
c = &call{}
c.mux.Lock()
o._callings[key] = c
o._mux.Unlock()
c.ret, c.err = f()
c.mux.Unlock()
o._mux.Lock()
delete(o._callings, key)
o._mux.Unlock()
return c.ret, c.err
}
// NewBatch .
func NewBatch() *Batch {
return &Batch{_callings: map[interface{}]*call{}}
}
// Do .
func Do(key interface{}, f func() (interface{}, error)) (interface{}, error) {
return _defaultBatch.Do(key, f)
}

View File

@ -0,0 +1,34 @@
// Copyright 2020 lesismal. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package concurrent
import (
"log"
"testing"
"time"
)
func TestBatch(t *testing.T) {
batchCall := func() (interface{}, error) {
time.Sleep(time.Second)
return time.Now().Format("2006/01/02 15:04:05.000"), nil
}
for i := 0; i < 10; i++ {
go func(id int) {
ret, err := Do(3, batchCall)
log.Println("Batch().Do():", id, ret, err)
}(2)
}
func(id int) {
ret, err := Do(3, batchCall)
log.Println("Batch().Do():", id, ret, err)
}(1)
func(id int) {
ret, err := Do(3, batchCall)
log.Println("Batch().Do():", id, ret, err)
}(3)
time.Sleep(time.Second)
}

100
vendor/github.com/lesismal/llib/concurrent/map.go generated vendored Normal file
View File

@ -0,0 +1,100 @@
package concurrent
import (
"sync"
"sync/atomic"
"github.com/cespare/xxhash"
)
type bucket struct {
mux sync.RWMutex
values map[string]interface{}
}
func (b *bucket) Get(k string) (interface{}, bool) {
b.mux.RLock()
v, ok := b.values[k]
b.mux.RUnlock()
return v, ok
}
func (b *bucket) Set(k string, v interface{}) bool {
b.mux.Lock()
_, exsist := b.values[k]
b.values[k] = v
b.mux.Unlock()
return !exsist
}
func (b *bucket) Delete(k string) bool {
b.mux.Lock()
_, exsist := b.values[k]
delete(b.values, k)
b.mux.Unlock()
return exsist
}
func (b *bucket) forEach(f func(k string, v interface{}) bool) bool {
success := false
b.mux.RLock()
for k, v := range b.values {
success = f(k, v)
if !success {
break
}
}
b.mux.RUnlock()
return success
}
type Map struct {
size int64
buckets []*bucket
}
func (m *Map) Get(k string) (interface{}, bool) {
i := hash(k) % uint64(len(m.buckets))
return m.buckets[i].Get(k)
}
func (m *Map) Set(k string, v interface{}) {
i := hash(k) % uint64(len(m.buckets))
if m.buckets[i].Set(k, v) {
atomic.AddInt64(&m.size, 1)
}
}
func (m *Map) Delete(k string) {
i := hash(k) % uint64(len(m.buckets))
if m.buckets[i].Delete(k) {
atomic.AddInt64(&m.size, -1)
}
}
func (m *Map) Size() int64 {
return atomic.LoadInt64(&m.size)
}
func (m *Map) ForEach(f func(k string, v interface{}) bool) {
for _, b := range m.buckets {
if !b.forEach(f) {
return
}
}
}
func NewMap(bucketNum int) *Map {
if bucketNum <= 0 {
bucketNum = 64
}
m := &Map{buckets: make([]*bucket, bucketNum)}
for i := 0; i < bucketNum; i++ {
m.buckets[i] = &bucket{values: map[string]interface{}{}}
}
return m
}
func hash(k string) uint64 {
return xxhash.Sum64String(k)
}

109
vendor/github.com/lesismal/llib/concurrent/map_test.go generated vendored Normal file
View File

@ -0,0 +1,109 @@
// Copyright 2020 lesismal. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package concurrent
import (
"fmt"
"log"
"testing"
)
func TestMap(t *testing.T) {
m := NewMap(64)
size := 100000
for i := 0; i < size; i++ {
k := fmt.Sprintf("key_%d", i)
v := fmt.Sprintf("value_%d", i)
vv, ok := m.Get(k)
if ok {
log.Fatalf("[%v] exists: '%v'", k, vv)
}
m.Set(k, v)
vv, ok = m.Get(k)
if !ok {
log.Fatalf("[%v] does not exist: '%v'", k, vv)
}
if v != vv {
log.Fatalf("invalid value: '%v' for key [%v] ", vv, k)
}
}
cnt := 0
m.ForEach(func(k string, v interface{}) bool {
if k[3:] != (v.(string))[5:] {
log.Fatalf("invalid key-value: '%v', '%v'", k, v)
}
cnt++
return true
})
if cnt != size {
log.Fatalf("invalid ForEach num: %v, want: %v", cnt, size)
}
if m.Size() != int64(size) {
log.Fatalf("invalid size: %v, want: %v", m.Size(), size)
}
for i := 0; i < size; i++ {
k := fmt.Sprintf("key_%d", i)
v := fmt.Sprintf("value_%d", i)
vv, ok := m.Get(k)
if !ok {
log.Fatalf("[%v] does not exist: '%v'", k, vv)
}
if v != vv {
log.Fatalf("invalid value: '%v' for key [%v]", vv, k)
}
m.Delete(k)
if m.Size() != int64(size-i-1) {
log.Fatalf("invalid size: %v, want: %v", m.Size(), int64(size-i-1))
}
}
for i := 0; i < size; i++ {
k := fmt.Sprintf("key_%d", i)
vv, ok := m.Get(k)
if ok {
log.Fatalf("[%v] exists: '%v'", k, vv)
}
if m.Size() != 0 {
log.Fatalf("invalid size: %v, want: %v", m.Size(), 0)
}
}
}
func BenchmarkMapSet(b *testing.B) {
m := NewMap(64)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
k := fmt.Sprintf("key_%d", i)
v := fmt.Sprintf("value_%d", i)
m.Set(k, v)
}
}
func BenchmarkMapGet(b *testing.B) {
m := NewMap(64)
for i := 0; i < b.N; i++ {
k := fmt.Sprintf("key_%d", i)
v := fmt.Sprintf("value_%d", i)
m.Set(k, v)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
k := fmt.Sprintf("key_%d", i)
v := fmt.Sprintf("value_%d", i)
vv, ok := m.Get(k)
if !ok {
log.Fatalf("[%v] does not exist: '%v'", k, vv)
}
if v != vv {
log.Fatalf("invalid value: '%v' for key [%v], want: %v", vv, k, v)
}
}
}

56
vendor/github.com/lesismal/llib/concurrent/mutex.go generated vendored Normal file
View File

@ -0,0 +1,56 @@
// Copyright 2020 lesismal. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package concurrent
import (
"sync"
)
var (
_defaultMux = NewMutex()
)
// Mutex .
type Mutex struct {
_mux sync.Mutex
_muxes map[interface{}]*sync.Mutex
}
// Lock .
func (m *Mutex) Lock(key interface{}) {
m._mux.Lock()
mux, ok := m._muxes[key]
if !ok {
mux = &sync.Mutex{}
m._muxes[key] = mux
}
m._mux.Unlock()
mux.Lock()
}
// Unlock .
func (m *Mutex) Unlock(key interface{}) {
m._mux.Lock()
mux, ok := m._muxes[key]
m._mux.Unlock()
if ok {
mux.Unlock()
}
}
// NewMutex .
func NewMutex() *Mutex {
return &Mutex{_muxes: map[interface{}]*sync.Mutex{}}
}
// // Lock .
// func Lock(key interface{}) {
// _defaultMux.Lock(key)
// }
// // Unlock .
// func Unlock(key interface{}) {
// _defaultMux.Unlock(key)
// }

View File

@ -0,0 +1,26 @@
// Copyright 2020 lesismal. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package concurrent
import (
"log"
"testing"
"time"
)
func TestMutex(t *testing.T) {
mux := NewMutex()
muxPrint := func(id int) {
for i := 0; i < 3; i++ {
mux.Lock(1)
time.Sleep(time.Second / 100)
log.Println("mux print:", id, i)
mux.Unlock(1)
}
}
go muxPrint(2)
muxPrint(1)
time.Sleep(time.Second / 10)
}

View File

@ -0,0 +1,38 @@
// Copyright 2020 lesismal. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package concurrent
import (
"log"
"testing"
"time"
)
func TestRWMutex(t *testing.T) {
rwmux := NewRWMutex()
rwmuxRLockPrint := func(id int) {
for i := 0; i < 3; i++ {
rwmux.RLock(2)
time.Sleep(time.Second / 100)
log.Println("rwmux print:", id, i)
rwmux.RUnlock(2)
}
}
go rwmuxRLockPrint(2)
rwmuxRLockPrint(1)
rwmuxLockPrint := func(id int) {
for i := 0; i < 3; i++ {
rwmux.Lock(2)
time.Sleep(time.Second / 100)
log.Println("rwmux print:", id, i)
rwmux.Unlock(2)
}
}
go rwmuxLockPrint(2)
rwmuxLockPrint(1)
time.Sleep(time.Second / 10)
}

88
vendor/github.com/lesismal/llib/concurrent/rwmutext.go generated vendored Normal file
View File

@ -0,0 +1,88 @@
// Copyright 2020 lesismal. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package concurrent
import (
"sync"
)
var (
_defaultRWMux = NewRWMutex()
)
// RWMutex .
type RWMutex struct {
_mux sync.Mutex
_rwmuxes map[interface{}]*sync.RWMutex
}
// Lock .
func (m *RWMutex) Lock(key interface{}) {
m._mux.Lock()
mux, ok := m._rwmuxes[key]
if !ok {
mux = &sync.RWMutex{}
m._rwmuxes[key] = mux
}
m._mux.Unlock()
mux.Lock()
}
// Unlock .
func (m *RWMutex) Unlock(key interface{}) {
m._mux.Lock()
mux, ok := m._rwmuxes[key]
m._mux.Unlock()
if ok {
mux.Unlock()
}
}
// RLock .
func (m *RWMutex) RLock(key interface{}) {
m._mux.Lock()
mux, ok := m._rwmuxes[key]
if !ok {
mux = &sync.RWMutex{}
m._rwmuxes[key] = mux
}
m._mux.Unlock()
mux.RLock()
}
// RUnlock .
func (m *RWMutex) RUnlock(key interface{}) {
m._mux.Lock()
mux, ok := m._rwmuxes[key]
m._mux.Unlock()
if ok {
mux.RUnlock()
}
}
// NewRWMutex .
func NewRWMutex() *RWMutex {
return &RWMutex{_rwmuxes: map[interface{}]*sync.RWMutex{}}
}
// Lock .
func Lock(key interface{}) {
_defaultRWMux.Lock(key)
}
// Unlock .
func Unlock(key interface{}) {
_defaultRWMux.Unlock(key)
}
// RLock .
func RLock(key interface{}) {
_defaultRWMux.RLock(key)
}
// RUnlock .
func RUnlock(key interface{}) {
_defaultRWMux.RUnlock(key)
}

9
vendor/github.com/lesismal/llib/go.mod generated vendored Normal file
View File

@ -0,0 +1,9 @@
module github.com/lesismal/llib
go 1.16
require (
github.com/cespare/xxhash v1.1.0 // indirect
golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5
golang.org/x/net v0.0.0-20210510120150-4163338589ed
)

17
vendor/github.com/lesismal/llib/go.sum generated vendored Normal file
View File

@ -0,0 +1,17 @@
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5 h1:N6Jp/LCiEoIBX56BZSR2bepK5GtbSC2DDOYT742mMfE=
golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210510120150-4163338589ed h1:p9UgmWI9wKpfYmgaV/IZKGdXc5qEK45tDwwwDyjS26I=
golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@ -0,0 +1,99 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import "strconv"
type alert uint8
const (
// alert level
alertLevelWarning = 1
alertLevelError = 2
)
const (
alertCloseNotify alert = 0
alertUnexpectedMessage alert = 10
alertBadRecordMAC alert = 20
alertDecryptionFailed alert = 21
alertRecordOverflow alert = 22
alertDecompressionFailure alert = 30
alertHandshakeFailure alert = 40
alertBadCertificate alert = 42
alertUnsupportedCertificate alert = 43
alertCertificateRevoked alert = 44
alertCertificateExpired alert = 45
alertCertificateUnknown alert = 46
alertIllegalParameter alert = 47
alertUnknownCA alert = 48
alertAccessDenied alert = 49
alertDecodeError alert = 50
alertDecryptError alert = 51
alertExportRestriction alert = 60
alertProtocolVersion alert = 70
alertInsufficientSecurity alert = 71
alertInternalError alert = 80
alertInappropriateFallback alert = 86
alertUserCanceled alert = 90
alertNoRenegotiation alert = 100
alertMissingExtension alert = 109
alertUnsupportedExtension alert = 110
alertCertificateUnobtainable alert = 111
alertUnrecognizedName alert = 112
alertBadCertificateStatusResponse alert = 113
alertBadCertificateHashValue alert = 114
alertUnknownPSKIdentity alert = 115
alertCertificateRequired alert = 116
alertNoApplicationProtocol alert = 120
)
var alertText = map[alert]string{
alertCloseNotify: "close notify",
alertUnexpectedMessage: "unexpected message",
alertBadRecordMAC: "bad record MAC",
alertDecryptionFailed: "decryption failed",
alertRecordOverflow: "record overflow",
alertDecompressionFailure: "decompression failure",
alertHandshakeFailure: "handshake failure",
alertBadCertificate: "bad certificate",
alertUnsupportedCertificate: "unsupported certificate",
alertCertificateRevoked: "revoked certificate",
alertCertificateExpired: "expired certificate",
alertCertificateUnknown: "unknown certificate",
alertIllegalParameter: "illegal parameter",
alertUnknownCA: "unknown certificate authority",
alertAccessDenied: "access denied",
alertDecodeError: "error decoding message",
alertDecryptError: "error decrypting message",
alertExportRestriction: "export restriction",
alertProtocolVersion: "protocol version not supported",
alertInsufficientSecurity: "insufficient security level",
alertInternalError: "internal error",
alertInappropriateFallback: "inappropriate fallback",
alertUserCanceled: "user canceled",
alertNoRenegotiation: "no renegotiation",
alertMissingExtension: "missing extension",
alertUnsupportedExtension: "unsupported extension",
alertCertificateUnobtainable: "certificate unobtainable",
alertUnrecognizedName: "unrecognized name",
alertBadCertificateStatusResponse: "bad certificate status response",
alertBadCertificateHashValue: "bad certificate hash value",
alertUnknownPSKIdentity: "unknown PSK identity",
alertCertificateRequired: "certificate required",
alertNoApplicationProtocol: "no application protocol",
}
func (e alert) String() string {
s, ok := alertText[e]
if ok {
return "tls: " + s
}
return "tls: alert(" + strconv.Itoa(int(e)) + ")"
}
func (e alert) Error() string {
return e.String()
}

289
vendor/github.com/lesismal/llib/std/crypto/tls/auth.go generated vendored Normal file
View File

@ -0,0 +1,289 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"errors"
"fmt"
"hash"
"io"
)
// verifyHandshakeSignature verifies a signature against pre-hashed
// (if required) handshake contents.
func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error {
switch sigType {
case signatureECDSA:
pubKey, ok := pubkey.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("expected an ECDSA public key, got %T", pubkey)
}
if !ecdsa.VerifyASN1(pubKey, signed, sig) {
return errors.New("ECDSA verification failure")
}
case signatureEd25519:
pubKey, ok := pubkey.(ed25519.PublicKey)
if !ok {
return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey)
}
if !ed25519.Verify(pubKey, signed, sig) {
return errors.New("Ed25519 verification failure")
}
case signaturePKCS1v15:
pubKey, ok := pubkey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
}
if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil {
return err
}
case signatureRSAPSS:
pubKey, ok := pubkey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
}
signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}
if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil {
return err
}
default:
return errors.New("internal error: unknown signature type")
}
return nil
}
const (
serverSignatureContext = "TLS 1.3, server CertificateVerify\x00"
clientSignatureContext = "TLS 1.3, client CertificateVerify\x00"
)
var signaturePadding = []byte{
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
}
// signedMessage returns the pre-hashed (if necessary) message to be signed by
// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3.
func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte {
if sigHash == directSigning {
b := &bytes.Buffer{}
b.Write(signaturePadding)
io.WriteString(b, context)
b.Write(transcript.Sum(nil))
return b.Bytes()
}
h := sigHash.New()
h.Write(signaturePadding)
io.WriteString(h, context)
h.Write(transcript.Sum(nil))
return h.Sum(nil)
}
// typeAndHashFromSignatureScheme returns the corresponding signature type and
// crypto.Hash for a given TLS SignatureScheme.
func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) {
switch signatureAlgorithm {
case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512:
sigType = signaturePKCS1v15
case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512:
sigType = signatureRSAPSS
case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512:
sigType = signatureECDSA
case Ed25519:
sigType = signatureEd25519
default:
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
}
switch signatureAlgorithm {
case PKCS1WithSHA1, ECDSAWithSHA1:
hash = crypto.SHA1
case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256:
hash = crypto.SHA256
case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384:
hash = crypto.SHA384
case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512:
hash = crypto.SHA512
case Ed25519:
hash = directSigning
default:
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
}
return sigType, hash, nil
}
// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for
// a given public key used with TLS 1.0 and 1.1, before the introduction of
// signature algorithm negotiation.
func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) {
switch pub.(type) {
case *rsa.PublicKey:
return signaturePKCS1v15, crypto.MD5SHA1, nil
case *ecdsa.PublicKey:
return signatureECDSA, crypto.SHA1, nil
case ed25519.PublicKey:
// RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1,
// but it requires holding on to a handshake transcript to do a
// full signature, and not even OpenSSL bothers with the
// complexity, so we can't even test it properly.
return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2")
default:
return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub)
}
}
var rsaSignatureSchemes = []struct {
scheme SignatureScheme
minModulusBytes int
maxVersion uint16
}{
// RSA-PSS is used with PSSSaltLengthEqualsHash, and requires
// emLen >= hLen + sLen + 2
{PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13},
{PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13},
{PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13},
// PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires
// emLen >= len(prefix) + hLen + 11
// TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS.
{PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12},
{PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12},
{PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12},
{PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12},
}
// signatureSchemesForCertificate returns the list of supported SignatureSchemes
// for a given certificate, based on the public key and the protocol version,
// and optionally filtered by its explicit SupportedSignatureAlgorithms.
//
// This function must be kept in sync with supportedSignatureAlgorithms.
func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme {
priv, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil
}
var sigAlgs []SignatureScheme
switch pub := priv.Public().(type) {
case *ecdsa.PublicKey:
if version != VersionTLS13 {
// In TLS 1.2 and earlier, ECDSA algorithms are not
// constrained to a single curve.
sigAlgs = []SignatureScheme{
ECDSAWithP256AndSHA256,
ECDSAWithP384AndSHA384,
ECDSAWithP521AndSHA512,
ECDSAWithSHA1,
}
break
}
switch pub.Curve {
case elliptic.P256():
sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256}
case elliptic.P384():
sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384}
case elliptic.P521():
sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512}
default:
return nil
}
case *rsa.PublicKey:
size := pub.Size()
sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes))
for _, candidate := range rsaSignatureSchemes {
if size >= candidate.minModulusBytes && version <= candidate.maxVersion {
sigAlgs = append(sigAlgs, candidate.scheme)
}
}
case ed25519.PublicKey:
sigAlgs = []SignatureScheme{Ed25519}
default:
return nil
}
if cert.SupportedSignatureAlgorithms != nil {
var filteredSigAlgs []SignatureScheme
for _, sigAlg := range sigAlgs {
if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) {
filteredSigAlgs = append(filteredSigAlgs, sigAlg)
}
}
return filteredSigAlgs
}
return sigAlgs
}
// selectSignatureScheme picks a SignatureScheme from the peer's preference list
// that works with the selected certificate. It's only called for protocol
// versions that support signature algorithms, so TLS 1.2 and 1.3.
func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) {
supportedAlgs := signatureSchemesForCertificate(vers, c)
if len(supportedAlgs) == 0 {
return 0, unsupportedCertificateError(c)
}
if len(peerAlgs) == 0 && vers == VersionTLS12 {
// For TLS 1.2, if the client didn't send signature_algorithms then we
// can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1.
peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1}
}
// Pick signature scheme in the peer's preference order, as our
// preference order is not configurable.
for _, preferredAlg := range peerAlgs {
if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) {
return preferredAlg, nil
}
}
return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms")
}
// unsupportedCertificateError returns a helpful error for certificates with
// an unsupported private key.
func unsupportedCertificateError(cert *Certificate) error {
switch cert.PrivateKey.(type) {
case rsa.PrivateKey, ecdsa.PrivateKey:
return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T",
cert.PrivateKey, cert.PrivateKey)
case *ed25519.PrivateKey:
return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey")
}
signer, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer",
cert.PrivateKey)
}
switch pub := signer.Public().(type) {
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P256():
case elliptic.P384():
case elliptic.P521():
default:
return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name)
}
case *rsa.PublicKey:
return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms")
case ed25519.PublicKey:
default:
return fmt.Errorf("tls: unsupported certificate key (%T)", pub)
}
if cert.SupportedSignatureAlgorithms != nil {
return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms")
}
return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey)
}

View File

@ -0,0 +1,168 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto"
"testing"
)
func TestSignatureSelection(t *testing.T) {
rsaCert := &Certificate{
Certificate: [][]byte{testRSACertificate},
PrivateKey: testRSAPrivateKey,
}
pkcs1Cert := &Certificate{
Certificate: [][]byte{testRSACertificate},
PrivateKey: testRSAPrivateKey,
SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256},
}
ecdsaCert := &Certificate{
Certificate: [][]byte{testP256Certificate},
PrivateKey: testP256PrivateKey,
}
ed25519Cert := &Certificate{
Certificate: [][]byte{testEd25519Certificate},
PrivateKey: testEd25519PrivateKey,
}
tests := []struct {
cert *Certificate
peerSigAlgs []SignatureScheme
tlsVersion uint16
expectedSigAlg SignatureScheme
expectedSigType uint8
expectedHash crypto.Hash
}{
{rsaCert, []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}, VersionTLS12, PKCS1WithSHA1, signaturePKCS1v15, crypto.SHA1},
{rsaCert, []SignatureScheme{PKCS1WithSHA512, PKCS1WithSHA1}, VersionTLS12, PKCS1WithSHA512, signaturePKCS1v15, crypto.SHA512},
{rsaCert, []SignatureScheme{PSSWithSHA256, PKCS1WithSHA256}, VersionTLS12, PSSWithSHA256, signatureRSAPSS, crypto.SHA256},
{pkcs1Cert, []SignatureScheme{PSSWithSHA256, PKCS1WithSHA256}, VersionTLS12, PKCS1WithSHA256, signaturePKCS1v15, crypto.SHA256},
{rsaCert, []SignatureScheme{PSSWithSHA384, PKCS1WithSHA1}, VersionTLS13, PSSWithSHA384, signatureRSAPSS, crypto.SHA384},
{ecdsaCert, []SignatureScheme{ECDSAWithSHA1}, VersionTLS12, ECDSAWithSHA1, signatureECDSA, crypto.SHA1},
{ecdsaCert, []SignatureScheme{ECDSAWithP256AndSHA256}, VersionTLS12, ECDSAWithP256AndSHA256, signatureECDSA, crypto.SHA256},
{ecdsaCert, []SignatureScheme{ECDSAWithP256AndSHA256}, VersionTLS13, ECDSAWithP256AndSHA256, signatureECDSA, crypto.SHA256},
{ed25519Cert, []SignatureScheme{Ed25519}, VersionTLS12, Ed25519, signatureEd25519, directSigning},
{ed25519Cert, []SignatureScheme{Ed25519}, VersionTLS13, Ed25519, signatureEd25519, directSigning},
// TLS 1.2 without signature_algorithms extension
{rsaCert, nil, VersionTLS12, PKCS1WithSHA1, signaturePKCS1v15, crypto.SHA1},
{ecdsaCert, nil, VersionTLS12, ECDSAWithSHA1, signatureECDSA, crypto.SHA1},
// TLS 1.2 does not restrict the ECDSA curve (our ecdsaCert is P-256)
{ecdsaCert, []SignatureScheme{ECDSAWithP384AndSHA384}, VersionTLS12, ECDSAWithP384AndSHA384, signatureECDSA, crypto.SHA384},
}
for testNo, test := range tests {
sigAlg, err := selectSignatureScheme(test.tlsVersion, test.cert, test.peerSigAlgs)
if err != nil {
t.Errorf("test[%d]: unexpected selectSignatureScheme error: %v", testNo, err)
}
if test.expectedSigAlg != sigAlg {
t.Errorf("test[%d]: expected signature scheme %v, got %v", testNo, test.expectedSigAlg, sigAlg)
}
sigType, hashFunc, err := typeAndHashFromSignatureScheme(sigAlg)
if err != nil {
t.Errorf("test[%d]: unexpected typeAndHashFromSignatureScheme error: %v", testNo, err)
}
if test.expectedSigType != sigType {
t.Errorf("test[%d]: expected signature algorithm %#x, got %#x", testNo, test.expectedSigType, sigType)
}
if test.expectedHash != hashFunc {
t.Errorf("test[%d]: expected hash function %#x, got %#x", testNo, test.expectedHash, hashFunc)
}
}
brokenCert := &Certificate{
Certificate: [][]byte{testRSACertificate},
PrivateKey: testRSAPrivateKey,
SupportedSignatureAlgorithms: []SignatureScheme{Ed25519},
}
badTests := []struct {
cert *Certificate
peerSigAlgs []SignatureScheme
tlsVersion uint16
}{
{rsaCert, []SignatureScheme{ECDSAWithP256AndSHA256, ECDSAWithSHA1}, VersionTLS12},
{ecdsaCert, []SignatureScheme{PKCS1WithSHA256, PKCS1WithSHA1}, VersionTLS12},
{rsaCert, []SignatureScheme{0}, VersionTLS12},
{ed25519Cert, []SignatureScheme{ECDSAWithP256AndSHA256, ECDSAWithSHA1}, VersionTLS12},
{ecdsaCert, []SignatureScheme{Ed25519}, VersionTLS12},
{brokenCert, []SignatureScheme{Ed25519}, VersionTLS12},
{brokenCert, []SignatureScheme{PKCS1WithSHA256}, VersionTLS12},
// RFC 5246, Section 7.4.1.4.1, says to only consider {sha1,ecdsa} as
// default when the extension is missing, and RFC 8422 does not update
// it. Anyway, if a stack supports Ed25519 it better support sigalgs.
{ed25519Cert, nil, VersionTLS12},
// TLS 1.3 has no default signature_algorithms.
{rsaCert, nil, VersionTLS13},
{ecdsaCert, nil, VersionTLS13},
{ed25519Cert, nil, VersionTLS13},
// Wrong curve, which TLS 1.3 checks
{ecdsaCert, []SignatureScheme{ECDSAWithP384AndSHA384}, VersionTLS13},
// TLS 1.3 does not support PKCS1v1.5 or SHA-1.
{rsaCert, []SignatureScheme{PKCS1WithSHA256}, VersionTLS13},
{pkcs1Cert, []SignatureScheme{PSSWithSHA256, PKCS1WithSHA256}, VersionTLS13},
{ecdsaCert, []SignatureScheme{ECDSAWithSHA1}, VersionTLS13},
// The key can be too small for the hash.
{rsaCert, []SignatureScheme{PSSWithSHA512}, VersionTLS12},
}
for testNo, test := range badTests {
sigAlg, err := selectSignatureScheme(test.tlsVersion, test.cert, test.peerSigAlgs)
if err == nil {
t.Errorf("test[%d]: unexpected success, got %v", testNo, sigAlg)
}
}
}
func TestLegacyTypeAndHash(t *testing.T) {
sigType, hashFunc, err := legacyTypeAndHashFromPublicKey(testRSAPrivateKey.Public())
if err != nil {
t.Errorf("RSA: unexpected error: %v", err)
}
if expectedSigType := signaturePKCS1v15; expectedSigType != sigType {
t.Errorf("RSA: expected signature type %#x, got %#x", expectedSigType, sigType)
}
if expectedHashFunc := crypto.MD5SHA1; expectedHashFunc != hashFunc {
t.Errorf("RSA: expected hash %#x, got %#x", expectedHashFunc, hashFunc)
}
sigType, hashFunc, err = legacyTypeAndHashFromPublicKey(testECDSAPrivateKey.Public())
if err != nil {
t.Errorf("ECDSA: unexpected error: %v", err)
}
if expectedSigType := signatureECDSA; expectedSigType != sigType {
t.Errorf("ECDSA: expected signature type %#x, got %#x", expectedSigType, sigType)
}
if expectedHashFunc := crypto.SHA1; expectedHashFunc != hashFunc {
t.Errorf("ECDSA: expected hash %#x, got %#x", expectedHashFunc, hashFunc)
}
// Ed25519 is not supported by TLS 1.0 and 1.1.
_, _, err = legacyTypeAndHashFromPublicKey(testEd25519PrivateKey.Public())
if err == nil {
t.Errorf("Ed25519: unexpected success")
}
}
// TestSupportedSignatureAlgorithms checks that all supportedSignatureAlgorithms
// have valid type and hash information.
func TestSupportedSignatureAlgorithms(t *testing.T) {
for _, sigAlg := range supportedSignatureAlgorithms {
sigType, hash, err := typeAndHashFromSignatureScheme(sigAlg)
if err != nil {
t.Errorf("%v: unexpected error: %v", sigAlg, err)
}
if sigType == 0 {
t.Errorf("%v: missing signature type", sigAlg)
}
if hash == 0 && sigAlg != Ed25519 {
t.Errorf("%v: missing hash", sigAlg)
}
}
}

View File

@ -0,0 +1,516 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/hmac"
"crypto/rc4"
"crypto/sha1"
"crypto/sha256"
"crypto/x509"
"fmt"
"hash"
"golang.org/x/crypto/chacha20poly1305"
)
// CipherSuite is a TLS cipher suite. Note that most functions in this package
// accept and expose cipher suite IDs instead of this type.
type CipherSuite struct {
ID uint16
Name string
// Supported versions is the list of TLS protocol versions that can
// negotiate this cipher suite.
SupportedVersions []uint16
// Insecure is true if the cipher suite has known security issues
// due to its primitives, design, or implementation.
Insecure bool
}
var (
supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12}
supportedOnlyTLS12 = []uint16{VersionTLS12}
supportedOnlyTLS13 = []uint16{VersionTLS13}
)
// CipherSuites returns a list of cipher suites currently implemented by this
// package, excluding those with security issues, which are returned by
// InsecureCipherSuites.
//
// The list is sorted by ID. Note that the default cipher suites selected by
// this package might depend on logic that can't be captured by a static list.
func CipherSuites() []*CipherSuite {
return []*CipherSuite{
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, false},
{TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false},
{TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false},
{TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
}
}
// InsecureCipherSuites returns a list of cipher suites currently implemented by
// this package and which have security issues.
//
// Most applications should not use the cipher suites in this list, and should
// only use those returned by CipherSuites.
func InsecureCipherSuites() []*CipherSuite {
// RC4 suites are broken because RC4 is.
// CBC-SHA256 suites have no Lucky13 countermeasures.
return []*CipherSuite{
{TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
}
}
// CipherSuiteName returns the standard name for the passed cipher suite ID
// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation
// of the ID value if the cipher suite is not implemented by this package.
func CipherSuiteName(id uint16) string {
for _, c := range CipherSuites() {
if c.ID == id {
return c.Name
}
}
for _, c := range InsecureCipherSuites() {
if c.ID == id {
return c.Name
}
}
return fmt.Sprintf("0x%04X", id)
}
// a keyAgreement implements the client and server side of a TLS key agreement
// protocol by generating and processing key exchange messages.
type keyAgreement interface {
// On the server side, the first two methods are called in order.
// In the case that the key agreement protocol doesn't use a
// ServerKeyExchange message, generateServerKeyExchange can return nil,
// nil.
generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error)
processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error)
// On the client side, the next two methods are called in order.
// This method may not be called if the server doesn't send a
// ServerKeyExchange message.
processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error
generateClientKeyExchange(*Config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error)
}
const (
// suiteECDHE indicates that the cipher suite involves elliptic curve
// Diffie-Hellman. This means that it should only be selected when the
// client indicates that it supports ECC with a curve and point format
// that we're happy with.
suiteECDHE = 1 << iota
// suiteECSign indicates that the cipher suite involves an ECDSA or
// EdDSA signature and therefore may only be selected when the server's
// certificate is ECDSA or EdDSA. If this is not set then the cipher suite
// is RSA based.
suiteECSign
// suiteTLS12 indicates that the cipher suite should only be advertised
// and accepted when using TLS 1.2.
suiteTLS12
// suiteSHA384 indicates that the cipher suite uses SHA384 as the
// handshake hash.
suiteSHA384
// suiteDefaultOff indicates that this cipher suite is not included by
// default.
suiteDefaultOff
)
// A cipherSuite is a specific combination of key agreement, cipher and MAC function.
type cipherSuite struct {
id uint16
// the lengths, in bytes, of the key material needed for each component.
keyLen int
macLen int
ivLen int
ka func(version uint16) keyAgreement
// flags is a bitmask of the suite* values, above.
flags int
cipher func(key, iv []byte, isRead bool) interface{}
mac func(key []byte) hash.Hash
aead func(key, fixedNonce []byte) aead
}
var cipherSuites = []*cipherSuite{
// Ciphersuite order is chosen so that ECDHE comes before plain RSA and
// AEADs are the top preference.
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM},
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM},
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
{TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM},
{TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil},
{TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
{TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil},
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil},
// RC4-based cipher suites are disabled by default.
{TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, suiteDefaultOff, cipherRC4, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE | suiteDefaultOff, cipherRC4, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteDefaultOff, cipherRC4, macSHA1, nil},
}
// selectCipherSuite returns the first cipher suite from ids which is also in
// supportedIDs and passes the ok filter.
func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite {
for _, id := range ids {
candidate := cipherSuiteByID(id)
if candidate == nil || !ok(candidate) {
continue
}
for _, suppID := range supportedIDs {
if id == suppID {
return candidate
}
}
}
return nil
}
// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash
// algorithm to be used with HKDF. See RFC 8446, Appendix B.4.
type cipherSuiteTLS13 struct {
id uint16
keyLen int
aead func(key, fixedNonce []byte) aead
hash crypto.Hash
}
var cipherSuitesTLS13 = []*cipherSuiteTLS13{
{TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256},
{TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256},
{TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384},
}
func cipherRC4(key, iv []byte, isRead bool) interface{} {
cipher, _ := rc4.NewCipher(key)
return cipher
}
func cipher3DES(key, iv []byte, isRead bool) interface{} {
block, _ := des.NewTripleDESCipher(key)
if isRead {
return cipher.NewCBCDecrypter(block, iv)
}
return cipher.NewCBCEncrypter(block, iv)
}
func cipherAES(key, iv []byte, isRead bool) interface{} {
block, _ := aes.NewCipher(key)
if isRead {
return cipher.NewCBCDecrypter(block, iv)
}
return cipher.NewCBCEncrypter(block, iv)
}
// macSHA1 returns a SHA-1 based constant time MAC.
func macSHA1(key []byte) hash.Hash {
return hmac.New(newConstantTimeHash(sha1.New), key)
}
// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and
// is currently only used in disabled-by-default cipher suites.
func macSHA256(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}
type aead interface {
cipher.AEAD
// explicitNonceLen returns the number of bytes of explicit nonce
// included in each record. This is eight for older AEADs and
// zero for modern ones.
explicitNonceLen() int
}
const (
aeadNonceLength = 12
noncePrefixLength = 4
)
// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to
// each call.
type prefixNonceAEAD struct {
// nonce contains the fixed part of the nonce in the first four bytes.
nonce [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength }
func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() }
func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
copy(f.nonce[4:], nonce)
return f.aead.Seal(out, f.nonce[:], plaintext, additionalData)
}
func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
copy(f.nonce[4:], nonce)
return f.aead.Open(out, f.nonce[:], ciphertext, additionalData)
}
// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
// before each call.
type xorNonceAEAD struct {
nonceMask [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result
}
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result, err
}
func aeadAESGCM(key, noncePrefix []byte) aead {
if len(noncePrefix) != noncePrefixLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
if err != nil {
panic(err)
}
ret := &prefixNonceAEAD{aead: aead}
copy(ret.nonce[:], noncePrefix)
return ret
}
func aeadAESGCMTLS13(key, nonceMask []byte) aead {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
func aeadChaCha20Poly1305(key, nonceMask []byte) aead {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aead, err := chacha20poly1305.New(key)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
type constantTimeHash interface {
hash.Hash
ConstantTimeSum(b []byte) []byte
}
// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces
// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC.
type cthWrapper struct {
h constantTimeHash
}
func (c *cthWrapper) Size() int { return c.h.Size() }
func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() }
func (c *cthWrapper) Reset() { c.h.Reset() }
func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) }
func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) }
func newConstantTimeHash(h func() hash.Hash) func() hash.Hash {
return func() hash.Hash {
return &cthWrapper{h().(constantTimeHash)}
}
}
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3.
func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte {
h.Reset()
h.Write(seq)
h.Write(header)
h.Write(data)
res := h.Sum(out)
if extra != nil {
h.Write(extra)
}
return res
}
func rsaKA(version uint16) keyAgreement {
return rsaKeyAgreement{}
}
func ecdheECDSAKA(version uint16) keyAgreement {
return &ecdheKeyAgreement{
isRSA: false,
version: version,
}
}
func ecdheRSAKA(version uint16) keyAgreement {
return &ecdheKeyAgreement{
isRSA: true,
version: version,
}
}
// mutualCipherSuite returns a cipherSuite given a list of supported
// ciphersuites and the id requested by the peer.
func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
for _, id := range have {
if id == want {
return cipherSuiteByID(id)
}
}
return nil
}
func cipherSuiteByID(id uint16) *cipherSuite {
for _, cipherSuite := range cipherSuites {
if cipherSuite.id == id {
return cipherSuite
}
}
return nil
}
func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 {
for _, id := range have {
if id == want {
return cipherSuiteTLS13ByID(id)
}
}
return nil
}
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 {
for _, cipherSuite := range cipherSuitesTLS13 {
if cipherSuite.id == id {
return cipherSuite
}
}
return nil
}
// A list of cipher suite IDs that are, or have been, implemented by this
// package.
//
// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml
const (
// TLS 1.0 - 1.2 cipher suites.
TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005
TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a
TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f
TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035
TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c
TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c
TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a
TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9
// TLS 1.3 cipher suites.
TLS_AES_128_GCM_SHA256 uint16 = 0x1301
TLS_AES_256_GCM_SHA384 uint16 = 0x1302
TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303
// TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator
// that the client is doing version fallback. See RFC 7507.
TLS_FALLBACK_SCSV uint16 = 0x5600
// Legacy names for the corresponding cipher suites with the correct _SHA256
// suffix, retained for backward compatibility.
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
)

1563
vendor/github.com/lesismal/llib/std/crypto/tls/common.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,116 @@
// Code generated by "stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go"; DO NOT EDIT.
package tls
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[PKCS1WithSHA256-1025]
_ = x[PKCS1WithSHA384-1281]
_ = x[PKCS1WithSHA512-1537]
_ = x[PSSWithSHA256-2052]
_ = x[PSSWithSHA384-2053]
_ = x[PSSWithSHA512-2054]
_ = x[ECDSAWithP256AndSHA256-1027]
_ = x[ECDSAWithP384AndSHA384-1283]
_ = x[ECDSAWithP521AndSHA512-1539]
_ = x[Ed25519-2055]
_ = x[PKCS1WithSHA1-513]
_ = x[ECDSAWithSHA1-515]
}
const (
_SignatureScheme_name_0 = "PKCS1WithSHA1"
_SignatureScheme_name_1 = "ECDSAWithSHA1"
_SignatureScheme_name_2 = "PKCS1WithSHA256"
_SignatureScheme_name_3 = "ECDSAWithP256AndSHA256"
_SignatureScheme_name_4 = "PKCS1WithSHA384"
_SignatureScheme_name_5 = "ECDSAWithP384AndSHA384"
_SignatureScheme_name_6 = "PKCS1WithSHA512"
_SignatureScheme_name_7 = "ECDSAWithP521AndSHA512"
_SignatureScheme_name_8 = "PSSWithSHA256PSSWithSHA384PSSWithSHA512Ed25519"
)
var (
_SignatureScheme_index_8 = [...]uint8{0, 13, 26, 39, 46}
)
func (i SignatureScheme) String() string {
switch {
case i == 513:
return _SignatureScheme_name_0
case i == 515:
return _SignatureScheme_name_1
case i == 1025:
return _SignatureScheme_name_2
case i == 1027:
return _SignatureScheme_name_3
case i == 1281:
return _SignatureScheme_name_4
case i == 1283:
return _SignatureScheme_name_5
case i == 1537:
return _SignatureScheme_name_6
case i == 1539:
return _SignatureScheme_name_7
case 2052 <= i && i <= 2055:
i -= 2052
return _SignatureScheme_name_8[_SignatureScheme_index_8[i]:_SignatureScheme_index_8[i+1]]
default:
return "SignatureScheme(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[CurveP256-23]
_ = x[CurveP384-24]
_ = x[CurveP521-25]
_ = x[X25519-29]
}
const (
_CurveID_name_0 = "CurveP256CurveP384CurveP521"
_CurveID_name_1 = "X25519"
)
var (
_CurveID_index_0 = [...]uint8{0, 9, 18, 27}
)
func (i CurveID) String() string {
switch {
case 23 <= i && i <= 25:
i -= 23
return _CurveID_name_0[_CurveID_index_0[i]:_CurveID_index_0[i+1]]
case i == 29:
return _CurveID_name_1
default:
return "CurveID(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[NoClientCert-0]
_ = x[RequestClientCert-1]
_ = x[RequireAnyClientCert-2]
_ = x[VerifyClientCertIfGiven-3]
_ = x[RequireAndVerifyClientCert-4]
}
const _ClientAuthType_name = "NoClientCertRequestClientCertRequireAnyClientCertVerifyClientCertIfGivenRequireAndVerifyClientCert"
var _ClientAuthType_index = [...]uint8{0, 12, 29, 49, 72, 98}
func (i ClientAuthType) String() string {
if i < 0 || i >= ClientAuthType(len(_ClientAuthType_index)-1) {
return "ClientAuthType(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _ClientAuthType_name[_ClientAuthType_index[i]:_ClientAuthType_index[i+1]]
}

1775
vendor/github.com/lesismal/llib/std/crypto/tls/conn.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,287 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"io"
"net"
"testing"
)
func TestRoundUp(t *testing.T) {
if roundUp(0, 16) != 0 ||
roundUp(1, 16) != 16 ||
roundUp(15, 16) != 16 ||
roundUp(16, 16) != 16 ||
roundUp(17, 16) != 32 {
t.Error("roundUp broken")
}
}
// will be initialized with {0, 255, 255, ..., 255}
var padding255Bad = [256]byte{}
// will be initialized with {255, 255, 255, ..., 255}
var padding255Good = [256]byte{255}
var paddingTests = []struct {
in []byte
good bool
expectedLen int
}{
{[]byte{1, 2, 3, 4, 0}, true, 4},
{[]byte{1, 2, 3, 4, 0, 1}, false, 0},
{[]byte{1, 2, 3, 4, 99, 99}, false, 0},
{[]byte{1, 2, 3, 4, 1, 1}, true, 4},
{[]byte{1, 2, 3, 2, 2, 2}, true, 3},
{[]byte{1, 2, 3, 3, 3, 3}, true, 2},
{[]byte{1, 2, 3, 4, 3, 3}, false, 0},
{[]byte{1, 4, 4, 4, 4, 4}, true, 1},
{[]byte{5, 5, 5, 5, 5, 5}, true, 0},
{[]byte{6, 6, 6, 6, 6, 6}, false, 0},
{padding255Bad[:], false, 0},
{padding255Good[:], true, 0},
}
func TestRemovePadding(t *testing.T) {
for i := 1; i < len(padding255Bad); i++ {
padding255Bad[i] = 255
padding255Good[i] = 255
}
for i, test := range paddingTests {
paddingLen, good := extractPadding(test.in)
expectedGood := byte(255)
if !test.good {
expectedGood = 0
}
if good != expectedGood {
t.Errorf("#%d: wrong validity, want:%d got:%d", i, expectedGood, good)
}
if good == 255 && len(test.in)-paddingLen != test.expectedLen {
t.Errorf("#%d: got %d, want %d", i, len(test.in)-paddingLen, test.expectedLen)
}
}
}
var certExampleCom = `308201713082011ba003020102021005a75ddf21014d5f417083b7a010ba2e300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343135335a170d3137303831373231343135335a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100b37f0fdd67e715bf532046ac34acbd8fdc4dabe2b598588f3f58b1f12e6219a16cbfe54d2b4b665396013589262360b6721efa27d546854f17cc9aeec6751db10203010001a34d304b300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030160603551d11040f300d820b6578616d706c652e636f6d300d06092a864886f70d01010b050003410059fc487866d3d855503c8e064ca32aac5e9babcece89ec597f8b2b24c17867f4a5d3b4ece06e795bfc5448ccbd2ffca1b3433171ebf3557a4737b020565350a0`
var certWildcardExampleCom = `308201743082011ea003020102021100a7aa6297c9416a4633af8bec2958c607300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343231395a170d3137303831373231343231395a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100b105afc859a711ee864114e7d2d46c2dcbe392d3506249f6c2285b0eb342cc4bf2d803677c61c0abde443f084745c1a6d62080e5664ef2cc8f50ad8a0ab8870b0203010001a34f304d300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030180603551d110411300f820d2a2e6578616d706c652e636f6d300d06092a864886f70d01010b0500034100af26088584d266e3f6566360cf862c7fecc441484b098b107439543144a2b93f20781988281e108c6d7656934e56950e1e5f2bcf38796b814ccb729445856c34`
var certFooExampleCom = `308201753082011fa00302010202101bbdb6070b0aeffc49008cde74deef29300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343234345a170d3137303831373231343234345a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100f00ac69d8ca2829f26216c7b50f1d4bbabad58d447706476cd89a2f3e1859943748aa42c15eedc93ac7c49e40d3b05ed645cb6b81c4efba60d961f44211a54eb0203010001a351304f300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000301a0603551d1104133011820f666f6f2e6578616d706c652e636f6d300d06092a864886f70d01010b0500034100a0957fca6d1e0f1ef4b247348c7a8ca092c29c9c0ecc1898ea6b8065d23af6d922a410dd2335a0ea15edd1394cef9f62c9e876a21e35250a0b4fe1ddceba0f36`
func TestCertificateSelection(t *testing.T) {
config := Config{
Certificates: []Certificate{
{
Certificate: [][]byte{fromHex(certExampleCom)},
},
{
Certificate: [][]byte{fromHex(certWildcardExampleCom)},
},
{
Certificate: [][]byte{fromHex(certFooExampleCom)},
},
},
}
config.BuildNameToCertificate()
pointerToIndex := func(c *Certificate) int {
for i := range config.Certificates {
if c == &config.Certificates[i] {
return i
}
}
return -1
}
certificateForName := func(name string) *Certificate {
clientHello := &ClientHelloInfo{
ServerName: name,
}
if cert, err := config.getCertificate(clientHello); err != nil {
t.Errorf("unable to get certificate for name '%s': %s", name, err)
return nil
} else {
return cert
}
}
if n := pointerToIndex(certificateForName("example.com")); n != 0 {
t.Errorf("example.com returned certificate %d, not 0", n)
}
if n := pointerToIndex(certificateForName("bar.example.com")); n != 1 {
t.Errorf("bar.example.com returned certificate %d, not 1", n)
}
if n := pointerToIndex(certificateForName("foo.example.com")); n != 2 {
t.Errorf("foo.example.com returned certificate %d, not 2", n)
}
if n := pointerToIndex(certificateForName("foo.bar.example.com")); n != 0 {
t.Errorf("foo.bar.example.com returned certificate %d, not 0", n)
}
}
// Run with multiple crypto configs to test the logic for computing TLS record overheads.
func runDynamicRecordSizingTest(t *testing.T, config *Config) {
clientConn, serverConn := localPipe(t)
serverConfig := config.Clone()
serverConfig.DynamicRecordSizingDisabled = false
tlsConn := Server(serverConn, serverConfig)
handshakeDone := make(chan struct{})
recordSizesChan := make(chan []int, 1)
defer func() { <-recordSizesChan }() // wait for the goroutine to exit
go func() {
// This goroutine performs a TLS handshake over clientConn and
// then reads TLS records until EOF. It writes a slice that
// contains all the record sizes to recordSizesChan.
defer close(recordSizesChan)
defer clientConn.Close()
tlsConn := Client(clientConn, config)
if err := tlsConn.Handshake(); err != nil {
t.Errorf("Error from client handshake: %v", err)
return
}
close(handshakeDone)
var recordHeader [recordHeaderLen]byte
var record []byte
var recordSizes []int
for {
n, err := io.ReadFull(clientConn, recordHeader[:])
if err == io.EOF {
break
}
if err != nil || n != len(recordHeader) {
t.Errorf("io.ReadFull = %d, %v", n, err)
return
}
length := int(recordHeader[3])<<8 | int(recordHeader[4])
if len(record) < length {
record = make([]byte, length)
}
n, err = io.ReadFull(clientConn, record[:length])
if err != nil || n != length {
t.Errorf("io.ReadFull = %d, %v", n, err)
return
}
recordSizes = append(recordSizes, recordHeaderLen+length)
}
recordSizesChan <- recordSizes
}()
if err := tlsConn.Handshake(); err != nil {
t.Fatalf("Error from server handshake: %s", err)
}
<-handshakeDone
// The server writes these plaintexts in order.
plaintext := bytes.Join([][]byte{
bytes.Repeat([]byte("x"), recordSizeBoostThreshold),
bytes.Repeat([]byte("y"), maxPlaintext*2),
bytes.Repeat([]byte("z"), maxPlaintext),
}, nil)
if _, err := tlsConn.Write(plaintext); err != nil {
t.Fatalf("Error from server write: %s", err)
}
if err := tlsConn.Close(); err != nil {
t.Fatalf("Error from server close: %s", err)
}
recordSizes := <-recordSizesChan
if recordSizes == nil {
t.Fatalf("Client encountered an error")
}
// Drop the size of the second to last record, which is likely to be
// truncated, and the last record, which is a close_notify alert.
recordSizes = recordSizes[:len(recordSizes)-2]
// recordSizes should contain a series of records smaller than
// tcpMSSEstimate followed by some larger than maxPlaintext.
seenLargeRecord := false
for i, size := range recordSizes {
if !seenLargeRecord {
if size > (i+1)*tcpMSSEstimate {
t.Fatalf("Record #%d has size %d, which is too large too soon", i, size)
}
if size >= maxPlaintext {
seenLargeRecord = true
}
} else if size <= maxPlaintext {
t.Fatalf("Record #%d has size %d but should be full sized", i, size)
}
}
if !seenLargeRecord {
t.Fatalf("No large records observed")
}
}
func TestDynamicRecordSizingWithStreamCipher(t *testing.T) {
config := testConfig.Clone()
config.MaxVersion = VersionTLS12
config.CipherSuites = []uint16{TLS_RSA_WITH_RC4_128_SHA}
runDynamicRecordSizingTest(t, config)
}
func TestDynamicRecordSizingWithCBC(t *testing.T) {
config := testConfig.Clone()
config.MaxVersion = VersionTLS12
config.CipherSuites = []uint16{TLS_RSA_WITH_AES_256_CBC_SHA}
runDynamicRecordSizingTest(t, config)
}
func TestDynamicRecordSizingWithAEAD(t *testing.T) {
config := testConfig.Clone()
config.MaxVersion = VersionTLS12
config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}
runDynamicRecordSizingTest(t, config)
}
func TestDynamicRecordSizingWithTLSv13(t *testing.T) {
config := testConfig.Clone()
runDynamicRecordSizingTest(t, config)
}
// hairpinConn is a net.Conn that makes a “hairpin” call when closed, back into
// the tls.Conn which is calling it.
type hairpinConn struct {
net.Conn
tlsConn *Conn
}
func (conn *hairpinConn) Close() error {
conn.tlsConn.ConnectionState()
return nil
}
func TestHairpinInClose(t *testing.T) {
// This tests that the underlying net.Conn can call back into the
// tls.Conn when being closed without deadlocking.
client, server := localPipe(t)
defer server.Close()
defer client.Close()
conn := &hairpinConn{client, nil}
tlsConn := Server(conn, &Config{
GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
panic("unreachable")
},
})
conn.tlsConn = tlsConn
// This call should not deadlock.
tlsConn.Close()
}

View File

@ -0,0 +1,232 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls_test
import (
"crypto/tls"
"crypto/x509"
"log"
"net/http"
"net/http/httptest"
"os"
"time"
)
// zeroSource is an io.Reader that returns an unlimited number of zero bytes.
type zeroSource struct{}
func (zeroSource) Read(b []byte) (n int, err error) {
for i := range b {
b[i] = 0
}
return len(b), nil
}
func ExampleDial() {
// Connecting with a custom root-certificate set.
const rootPEM = `
-- GlobalSign Root R2, valid until Dec 15, 2021
-----BEGIN CERTIFICATE-----
MIIDujCCAqKgAwIBAgILBAAAAAABD4Ym5g0wDQYJKoZIhvcNAQEFBQAwTDEgMB4G
A1UECxMXR2xvYmFsU2lnbiBSb290IENBIC0gUjIxEzARBgNVBAoTCkdsb2JhbFNp
Z24xEzARBgNVBAMTCkdsb2JhbFNpZ24wHhcNMDYxMjE1MDgwMDAwWhcNMjExMjE1
MDgwMDAwWjBMMSAwHgYDVQQLExdHbG9iYWxTaWduIFJvb3QgQ0EgLSBSMjETMBEG
A1UEChMKR2xvYmFsU2lnbjETMBEGA1UEAxMKR2xvYmFsU2lnbjCCASIwDQYJKoZI
hvcNAQEBBQADggEPADCCAQoCggEBAKbPJA6+Lm8omUVCxKs+IVSbC9N/hHD6ErPL
v4dfxn+G07IwXNb9rfF73OX4YJYJkhD10FPe+3t+c4isUoh7SqbKSaZeqKeMWhG8
eoLrvozps6yWJQeXSpkqBy+0Hne/ig+1AnwblrjFuTosvNYSuetZfeLQBoZfXklq
tTleiDTsvHgMCJiEbKjNS7SgfQx5TfC4LcshytVsW33hoCmEofnTlEnLJGKRILzd
C9XZzPnqJworc5HGnRusyMvo4KD0L5CLTfuwNhv2GXqF4G3yYROIXJ/gkwpRl4pa
zq+r1feqCapgvdzZX99yqWATXgAByUr6P6TqBwMhAo6CygPCm48CAwEAAaOBnDCB
mTAOBgNVHQ8BAf8EBAMCAQYwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUm+IH
V2ccHsBqBt5ZtJot39wZhi4wNgYDVR0fBC8wLTAroCmgJ4YlaHR0cDovL2NybC5n
bG9iYWxzaWduLm5ldC9yb290LXIyLmNybDAfBgNVHSMEGDAWgBSb4gdXZxwewGoG
3lm0mi3f3BmGLjANBgkqhkiG9w0BAQUFAAOCAQEAmYFThxxol4aR7OBKuEQLq4Gs
J0/WwbgcQ3izDJr86iw8bmEbTUsp9Z8FHSbBuOmDAGJFtqkIk7mpM0sYmsL4h4hO
291xNBrBVNpGP+DTKqttVCL1OmLNIG+6KYnX3ZHu01yiPqFbQfXf5WRDLenVOavS
ot+3i9DAgBkcRcAtjOj4LaR0VknFBbVPFd5uRHg5h6h+u/N5GJG79G+dwfCMNYxd
AfvDbbnvRG15RjF+Cv6pgsH/76tuIMRQyV+dTZsXjAzlAcmgQWpzU/qlULRuJQ/7
TBj0/VLZjmmx6BEP3ojY+x1J96relc8geMJgEtslQIxq/H5COEBkEveegeGTLg==
-----END CERTIFICATE-----`
// First, create the set of root certificates. For this example we only
// have one. It's also possible to omit this in order to use the
// default root set of the current operating system.
roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM([]byte(rootPEM))
if !ok {
panic("failed to parse root certificate")
}
conn, err := tls.Dial("tcp", "mail.google.com:443", &tls.Config{
RootCAs: roots,
})
if err != nil {
panic("failed to connect: " + err.Error())
}
conn.Close()
}
func ExampleConfig_keyLogWriter() {
// Debugging TLS applications by decrypting a network traffic capture.
// WARNING: Use of KeyLogWriter compromises security and should only be
// used for debugging.
// Dummy test HTTP server for the example with insecure random so output is
// reproducible.
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
server.TLS = &tls.Config{
Rand: zeroSource{}, // for example only; don't do this.
}
server.StartTLS()
defer server.Close()
// Typically the log would go to an open file:
// w, err := os.OpenFile("tls-secrets.txt", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
w := os.Stdout
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
KeyLogWriter: w,
Rand: zeroSource{}, // for reproducible output; don't do this.
InsecureSkipVerify: true, // test server certificate is not trusted.
},
},
}
resp, err := client.Get(server.URL)
if err != nil {
log.Fatalf("Failed to get URL: %v", err)
}
resp.Body.Close()
// The resulting file can be used with Wireshark to decrypt the TLS
// connection by setting (Pre)-Master-Secret log filename in SSL Protocol
// preferences.
}
func ExampleLoadX509KeyPair() {
cert, err := tls.LoadX509KeyPair("testdata/example-cert.pem", "testdata/example-key.pem")
if err != nil {
log.Fatal(err)
}
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
listener, err := tls.Listen("tcp", ":2000", cfg)
if err != nil {
log.Fatal(err)
}
_ = listener
}
func ExampleX509KeyPair() {
certPem := []byte(`-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----`)
keyPem := []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)
cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
log.Fatal(err)
}
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
listener, err := tls.Listen("tcp", ":2000", cfg)
if err != nil {
log.Fatal(err)
}
_ = listener
}
func ExampleX509KeyPair_httpServer() {
certPem := []byte(`-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----`)
keyPem := []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)
cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
log.Fatal(err)
}
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
srv := &http.Server{
TLSConfig: cfg,
ReadTimeout: time.Minute,
WriteTimeout: time.Minute,
}
log.Fatal(srv.ListenAndServeTLS("", ""))
}
func ExampleConfig_verifyConnection() {
// VerifyConnection can be used to replace and customize connection
// verification. This example shows a VerifyConnection implementation that
// will be approximately equivalent to what crypto/tls does normally to
// verify the peer's certificate.
// Client side configuration.
_ = &tls.Config{
// Set InsecureSkipVerify to skip the default validation we are
// replacing. This will not disable VerifyConnection.
InsecureSkipVerify: true,
VerifyConnection: func(cs tls.ConnectionState) error {
opts := x509.VerifyOptions{
DNSName: cs.ServerName,
Intermediates: x509.NewCertPool(),
}
for _, cert := range cs.PeerCertificates[1:] {
opts.Intermediates.AddCert(cert)
}
_, err := cs.PeerCertificates[0].Verify(opts)
return err
},
}
// Server side configuration.
_ = &tls.Config{
// Require client certificates (or VerifyConnection will run anyway and
// panic accessing cs.PeerCertificates[0]) but don't verify them with the
// default verifier. This will not disable VerifyConnection.
ClientAuth: tls.RequireAnyClientCert,
VerifyConnection: func(cs tls.ConnectionState) error {
opts := x509.VerifyOptions{
DNSName: cs.ServerName,
Intermediates: x509.NewCertPool(),
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
for _, cert := range cs.PeerCertificates[1:] {
opts.Intermediates.AddCert(cert)
}
_, err := cs.PeerCertificates[0].Verify(opts)
return err
},
}
// Note that when certificates are not handled by the default verifier
// ConnectionState.VerifiedChains will be nil.
}

View File

@ -0,0 +1,172 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
// Generate a self-signed X.509 certificate for a TLS server. Outputs to
// 'cert.pem' and 'key.pem' and will overwrite existing files.
package main
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"flag"
"log"
"math/big"
"net"
"os"
"strings"
"time"
)
var (
host = flag.String("host", "", "Comma-separated hostnames and IPs to generate a certificate for")
validFrom = flag.String("start-date", "", "Creation date formatted as Jan 1 15:04:05 2011")
validFor = flag.Duration("duration", 365*24*time.Hour, "Duration that certificate is valid for")
isCA = flag.Bool("ca", false, "whether this cert should be its own Certificate Authority")
rsaBits = flag.Int("rsa-bits", 2048, "Size of RSA key to generate. Ignored if --ecdsa-curve is set")
ecdsaCurve = flag.String("ecdsa-curve", "", "ECDSA curve to use to generate a key. Valid values are P224, P256 (recommended), P384, P521")
ed25519Key = flag.Bool("ed25519", false, "Generate an Ed25519 key")
)
func publicKey(priv interface{}) interface{} {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
case ed25519.PrivateKey:
return k.Public().(ed25519.PublicKey)
default:
return nil
}
}
func main() {
flag.Parse()
if len(*host) == 0 {
log.Fatalf("Missing required --host parameter")
}
var priv interface{}
var err error
switch *ecdsaCurve {
case "":
if *ed25519Key {
_, priv, err = ed25519.GenerateKey(rand.Reader)
} else {
priv, err = rsa.GenerateKey(rand.Reader, *rsaBits)
}
case "P224":
priv, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
case "P256":
priv, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case "P384":
priv, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case "P521":
priv, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
default:
log.Fatalf("Unrecognized elliptic curve: %q", *ecdsaCurve)
}
if err != nil {
log.Fatalf("Failed to generate private key: %v", err)
}
// ECDSA, ED25519 and RSA subject keys should have the DigitalSignature
// KeyUsage bits set in the x509.Certificate template
keyUsage := x509.KeyUsageDigitalSignature
// Only RSA subject keys should have the KeyEncipherment KeyUsage bits set. In
// the context of TLS this KeyUsage is particular to RSA key exchange and
// authentication.
if _, isRSA := priv.(*rsa.PrivateKey); isRSA {
keyUsage |= x509.KeyUsageKeyEncipherment
}
var notBefore time.Time
if len(*validFrom) == 0 {
notBefore = time.Now()
} else {
notBefore, err = time.Parse("Jan 2 15:04:05 2006", *validFrom)
if err != nil {
log.Fatalf("Failed to parse creation date: %v", err)
}
}
notAfter := notBefore.Add(*validFor)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
log.Fatalf("Failed to generate serial number: %v", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Acme Co"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
hosts := strings.Split(*host, ",")
for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}
if *isCA {
template.IsCA = true
template.KeyUsage |= x509.KeyUsageCertSign
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
if err != nil {
log.Fatalf("Failed to create certificate: %v", err)
}
certOut, err := os.Create("cert.pem")
if err != nil {
log.Fatalf("Failed to open cert.pem for writing: %v", err)
}
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
log.Fatalf("Failed to write data to cert.pem: %v", err)
}
if err := certOut.Close(); err != nil {
log.Fatalf("Error closing cert.pem: %v", err)
}
log.Print("wrote cert.pem\n")
keyOut, err := os.OpenFile("key.pem", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
log.Fatalf("Failed to open key.pem for writing: %v", err)
return
}
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
log.Fatalf("Unable to marshal private key: %v", err)
}
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
log.Fatalf("Failed to write data to key.pem: %v", err)
}
if err := keyOut.Close(); err != nil {
log.Fatalf("Error closing key.pem: %v", err)
}
log.Print("wrote key.pem\n")
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,685 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"crypto"
"crypto/hmac"
"crypto/rsa"
"errors"
"hash"
"sync/atomic"
"time"
)
type clientHandshakeStateTLS13 struct {
c *Conn
serverHello *serverHelloMsg
hello *clientHelloMsg
ecdheParams ecdheParameters
session *ClientSessionState
earlySecret []byte
binderKey []byte
certReq *certificateRequestMsgTLS13
usingPSK bool
sentDummyCCS bool
suite *cipherSuiteTLS13
transcript hash.Hash
masterSecret []byte
trafficSecret []byte // client_application_traffic_secret_0
}
// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and,
// optionally, hs.session, hs.earlySecret and hs.binderKey to be set.
func (hs *clientHandshakeStateTLS13) handshake() error {
c := hs.c
// The server must not select TLS 1.3 in a renegotiation. See RFC 8446,
// sections 4.1.2 and 4.1.3.
if c.handshakes > 0 {
c.sendAlert(alertProtocolVersion)
return errors.New("tls: server selected TLS 1.3 in a renegotiation")
}
// Consistency check on the presence of a keyShare and its parameters.
if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 {
return c.sendAlert(alertInternalError)
}
if err := hs.checkServerHelloOrHRR(); err != nil {
return err
}
hs.transcript = hs.suite.hash.New()
hs.transcript.Write(hs.hello.marshal())
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
if err := hs.processHelloRetryRequest(); err != nil {
return err
}
}
hs.transcript.Write(hs.serverHello.marshal())
c.buffering = true
if err := hs.processServerHello(); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
if err := hs.establishHandshakeKeys(); err != nil {
return err
}
if err := hs.readServerParameters(); err != nil {
return err
}
if err := hs.readServerCertificate(); err != nil {
return err
}
if err := hs.readServerFinished(); err != nil {
return err
}
if err := hs.sendClientCertificate(); err != nil {
return err
}
if err := hs.sendClientFinished(); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
atomic.StoreUint32(&c.handshakeStatus, 1)
return nil
}
// checkServerHelloOrHRR does validity checks that apply to both ServerHello and
// HelloRetryRequest messages. It sets hs.suite.
func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
c := hs.c
if hs.serverHello.supportedVersion == 0 {
c.sendAlert(alertMissingExtension)
return errors.New("tls: server selected TLS 1.3 using the legacy version field")
}
if hs.serverHello.supportedVersion != VersionTLS13 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid version after a HelloRetryRequest")
}
if hs.serverHello.vers != VersionTLS12 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an incorrect legacy version")
}
if hs.serverHello.ocspStapling ||
hs.serverHello.ticketSupported ||
hs.serverHello.secureRenegotiationSupported ||
len(hs.serverHello.secureRenegotiation) != 0 ||
len(hs.serverHello.alpnProtocol) != 0 ||
len(hs.serverHello.scts) != 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3")
}
if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server did not echo the legacy session ID")
}
if hs.serverHello.compressionMethod != compressionNone {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported compression format")
}
selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite)
if hs.suite != nil && selectedSuite != hs.suite {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server changed cipher suite after a HelloRetryRequest")
}
if selectedSuite == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server chose an unconfigured cipher suite")
}
hs.suite = selectedSuite
c.cipherSuite = hs.suite.id
return nil
}
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.sentDummyCCS {
return nil
}
hs.sentDummyCCS = true
_, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
return err
}
// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and
// resends hs.hello, and reads the new ServerHello into hs.serverHello.
func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
c := hs.c
// The first ClientHello gets double-hashed into the transcript upon a
// HelloRetryRequest. (The idea is that the server might offload transcript
// storage to the client in the cookie.) See RFC 8446, Section 4.4.1.
chHash := hs.transcript.Sum(nil)
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.transcript.Write(chHash)
hs.transcript.Write(hs.serverHello.marshal())
// The only HelloRetryRequest extensions we support are key_share and
// cookie, and clients must abort the handshake if the HRR would not result
// in any change in the ClientHello.
if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an unnecessary HelloRetryRequest message")
}
if hs.serverHello.cookie != nil {
hs.hello.cookie = hs.serverHello.cookie
}
if hs.serverHello.serverShare.group != 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: received malformed key_share extension")
}
// If the server sent a key_share extension selecting a group, ensure it's
// a group we advertised but did not send a key share for, and send a key
// share for it this time.
if curveID := hs.serverHello.selectedGroup; curveID != 0 {
curveOK := false
for _, id := range hs.hello.supportedCurves {
if id == curveID {
curveOK = true
break
}
}
if !curveOK {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
if hs.ecdheParams.CurveID() == curveID {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
}
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
c.sendAlert(alertInternalError)
return errors.New("tls: CurvePreferences includes unsupported curve")
}
params, err := generateECDHEParameters(c.config.rand(), curveID)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.ecdheParams = params
hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}}
}
hs.hello.raw = nil
if len(hs.hello.pskIdentities) > 0 {
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil {
return c.sendAlert(alertInternalError)
}
if pskSuite.hash == hs.suite.hash {
// Update binders and obfuscated_ticket_age.
ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond)
hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd
transcript := hs.suite.hash.New()
transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
transcript.Write(chHash)
transcript.Write(hs.serverHello.marshal())
transcript.Write(hs.hello.marshalWithoutBinders())
pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)}
hs.hello.updateBinders(pskBinders)
} else {
// Server selected a cipher suite incompatible with the PSK.
hs.hello.pskIdentities = nil
hs.hello.pskBinders = nil
}
}
hs.transcript.Write(hs.hello.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
return err
}
msg, err := c.readHandshake()
if err != nil {
return err
}
serverHello, ok := msg.(*serverHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(serverHello, msg)
}
hs.serverHello = serverHello
if err := hs.checkServerHelloOrHRR(); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) processServerHello() error {
c := hs.c
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: server sent two HelloRetryRequest messages")
}
if len(hs.serverHello.cookie) != 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent a cookie in a normal ServerHello")
}
if hs.serverHello.selectedGroup != 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: malformed key_share extension")
}
if hs.serverHello.serverShare.group == 0 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server did not send a key share")
}
if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
if !hs.serverHello.selectedIdentityPresent {
return nil
}
if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid PSK")
}
if len(hs.hello.pskIdentities) != 1 || hs.session == nil {
return c.sendAlert(alertInternalError)
}
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil {
return c.sendAlert(alertInternalError)
}
if pskSuite.hash != hs.suite.hash {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid PSK and cipher suite pair")
}
hs.usingPSK = true
c.didResume = true
c.peerCertificates = hs.session.serverCertificates
c.verifiedChains = hs.session.verifiedChains
c.ocspResponse = hs.session.ocspResponse
c.scts = hs.session.scts
return nil
}
func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
c := hs.c
sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data)
if sharedKey == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server key share")
}
earlySecret := hs.earlySecret
if !hs.usingPSK {
earlySecret = hs.suite.extract(nil, nil)
}
handshakeSecret := hs.suite.extract(sharedKey,
hs.suite.deriveSecret(earlySecret, "derived", nil))
clientSecret := hs.suite.deriveSecret(handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript)
c.out.setTrafficSecret(hs.suite, clientSecret)
serverSecret := hs.suite.deriveSecret(handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, serverSecret)
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.masterSecret = hs.suite.extract(nil,
hs.suite.deriveSecret(handshakeSecret, "derived", nil))
return nil
}
func (hs *clientHandshakeStateTLS13) readServerParameters() error {
c := hs.c
msg, err := c.readHandshake()
if err != nil {
return err
}
encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(encryptedExtensions, msg)
}
hs.transcript.Write(encryptedExtensions.marshal())
if encryptedExtensions.alpnProtocol != "" {
if len(hs.hello.alpnProtocols) == 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server advertised unrequested ALPN extension")
}
if mutualProtocol([]string{encryptedExtensions.alpnProtocol}, hs.hello.alpnProtocols) == "" {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server selected unadvertised ALPN protocol")
}
c.clientProtocol = encryptedExtensions.alpnProtocol
}
return nil
}
func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
c := hs.c
// Either a PSK or a certificate is always used, but not both.
// See RFC 8446, Section 4.1.1.
if hs.usingPSK {
// Make sure the connection is still being verified whether or not this
// is a resumption. Resumptions currently don't reverify certificates so
// they don't call verifyServerCertificate. See Issue 31641.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
msg, err := c.readHandshake()
if err != nil {
return err
}
certReq, ok := msg.(*certificateRequestMsgTLS13)
if ok {
hs.transcript.Write(certReq.marshal())
hs.certReq = certReq
msg, err = c.readHandshake()
if err != nil {
return err
}
}
certMsg, ok := msg.(*certificateMsgTLS13)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if len(certMsg.certificate.Certificate) == 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: received empty certificates message")
}
hs.transcript.Write(certMsg.marshal())
c.scts = certMsg.certificate.SignedCertificateTimestamps
c.ocspResponse = certMsg.certificate.OCSPStaple
if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil {
return err
}
msg, err = c.readHandshake()
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
// See RFC 8446, Section 4.4.3.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
hs.transcript.Write(certVerify.marshal())
return nil
}
func (hs *clientHandshakeStateTLS13) readServerFinished() error {
c := hs.c
msg, err := c.readHandshake()
if err != nil {
return err
}
finished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(finished, msg)
}
expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
if !hmac.Equal(expectedMAC, finished.verifyData) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid server finished hash")
}
hs.transcript.Write(finished.marshal())
// Derive secrets that take context through the server Finished.
hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, serverSecret)
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
return nil
}
func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
c := hs.c
if hs.certReq == nil {
return nil
}
cert, err := c.getClientCertificate(&CertificateRequestInfo{
AcceptableCAs: hs.certReq.certificateAuthorities,
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
Version: c.vers,
})
if err != nil {
return err
}
certMsg := new(certificateMsgTLS13)
certMsg.certificate = *cert
certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0
hs.transcript.Write(certMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err
}
// If we sent an empty certificate message, skip the CertificateVerify.
if len(cert.Certificate) == 0 {
return nil
}
certVerifyMsg := new(certificateVerifyMsg)
certVerifyMsg.hasSignatureAlgorithm = true
certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms)
if err != nil {
// getClientCertificate returned a certificate incompatible with the
// CertificateRequestInfo supported signature algorithms.
c.sendAlert(alertHandshakeFailure)
return err
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
if err != nil {
c.sendAlert(alertInternalError)
return errors.New("tls: failed to sign handshake: " + err.Error())
}
certVerifyMsg.signature = sig
hs.transcript.Write(certVerifyMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
c := hs.c
finished := &finishedMsg{
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
hs.transcript.Write(finished.marshal())
if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
return err
}
c.out.setTrafficSecret(hs.suite, hs.trafficSecret)
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript)
}
return nil
}
func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
if !c.isClient {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: received new session ticket from a client")
}
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
return nil
}
// See RFC 8446, Section 4.6.1.
if msg.lifetime == 0 {
return nil
}
lifetime := time.Duration(msg.lifetime) * time.Second
if lifetime > maxSessionTicketLifetime {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: received a session ticket with invalid lifetime")
}
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
if cipherSuite == nil || c.resumptionSecret == nil {
return c.sendAlert(alertInternalError)
}
// Save the resumption_master_secret and nonce instead of deriving the PSK
// to do the least amount of work on NewSessionTicket messages before we
// know if the ticket will be used. Forward secrecy of resumed connections
// is guaranteed by the requirement for pskModeDHE.
session := &ClientSessionState{
sessionTicket: msg.label,
vers: c.vers,
cipherSuite: c.cipherSuite,
masterSecret: c.resumptionSecret,
serverCertificates: c.peerCertificates,
verifiedChains: c.verifiedChains,
receivedAt: c.config.time(),
nonce: msg.nonce,
useBy: c.config.time().Add(lifetime),
ageAdd: msg.ageAdd,
ocspResponse: c.ocspResponse,
scts: c.scts,
}
cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
c.config.ClientSessionCache.Put(cacheKey, session)
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,465 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"math/rand"
"reflect"
"strings"
"testing"
"testing/quick"
"time"
)
var tests = []interface{}{
&clientHelloMsg{},
&serverHelloMsg{},
&finishedMsg{},
&certificateMsg{},
&certificateRequestMsg{},
&certificateVerifyMsg{
hasSignatureAlgorithm: true,
},
&certificateStatusMsg{},
&clientKeyExchangeMsg{},
&newSessionTicketMsg{},
&sessionState{},
&sessionStateTLS13{},
&encryptedExtensionsMsg{},
&endOfEarlyDataMsg{},
&keyUpdateMsg{},
&newSessionTicketMsgTLS13{},
&certificateRequestMsgTLS13{},
&certificateMsgTLS13{},
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(time.Now().UnixNano()))
for i, iface := range tests {
ty := reflect.ValueOf(iface).Type()
n := 100
if testing.Short() {
n = 5
}
for j := 0; j < n; j++ {
v, ok := quick.Value(ty, rand)
if !ok {
t.Errorf("#%d: failed to create value", i)
break
}
m1 := v.Interface().(handshakeMessage)
marshaled := m1.marshal()
m2 := iface.(handshakeMessage)
if !m2.unmarshal(marshaled) {
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
break
}
m2.marshal() // to fill any marshal cache in the message
if !reflect.DeepEqual(m1, m2) {
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
break
}
if i >= 3 {
// The first three message types (ClientHello,
// ServerHello and Finished) are allowed to
// have parsable prefixes because the extension
// data is optional and the length of the
// Finished varies across versions.
for j := 0; j < len(marshaled); j++ {
if m2.unmarshal(marshaled[0:j]) {
t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
break
}
}
}
}
}
}
func TestFuzz(t *testing.T) {
rand := rand.New(rand.NewSource(0))
for _, iface := range tests {
m := iface.(handshakeMessage)
for j := 0; j < 1000; j++ {
len := rand.Intn(100)
bytes := randomBytes(len, rand)
// This just looks for crashes due to bounds errors etc.
m.unmarshal(bytes)
}
}
}
func randomBytes(n int, rand *rand.Rand) []byte {
r := make([]byte, n)
if _, err := rand.Read(r); err != nil {
panic("rand.Read failed: " + err.Error())
}
return r
}
func randomString(n int, rand *rand.Rand) string {
b := randomBytes(n, rand)
return string(b)
}
func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientHelloMsg{}
m.vers = uint16(rand.Intn(65536))
m.random = randomBytes(32, rand)
m.sessionId = randomBytes(rand.Intn(32), rand)
m.cipherSuites = make([]uint16, rand.Intn(63)+1)
for i := 0; i < len(m.cipherSuites); i++ {
cs := uint16(rand.Int31())
if cs == scsvRenegotiation {
cs += 1
}
m.cipherSuites[i] = cs
}
m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
if rand.Intn(10) > 5 {
m.serverName = randomString(rand.Intn(255), rand)
for strings.HasSuffix(m.serverName, ".") {
m.serverName = m.serverName[:len(m.serverName)-1]
}
}
m.ocspStapling = rand.Intn(10) > 5
m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
for i := range m.supportedCurves {
m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
}
if rand.Intn(10) > 5 {
m.ticketSupported = true
if rand.Intn(10) > 5 {
m.sessionTicket = randomBytes(rand.Intn(300), rand)
} else {
m.sessionTicket = make([]byte, 0)
}
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms
}
for i := 0; i < rand.Intn(5); i++ {
m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
}
if rand.Intn(10) > 5 {
m.scts = true
}
if rand.Intn(10) > 5 {
m.secureRenegotiationSupported = true
m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
}
for i := 0; i < rand.Intn(5); i++ {
m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
}
if rand.Intn(10) > 5 {
m.cookie = randomBytes(rand.Intn(500)+1, rand)
}
for i := 0; i < rand.Intn(5); i++ {
var ks keyShare
ks.group = CurveID(rand.Intn(30000) + 1)
ks.data = randomBytes(rand.Intn(200)+1, rand)
m.keyShares = append(m.keyShares, ks)
}
switch rand.Intn(3) {
case 1:
m.pskModes = []uint8{pskModeDHE}
case 2:
m.pskModes = []uint8{pskModeDHE, pskModePlain}
}
for i := 0; i < rand.Intn(5); i++ {
var psk pskIdentity
psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
psk.label = randomBytes(rand.Intn(500)+1, rand)
m.pskIdentities = append(m.pskIdentities, psk)
m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
}
if rand.Intn(10) > 5 {
m.earlyData = true
}
return reflect.ValueOf(m)
}
func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &serverHelloMsg{}
m.vers = uint16(rand.Intn(65536))
m.random = randomBytes(32, rand)
m.sessionId = randomBytes(rand.Intn(32), rand)
m.cipherSuite = uint16(rand.Int31())
m.compressionMethod = uint8(rand.Intn(256))
m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
if rand.Intn(10) > 5 {
m.ocspStapling = true
}
if rand.Intn(10) > 5 {
m.ticketSupported = true
}
if rand.Intn(10) > 5 {
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
}
for i := 0; i < rand.Intn(4); i++ {
m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
}
if rand.Intn(10) > 5 {
m.secureRenegotiationSupported = true
m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
}
if rand.Intn(10) > 5 {
m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
}
if rand.Intn(10) > 5 {
m.cookie = randomBytes(rand.Intn(500)+1, rand)
}
if rand.Intn(10) > 5 {
for i := 0; i < rand.Intn(5); i++ {
m.serverShare.group = CurveID(rand.Intn(30000) + 1)
m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
}
} else if rand.Intn(10) > 5 {
m.selectedGroup = CurveID(rand.Intn(30000) + 1)
}
if rand.Intn(10) > 5 {
m.selectedIdentityPresent = true
m.selectedIdentity = uint16(rand.Intn(0xffff))
}
return reflect.ValueOf(m)
}
func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &encryptedExtensionsMsg{}
if rand.Intn(10) > 5 {
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
}
return reflect.ValueOf(m)
}
func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateMsg{}
numCerts := rand.Intn(20)
m.certificates = make([][]byte, numCerts)
for i := 0; i < numCerts; i++ {
m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
}
return reflect.ValueOf(m)
}
func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateRequestMsg{}
m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
for i := 0; i < rand.Intn(100); i++ {
m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
}
return reflect.ValueOf(m)
}
func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateVerifyMsg{}
m.hasSignatureAlgorithm = true
m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
m.signature = randomBytes(rand.Intn(15)+1, rand)
return reflect.ValueOf(m)
}
func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateStatusMsg{}
m.response = randomBytes(rand.Intn(10)+1, rand)
return reflect.ValueOf(m)
}
func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientKeyExchangeMsg{}
m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
return reflect.ValueOf(m)
}
func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &finishedMsg{}
m.verifyData = randomBytes(12, rand)
return reflect.ValueOf(m)
}
func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &newSessionTicketMsg{}
m.ticket = randomBytes(rand.Intn(4), rand)
return reflect.ValueOf(m)
}
func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
s := &sessionState{}
s.vers = uint16(rand.Intn(10000))
s.cipherSuite = uint16(rand.Intn(10000))
s.masterSecret = randomBytes(rand.Intn(100)+1, rand)
s.createdAt = uint64(rand.Int63())
for i := 0; i < rand.Intn(20); i++ {
s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand))
}
return reflect.ValueOf(s)
}
func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
s := &sessionStateTLS13{}
s.cipherSuite = uint16(rand.Intn(10000))
s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand)
s.createdAt = uint64(rand.Int63())
for i := 0; i < rand.Intn(2)+1; i++ {
s.certificate.Certificate = append(
s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
}
if rand.Intn(10) > 5 {
s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
}
if rand.Intn(10) > 5 {
for i := 0; i < rand.Intn(2)+1; i++ {
s.certificate.SignedCertificateTimestamps = append(
s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
}
}
return reflect.ValueOf(s)
}
func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &endOfEarlyDataMsg{}
return reflect.ValueOf(m)
}
func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &keyUpdateMsg{}
m.updateRequested = rand.Intn(10) > 5
return reflect.ValueOf(m)
}
func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &newSessionTicketMsgTLS13{}
m.lifetime = uint32(rand.Intn(500000))
m.ageAdd = uint32(rand.Intn(500000))
m.nonce = randomBytes(rand.Intn(100), rand)
m.label = randomBytes(rand.Intn(1000), rand)
if rand.Intn(10) > 5 {
m.maxEarlyData = uint32(rand.Intn(500000))
}
return reflect.ValueOf(m)
}
func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateRequestMsgTLS13{}
if rand.Intn(10) > 5 {
m.ocspStapling = true
}
if rand.Intn(10) > 5 {
m.scts = true
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms
}
if rand.Intn(10) > 5 {
m.certificateAuthorities = make([][]byte, 3)
for i := 0; i < 3; i++ {
m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
}
}
return reflect.ValueOf(m)
}
func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateMsgTLS13{}
for i := 0; i < rand.Intn(2)+1; i++ {
m.certificate.Certificate = append(
m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
}
if rand.Intn(10) > 5 {
m.ocspStapling = true
m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
}
if rand.Intn(10) > 5 {
m.scts = true
for i := 0; i < rand.Intn(2)+1; i++ {
m.certificate.SignedCertificateTimestamps = append(
m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
}
}
return reflect.ValueOf(m)
}
func TestRejectEmptySCTList(t *testing.T) {
// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
var random [32]byte
sct := []byte{0x42, 0x42, 0x42, 0x42}
serverHello := serverHelloMsg{
vers: VersionTLS12,
random: random[:],
scts: [][]byte{sct},
}
serverHelloBytes := serverHello.marshal()
var serverHelloCopy serverHelloMsg
if !serverHelloCopy.unmarshal(serverHelloBytes) {
t.Fatal("Failed to unmarshal initial message")
}
// Change serverHelloBytes so that the SCT list is empty
i := bytes.Index(serverHelloBytes, sct)
if i < 0 {
t.Fatal("Cannot find SCT in ServerHello")
}
var serverHelloEmptySCT []byte
serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
// Append the extension length and SCT list length for an empty list.
serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
// Update the handshake message length.
serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
// Update the extensions length
serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
t.Fatal("Unmarshaled ServerHello with empty SCT list")
}
}
func TestRejectEmptySCT(t *testing.T) {
// Not only must the SCT list be non-empty, but the SCT elements must
// not be zero length.
var random [32]byte
serverHello := serverHelloMsg{
vers: VersionTLS12,
random: random[:],
scts: [][]byte{nil},
}
serverHelloBytes := serverHello.marshal()
var serverHelloCopy serverHelloMsg
if serverHelloCopy.unmarshal(serverHelloBytes) {
t.Fatal("Unmarshaled ServerHello with zero-length SCT")
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,971 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"crypto"
"crypto/hmac"
"crypto/rsa"
"errors"
"hash"
"io"
"sync/atomic"
"time"
)
// maxClientPSKIdentities is the number of client PSK identities the server will
// attempt to validate. It will ignore the rest not to let cheap ClientHello
// messages cause too much work in session ticket decryption attempts.
const maxClientPSKIdentities = 5
type serverHandshakeStateTLS13 struct {
c *Conn
clientHello *clientHelloMsg
hello *serverHelloMsg
sentDummyCCS bool
usingPSK bool
suite *cipherSuiteTLS13
cert *Certificate
sigAlg SignatureScheme
earlySecret []byte
sharedKey []byte
handshakeSecret []byte
masterSecret []byte
trafficSecret []byte // client_application_traffic_secret_0
transcript hash.Hash
clientFinished []byte
err error
}
func (hs *serverHandshakeStateTLS13) handshake() error {
c := hs.c
if c.handshakeStatusAsync >= stateServerHandshake13HandshakeDone {
return nil
}
if hs.err != nil && hs.err != errDataNotEnough {
return hs.err
}
// For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2.
if err := hs.processClientHello(); err != nil {
hs.err = err
return err
}
if err := hs.checkForResumption(); err != nil {
hs.err = err
return err
}
if err := hs.pickCertificate(); err != nil {
hs.err = err
return err
}
c.buffering = true
if err := hs.sendServerParameters(); err != nil {
hs.err = err
return err
}
if err := hs.sendServerCertificate(); err != nil {
hs.err = err
return err
}
if err := hs.sendServerFinished(); err != nil {
hs.err = err
return err
}
// Note that at this point we could start sending application data without
// waiting for the client's second flight, but the application might not
// expect the lack of replay protection of the ClientHello parameters.
if _, err := c.flush(); err != nil {
hs.err = err
return err
}
if err := hs.readClientCertificate(); err != nil {
hs.err = err
return err
}
if err := hs.readClientFinished(); err != nil {
hs.err = err
return err
}
c.handshakeStatusAsync = stateServerHandshake13HandshakeDone
atomic.StoreUint32(&c.handshakeStatus, 1)
return nil
}
func (hs *serverHandshakeStateTLS13) processClientHello() error {
c := hs.c
if c.handshakeStatusAsync >= stateServerHandshake13ProcessClientHello {
return nil
}
c.handshakeStatusAsync = stateServerHandshake13ProcessClientHello
hs.hello = new(serverHelloMsg)
// TLS 1.3 froze the ServerHello.legacy_version field, and uses
// supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1.
hs.hello.vers = VersionTLS12
hs.hello.supportedVersion = c.vers
if len(hs.clientHello.supportedVersions) == 0 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client used the legacy version field to negotiate TLS 1.3")
}
// Abort if the client is doing a fallback and landing lower than what we
// support. See RFC 7507, which however does not specify the interaction
// with supported_versions. The only difference is that with
// supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4]
// handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case,
// it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to
// TLS 1.2, because a TLS 1.3 server would abort here. The situation before
// supported_versions was not better because there was just no way to do a
// TLS 1.4 handshake without risking the server selecting TLS 1.3.
for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV {
// Use c.vers instead of max(supported_versions) because an attacker
// could defeat this by adding an arbitrary high version otherwise.
if c.vers < c.config.maxSupportedVersion() {
c.sendAlert(alertInappropriateFallback)
return errors.New("tls: client using inappropriate protocol fallback")
}
break
}
}
if len(hs.clientHello.compressionMethods) != 1 ||
hs.clientHello.compressionMethods[0] != compressionNone {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: TLS 1.3 client supports illegal compression methods")
}
hs.hello.random = make([]byte, 32)
if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil {
c.sendAlert(alertInternalError)
return err
}
if len(hs.clientHello.secureRenegotiation) != 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: initial handshake had non-empty renegotiation extension")
}
if hs.clientHello.earlyData {
// See RFC 8446, Section 4.2.10 for the complicated behavior required
// here. The scenario is that a different server at our address offered
// to accept early data in the past, which we can't handle. For now, all
// 0-RTT enabled session tickets need to expire before a Go server can
// replace a server or join a pool. That's the same requirement that
// applies to mixing or replacing with any TLS 1.2 server.
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: client sent unexpected early data")
}
hs.hello.sessionId = hs.clientHello.sessionId
hs.hello.compressionMethod = compressionNone
var preferenceList, supportedList []uint16
if c.config.PreferServerCipherSuites {
preferenceList = defaultCipherSuitesTLS13()
supportedList = hs.clientHello.cipherSuites
// If the client does not seem to have hardware support for AES-GCM,
// prefer other AEAD ciphers even if we prioritized AES-GCM ciphers
// by default.
if !aesgcmPreferred(hs.clientHello.cipherSuites) {
preferenceList = deprioritizeAES(preferenceList)
}
} else {
preferenceList = hs.clientHello.cipherSuites
supportedList = defaultCipherSuitesTLS13()
// If we don't have hardware support for AES-GCM, prefer other AEAD
// ciphers even if the client prioritized AES-GCM.
if !hasAESGCMHardwareSupport {
preferenceList = deprioritizeAES(preferenceList)
}
}
for _, suiteID := range preferenceList {
hs.suite = mutualCipherSuiteTLS13(supportedList, suiteID)
if hs.suite != nil {
break
}
}
if hs.suite == nil {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no cipher suite supported by both client and server")
}
c.cipherSuite = hs.suite.id
hs.hello.cipherSuite = hs.suite.id
hs.transcript = hs.suite.hash.New()
// Pick the ECDHE group in server preference order, but give priority to
// groups with a key share, to avoid a HelloRetryRequest round-trip.
var selectedGroup CurveID
var clientKeyShare *keyShare
GroupSelection:
for _, preferredGroup := range c.config.curvePreferences() {
for _, ks := range hs.clientHello.keyShares {
if ks.group == preferredGroup {
selectedGroup = ks.group
clientKeyShare = &ks
break GroupSelection
}
}
if selectedGroup != 0 {
continue
}
for _, group := range hs.clientHello.supportedCurves {
if group == preferredGroup {
selectedGroup = group
break
}
}
}
if selectedGroup == 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no ECDHE curve supported by both client and server")
}
if clientKeyShare == nil {
if err := hs.doHelloRetryRequest(selectedGroup); err != nil {
return err
}
clientKeyShare = &hs.clientHello.keyShares[0]
}
if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok {
c.sendAlert(alertInternalError)
return errors.New("tls: CurvePreferences includes unsupported curve")
}
params, err := generateECDHEParameters(c.config.rand(), selectedGroup)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()}
hs.sharedKey = params.SharedKey(clientKeyShare.data)
if hs.sharedKey == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid client key share")
}
c.serverName = hs.clientHello.serverName
return nil
}
func (hs *serverHandshakeStateTLS13) checkForResumption() error {
c := hs.c
if c.handshakeStatusAsync >= stateServerHandshake13CheckForResumption {
return nil
}
c.handshakeStatusAsync = stateServerHandshake13CheckForResumption
if c.config.SessionTicketsDisabled {
return nil
}
modeOK := false
for _, mode := range hs.clientHello.pskModes {
if mode == pskModeDHE {
modeOK = true
break
}
}
if !modeOK {
return nil
}
if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid or missing PSK binders")
}
if len(hs.clientHello.pskIdentities) == 0 {
return nil
}
for i, identity := range hs.clientHello.pskIdentities {
if i >= maxClientPSKIdentities {
break
}
plaintext, _ := c.decryptTicket(identity.label)
if plaintext == nil {
continue
}
sessionState := new(sessionStateTLS13)
if ok := sessionState.unmarshal(plaintext); !ok {
continue
}
createdAt := time.Unix(int64(sessionState.createdAt), 0)
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
continue
}
// We don't check the obfuscated ticket age because it's affected by
// clock skew and it's only a freshness signal useful for shrinking the
// window for replay attacks, which don't affect us as we don't do 0-RTT.
pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite)
if pskSuite == nil || pskSuite.hash != hs.suite.hash {
continue
}
// PSK connections don't re-establish client certificates, but carry
// them over in the session ticket. Ensure the presence of client certs
// in the ticket is consistent with the configured requirements.
sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0
needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts {
continue
}
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
continue
}
psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption",
nil, hs.suite.hash.Size())
hs.earlySecret = hs.suite.extract(psk, nil)
binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil)
// Clone the transcript in case a HelloRetryRequest was recorded.
transcript := cloneHash(hs.transcript, hs.suite.hash)
if transcript == nil {
c.sendAlert(alertInternalError)
return errors.New("tls: internal error: failed to clone hash")
}
transcript.Write(hs.clientHello.marshalWithoutBinders())
pskBinder := hs.suite.finishedHash(binderKey, transcript)
if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid PSK binder")
}
c.didResume = true
if err := c.processCertsFromClient(sessionState.certificate); err != nil {
return err
}
hs.hello.selectedIdentityPresent = true
hs.hello.selectedIdentity = uint16(i)
hs.usingPSK = true
return nil
}
return nil
}
// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
// interfaces implemented by standard library hashes to clone the state of in
// to a new instance of h. It returns nil if the operation fails.
func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash {
// Recreate the interface to avoid importing encoding.
type binaryMarshaler interface {
MarshalBinary() (data []byte, err error)
UnmarshalBinary(data []byte) error
}
marshaler, ok := in.(binaryMarshaler)
if !ok {
return nil
}
state, err := marshaler.MarshalBinary()
if err != nil {
return nil
}
out := h.New()
unmarshaler, ok := out.(binaryMarshaler)
if !ok {
return nil
}
if err := unmarshaler.UnmarshalBinary(state); err != nil {
return nil
}
return out
}
func (hs *serverHandshakeStateTLS13) pickCertificate() error {
c := hs.c
if c.handshakeStatusAsync >= stateServerHandshake13PickCertificate {
return nil
}
c.handshakeStatusAsync = stateServerHandshake13PickCertificate
// Only one of PSK and certificates are used at a time.
if hs.usingPSK {
return nil
}
// signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3.
if len(hs.clientHello.supportedSignatureAlgorithms) == 0 {
return c.sendAlert(alertMissingExtension)
}
certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
} else {
c.sendAlert(alertInternalError)
}
return err
}
hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms)
if err != nil {
// getCertificate returned a certificate that is unsupported or
// incompatible with the client's signature algorithms.
c.sendAlert(alertHandshakeFailure)
return err
}
hs.cert = certificate
return nil
}
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.sentDummyCCS {
return nil
}
hs.sentDummyCCS = true
_, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
return err
}
func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error {
c := hs.c
// The first ClientHello gets double-hashed into the transcript upon a
// HelloRetryRequest. See RFC 8446, Section 4.4.1.
hs.transcript.Write(hs.clientHello.marshal())
chHash := hs.transcript.Sum(nil)
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.transcript.Write(chHash)
helloRetryRequest := &serverHelloMsg{
vers: hs.hello.vers,
random: helloRetryRequestRandom,
sessionId: hs.hello.sessionId,
cipherSuite: hs.hello.cipherSuite,
compressionMethod: hs.hello.compressionMethod,
supportedVersion: hs.hello.supportedVersion,
selectedGroup: selectedGroup,
}
hs.transcript.Write(helloRetryRequest.marshal())
if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
msg, err := c.readHandshake()
if err != nil {
return err
}
clientHello, ok := msg.(*clientHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(clientHello, msg)
}
if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client sent invalid key share in second ClientHello")
}
if clientHello.earlyData {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client indicated early data in second ClientHello")
}
if illegalClientHelloChange(clientHello, hs.clientHello) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client illegally modified second ClientHello")
}
hs.clientHello = clientHello
return nil
}
// illegalClientHelloChange reports whether the two ClientHello messages are
// different, with the exception of the changes allowed before and after a
// HelloRetryRequest. See RFC 8446, Section 4.1.2.
func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool {
if len(ch.supportedVersions) != len(ch1.supportedVersions) ||
len(ch.cipherSuites) != len(ch1.cipherSuites) ||
len(ch.supportedCurves) != len(ch1.supportedCurves) ||
len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) ||
len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) ||
len(ch.alpnProtocols) != len(ch1.alpnProtocols) {
return true
}
for i := range ch.supportedVersions {
if ch.supportedVersions[i] != ch1.supportedVersions[i] {
return true
}
}
for i := range ch.cipherSuites {
if ch.cipherSuites[i] != ch1.cipherSuites[i] {
return true
}
}
for i := range ch.supportedCurves {
if ch.supportedCurves[i] != ch1.supportedCurves[i] {
return true
}
}
for i := range ch.supportedSignatureAlgorithms {
if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] {
return true
}
}
for i := range ch.supportedSignatureAlgorithmsCert {
if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] {
return true
}
}
for i := range ch.alpnProtocols {
if ch.alpnProtocols[i] != ch1.alpnProtocols[i] {
return true
}
}
return ch.vers != ch1.vers ||
!bytes.Equal(ch.random, ch1.random) ||
!bytes.Equal(ch.sessionId, ch1.sessionId) ||
!bytes.Equal(ch.compressionMethods, ch1.compressionMethods) ||
ch.serverName != ch1.serverName ||
ch.ocspStapling != ch1.ocspStapling ||
!bytes.Equal(ch.supportedPoints, ch1.supportedPoints) ||
ch.ticketSupported != ch1.ticketSupported ||
!bytes.Equal(ch.sessionTicket, ch1.sessionTicket) ||
ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported ||
!bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) ||
ch.scts != ch1.scts ||
!bytes.Equal(ch.cookie, ch1.cookie) ||
!bytes.Equal(ch.pskModes, ch1.pskModes)
}
func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
c := hs.c
if c.handshakeStatusAsync >= stateServerHandshake13SendServerParameters {
return nil
}
c.handshakeStatusAsync = stateServerHandshake13SendServerParameters
hs.transcript.Write(hs.clientHello.marshal())
hs.transcript.Write(hs.hello.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
earlySecret := hs.earlySecret
if earlySecret == nil {
earlySecret = hs.suite.extract(nil, nil)
}
hs.handshakeSecret = hs.suite.extract(hs.sharedKey,
hs.suite.deriveSecret(earlySecret, "derived", nil))
clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, clientSecret)
serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript)
c.out.setTrafficSecret(hs.suite, serverSecret)
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
encryptedExtensions := new(encryptedExtensionsMsg)
if len(hs.clientHello.alpnProtocols) > 0 {
if selectedProto := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); selectedProto != "" {
encryptedExtensions.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
}
}
hs.transcript.Write(encryptedExtensions.marshal())
if _, err := c.writeRecord(recordTypeHandshake, encryptedExtensions.marshal()); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) requestClientCert() bool {
return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK
}
func (hs *serverHandshakeStateTLS13) sendServerCertificate() error {
c := hs.c
if c.handshakeStatusAsync >= stateServerHandshake13SendServerCertificate {
return nil
}
c.handshakeStatusAsync = stateServerHandshake13SendServerCertificate
// Only one of PSK and certificates are used at a time.
if hs.usingPSK {
return nil
}
if hs.requestClientCert() {
// Request a client certificate
certReq := new(certificateRequestMsgTLS13)
certReq.ocspStapling = true
certReq.scts = true
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms
if c.config.ClientCAs != nil {
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
hs.transcript.Write(certReq.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
return err
}
}
certMsg := new(certificateMsgTLS13)
certMsg.certificate = *hs.cert
certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0
hs.transcript.Write(certMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err
}
certVerifyMsg := new(certificateVerifyMsg)
certVerifyMsg.hasSignatureAlgorithm = true
certVerifyMsg.signatureAlgorithm = hs.sigAlg
sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg)
if err != nil {
return c.sendAlert(alertInternalError)
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
if err != nil {
public := hs.cert.PrivateKey.(crypto.Signer).Public()
if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS &&
rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS
c.sendAlert(alertHandshakeFailure)
} else {
c.sendAlert(alertInternalError)
}
return errors.New("tls: failed to sign handshake: " + err.Error())
}
certVerifyMsg.signature = sig
hs.transcript.Write(certVerifyMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
c := hs.c
if c.handshakeStatusAsync >= stateServerHandshake13SendServerFinished {
return nil
}
c.handshakeStatusAsync = stateServerHandshake13SendServerFinished
finished := &finishedMsg{
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
hs.transcript.Write(finished.marshal())
if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
return err
}
// Derive secrets that take context through the server Finished.
hs.masterSecret = hs.suite.extract(nil,
hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil))
hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript)
c.out.setTrafficSecret(hs.suite, serverSecret)
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
// If we did not request client certificates, at this point we can
// precompute the client finished and roll the transcript forward to send
// session tickets in our first flight.
if !hs.requestClientCert() {
if err := hs.sendSessionTickets(); err != nil {
return err
}
}
return nil
}
func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool {
if hs.c.config.SessionTicketsDisabled {
return false
}
// Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9.
for _, pskMode := range hs.clientHello.pskModes {
if pskMode == pskModeDHE {
return true
}
}
return false
}
func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
c := hs.c
hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
finishedMsg := &finishedMsg{
verifyData: hs.clientFinished,
}
hs.transcript.Write(finishedMsg.marshal())
if !hs.shouldSendSessionTickets() {
return nil
}
resumptionSecret := hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript)
m := new(newSessionTicketMsgTLS13)
var certsFromClient [][]byte
for _, cert := range c.peerCertificates {
certsFromClient = append(certsFromClient, cert.Raw)
}
state := sessionStateTLS13{
cipherSuite: hs.suite.id,
createdAt: uint64(c.config.time().Unix()),
resumptionSecret: resumptionSecret,
certificate: Certificate{
Certificate: certsFromClient,
OCSPStaple: c.ocspResponse,
SignedCertificateTimestamps: c.scts,
},
}
var err error
m.label, err = c.encryptTicket(state.marshal())
if err != nil {
return err
}
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
c := hs.c
if c.handshakeStatusAsync >= stateServerHandshake13ReadClientCertificate {
return nil
}
if !hs.requestClientCert() {
// Make sure the connection is still being verified whether or not
// the server requested a client certificate.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
return nil
}
// If we requested a client certificate, then the client must send a
// certificate message. If it's empty, no CertificateVerify is sent.
if c.certMsg == nil {
msg, err := c.readHandshake()
if err != nil {
if err != errDataNotEnough {
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
}
return err
}
certMsg, ok := msg.(*certificateMsgTLS13)
if !ok {
c.sendAlert(alertUnexpectedMessage)
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
return unexpectedMessageError(certMsg, msg)
}
c.certMsg = certMsg
hs.transcript.Write(certMsg.marshal())
if err := c.processCertsFromClient(certMsg.certificate); err != nil {
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
return err
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
return err
}
}
}
i := 0
if len(c.certMsg.certificate.Certificate) != 0 {
if len(c.certMsgVerified) == 0 {
c.certMsgVerified = make([]bool, len(c.certMsg.certificate.Certificate))
}
for ; i < len(c.certMsg.certificate.Certificate); i++ {
if c.certMsgVerified[i] {
} else {
msg, err := c.readHandshake()
if err != nil {
if err != errDataNotEnough {
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
}
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
return unexpectedMessageError(certVerify, msg)
}
// See RFC 8446, Section 4.4.3.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) {
c.sendAlert(alertIllegalParameter)
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
return errors.New("tls: client certificate used with invalid signature algorithm")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
return c.sendAlert(alertInternalError)
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
c.sendAlert(alertIllegalParameter)
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
return errors.New("tls: client certificate used with invalid signature algorithm")
}
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
hs.transcript.Write(certVerify.marshal())
c.certMsgVerified[i] = true
}
}
}
// If we waited until the client certificates to send session tickets, we
// are ready to do it now.
if i == len(c.certMsg.certificate.Certificate) {
if err := hs.sendSessionTickets(); err != nil {
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
return err
}
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
return nil
}
c.handshakeStatusAsync = stateServerHandshake13ReadClientCertificate
return nil
}
func (hs *serverHandshakeStateTLS13) readClientFinished() error {
c := hs.c
if c.handshakeStatusAsync >= stateServerHandshake13ReadClientFinished {
return nil
}
msg, err := c.readHandshake()
if err != nil {
if err != errDataNotEnough {
c.handshakeStatusAsync = stateServerHandshake13ReadClientFinished
}
return err
}
finished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
c.handshakeStatusAsync = stateServerHandshake13ReadClientFinished
return unexpectedMessageError(finished, msg)
}
if !hmac.Equal(hs.clientFinished, finished.verifyData) {
c.sendAlert(alertDecryptError)
c.handshakeStatusAsync = stateServerHandshake13ReadClientFinished
return errors.New("tls: invalid client finished hash")
}
c.in.setTrafficSecret(hs.suite, hs.trafficSecret)
c.handshakeStatusAsync = stateServerHandshake13ReadClientFinished
return nil
}

View File

@ -0,0 +1,535 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bufio"
"crypto/ed25519"
"crypto/x509"
"encoding/hex"
"errors"
"flag"
"fmt"
"io"
"net"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
)
// TLS reference tests run a connection against a reference implementation
// (OpenSSL) of TLS and record the bytes of the resulting connection. The Go
// code, during a test, is configured with deterministic randomness and so the
// reference test can be reproduced exactly in the future.
//
// In order to save everyone who wishes to run the tests from needing the
// reference implementation installed, the reference connections are saved in
// files in the testdata directory. Thus running the tests involves nothing
// external, but creating and updating them requires the reference
// implementation.
//
// Tests can be updated by running them with the -update flag. This will cause
// the test files for failing tests to be regenerated. Since the reference
// implementation will always generate fresh random numbers, large parts of the
// reference connection will always change.
var (
update = flag.Bool("update", false, "update golden files on failure")
fast = flag.Bool("fast", false, "impose a quick, possibly flaky timeout on recorded tests")
keyFile = flag.String("keylog", "", "destination file for KeyLogWriter")
)
func runTestAndUpdateIfNeeded(t *testing.T, name string, run func(t *testing.T, update bool), wait bool) {
success := t.Run(name, func(t *testing.T) {
if !*update && !wait {
t.Parallel()
}
run(t, false)
})
if !success && *update {
t.Run(name+"#update", func(t *testing.T) {
run(t, true)
})
}
}
// checkOpenSSLVersion ensures that the version of OpenSSL looks reasonable
// before updating the test data.
func checkOpenSSLVersion() error {
if !*update {
return nil
}
openssl := exec.Command("openssl", "version")
output, err := openssl.CombinedOutput()
if err != nil {
return err
}
version := string(output)
if strings.HasPrefix(version, "OpenSSL 1.1.1") {
return nil
}
println("***********************************************")
println("")
println("You need to build OpenSSL 1.1.1 from source in order")
println("to update the test data.")
println("")
println("Configure it with:")
println("./Configure enable-weak-ssl-ciphers no-shared")
println("and then add the apps/ directory at the front of your PATH.")
println("***********************************************")
return errors.New("version of OpenSSL does not appear to be suitable for updating test data")
}
// recordingConn is a net.Conn that records the traffic that passes through it.
// WriteTo can be used to produce output that can be later be loaded with
// ParseTestData.
type recordingConn struct {
net.Conn
sync.Mutex
flows [][]byte
reading bool
}
func (r *recordingConn) Read(b []byte) (n int, err error) {
if n, err = r.Conn.Read(b); n == 0 {
return
}
b = b[:n]
r.Lock()
defer r.Unlock()
if l := len(r.flows); l == 0 || !r.reading {
buf := make([]byte, len(b))
copy(buf, b)
r.flows = append(r.flows, buf)
} else {
r.flows[l-1] = append(r.flows[l-1], b[:n]...)
}
r.reading = true
return
}
func (r *recordingConn) Write(b []byte) (n int, err error) {
if n, err = r.Conn.Write(b); n == 0 {
return
}
b = b[:n]
r.Lock()
defer r.Unlock()
if l := len(r.flows); l == 0 || r.reading {
buf := make([]byte, len(b))
copy(buf, b)
r.flows = append(r.flows, buf)
} else {
r.flows[l-1] = append(r.flows[l-1], b[:n]...)
}
r.reading = false
return
}
// WriteTo writes Go source code to w that contains the recorded traffic.
func (r *recordingConn) WriteTo(w io.Writer) (int64, error) {
// TLS always starts with a client to server flow.
clientToServer := true
var written int64
for i, flow := range r.flows {
source, dest := "client", "server"
if !clientToServer {
source, dest = dest, source
}
n, err := fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest)
written += int64(n)
if err != nil {
return written, err
}
dumper := hex.Dumper(w)
n, err = dumper.Write(flow)
written += int64(n)
if err != nil {
return written, err
}
err = dumper.Close()
if err != nil {
return written, err
}
clientToServer = !clientToServer
}
return written, nil
}
func parseTestData(r io.Reader) (flows [][]byte, err error) {
var currentFlow []byte
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
// If the line starts with ">>> " then it marks the beginning
// of a new flow.
if strings.HasPrefix(line, ">>> ") {
if len(currentFlow) > 0 || len(flows) > 0 {
flows = append(flows, currentFlow)
currentFlow = nil
}
continue
}
// Otherwise the line is a line of hex dump that looks like:
// 00000170 fc f5 06 bf (...) |.....X{&?......!|
// (Some bytes have been omitted from the middle section.)
if i := strings.IndexByte(line, ' '); i >= 0 {
line = line[i:]
} else {
return nil, errors.New("invalid test data")
}
if i := strings.IndexByte(line, '|'); i >= 0 {
line = line[:i]
} else {
return nil, errors.New("invalid test data")
}
hexBytes := strings.Fields(line)
for _, hexByte := range hexBytes {
val, err := strconv.ParseUint(hexByte, 16, 8)
if err != nil {
return nil, errors.New("invalid hex byte in test data: " + err.Error())
}
currentFlow = append(currentFlow, byte(val))
}
}
if len(currentFlow) > 0 {
flows = append(flows, currentFlow)
}
return flows, nil
}
// tempFile creates a temp file containing contents and returns its path.
func tempFile(contents string) string {
file, err := os.CreateTemp("", "go-tls-test")
if err != nil {
panic("failed to create temp file: " + err.Error())
}
path := file.Name()
file.WriteString(contents)
file.Close()
return path
}
// localListener is set up by TestMain and used by localPipe to create Conn
// pairs like net.Pipe, but connected by an actual buffered TCP connection.
var localListener struct {
mu sync.Mutex
addr net.Addr
ch chan net.Conn
}
const localFlakes = 0 // change to 1 or 2 to exercise localServer/localPipe handling of mismatches
func localServer(l net.Listener) {
for n := 0; ; n++ {
c, err := l.Accept()
if err != nil {
return
}
if localFlakes == 1 && n%2 == 0 {
c.Close()
continue
}
localListener.ch <- c
}
}
var isConnRefused = func(err error) bool { return false }
func localPipe(t testing.TB) (net.Conn, net.Conn) {
localListener.mu.Lock()
defer localListener.mu.Unlock()
addr := localListener.addr
var err error
Dialing:
// We expect a rare mismatch, but probably not 5 in a row.
for i := 0; i < 5; i++ {
tooSlow := time.NewTimer(1 * time.Second)
defer tooSlow.Stop()
var c1 net.Conn
c1, err = net.Dial(addr.Network(), addr.String())
if err != nil {
if runtime.GOOS == "dragonfly" && (isConnRefused(err) || os.IsTimeout(err)) {
// golang.org/issue/29583: Dragonfly sometimes returns a spurious
// ECONNREFUSED or ETIMEDOUT.
<-tooSlow.C
continue
}
t.Fatalf("localPipe: %v", err)
}
if localFlakes == 2 && i == 0 {
c1.Close()
continue
}
for {
select {
case <-tooSlow.C:
t.Logf("localPipe: timeout waiting for %v", c1.LocalAddr())
c1.Close()
continue Dialing
case c2 := <-localListener.ch:
if c2.RemoteAddr().String() == c1.LocalAddr().String() {
return c1, c2
}
t.Logf("localPipe: unexpected connection: %v != %v", c2.RemoteAddr(), c1.LocalAddr())
c2.Close()
}
}
}
t.Fatalf("localPipe: failed to connect: %v", err)
panic("unreachable")
}
// zeroSource is an io.Reader that returns an unlimited number of zero bytes.
type zeroSource struct{}
func (zeroSource) Read(b []byte) (n int, err error) {
for i := range b {
b[i] = 0
}
return len(b), nil
}
func allCipherSuites() []uint16 {
ids := make([]uint16, len(cipherSuites))
for i, suite := range cipherSuites {
ids[i] = suite.id
}
return ids
}
var testConfig *Config
func TestMain(m *testing.M) {
flag.Parse()
os.Exit(runMain(m))
}
func runMain(m *testing.M) int {
// TLS 1.3 cipher suites preferences are not configurable and change based
// on the architecture. Force them to the version with AES acceleration for
// test consistency.
once.Do(initDefaultCipherSuites)
varDefaultCipherSuitesTLS13 = []uint16{
TLS_AES_128_GCM_SHA256,
TLS_CHACHA20_POLY1305_SHA256,
TLS_AES_256_GCM_SHA384,
}
// Set up localPipe.
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
l, err = net.Listen("tcp6", "[::1]:0")
}
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to open local listener: %v", err)
os.Exit(1)
}
localListener.ch = make(chan net.Conn)
localListener.addr = l.Addr()
defer l.Close()
go localServer(l)
if err := checkOpenSSLVersion(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v", err)
os.Exit(1)
}
testConfig = &Config{
Time: func() time.Time { return time.Unix(0, 0) },
Rand: zeroSource{},
Certificates: make([]Certificate, 2),
InsecureSkipVerify: true,
CipherSuites: allCipherSuites(),
}
testConfig.Certificates[0].Certificate = [][]byte{testRSACertificate}
testConfig.Certificates[0].PrivateKey = testRSAPrivateKey
testConfig.Certificates[1].Certificate = [][]byte{testSNICertificate}
testConfig.Certificates[1].PrivateKey = testRSAPrivateKey
testConfig.BuildNameToCertificate()
if *keyFile != "" {
f, err := os.OpenFile(*keyFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
panic("failed to open -keylog file: " + err.Error())
}
testConfig.KeyLogWriter = f
defer f.Close()
}
return m.Run()
}
func testHandshake(t *testing.T, clientConfig, serverConfig *Config) (serverState, clientState ConnectionState, err error) {
const sentinel = "SENTINEL\n"
c, s := localPipe(t)
errChan := make(chan error)
go func() {
cli := Client(c, clientConfig)
err := cli.Handshake()
if err != nil {
errChan <- fmt.Errorf("client: %v", err)
c.Close()
return
}
defer cli.Close()
clientState = cli.ConnectionState()
buf, err := io.ReadAll(cli)
if err != nil {
t.Errorf("failed to call cli.Read: %v", err)
}
if got := string(buf); got != sentinel {
t.Errorf("read %q from TLS connection, but expected %q", got, sentinel)
}
errChan <- nil
}()
server := Server(s, serverConfig)
err = server.Handshake()
if err == nil {
serverState = server.ConnectionState()
if _, err := io.WriteString(server, sentinel); err != nil {
t.Errorf("failed to call server.Write: %v", err)
}
if err := server.Close(); err != nil {
t.Errorf("failed to call server.Close: %v", err)
}
err = <-errChan
} else {
s.Close()
<-errChan
}
return
}
func fromHex(s string) []byte {
b, _ := hex.DecodeString(s)
return b
}
var testRSACertificate = fromHex("3082024b308201b4a003020102020900e8f09d3fe25beaa6300d06092a864886f70d01010b0500301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f74301e170d3136303130313030303030305a170d3235303130313030303030305a301a310b3009060355040a1302476f310b300906035504031302476f30819f300d06092a864886f70d010101050003818d0030818902818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d70203010001a38193308190300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030106082b06010505070302300c0603551d130101ff0402300030190603551d0e041204109f91161f43433e49a6de6db680d79f60301b0603551d230414301280104813494d137e1631bba301d5acab6e7b30190603551d1104123010820e6578616d706c652e676f6c616e67300d06092a864886f70d01010b0500038181009d30cc402b5b50a061cbbae55358e1ed8328a9581aa938a495a1ac315a1a84663d43d32dd90bf297dfd320643892243a00bccf9c7db74020015faad3166109a276fd13c3cce10c5ceeb18782f16c04ed73bbb343778d0c1cf10fa1d8408361c94c722b9daedb4606064df4c1b33ec0d1bd42d4dbfe3d1360845c21d33be9fae7")
var testRSACertificateIssuer = fromHex("3082021930820182a003020102020900ca5e4e811a965964300d06092a864886f70d01010b0500301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f74301e170d3136303130313030303030305a170d3235303130313030303030305a301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f7430819f300d06092a864886f70d010101050003818d0030818902818100d667b378bb22f34143b6cd2008236abefaf2852adf3ab05e01329e2c14834f5105df3f3073f99dab5442d45ee5f8f57b0111c8cb682fbb719a86944eebfffef3406206d898b8c1b1887797c9c5006547bb8f00e694b7a063f10839f269f2c34fff7a1f4b21fbcd6bfdfb13ac792d1d11f277b5c5b48600992203059f2a8f8cc50203010001a35d305b300e0603551d0f0101ff040403020204301d0603551d250416301406082b0601050507030106082b06010505070302300f0603551d130101ff040530030101ff30190603551d0e041204104813494d137e1631bba301d5acab6e7b300d06092a864886f70d01010b050003818100c1154b4bab5266221f293766ae4138899bd4c5e36b13cee670ceeaa4cbdf4f6679017e2fe649765af545749fe4249418a56bd38a04b81e261f5ce86b8d5c65413156a50d12449554748c59a30c515bc36a59d38bddf51173e899820b282e40aa78c806526fd184fb6b4cf186ec728edffa585440d2b3225325f7ab580e87dd76")
// testRSAPSSCertificate has signatureAlgorithm rsassaPss, but subjectPublicKeyInfo
// algorithm rsaEncryption, for use with the rsa_pss_rsae_* SignatureSchemes.
// See also TestRSAPSSKeyError. testRSAPSSCertificate is self-signed.
var testRSAPSSCertificate = fromHex("308202583082018da003020102021100f29926eb87ea8a0db9fcc247347c11b0304106092a864886f70d01010a3034a00f300d06096086480165030402010500a11c301a06092a864886f70d010108300d06096086480165030402010500a20302012030123110300e060355040a130741636d6520436f301e170d3137313132333136313631305a170d3138313132333136313631305a30123110300e060355040a130741636d6520436f30819f300d06092a864886f70d010101050003818d0030818902818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d70203010001a3463044300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000300f0603551d110408300687047f000001304106092a864886f70d01010a3034a00f300d06096086480165030402010500a11c301a06092a864886f70d010108300d06096086480165030402010500a20302012003818100cdac4ef2ce5f8d79881042707f7cbf1b5a8a00ef19154b40151771006cd41626e5496d56da0c1a139fd84695593cb67f87765e18aa03ea067522dd78d2a589b8c92364e12838ce346c6e067b51f1a7e6f4b37ffab13f1411896679d18e880e0ba09e302ac067efca460288e9538122692297ad8093d4f7dd701424d7700a46a1")
var testECDSACertificate = fromHex("3082020030820162020900b8bf2d47a0d2ebf4300906072a8648ce3d04013045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464301e170d3132313132323135303633325a170d3232313132303135303633325a3045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746430819b301006072a8648ce3d020106052b81040023038186000400c4a1edbe98f90b4873367ec316561122f23d53c33b4d213dcd6b75e6f6b0dc9adf26c1bcb287f072327cb3642f1c90bcea6823107efee325c0483a69e0286dd33700ef0462dd0da09c706283d881d36431aa9e9731bd96b068c09b23de76643f1a5c7fe9120e5858b65f70dd9bd8ead5d7f5d5ccb9b69f30665b669a20e227e5bffe3b300906072a8648ce3d040103818c0030818802420188a24febe245c5487d1bacf5ed989dae4770c05e1bb62fbdf1b64db76140d311a2ceee0b7e927eff769dc33b7ea53fcefa10e259ec472d7cacda4e970e15a06fd00242014dfcbe67139c2d050ebd3fa38c25c13313830d9406bbd4377af6ec7ac9862eddd711697f857c56defb31782be4c7780daecbbe9e4e3624317b6a0f399512078f2a")
var testEd25519Certificate = fromHex("3082012e3081e1a00302010202100f431c425793941de987e4f1ad15005d300506032b657030123110300e060355040a130741636d6520436f301e170d3139303531363231333830315a170d3230303531353231333830315a30123110300e060355040a130741636d6520436f302a300506032b65700321003fe2152ee6e3ef3f4e854a7577a3649eede0bf842ccc92268ffa6f3483aaec8fa34d304b300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030160603551d11040f300d820b6578616d706c652e636f6d300506032b65700341006344ed9cc4be5324539fd2108d9fe82108909539e50dc155ff2c16b71dfcab7d4dd4e09313d0a942e0b66bfe5d6748d79f50bc6ccd4b03837cf20858cdaccf0c")
var testSNICertificate = fromHex("0441883421114c81480804c430820237308201a0a003020102020900e8f09d3fe25beaa6300d06092a864886f70d01010b0500301f310b3009060355040a1302476f3110300e06035504031307476f20526f6f74301e170d3136303130313030303030305a170d3235303130313030303030305a3023310b3009060355040a1302476f311430120603550403130b736e69746573742e636f6d30819f300d06092a864886f70d010101050003818d0030818902818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d70203010001a3773075300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030106082b06010505070302300c0603551d130101ff0402300030190603551d0e041204109f91161f43433e49a6de6db680d79f60301b0603551d230414301280104813494d137e1631bba301d5acab6e7b300d06092a864886f70d01010b0500038181007beeecff0230dbb2e7a334af65430b7116e09f327c3bbf918107fc9c66cb497493207ae9b4dbb045cb63d605ec1b5dd485bb69124d68fa298dc776699b47632fd6d73cab57042acb26f083c4087459bc5a3bb3ca4d878d7fe31016b7bc9a627438666566e3389bfaeebe6becc9a0093ceed18d0f9ac79d56f3a73f18188988ed")
var testP256Certificate = fromHex("308201693082010ea00302010202105012dc24e1124ade4f3e153326ff27bf300a06082a8648ce3d04030230123110300e060355040a130741636d6520436f301e170d3137303533313232343934375a170d3138303533313232343934375a30123110300e060355040a130741636d6520436f3059301306072a8648ce3d020106082a8648ce3d03010703420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75a3463044300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000300f0603551d1104083006820474657374300a06082a8648ce3d0403020349003046022100963712d6226c7b2bef41512d47e1434131aaca3ba585d666c924df71ac0448b3022100f4d05c725064741aef125f243cdbccaa2a5d485927831f221c43023bd5ae471a")
var testRSAPrivateKey, _ = x509.ParsePKCS1PrivateKey(fromHex("3082025b02010002818100db467d932e12270648bc062821ab7ec4b6a25dfe1e5245887a3647a5080d92425bc281c0be97799840fb4f6d14fd2b138bc2a52e67d8d4099ed62238b74a0b74732bc234f1d193e596d9747bf3589f6c613cc0b041d4d92b2b2423775b1c3bbd755dce2054cfa163871d1e24c4f31d1a508baab61443ed97a77562f414c852d702030100010281800b07fbcf48b50f1388db34b016298b8217f2092a7c9a04f77db6775a3d1279b62ee9951f7e371e9de33f015aea80660760b3951dc589a9f925ed7de13e8f520e1ccbc7498ce78e7fab6d59582c2386cc07ed688212a576ff37833bd5943483b5554d15a0b9b4010ed9bf09f207e7e9805f649240ed6c1256ed75ab7cd56d9671024100fded810da442775f5923debae4ac758390a032a16598d62f059bb2e781a9c2f41bfa015c209f966513fe3bf5a58717cbdb385100de914f88d649b7d15309fa49024100dd10978c623463a1802c52f012cfa72ff5d901f25a2292446552c2568b1840e49a312e127217c2186615aae4fb6602a4f6ebf3f3d160f3b3ad04c592f65ae41f02400c69062ca781841a09de41ed7a6d9f54adc5d693a2c6847949d9e1358555c9ac6a8d9e71653ac77beb2d3abaf7bb1183aa14278956575dbebf525d0482fd72d90240560fe1900ba36dae3022115fd952f2399fb28e2975a1c3e3d0b679660bdcb356cc189d611cfdd6d87cd5aea45aa30a2082e8b51e94c2f3dd5d5c6036a8a615ed0240143993d80ece56f877cb80048335701eb0e608cc0c1ca8c2227b52edf8f1ac99c562f2541b5ce81f0515af1c5b4770dba53383964b4b725ff46fdec3d08907df"))
var testECDSAPrivateKey, _ = x509.ParseECPrivateKey(fromHex("3081dc0201010442019883e909ad0ac9ea3d33f9eae661f1785206970f8ca9a91672f1eedca7a8ef12bd6561bb246dda5df4b4d5e7e3a92649bc5d83a0bf92972e00e62067d0c7bd99d7a00706052b81040023a18189038186000400c4a1edbe98f90b4873367ec316561122f23d53c33b4d213dcd6b75e6f6b0dc9adf26c1bcb287f072327cb3642f1c90bcea6823107efee325c0483a69e0286dd33700ef0462dd0da09c706283d881d36431aa9e9731bd96b068c09b23de76643f1a5c7fe9120e5858b65f70dd9bd8ead5d7f5d5ccb9b69f30665b669a20e227e5bffe3b"))
var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75"))
var testEd25519PrivateKey = ed25519.PrivateKey(fromHex("3a884965e76b3f55e5faf9615458a92354894234de3ec9f684d46d55cebf3dc63fe2152ee6e3ef3f4e854a7577a3649eede0bf842ccc92268ffa6f3483aaec8f"))
const clientCertificatePEM = `
-----BEGIN CERTIFICATE-----
MIIB7zCCAVigAwIBAgIQXBnBiWWDVW/cC8m5k5/pvDANBgkqhkiG9w0BAQsFADAS
MRAwDgYDVQQKEwdBY21lIENvMB4XDTE2MDgxNzIxNTIzMVoXDTE3MDgxNzIxNTIz
MVowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkC
gYEAum+qhr3Pv5/y71yUYHhv6BPy0ZZvzdkybiI3zkH5yl0prOEn2mGi7oHLEMff
NFiVhuk9GeZcJ3NgyI14AvQdpJgJoxlwaTwlYmYqqyIjxXuFOE8uCXMyp70+m63K
hAfmDzr/d8WdQYUAirab7rCkPy1MTOZCPrtRyN1IVPQMjkcCAwEAAaNGMEQwDgYD
VR0PAQH/BAQDAgWgMBMGA1UdJQQMMAoGCCsGAQUFBwMBMAwGA1UdEwEB/wQCMAAw
DwYDVR0RBAgwBocEfwAAATANBgkqhkiG9w0BAQsFAAOBgQBGq0Si+yhU+Fpn+GKU
8ZqyGJ7ysd4dfm92lam6512oFmyc9wnTN+RLKzZ8Aa1B0jLYw9KT+RBrjpW5LBeK
o0RIvFkTgxYEiKSBXCUNmAysEbEoVr4dzWFihAm/1oDGRY2CLLTYg5vbySK3KhIR
e/oCO8HJ/+rJnahJ05XX1Q7lNQ==
-----END CERTIFICATE-----`
var clientKeyPEM = testingKey(`
-----BEGIN RSA TESTING KEY-----
MIICXQIBAAKBgQC6b6qGvc+/n/LvXJRgeG/oE/LRlm/N2TJuIjfOQfnKXSms4Sfa
YaLugcsQx980WJWG6T0Z5lwnc2DIjXgC9B2kmAmjGXBpPCViZiqrIiPFe4U4Ty4J
czKnvT6brcqEB+YPOv93xZ1BhQCKtpvusKQ/LUxM5kI+u1HI3UhU9AyORwIDAQAB
AoGAEJZ03q4uuMb7b26WSQsOMeDsftdatT747LGgs3pNRkMJvTb/O7/qJjxoG+Mc
qeSj0TAZXp+PXXc3ikCECAc+R8rVMfWdmp903XgO/qYtmZGCorxAHEmR80SrfMXv
PJnznLQWc8U9nphQErR+tTESg7xWEzmFcPKwnZd1xg8ERYkCQQDTGtrFczlB2b/Z
9TjNMqUlMnTLIk/a/rPE2fLLmAYhK5sHnJdvDURaH2mF4nso0EGtENnTsh6LATnY
dkrxXGm9AkEA4hXHG2q3MnhgK1Z5hjv+Fnqd+8bcbII9WW4flFs15EKoMgS1w/PJ
zbsySaSy5IVS8XeShmT9+3lrleed4sy+UwJBAJOOAbxhfXP5r4+5R6ql66jES75w
jUCVJzJA5ORJrn8g64u2eGK28z/LFQbv9wXgCwfc72R468BdawFSLa/m2EECQGbZ
rWiFla26IVXV0xcD98VWJsTBZMlgPnSOqoMdM1kSEd4fUmlAYI/dFzV1XYSkOmVr
FhdZnklmpVDeu27P4c0CQQCuCOup0FlJSBpWY1TTfun/KMBkBatMz0VMA3d7FKIU
csPezl677Yjo8u1r/KzeI6zLg87Z8E6r6ZWNc9wBSZK6
-----END RSA TESTING KEY-----`)
const clientECDSACertificatePEM = `
-----BEGIN CERTIFICATE-----
MIIB/DCCAV4CCQCaMIRsJjXZFzAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw
EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0
eSBMdGQwHhcNMTIxMTE0MTMyNTUzWhcNMjIxMTEyMTMyNTUzWjBBMQswCQYDVQQG
EwJBVTEMMAoGA1UECBMDTlNXMRAwDgYDVQQHEwdQeXJtb250MRIwEAYDVQQDEwlK
b2VsIFNpbmcwgZswEAYHKoZIzj0CAQYFK4EEACMDgYYABACVjJF1FMBexFe01MNv
ja5oHt1vzobhfm6ySD6B5U7ixohLZNz1MLvT/2XMW/TdtWo+PtAd3kfDdq0Z9kUs
jLzYHQFMH3CQRnZIi4+DzEpcj0B22uCJ7B0rxE4wdihBsmKo+1vx+U56jb0JuK7q
ixgnTy5w/hOWusPTQBbNZU6sER7m8TAJBgcqhkjOPQQBA4GMADCBiAJCAOAUxGBg
C3JosDJdYUoCdFzCgbkWqD8pyDbHgf9stlvZcPE4O1BIKJTLCRpS8V3ujfK58PDa
2RU6+b0DeoeiIzXsAkIBo9SKeDUcSpoj0gq+KxAxnZxfvuiRs9oa9V2jI/Umi0Vw
jWVim34BmT0Y9hCaOGGbLlfk+syxis7iI6CH8OFnUes=
-----END CERTIFICATE-----`
var clientECDSAKeyPEM = testingKey(`
-----BEGIN EC PARAMETERS-----
BgUrgQQAIw==
-----END EC PARAMETERS-----
-----BEGIN EC TESTING KEY-----
MIHcAgEBBEIBkJN9X4IqZIguiEVKMqeBUP5xtRsEv4HJEtOpOGLELwO53SD78Ew8
k+wLWoqizS3NpQyMtrU8JFdWfj+C57UNkOugBwYFK4EEACOhgYkDgYYABACVjJF1
FMBexFe01MNvja5oHt1vzobhfm6ySD6B5U7ixohLZNz1MLvT/2XMW/TdtWo+PtAd
3kfDdq0Z9kUsjLzYHQFMH3CQRnZIi4+DzEpcj0B22uCJ7B0rxE4wdihBsmKo+1vx
+U56jb0JuK7qixgnTy5w/hOWusPTQBbNZU6sER7m8Q==
-----END EC TESTING KEY-----`)
const clientEd25519CertificatePEM = `
-----BEGIN CERTIFICATE-----
MIIBLjCB4aADAgECAhAX0YGTviqMISAQJRXoNCNPMAUGAytlcDASMRAwDgYDVQQK
EwdBY21lIENvMB4XDTE5MDUxNjIxNTQyNloXDTIwMDUxNTIxNTQyNlowEjEQMA4G
A1UEChMHQWNtZSBDbzAqMAUGAytlcAMhAAvgtWC14nkwPb7jHuBQsQTIbcd4bGkv
xRStmmNveRKRo00wSzAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUH
AwIwDAYDVR0TAQH/BAIwADAWBgNVHREEDzANggtleGFtcGxlLmNvbTAFBgMrZXAD
QQD8GRcqlKUx+inILn9boF2KTjRAOdazENwZ/qAicbP1j6FYDc308YUkv+Y9FN/f
7Q7hF9gRomDQijcjKsJGqjoI
-----END CERTIFICATE-----`
var clientEd25519KeyPEM = testingKey(`
-----BEGIN TESTING KEY-----
MC4CAQAwBQYDK2VwBCIEINifzf07d9qx3d44e0FSbV4mC/xQxT644RRbpgNpin7I
-----END TESTING KEY-----`)

View File

@ -0,0 +1,18 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package tls
import (
"errors"
"syscall"
)
func init() {
isConnRefused = func(err error) bool {
return errors.Is(err, syscall.ECONNREFUSED)
}
}

View File

@ -0,0 +1,334 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto"
"crypto/md5"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"errors"
"fmt"
"io"
)
var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
// rsaKeyAgreement implements the standard TLS key agreement where the client
// encrypts the pre-master secret to the server's public key.
type rsaKeyAgreement struct{}
func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
return nil, nil
}
func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) < 2 {
return nil, errClientKeyExchange
}
ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
if ciphertextLen != len(ckx.ciphertext)-2 {
return nil, errClientKeyExchange
}
ciphertext := ckx.ciphertext[2:]
priv, ok := cert.PrivateKey.(crypto.Decrypter)
if !ok {
return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
}
// Perform constant time RSA PKCS #1 v1.5 decryption
preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48})
if err != nil {
return nil, err
}
// We don't check the version number in the premaster secret. For one,
// by checking it, we would leak information about the validity of the
// encrypted pre-master secret. Secondly, it provides only a small
// benefit against a downgrade attack and some implementations send the
// wrong version anyway. See the discussion at the end of section
// 7.4.7.1 of RFC 4346.
return preMasterSecret, nil
}
func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
return errors.New("tls: unexpected ServerKeyExchange")
}
func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
preMasterSecret := make([]byte, 48)
preMasterSecret[0] = byte(clientHello.vers >> 8)
preMasterSecret[1] = byte(clientHello.vers)
_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
if err != nil {
return nil, nil, err
}
encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret)
if err != nil {
return nil, nil, err
}
ckx := new(clientKeyExchangeMsg)
ckx.ciphertext = make([]byte, len(encrypted)+2)
ckx.ciphertext[0] = byte(len(encrypted) >> 8)
ckx.ciphertext[1] = byte(len(encrypted))
copy(ckx.ciphertext[2:], encrypted)
return preMasterSecret, ckx, nil
}
// sha1Hash calculates a SHA1 hash over the given byte slices.
func sha1Hash(slices [][]byte) []byte {
hsha1 := sha1.New()
for _, slice := range slices {
hsha1.Write(slice)
}
return hsha1.Sum(nil)
}
// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
// concatenation of an MD5 and SHA1 hash.
func md5SHA1Hash(slices [][]byte) []byte {
md5sha1 := make([]byte, md5.Size+sha1.Size)
hmd5 := md5.New()
for _, slice := range slices {
hmd5.Write(slice)
}
copy(md5sha1, hmd5.Sum(nil))
copy(md5sha1[md5.Size:], sha1Hash(slices))
return md5sha1
}
// hashForServerKeyExchange hashes the given slices and returns their digest
// using the given hash function (for >= TLS 1.2) or using a default based on
// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't
// do pre-hashing, it returns the concatenation of the slices.
func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte {
if sigType == signatureEd25519 {
var signed []byte
for _, slice := range slices {
signed = append(signed, slice...)
}
return signed
}
if version >= VersionTLS12 {
h := hashFunc.New()
for _, slice := range slices {
h.Write(slice)
}
digest := h.Sum(nil)
return digest
}
if sigType == signatureECDSA {
return sha1Hash(slices)
}
return md5SHA1Hash(slices)
}
// ecdheKeyAgreement implements a TLS key agreement where the server
// generates an ephemeral EC public/private key pair and signs it. The
// pre-master secret is then calculated using ECDH. The signature may
// be ECDSA, Ed25519 or RSA.
type ecdheKeyAgreement struct {
version uint16
isRSA bool
params ecdheParameters
// ckx and preMasterSecret are generated in processServerKeyExchange
// and returned in generateClientKeyExchange.
ckx *clientKeyExchangeMsg
preMasterSecret []byte
}
func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
var curveID CurveID
for _, c := range clientHello.supportedCurves {
if config.supportsCurve(c) {
curveID = c
break
}
}
if curveID == 0 {
return nil, errors.New("tls: no supported elliptic curves offered")
}
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
return nil, errors.New("tls: CurvePreferences includes unsupported curve")
}
params, err := generateECDHEParameters(config.rand(), curveID)
if err != nil {
return nil, err
}
ka.params = params
// See RFC 4492, Section 5.4.
ecdhePublic := params.PublicKey()
serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic))
serverECDHEParams[0] = 3 // named curve
serverECDHEParams[1] = byte(curveID >> 8)
serverECDHEParams[2] = byte(curveID)
serverECDHEParams[3] = byte(len(ecdhePublic))
copy(serverECDHEParams[4:], ecdhePublic)
priv, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey)
}
var signatureAlgorithm SignatureScheme
var sigType uint8
var sigHash crypto.Hash
if ka.version >= VersionTLS12 {
signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms)
if err != nil {
return nil, err
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
return nil, err
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public())
if err != nil {
return nil, err
}
}
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
}
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := priv.Sign(config.rand(), signed, signOpts)
if err != nil {
return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
}
skx := new(serverKeyExchangeMsg)
sigAndHashLen := 0
if ka.version >= VersionTLS12 {
sigAndHashLen = 2
}
skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig))
copy(skx.key, serverECDHEParams)
k := skx.key[len(serverECDHEParams):]
if ka.version >= VersionTLS12 {
k[0] = byte(signatureAlgorithm >> 8)
k[1] = byte(signatureAlgorithm)
k = k[2:]
}
k[0] = byte(len(sig) >> 8)
k[1] = byte(len(sig))
copy(k[2:], sig)
return skx, nil
}
func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
return nil, errClientKeyExchange
}
preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:])
if preMasterSecret == nil {
return nil, errClientKeyExchange
}
return preMasterSecret, nil
}
func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
if len(skx.key) < 4 {
return errServerKeyExchange
}
if skx.key[0] != 3 { // named curve
return errors.New("tls: server selected unsupported curve")
}
curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
publicLen := int(skx.key[3])
if publicLen+4 > len(skx.key) {
return errServerKeyExchange
}
serverECDHEParams := skx.key[:4+publicLen]
publicKey := serverECDHEParams[4:]
sig := skx.key[4+publicLen:]
if len(sig) < 2 {
return errServerKeyExchange
}
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
return errors.New("tls: server selected unsupported curve")
}
params, err := generateECDHEParameters(config.rand(), curveID)
if err != nil {
return err
}
ka.params = params
ka.preMasterSecret = params.SharedKey(publicKey)
if ka.preMasterSecret == nil {
return errServerKeyExchange
}
ourPublicKey := params.PublicKey()
ka.ckx = new(clientKeyExchangeMsg)
ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey))
ka.ckx.ciphertext[0] = byte(len(ourPublicKey))
copy(ka.ckx.ciphertext[1:], ourPublicKey)
var sigType uint8
var sigHash crypto.Hash
if ka.version >= VersionTLS12 {
signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
sig = sig[2:]
if len(sig) < 2 {
return errServerKeyExchange
}
if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) {
return errors.New("tls: certificate used with invalid signature algorithm")
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
return err
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey)
if err != nil {
return err
}
}
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
return errServerKeyExchange
}
sigLen := int(sig[0])<<8 | int(sig[1])
if sigLen+2 != len(sig) {
return errServerKeyExchange
}
sig = sig[2:]
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams)
if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil {
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
return nil
}
func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
if ka.ckx == nil {
return nil, nil, errors.New("tls: missing ServerKeyExchange message")
}
return ka.preMasterSecret, ka.ckx, nil
}

View File

@ -0,0 +1,199 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto/elliptic"
"crypto/hmac"
"errors"
"hash"
"io"
"math/big"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/hkdf"
)
// This file contains the functions necessary to compute the TLS 1.3 key
// schedule. See RFC 8446, Section 7.
const (
resumptionBinderLabel = "res binder"
clientHandshakeTrafficLabel = "c hs traffic"
serverHandshakeTrafficLabel = "s hs traffic"
clientApplicationTrafficLabel = "c ap traffic"
serverApplicationTrafficLabel = "s ap traffic"
exporterLabel = "exp master"
resumptionLabel = "res master"
trafficUpdateLabel = "traffic upd"
)
// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte {
var hkdfLabel cryptobyte.Builder
hkdfLabel.AddUint16(uint16(length))
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte("tls13 "))
b.AddBytes([]byte(label))
})
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(context)
})
out := make([]byte, length)
n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out)
if err != nil || n != length {
panic("tls: HKDF-Expand-Label invocation failed unexpectedly")
}
return out
}
// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1.
func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte {
if transcript == nil {
transcript = c.hash.New()
}
return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size())
}
// extract implements HKDF-Extract with the cipher suite hash.
func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte {
if newSecret == nil {
newSecret = make([]byte, c.hash.Size())
}
return hkdf.Extract(c.hash.New, newSecret, currentSecret)
}
// nextTrafficSecret generates the next traffic secret, given the current one,
// according to RFC 8446, Section 7.2.
func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte {
return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size())
}
// trafficKey generates traffic keys according to RFC 8446, Section 7.3.
func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) {
key = c.expandLabel(trafficSecret, "key", nil, c.keyLen)
iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength)
return
}
// finishedHash generates the Finished verify_data or PskBinderEntry according
// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey
// selection.
func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte {
finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size())
verifyData := hmac.New(c.hash.New, finishedKey)
verifyData.Write(transcript.Sum(nil))
return verifyData.Sum(nil)
}
// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to
// RFC 8446, Section 7.5.
func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) {
expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript)
return func(label string, context []byte, length int) ([]byte, error) {
secret := c.deriveSecret(expMasterSecret, label, nil)
h := c.hash.New()
h.Write(context)
return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil
}
}
// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519,
// according to RFC 8446, Section 4.2.8.2.
type ecdheParameters interface {
CurveID() CurveID
PublicKey() []byte
SharedKey(peerPublicKey []byte) []byte
}
func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) {
if curveID == X25519 {
privateKey := make([]byte, curve25519.ScalarSize)
if _, err := io.ReadFull(rand, privateKey); err != nil {
return nil, err
}
publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint)
if err != nil {
return nil, err
}
return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil
}
curve, ok := curveForCurveID(curveID)
if !ok {
return nil, errors.New("tls: internal error: unsupported curve")
}
p := &nistParameters{curveID: curveID}
var err error
p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand)
if err != nil {
return nil, err
}
return p, nil
}
func curveForCurveID(id CurveID) (elliptic.Curve, bool) {
switch id {
case CurveP256:
return elliptic.P256(), true
case CurveP384:
return elliptic.P384(), true
case CurveP521:
return elliptic.P521(), true
default:
return nil, false
}
}
type nistParameters struct {
privateKey []byte
x, y *big.Int // public key
curveID CurveID
}
func (p *nistParameters) CurveID() CurveID {
return p.curveID
}
func (p *nistParameters) PublicKey() []byte {
curve, _ := curveForCurveID(p.curveID)
return elliptic.Marshal(curve, p.x, p.y)
}
func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte {
curve, _ := curveForCurveID(p.curveID)
// Unmarshal also checks whether the given point is on the curve.
x, y := elliptic.Unmarshal(curve, peerPublicKey)
if x == nil {
return nil
}
xShared, _ := curve.ScalarMult(x, y, p.privateKey)
sharedKey := make([]byte, (curve.Params().BitSize+7)/8)
return xShared.FillBytes(sharedKey)
}
type x25519Parameters struct {
privateKey []byte
publicKey []byte
}
func (p *x25519Parameters) CurveID() CurveID {
return X25519
}
func (p *x25519Parameters) PublicKey() []byte {
return p.publicKey[:]
}
func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte {
sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey)
if err != nil {
return nil
}
return sharedKey
}

View File

@ -0,0 +1,175 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"encoding/hex"
"hash"
"strings"
"testing"
"unicode"
)
// This file contains tests derived from draft-ietf-tls-tls13-vectors-07.
func parseVector(v string) []byte {
v = strings.Map(func(c rune) rune {
if unicode.IsSpace(c) {
return -1
}
return c
}, v)
parts := strings.Split(v, ":")
v = parts[len(parts)-1]
res, err := hex.DecodeString(v)
if err != nil {
panic(err)
}
return res
}
func TestDeriveSecret(t *testing.T) {
chTranscript := cipherSuitesTLS13[0].hash.New()
chTranscript.Write(parseVector(`
payload (512 octets): 01 00 01 fc 03 03 1b c3 ce b6 bb e3 9c ff
93 83 55 b5 a5 0a db 6d b2 1b 7a 6a f6 49 d7 b4 bc 41 9d 78 76
48 7d 95 00 00 06 13 01 13 03 13 02 01 00 01 cd 00 00 00 0b 00
09 00 00 06 73 65 72 76 65 72 ff 01 00 01 00 00 0a 00 14 00 12
00 1d 00 17 00 18 00 19 01 00 01 01 01 02 01 03 01 04 00 33 00
26 00 24 00 1d 00 20 e4 ff b6 8a c0 5f 8d 96 c9 9d a2 66 98 34
6c 6b e1 64 82 ba dd da fe 05 1a 66 b4 f1 8d 66 8f 0b 00 2a 00
00 00 2b 00 03 02 03 04 00 0d 00 20 00 1e 04 03 05 03 06 03 02
03 08 04 08 05 08 06 04 01 05 01 06 01 02 01 04 02 05 02 06 02
02 02 00 2d 00 02 01 01 00 1c 00 02 40 01 00 15 00 57 00 00 00
00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
00 29 00 dd 00 b8 00 b2 2c 03 5d 82 93 59 ee 5f f7 af 4e c9 00
00 00 00 26 2a 64 94 dc 48 6d 2c 8a 34 cb 33 fa 90 bf 1b 00 70
ad 3c 49 88 83 c9 36 7c 09 a2 be 78 5a bc 55 cd 22 60 97 a3 a9
82 11 72 83 f8 2a 03 a1 43 ef d3 ff 5d d3 6d 64 e8 61 be 7f d6
1d 28 27 db 27 9c ce 14 50 77 d4 54 a3 66 4d 4e 6d a4 d2 9e e0
37 25 a6 a4 da fc d0 fc 67 d2 ae a7 05 29 51 3e 3d a2 67 7f a5
90 6c 5b 3f 7d 8f 92 f2 28 bd a4 0d da 72 14 70 f9 fb f2 97 b5
ae a6 17 64 6f ac 5c 03 27 2e 97 07 27 c6 21 a7 91 41 ef 5f 7d
e6 50 5e 5b fb c3 88 e9 33 43 69 40 93 93 4a e4 d3 57 fa d6 aa
cb 00 21 20 3a dd 4f b2 d8 fd f8 22 a0 ca 3c f7 67 8e f5 e8 8d
ae 99 01 41 c5 92 4d 57 bb 6f a3 1b 9e 5f 9d`))
type args struct {
secret []byte
label string
transcript hash.Hash
}
tests := []struct {
name string
args args
want []byte
}{
{
`derive secret for handshake "tls13 derived"`,
args{
parseVector(`PRK (32 octets): 33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c e2
10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a`),
"derived",
nil,
},
parseVector(`expanded (32 octets): 6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba
b6 97 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba`),
},
{
`derive secret "tls13 c e traffic"`,
args{
parseVector(`PRK (32 octets): 9b 21 88 e9 b2 fc 6d 64 d7 1d c3 29 90 0e 20 bb
41 91 50 00 f6 78 aa 83 9c bb 79 7c b7 d8 33 2c`),
"c e traffic",
chTranscript,
},
parseVector(`expanded (32 octets): 3f bb e6 a6 0d eb 66 c3 0a 32 79 5a ba 0e
ff 7e aa 10 10 55 86 e7 be 5c 09 67 8d 63 b6 ca ab 62`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := cipherSuitesTLS13[0]
if got := c.deriveSecret(tt.args.secret, tt.args.label, tt.args.transcript); !bytes.Equal(got, tt.want) {
t.Errorf("cipherSuiteTLS13.deriveSecret() = % x, want % x", got, tt.want)
}
})
}
}
func TestTrafficKey(t *testing.T) {
trafficSecret := parseVector(
`PRK (32 octets): b6 7b 7d 69 0c c1 6c 4e 75 e5 42 13 cb 2d 37 b4
e9 c9 12 bc de d9 10 5d 42 be fd 59 d3 91 ad 38`)
wantKey := parseVector(
`key expanded (16 octets): 3f ce 51 60 09 c2 17 27 d0 f2 e4 e8 6e
e4 03 bc`)
wantIV := parseVector(
`iv expanded (12 octets): 5d 31 3e b2 67 12 76 ee 13 00 0b 30`)
c := cipherSuitesTLS13[0]
gotKey, gotIV := c.trafficKey(trafficSecret)
if !bytes.Equal(gotKey, wantKey) {
t.Errorf("cipherSuiteTLS13.trafficKey() gotKey = % x, want % x", gotKey, wantKey)
}
if !bytes.Equal(gotIV, wantIV) {
t.Errorf("cipherSuiteTLS13.trafficKey() gotIV = % x, want % x", gotIV, wantIV)
}
}
func TestExtract(t *testing.T) {
type args struct {
newSecret []byte
currentSecret []byte
}
tests := []struct {
name string
args args
want []byte
}{
{
`extract secret "early"`,
args{
nil,
nil,
},
parseVector(`secret (32 octets): 33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c
e2 10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a`),
},
{
`extract secret "master"`,
args{
nil,
parseVector(`salt (32 octets): 43 de 77 e0 c7 77 13 85 9a 94 4d b9 db 25 90 b5
31 90 a6 5b 3e e2 e4 f1 2d d7 a0 bb 7c e2 54 b4`),
},
parseVector(`secret (32 octets): 18 df 06 84 3d 13 a0 8b f2 a4 49 84 4c 5f 8a
47 80 01 bc 4d 4c 62 79 84 d5 a4 1d a8 d0 40 29 19`),
},
{
`extract secret "handshake"`,
args{
parseVector(`IKM (32 octets): 8b d4 05 4f b5 5b 9d 63 fd fb ac f9 f0 4b 9f 0d
35 e6 d6 3f 53 75 63 ef d4 62 72 90 0f 89 49 2d`),
parseVector(`salt (32 octets): 6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba b6 97
16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba`),
},
parseVector(`secret (32 octets): 1d c8 26 e9 36 06 aa 6f dc 0a ad c1 2f 74 1b
01 04 6a a6 b9 9f 69 1e d2 21 a9 f0 ca 04 3f be ac`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := cipherSuitesTLS13[0]
if got := c.extract(tt.args.newSecret, tt.args.currentSecret); !bytes.Equal(got, tt.want) {
t.Errorf("cipherSuiteTLS13.extract() = % x, want % x", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,108 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"os"
"os/exec"
"path/filepath"
"testing"
"github.com/lesismal/llib/std/internal/testenv"
)
// Tests that the linker is able to remove references to the Client or Server if unused.
func TestLinkerGC(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
t.Parallel()
goBin := testenv.GoToolPath(t)
testenv.MustHaveGoBuild(t)
tests := []struct {
name string
program string
want []string
bad []string
}{
{
name: "empty_import",
program: `package main
import _ "crypto/tls"
func main() {}
`,
bad: []string{
"tls.(*Conn)",
"type.crypto/tls.clientHandshakeState",
"type.crypto/tls.serverHandshakeState",
},
},
{
name: "client_and_server",
program: `package main
import "crypto/tls"
func main() {
tls.Dial("", "", nil)
tls.Server(nil, nil)
}
`,
want: []string{
"crypto/tls.(*Conn).clientHandshake",
"crypto/tls.(*Conn).serverHandshake",
},
},
{
name: "only_client",
program: `package main
import "crypto/tls"
func main() { tls.Dial("", "", nil) }
`,
want: []string{
"crypto/tls.(*Conn).clientHandshake",
},
bad: []string{
"crypto/tls.(*Conn).serverHandshake",
},
},
// TODO: add only_server like func main() { tls.Server(nil, nil) }
// That currently brings in the client via Conn.handleRenegotiation.
}
tmpDir := t.TempDir()
goFile := filepath.Join(tmpDir, "x.go")
exeFile := filepath.Join(tmpDir, "x.exe")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := os.WriteFile(goFile, []byte(tt.program), 0644); err != nil {
t.Fatal(err)
}
os.Remove(exeFile)
cmd := exec.Command(goBin, "build", "-o", "x.exe", "x.go")
cmd.Dir = tmpDir
if out, err := cmd.CombinedOutput(); err != nil {
t.Fatalf("compile: %v, %s", err, out)
}
cmd = exec.Command(goBin, "tool", "nm", "x.exe")
cmd.Dir = tmpDir
nm, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("nm: %v, %s", err, nm)
}
for _, sym := range tt.want {
if !bytes.Contains(nm, []byte(sym)) {
t.Errorf("expected symbol %q not found", sym)
}
}
for _, sym := range tt.bad {
if bytes.Contains(nm, []byte(sym)) {
t.Errorf("unexpected symbol %q found", sym)
}
}
})
}
}

283
vendor/github.com/lesismal/llib/std/crypto/tls/prf.go generated vendored Normal file
View File

@ -0,0 +1,283 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto"
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"errors"
"fmt"
"hash"
)
// Split a premaster secret in two as specified in RFC 4346, Section 5.
func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
s1 = secret[0 : (len(secret)+1)/2]
s2 = secret[len(secret)/2:]
return
}
// pHash implements the P_hash function, as defined in RFC 4346, Section 5.
func pHash(result, secret, seed []byte, hash func() hash.Hash) {
h := hmac.New(hash, secret)
h.Write(seed)
a := h.Sum(nil)
j := 0
for j < len(result) {
h.Reset()
h.Write(a)
h.Write(seed)
b := h.Sum(nil)
copy(result[j:], b)
j += len(b)
h.Reset()
h.Write(a)
a = h.Sum(nil)
}
}
// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5.
func prf10(result, secret, label, seed []byte) {
hashSHA1 := sha1.New
hashMD5 := md5.New
labelAndSeed := make([]byte, len(label)+len(seed))
copy(labelAndSeed, label)
copy(labelAndSeed[len(label):], seed)
s1, s2 := splitPreMasterSecret(secret)
pHash(result, s1, labelAndSeed, hashMD5)
result2 := make([]byte, len(result))
pHash(result2, s2, labelAndSeed, hashSHA1)
for i, b := range result2 {
result[i] ^= b
}
}
// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5.
func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) {
return func(result, secret, label, seed []byte) {
labelAndSeed := make([]byte, len(label)+len(seed))
copy(labelAndSeed, label)
copy(labelAndSeed[len(label):], seed)
pHash(result, secret, labelAndSeed, hashFunc)
}
}
const (
masterSecretLength = 48 // Length of a master secret in TLS 1.1.
finishedVerifyLength = 12 // Length of verify_data in a Finished message.
)
var masterSecretLabel = []byte("master secret")
var keyExpansionLabel = []byte("key expansion")
var clientFinishedLabel = []byte("client finished")
var serverFinishedLabel = []byte("server finished")
func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) {
switch version {
case VersionTLS10, VersionTLS11:
return prf10, crypto.Hash(0)
case VersionTLS12:
if suite.flags&suiteSHA384 != 0 {
return prf12(sha512.New384), crypto.SHA384
}
return prf12(sha256.New), crypto.SHA256
default:
panic("unknown version")
}
}
func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) {
prf, _ := prfAndHashForVersion(version, suite)
return prf
}
// masterFromPreMasterSecret generates the master secret from the pre-master
// secret. See RFC 5246, Section 8.1.
func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte {
seed := make([]byte, 0, len(clientRandom)+len(serverRandom))
seed = append(seed, clientRandom...)
seed = append(seed, serverRandom...)
masterSecret := make([]byte, masterSecretLength)
prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed)
return masterSecret
}
// keysFromMasterSecret generates the connection keys from the master
// secret, given the lengths of the MAC key, cipher key and IV, as defined in
// RFC 2246, Section 6.3.
func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
seed := make([]byte, 0, len(serverRandom)+len(clientRandom))
seed = append(seed, serverRandom...)
seed = append(seed, clientRandom...)
n := 2*macLen + 2*keyLen + 2*ivLen
keyMaterial := make([]byte, n)
prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed)
clientMAC = keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
serverMAC = keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
clientKey = keyMaterial[:keyLen]
keyMaterial = keyMaterial[keyLen:]
serverKey = keyMaterial[:keyLen]
keyMaterial = keyMaterial[keyLen:]
clientIV = keyMaterial[:ivLen]
keyMaterial = keyMaterial[ivLen:]
serverIV = keyMaterial[:ivLen]
return
}
func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash {
var buffer []byte
if version >= VersionTLS12 {
buffer = []byte{}
}
prf, hash := prfAndHashForVersion(version, cipherSuite)
if hash != 0 {
return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf}
}
return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf}
}
// A finishedHash calculates the hash of a set of handshake messages suitable
// for including in a Finished message.
type finishedHash struct {
client hash.Hash
server hash.Hash
// Prior to TLS 1.2, an additional MD5 hash is required.
clientMD5 hash.Hash
serverMD5 hash.Hash
// In TLS 1.2, a full buffer is sadly required.
buffer []byte
version uint16
prf func(result, secret, label, seed []byte)
}
func (h *finishedHash) Write(msg []byte) (n int, err error) {
h.client.Write(msg)
h.server.Write(msg)
if h.version < VersionTLS12 {
h.clientMD5.Write(msg)
h.serverMD5.Write(msg)
}
if h.buffer != nil {
h.buffer = append(h.buffer, msg...)
}
return len(msg), nil
}
func (h finishedHash) Sum() []byte {
if h.version >= VersionTLS12 {
return h.client.Sum(nil)
}
out := make([]byte, 0, md5.Size+sha1.Size)
out = h.clientMD5.Sum(out)
return h.client.Sum(out)
}
// clientSum returns the contents of the verify_data member of a client's
// Finished message.
func (h finishedHash) clientSum(masterSecret []byte) []byte {
out := make([]byte, finishedVerifyLength)
h.prf(out, masterSecret, clientFinishedLabel, h.Sum())
return out
}
// serverSum returns the contents of the verify_data member of a server's
// Finished message.
func (h finishedHash) serverSum(masterSecret []byte) []byte {
out := make([]byte, finishedVerifyLength)
h.prf(out, masterSecret, serverFinishedLabel, h.Sum())
return out
}
// hashForClientCertificate returns the handshake messages so far, pre-hashed if
// necessary, suitable for signing by a TLS client certificate.
func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte {
if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil {
panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer")
}
if sigType == signatureEd25519 {
return h.buffer
}
if h.version >= VersionTLS12 {
hash := hashAlg.New()
hash.Write(h.buffer)
return hash.Sum(nil)
}
if sigType == signatureECDSA {
return h.server.Sum(nil)
}
return h.Sum()
}
// discardHandshakeBuffer is called when there is no more need to
// buffer the entirety of the handshake messages.
func (h *finishedHash) discardHandshakeBuffer() {
h.buffer = nil
}
// noExportedKeyingMaterial is used as a value of
// ConnectionState.ekm when renegotiation is enabled and thus
// we wish to fail all key-material export requests.
func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled")
}
// ekmFromMasterSecret generates exported keying material as defined in RFC 5705.
func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) {
return func(label string, context []byte, length int) ([]byte, error) {
switch label {
case "client finished", "server finished", "master secret", "key expansion":
// These values are reserved and may not be used.
return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label)
}
seedLen := len(serverRandom) + len(clientRandom)
if context != nil {
seedLen += 2 + len(context)
}
seed := make([]byte, 0, seedLen)
seed = append(seed, clientRandom...)
seed = append(seed, serverRandom...)
if context != nil {
if len(context) >= 1<<16 {
return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long")
}
seed = append(seed, byte(len(context)>>8), byte(len(context)))
seed = append(seed, context...)
}
keyMaterial := make([]byte, length)
prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed)
return keyMaterial, nil
}
}

View File

@ -0,0 +1,140 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"encoding/hex"
"testing"
)
type testSplitPreMasterSecretTest struct {
in, out1, out2 string
}
var testSplitPreMasterSecretTests = []testSplitPreMasterSecretTest{
{"", "", ""},
{"00", "00", "00"},
{"0011", "00", "11"},
{"001122", "0011", "1122"},
{"00112233", "0011", "2233"},
}
func TestSplitPreMasterSecret(t *testing.T) {
for i, test := range testSplitPreMasterSecretTests {
in, _ := hex.DecodeString(test.in)
out1, out2 := splitPreMasterSecret(in)
s1 := hex.EncodeToString(out1)
s2 := hex.EncodeToString(out2)
if s1 != test.out1 || s2 != test.out2 {
t.Errorf("#%d: got: (%s, %s) want: (%s, %s)", i, s1, s2, test.out1, test.out2)
}
}
}
type testKeysFromTest struct {
version uint16
suite *cipherSuite
preMasterSecret string
clientRandom, serverRandom string
masterSecret string
clientMAC, serverMAC string
clientKey, serverKey string
macLen, keyLen int
contextKeyingMaterial, noContextKeyingMaterial string
}
func TestKeysFromPreMasterSecret(t *testing.T) {
for i, test := range testKeysFromTests {
in, _ := hex.DecodeString(test.preMasterSecret)
clientRandom, _ := hex.DecodeString(test.clientRandom)
serverRandom, _ := hex.DecodeString(test.serverRandom)
masterSecret := masterFromPreMasterSecret(test.version, test.suite, in, clientRandom, serverRandom)
if s := hex.EncodeToString(masterSecret); s != test.masterSecret {
t.Errorf("#%d: bad master secret %s, want %s", i, s, test.masterSecret)
continue
}
clientMAC, serverMAC, clientKey, serverKey, _, _ := keysFromMasterSecret(test.version, test.suite, masterSecret, clientRandom, serverRandom, test.macLen, test.keyLen, 0)
clientMACString := hex.EncodeToString(clientMAC)
serverMACString := hex.EncodeToString(serverMAC)
clientKeyString := hex.EncodeToString(clientKey)
serverKeyString := hex.EncodeToString(serverKey)
if clientMACString != test.clientMAC ||
serverMACString != test.serverMAC ||
clientKeyString != test.clientKey ||
serverKeyString != test.serverKey {
t.Errorf("#%d: got: (%s, %s, %s, %s) want: (%s, %s, %s, %s)", i, clientMACString, serverMACString, clientKeyString, serverKeyString, test.clientMAC, test.serverMAC, test.clientKey, test.serverKey)
}
ekm := ekmFromMasterSecret(test.version, test.suite, masterSecret, clientRandom, serverRandom)
contextKeyingMaterial, err := ekm("label", []byte("context"), 32)
if err != nil {
t.Fatalf("ekmFromMasterSecret failed: %v", err)
}
noContextKeyingMaterial, err := ekm("label", nil, 32)
if err != nil {
t.Fatalf("ekmFromMasterSecret failed: %v", err)
}
if hex.EncodeToString(contextKeyingMaterial) != test.contextKeyingMaterial ||
hex.EncodeToString(noContextKeyingMaterial) != test.noContextKeyingMaterial {
t.Errorf("#%d: got keying material: (%s, %s) want: (%s, %s)", i, contextKeyingMaterial, noContextKeyingMaterial, test.contextKeyingMaterial, test.noContextKeyingMaterial)
}
}
}
// These test vectors were generated from GnuTLS using `gnutls-cli --insecure -d 9 `
var testKeysFromTests = []testKeysFromTest{
{
VersionTLS10,
cipherSuiteByID(TLS_RSA_WITH_RC4_128_SHA),
"0302cac83ad4b1db3b9ab49ad05957de2a504a634a386fc600889321e1a971f57479466830ac3e6f468e87f5385fa0c5",
"4ae66303755184a3917fcb44880605fcc53baa01912b22ed94473fc69cebd558",
"4ae663020ec16e6bb5130be918cfcafd4d765979a3136a5d50c593446e4e44db",
"3d851bab6e5556e959a16bc36d66cfae32f672bfa9ecdef6096cbb1b23472df1da63dbbd9827606413221d149ed08ceb",
"805aaa19b3d2c0a0759a4b6c9959890e08480119",
"2d22f9fe519c075c16448305ceee209fc24ad109",
"d50b5771244f850cd8117a9ccafe2cf1",
"e076e33206b30507a85c32855acd0919",
20,
16,
"4d1bb6fc278c37d27aa6e2a13c2e079095d143272c2aa939da33d88c1c0cec22",
"93fba89599b6321ae538e27c6548ceb8b46821864318f5190d64a375e5d69d41",
},
{
VersionTLS10,
cipherSuiteByID(TLS_RSA_WITH_RC4_128_SHA),
"03023f7527316bc12cbcd69e4b9e8275d62c028f27e65c745cfcddc7ce01bd3570a111378b63848127f1c36e5f9e4890",
"4ae66364b5ea56b20ce4e25555aed2d7e67f42788dd03f3fee4adae0459ab106",
"4ae66363ab815cbf6a248b87d6b556184e945e9b97fbdf247858b0bdafacfa1c",
"7d64be7c80c59b740200b4b9c26d0baaa1c5ae56705acbcf2307fe62beb4728c19392c83f20483801cce022c77645460",
"97742ed60a0554ca13f04f97ee193177b971e3b0",
"37068751700400e03a8477a5c7eec0813ab9e0dc",
"207cddbc600d2a200abac6502053ee5c",
"df3f94f6e1eacc753b815fe16055cd43",
20,
16,
"2c9f8961a72b97cbe76553b5f954caf8294fc6360ef995ac1256fe9516d0ce7f",
"274f19c10291d188857ad8878e2119f5aa437d4da556601cf1337aff23154016",
},
{
VersionTLS10,
cipherSuiteByID(TLS_RSA_WITH_RC4_128_SHA),
"832d515f1d61eebb2be56ba0ef79879efb9b527504abb386fb4310ed5d0e3b1f220d3bb6b455033a2773e6d8bdf951d278a187482b400d45deb88a5d5a6bb7d6a7a1decc04eb9ef0642876cd4a82d374d3b6ff35f0351dc5d411104de431375355addc39bfb1f6329fb163b0bc298d658338930d07d313cd980a7e3d9196cac1",
"4ae663b2ee389c0de147c509d8f18f5052afc4aaf9699efe8cb05ece883d3a5e",
"4ae664d503fd4cff50cfc1fb8fc606580f87b0fcdac9554ba0e01d785bdf278e",
"1aff2e7a2c4279d0126f57a65a77a8d9d0087cf2733366699bec27eb53d5740705a8574bb1acc2abbe90e44f0dd28d6c",
"3c7647c93c1379a31a609542aa44e7f117a70085",
"0d73102994be74a575a3ead8532590ca32a526d4",
"ac7581b0b6c10d85bbd905ffbf36c65e",
"ff07edde49682b45466bd2e39464b306",
20,
16,
"678b0d43f607de35241dc7e9d1a7388a52c35033a1a0336d4d740060a6638fe2",
"f3b4ac743f015ef21d79978297a53da3e579ee047133f38c234d829c0f907dab",
},
}

View File

@ -0,0 +1,185 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"errors"
"io"
"golang.org/x/crypto/cryptobyte"
)
// sessionState contains the information that is serialized into a session
// ticket in order to later resume a connection.
type sessionState struct {
vers uint16
cipherSuite uint16
createdAt uint64
masterSecret []byte // opaque master_secret<1..2^16-1>;
// struct { opaque certificate<1..2^24-1> } Certificate;
certificates [][]byte // Certificate certificate_list<0..2^24-1>;
// usedOldKey is true if the ticket from which this session came from
// was encrypted with an older key and thus should be refreshed.
usedOldKey bool
}
func (m *sessionState) marshal() []byte {
var b cryptobyte.Builder
b.AddUint16(m.vers)
b.AddUint16(m.cipherSuite)
addUint64(&b, m.createdAt)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.masterSecret)
})
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
for _, cert := range m.certificates {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(cert)
})
}
})
return b.BytesOrPanic()
}
func (m *sessionState) unmarshal(data []byte) bool {
*m = sessionState{usedOldKey: m.usedOldKey}
s := cryptobyte.String(data)
if ok := s.ReadUint16(&m.vers) &&
s.ReadUint16(&m.cipherSuite) &&
readUint64(&s, &m.createdAt) &&
readUint16LengthPrefixed(&s, &m.masterSecret) &&
len(m.masterSecret) != 0; !ok {
return false
}
var certList cryptobyte.String
if !s.ReadUint24LengthPrefixed(&certList) {
return false
}
for !certList.Empty() {
var cert []byte
if !readUint24LengthPrefixed(&certList, &cert) {
return false
}
m.certificates = append(m.certificates, cert)
}
return s.Empty()
}
// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first
// version (revision = 0) doesn't carry any of the information needed for 0-RTT
// validation and the nonce is always empty.
type sessionStateTLS13 struct {
// uint8 version = 0x0304;
// uint8 revision = 0;
cipherSuite uint16
createdAt uint64
resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>;
certificate Certificate // CertificateEntry certificate_list<0..2^24-1>;
}
func (m *sessionStateTLS13) marshal() []byte {
var b cryptobyte.Builder
b.AddUint16(VersionTLS13)
b.AddUint8(0) // revision
b.AddUint16(m.cipherSuite)
addUint64(&b, m.createdAt)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.resumptionSecret)
})
marshalCertificate(&b, m.certificate)
return b.BytesOrPanic()
}
func (m *sessionStateTLS13) unmarshal(data []byte) bool {
*m = sessionStateTLS13{}
s := cryptobyte.String(data)
var version uint16
var revision uint8
return s.ReadUint16(&version) &&
version == VersionTLS13 &&
s.ReadUint8(&revision) &&
revision == 0 &&
s.ReadUint16(&m.cipherSuite) &&
readUint64(&s, &m.createdAt) &&
readUint8LengthPrefixed(&s, &m.resumptionSecret) &&
len(m.resumptionSecret) != 0 &&
unmarshalCertificate(&s, &m.certificate) &&
s.Empty()
}
func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
if len(c.ticketKeys) == 0 {
return nil, errors.New("tls: internal error: session ticket keys unavailable")
}
encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size)
keyName := encrypted[:ticketKeyNameLen]
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
macBytes := encrypted[len(encrypted)-sha256.Size:]
if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
return nil, err
}
key := c.ticketKeys[0]
copy(keyName, key.keyName[:])
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
}
cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state)
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(encrypted[:len(encrypted)-sha256.Size])
mac.Sum(macBytes[:0])
return encrypted, nil
}
func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) {
if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
return nil, false
}
keyName := encrypted[:ticketKeyNameLen]
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
macBytes := encrypted[len(encrypted)-sha256.Size:]
ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
keyIndex := -1
for i, candidateKey := range c.ticketKeys {
if bytes.Equal(keyName, candidateKey.keyName[:]) {
keyIndex = i
break
}
}
if keyIndex == -1 {
return nil, false
}
key := &c.ticketKeys[keyIndex]
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(encrypted[:len(encrypted)-sha256.Size])
expected := mac.Sum(nil)
if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
return nil, false
}
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil, false
}
plaintext = make([]byte, len(ciphertext))
cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
return plaintext, keyIndex > 0
}

430
vendor/github.com/lesismal/llib/std/crypto/tls/tls.go generated vendored Normal file
View File

@ -0,0 +1,430 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tls partially implements TLS 1.2, as specified in RFC 5246,
// and TLS 1.3, as specified in RFC 8446.
package tls
// BUG(agl): The crypto/tls package only implements some countermeasures
// against Lucky13 attacks on CBC-mode encryption, and only on SHA1
// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net"
"os"
"strings"
"time"
)
// NewConn returns a new TLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func NewConn(conn net.Conn, config *Config, isClient bool, isNonBlock bool, v ...interface{}) *Conn {
c := &Conn{
conn: conn,
config: config,
isClient: isClient,
isNonBlock: isNonBlock,
}
c.handshakeFn = c.serverHandshake
if isClient {
c.handshakeFn = c.clientHandshake
}
if len(v) > 0 {
if allocator, ok := v[0].(Allocator); ok {
c.allocator = allocator
}
}
if c.allocator == nil {
c.allocator = &NativeAllocator{}
}
return c
}
// Server returns a new TLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Server(conn net.Conn, config *Config) *Conn {
c := &Conn{
conn: conn,
config: config,
}
c.handshakeFn = c.serverHandshake
return c
}
// Client returns a new TLS client side connection
// using conn as the underlying transport.
// The config cannot be nil: users must set either ServerName or
// InsecureSkipVerify in the config.
func Client(conn net.Conn, config *Config) *Conn {
c := &Conn{
conn: conn,
config: config,
isClient: true,
}
c.handshakeFn = c.clientHandshake
return c
}
// A listener implements a network listener (net.Listener) for TLS connections.
type listener struct {
net.Listener
config *Config
allocator Allocator
}
// Accept waits for and returns the next incoming TLS connection.
// The returned connection is of type *Conn.
func (l *listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
tlsConn := Server(c, l.config)
tlsConn.allocator = l.allocator
return tlsConn, nil
}
// NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with Server.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func NewListener(inner net.Listener, config *Config, v ...interface{}) net.Listener {
l := new(listener)
l.Listener = inner
l.config = config
if len(v) > 0 {
if allocator, ok := v[0].(Allocator); ok {
l.allocator = allocator
}
}
return l
}
// Listen creates a TLS listener accepting connections on the
// given network address using net.Listen.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Listen(network, laddr string, config *Config) (net.Listener, error) {
if config == nil || len(config.Certificates) == 0 &&
config.GetCertificate == nil && config.GetConfigForClient == nil {
return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
}
l, err := net.Listen(network, laddr)
if err != nil {
return nil, err
}
return NewListener(l, config), nil
}
type timeoutError struct{}
func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }
// DialWithDialer connects to the given network address using dialer.Dial and
// then initiates a TLS handshake, returning the resulting TLS connection. Any
// timeout or deadline given in the dialer apply to connection and TLS
// handshake as a whole.
//
// DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of Config for the defaults.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config, v ...interface{}) (*Conn, error) {
return dial(context.Background(), dialer, network, addr, config, v...)
}
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config, v ...interface{}) (*Conn, error) {
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
timeout := netDialer.Timeout
if !netDialer.Deadline.IsZero() {
deadlineTimeout := time.Until(netDialer.Deadline)
if timeout == 0 || deadlineTimeout < timeout {
timeout = deadlineTimeout
}
}
// hsErrCh is non-nil if we might not wait for Handshake to complete.
var hsErrCh chan error
if timeout != 0 || ctx.Done() != nil {
hsErrCh = make(chan error, 2)
}
if timeout != 0 {
timer := time.AfterFunc(timeout, func() {
hsErrCh <- timeoutError{}
})
defer timer.Stop()
}
rawConn, err := netDialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
if config == nil {
config = defaultConfig()
}
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
// Make a copy to avoid polluting argument or default.
c := config.Clone()
c.ServerName = hostname
config = c
}
conn := Client(rawConn, config)
if len(v) > 0 {
if allocator, ok := v[0].(Allocator); ok {
conn.allocator = allocator
}
}
if conn.allocator == nil {
conn.allocator = &NativeAllocator{}
}
if hsErrCh == nil {
err = conn.Handshake()
} else {
go func() {
hsErrCh <- conn.Handshake()
}()
select {
case <-ctx.Done():
err = ctx.Err()
case err = <-hsErrCh:
if err != nil {
// If the error was due to the context
// closing, prefer the context's error, rather
// than some random network teardown error.
if e := ctx.Err(); e != nil {
err = e
}
}
}
}
if err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
// Dial connects to the given network address using net.Dial
// and then initiates a TLS handshake, returning the resulting
// TLS connection.
// Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Dial(network, addr string, config *Config, v ...interface{}) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config, v...)
}
// Dialer dials TLS connections given a configuration and a Dialer for the
// underlying connection.
type Dialer struct {
// NetDialer is the optional dialer to use for the TLS connections'
// underlying TCP connections.
// A nil NetDialer is equivalent to the net.Dialer zero value.
NetDialer *net.Dialer
// Config is the TLS configuration to use for new connections.
// A nil configuration is equivalent to the zero
// configuration; see the documentation of Config for the
// defaults.
Config *Config
}
// Dial connects to the given network address and initiates a TLS
// handshake, returning the resulting TLS connection.
//
// The returned Conn, if any, will always be of type *Conn.
func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
return d.DialContext(context.Background(), network, addr)
}
func (d *Dialer) netDialer() *net.Dialer {
if d.NetDialer != nil {
return d.NetDialer
}
return new(net.Dialer)
}
// DialContext connects to the given network address and initiates a TLS
// handshake, returning the resulting TLS connection.
//
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
//
// The returned Conn, if any, will always be of type *Conn.
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
if err != nil {
// Don't return c (a typed nil) in an interface.
return nil, err
}
return c, nil
}
// LoadX509KeyPair reads and parses a public/private key pair from a pair
// of files. The files must contain PEM encoded data. The certificate file
// may contain intermediate certificates following the leaf certificate to
// form a certificate chain. On successful return, Certificate.Leaf will
// be nil because the parsed form of the certificate is not retained.
func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
certPEMBlock, err := os.ReadFile(certFile)
if err != nil {
return Certificate{}, err
}
keyPEMBlock, err := os.ReadFile(keyFile)
if err != nil {
return Certificate{}, err
}
return X509KeyPair(certPEMBlock, keyPEMBlock)
}
// X509KeyPair parses a public/private key pair from a pair of
// PEM encoded data. On successful return, Certificate.Leaf will be nil because
// the parsed form of the certificate is not retained.
func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
fail := func(err error) (Certificate, error) { return Certificate{}, err }
var cert Certificate
var skippedBlockTypes []string
for {
var certDERBlock *pem.Block
certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
if certDERBlock == nil {
break
}
if certDERBlock.Type == "CERTIFICATE" {
cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
} else {
skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
}
}
if len(cert.Certificate) == 0 {
if len(skippedBlockTypes) == 0 {
return fail(errors.New("tls: failed to find any PEM data in certificate input"))
}
if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
}
return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
}
skippedBlockTypes = skippedBlockTypes[:0]
var keyDERBlock *pem.Block
for {
keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
if keyDERBlock == nil {
if len(skippedBlockTypes) == 0 {
return fail(errors.New("tls: failed to find any PEM data in key input"))
}
if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
}
return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
}
if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
break
}
skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
}
// We don't need to parse the public key for TLS, but we so do anyway
// to check that it looks sane and matches the private key.
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return fail(err)
}
cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
if err != nil {
return fail(err)
}
switch pub := x509Cert.PublicKey.(type) {
case *rsa.PublicKey:
priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if pub.N.Cmp(priv.N) != 0 {
return fail(errors.New("tls: private key does not match public key"))
}
case *ecdsa.PublicKey:
priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
return fail(errors.New("tls: private key does not match public key"))
}
case ed25519.PublicKey:
priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
return fail(errors.New("tls: private key does not match public key"))
}
default:
return fail(errors.New("tls: unknown public key algorithm"))
}
return cert, nil
}
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys.
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
return key, nil
}
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
switch key := key.(type) {
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
return key, nil
default:
return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
}
}
if key, err := x509.ParseECPrivateKey(der); err == nil {
return key, nil
}
return nil, errors.New("tls: failed to parse private key")
}

File diff suppressed because it is too large Load Diff

226
vendor/github.com/lesismal/llib/std/internal/cpu/cpu.go generated vendored Normal file
View File

@ -0,0 +1,226 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package cpu implements processor feature detection
// used by the Go standard library.
package cpu
// DebugOptions is set to true by the runtime if the OS supports reading
// GODEBUG early in runtime startup.
// This should not be changed after it is initialized.
var DebugOptions bool
// CacheLinePad is used to pad structs to avoid false sharing.
type CacheLinePad struct{ _ [CacheLinePadSize]byte }
// CacheLineSize is the CPU's assumed cache line size.
// There is currently no runtime detection of the real cache line size
// so we use the constant per GOARCH CacheLinePadSize as an approximation.
var CacheLineSize uintptr = CacheLinePadSize
// The booleans in X86 contain the correspondingly named cpuid feature bit.
// HasAVX and HasAVX2 are only set if the OS does support XMM and YMM registers
// in addition to the cpuid feature bit being set.
// The struct is padded to avoid false sharing.
var X86 struct {
_ CacheLinePad
HasAES bool
HasADX bool
HasAVX bool
HasAVX2 bool
HasBMI1 bool
HasBMI2 bool
HasERMS bool
HasFMA bool
HasOSXSAVE bool
HasPCLMULQDQ bool
HasPOPCNT bool
HasSSE2 bool
HasSSE3 bool
HasSSSE3 bool
HasSSE41 bool
HasSSE42 bool
_ CacheLinePad
}
// The booleans in ARM contain the correspondingly named cpu feature bit.
// The struct is padded to avoid false sharing.
var ARM struct {
_ CacheLinePad
HasVFPv4 bool
HasIDIVA bool
_ CacheLinePad
}
// The booleans in ARM64 contain the correspondingly named cpu feature bit.
// The struct is padded to avoid false sharing.
var ARM64 struct {
_ CacheLinePad
HasAES bool
HasPMULL bool
HasSHA1 bool
HasSHA2 bool
HasCRC32 bool
HasATOMICS bool
HasCPUID bool
IsNeoverseN1 bool
IsZeus bool
_ CacheLinePad
}
var MIPS64X struct {
_ CacheLinePad
HasMSA bool // MIPS SIMD architecture
_ CacheLinePad
}
// For ppc64(le), it is safe to check only for ISA level starting on ISA v3.00,
// since there are no optional categories. There are some exceptions that also
// require kernel support to work (darn, scv), so there are feature bits for
// those as well. The minimum processor requirement is POWER8 (ISA 2.07).
// The struct is padded to avoid false sharing.
var PPC64 struct {
_ CacheLinePad
HasDARN bool // Hardware random number generator (requires kernel enablement)
HasSCV bool // Syscall vectored (requires kernel enablement)
IsPOWER8 bool // ISA v2.07 (POWER8)
IsPOWER9 bool // ISA v3.00 (POWER9)
_ CacheLinePad
}
var S390X struct {
_ CacheLinePad
HasZARCH bool // z architecture mode is active [mandatory]
HasSTFLE bool // store facility list extended [mandatory]
HasLDISP bool // long (20-bit) displacements [mandatory]
HasEIMM bool // 32-bit immediates [mandatory]
HasDFP bool // decimal floating point
HasETF3EH bool // ETF-3 enhanced
HasMSA bool // message security assist (CPACF)
HasAES bool // KM-AES{128,192,256} functions
HasAESCBC bool // KMC-AES{128,192,256} functions
HasAESCTR bool // KMCTR-AES{128,192,256} functions
HasAESGCM bool // KMA-GCM-AES{128,192,256} functions
HasGHASH bool // KIMD-GHASH function
HasSHA1 bool // K{I,L}MD-SHA-1 functions
HasSHA256 bool // K{I,L}MD-SHA-256 functions
HasSHA512 bool // K{I,L}MD-SHA-512 functions
HasSHA3 bool // K{I,L}MD-SHA3-{224,256,384,512} and K{I,L}MD-SHAKE-{128,256} functions
HasVX bool // vector facility. Note: the runtime sets this when it processes auxv records.
HasVXE bool // vector-enhancements facility 1
HasKDSA bool // elliptic curve functions
HasECDSA bool // NIST curves
HasEDDSA bool // Edwards curves
_ CacheLinePad
}
// Initialize examines the processor and sets the relevant variables above.
// This is called by the runtime package early in program initialization,
// before normal init functions are run. env is set by runtime if the OS supports
// cpu feature options in GODEBUG.
func Initialize(env string) {
doinit()
processOptions(env)
}
// options contains the cpu debug options that can be used in GODEBUG.
// Options are arch dependent and are added by the arch specific doinit functions.
// Features that are mandatory for the specific GOARCH should not be added to options
// (e.g. SSE2 on amd64).
var options []option
// Option names should be lower case. e.g. avx instead of AVX.
type option struct {
Name string
Feature *bool
Specified bool // whether feature value was specified in GODEBUG
Enable bool // whether feature should be enabled
Required bool // whether feature is mandatory and can not be disabled
}
// processOptions enables or disables CPU feature values based on the parsed env string.
// The env string is expected to be of the form cpu.feature1=value1,cpu.feature2=value2...
// where feature names is one of the architecture specific list stored in the
// cpu packages options variable and values are either 'on' or 'off'.
// If env contains cpu.all=off then all cpu features referenced through the options
// variable are disabled. Other feature names and values result in warning messages.
func processOptions(env string) {
field:
for env != "" {
field := ""
i := indexByte(env, ',')
if i < 0 {
field, env = env, ""
} else {
field, env = env[:i], env[i+1:]
}
if len(field) < 4 || field[:4] != "cpu." {
continue
}
i = indexByte(field, '=')
if i < 0 {
print("GODEBUG: no value specified for \"", field, "\"\n")
continue
}
key, value := field[4:i], field[i+1:] // e.g. "SSE2", "on"
var enable bool
switch value {
case "on":
enable = true
case "off":
enable = false
default:
print("GODEBUG: value \"", value, "\" not supported for cpu option \"", key, "\"\n")
continue field
}
if key == "all" {
for i := range options {
options[i].Specified = true
options[i].Enable = enable || options[i].Required
}
continue field
}
for i := range options {
if options[i].Name == key {
options[i].Specified = true
options[i].Enable = enable
continue field
}
}
print("GODEBUG: unknown cpu feature \"", key, "\"\n")
}
for _, o := range options {
if !o.Specified {
continue
}
if o.Enable && !*o.Feature {
print("GODEBUG: can not enable \"", o.Name, "\", missing CPU support\n")
continue
}
if !o.Enable && o.Required {
print("GODEBUG: can not disable \"", o.Name, "\", required CPU feature\n")
continue
}
*o.Feature = o.Enable
}
}
// indexByte returns the index of the first instance of c in s,
// or -1 if c is not present in s.
func indexByte(s string, c byte) int {
for i := 0; i < len(s); i++ {
if s[i] == c {
return i
}
}
return -1
}

View File

@ -0,0 +1,6 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This assembly file exists to allow internal/cpu to call
// non-exported runtime functions that use "go:linkname".

View File

@ -0,0 +1,7 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
const GOARCH = "386"

View File

@ -0,0 +1,7 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
const GOARCH = "amd64"

View File

@ -0,0 +1,34 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
const CacheLinePadSize = 32
// arm doesn't have a 'cpuid' equivalent, so we rely on HWCAP/HWCAP2.
// These are initialized by archauxv() and should not be changed after they are
// initialized.
var HWCap uint
var HWCap2 uint
// HWCAP/HWCAP2 bits. These are exposed by Linux and FreeBSD.
const (
hwcap_VFPv4 = 1 << 16
hwcap_IDIVA = 1 << 17
)
func doinit() {
options = []option{
{Name: "vfpv4", Feature: &ARM.HasVFPv4},
{Name: "idiva", Feature: &ARM.HasIDIVA},
}
// HWCAP feature bits
ARM.HasVFPv4 = isSet(HWCap, hwcap_VFPv4)
ARM.HasIDIVA = isSet(HWCap, hwcap_IDIVA)
}
func isSet(hwc uint, value uint) bool {
return hwc&value != 0
}

View File

@ -0,0 +1,28 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
const CacheLinePadSize = 64
func doinit() {
options = []option{
{Name: "aes", Feature: &ARM64.HasAES},
{Name: "pmull", Feature: &ARM64.HasPMULL},
{Name: "sha1", Feature: &ARM64.HasSHA1},
{Name: "sha2", Feature: &ARM64.HasSHA2},
{Name: "crc32", Feature: &ARM64.HasCRC32},
{Name: "atomics", Feature: &ARM64.HasATOMICS},
{Name: "cpuid", Feature: &ARM64.HasCPUID},
{Name: "isNeoverseN1", Feature: &ARM64.IsNeoverseN1},
{Name: "isZeus", Feature: &ARM64.IsZeus},
}
// arm64 uses different ways to detect CPU features at runtime depending on the operating system.
osInit()
}
func getisar0() uint64
func getMIDR() uint64

View File

@ -0,0 +1,18 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "textflag.h"
// func getisar0() uint64
TEXT ·getisar0(SB),NOSPLIT,$0
// get Instruction Set Attributes 0 into R0
MRS ID_AA64ISAR0_EL1, R0
MOVD R0, ret+0(FP)
RET
// func getMIDR() uint64
TEXT ·getMIDR(SB), NOSPLIT, $0-8
MRS MIDR_EL1, R0
MOVD R0, ret+0(FP)
RET

View File

@ -0,0 +1,11 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build arm64
package cpu
func osInit() {
hwcapInit("android")
}

View File

@ -0,0 +1,34 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build arm64
// +build darwin
// +build !ios
package cpu
func osInit() {
ARM64.HasATOMICS = sysctlEnabled([]byte("hw.optional.armv8_1_atomics\x00"))
ARM64.HasCRC32 = sysctlEnabled([]byte("hw.optional.armv8_crc32\x00"))
// There are no hw.optional sysctl values for the below features on Mac OS 11.0
// to detect their supported state dynamically. Assume the CPU features that
// Apple Silicon M1 supports to be available as a minimal set of features
// to all Go programs running on darwin/arm64.
ARM64.HasAES = true
ARM64.HasPMULL = true
ARM64.HasSHA1 = true
ARM64.HasSHA2 = true
}
//go:noescape
func getsysctlbyname(name []byte) (int32, int32)
func sysctlEnabled(name []byte) bool {
ret, value := getsysctlbyname(name)
if ret < 0 {
return false
}
return value > 0
}

View File

@ -0,0 +1,45 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build arm64
package cpu
func osInit() {
// Retrieve info from system register ID_AA64ISAR0_EL1.
isar0 := getisar0()
// ID_AA64ISAR0_EL1
switch extractBits(isar0, 4, 7) {
case 1:
ARM64.HasAES = true
case 2:
ARM64.HasAES = true
ARM64.HasPMULL = true
}
switch extractBits(isar0, 8, 11) {
case 1:
ARM64.HasSHA1 = true
}
switch extractBits(isar0, 12, 15) {
case 1, 2:
ARM64.HasSHA2 = true
}
switch extractBits(isar0, 16, 19) {
case 1:
ARM64.HasCRC32 = true
}
switch extractBits(isar0, 20, 23) {
case 2:
ARM64.HasATOMICS = true
}
}
func extractBits(data uint64, start, end uint) uint {
return (uint)(data>>start) & ((1 << (end - start + 1)) - 1)
}

View File

@ -0,0 +1,63 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build arm64
// +build linux
package cpu
// HWCap may be initialized by archauxv and
// should not be changed after it was initialized.
var HWCap uint
// HWCAP bits. These are exposed by Linux.
const (
hwcap_AES = 1 << 3
hwcap_PMULL = 1 << 4
hwcap_SHA1 = 1 << 5
hwcap_SHA2 = 1 << 6
hwcap_CRC32 = 1 << 7
hwcap_ATOMICS = 1 << 8
hwcap_CPUID = 1 << 11
)
func hwcapInit(os string) {
// HWCap was populated by the runtime from the auxiliary vector.
// Use HWCap information since reading aarch64 system registers
// is not supported in user space on older linux kernels.
ARM64.HasAES = isSet(HWCap, hwcap_AES)
ARM64.HasPMULL = isSet(HWCap, hwcap_PMULL)
ARM64.HasSHA1 = isSet(HWCap, hwcap_SHA1)
ARM64.HasSHA2 = isSet(HWCap, hwcap_SHA2)
ARM64.HasCRC32 = isSet(HWCap, hwcap_CRC32)
ARM64.HasCPUID = isSet(HWCap, hwcap_CPUID)
// The Samsung S9+ kernel reports support for atomics, but not all cores
// actually support them, resulting in SIGILL. See issue #28431.
// TODO(elias.naur): Only disable the optimization on bad chipsets on android.
ARM64.HasATOMICS = isSet(HWCap, hwcap_ATOMICS) && os != "android"
// Check to see if executing on a NeoverseN1 and in order to do that,
// check the AUXV for the CPUID bit. The getMIDR function executes an
// instruction which would normally be an illegal instruction, but it's
// trapped by the kernel, the value sanitized and then returned. Without
// the CPUID bit the kernel will not trap the instruction and the process
// will be terminated with SIGILL.
if ARM64.HasCPUID {
midr := getMIDR()
part_num := uint16((midr >> 4) & 0xfff)
implementor := byte((midr >> 24) & 0xff)
if implementor == 'A' && part_num == 0xd0c {
ARM64.IsNeoverseN1 = true
}
if implementor == 'A' && part_num == 0xd40 {
ARM64.IsZeus = true
}
}
}
func isSet(hwc uint, value uint) bool {
return hwc&value != 0
}

View File

@ -0,0 +1,13 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build arm64
// +build linux
// +build !android
package cpu
func osInit() {
hwcapInit("linux")
}

View File

@ -0,0 +1,17 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build arm64
// +build !linux
// +build !freebsd
// +build !android
// +build !darwin ios
package cpu
func osInit() {
// Other operating systems do not support reading HWCap from auxiliary vector,
// reading privileged aarch64 system registers or sysctl in user space to detect
// CPU features at runtime.
}

View File

@ -0,0 +1,10 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
const CacheLinePadSize = 32
func doinit() {
}

View File

@ -0,0 +1,32 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build mips64 mips64le
package cpu
const CacheLinePadSize = 32
// This is initialized by archauxv and should not be changed after it is
// initialized.
var HWCap uint
// HWCAP bits. These are exposed by the Linux kernel 5.4.
const (
// CPU features
hwcap_MIPS_MSA = 1 << 1
)
func doinit() {
options = []option{
{Name: "msa", Feature: &MIPS64X.HasMSA},
}
// HWCAP feature bits
MIPS64X.HasMSA = isSet(HWCap, hwcap_MIPS_MSA)
}
func isSet(hwc uint, value uint) bool {
return hwc&value != 0
}

View File

@ -0,0 +1,10 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
const CacheLinePadSize = 32
func doinit() {
}

View File

@ -0,0 +1,19 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !386
// +build !amd64
package cpu
// Name returns the CPU name given by the vendor
// if it can be read directly from memory or by CPU instructions.
// If the CPU name can not be determined an empty string is returned.
//
// Implementations that use the Operating System (e.g. sysctl or /sys/)
// to gather CPU information for display should be placed in internal/sysinfo.
func Name() string {
// "A CPU has no name".
return ""
}

View File

@ -0,0 +1,23 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ppc64 ppc64le
package cpu
const CacheLinePadSize = 128
func doinit() {
options = []option{
{Name: "darn", Feature: &PPC64.HasDARN},
{Name: "scv", Feature: &PPC64.HasSCV},
{Name: "power9", Feature: &PPC64.IsPOWER9},
}
osinit()
}
func isSet(hwc uint, value uint) bool {
return hwc&value != 0
}

View File

@ -0,0 +1,21 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ppc64 ppc64le
package cpu
const (
// getsystemcfg constants
_SC_IMPL = 2
_IMPL_POWER9 = 0x20000
)
func osinit() {
impl := getsystemcfg(_SC_IMPL)
PPC64.IsPOWER9 = isSet(impl, _IMPL_POWER9)
}
// getsystemcfg is defined in runtime/os2_aix.go
func getsystemcfg(label uint) uint

View File

@ -0,0 +1,29 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ppc64 ppc64le
package cpu
// ppc64 doesn't have a 'cpuid' equivalent, so we rely on HWCAP/HWCAP2.
// These are initialized by archauxv and should not be changed after they are
// initialized.
var HWCap uint
var HWCap2 uint
// HWCAP bits. These are exposed by Linux.
const (
// ISA Level
hwcap2_ARCH_3_00 = 0x00800000
// CPU features
hwcap2_DARN = 0x00200000
hwcap2_SCV = 0x00100000
)
func osinit() {
PPC64.IsPOWER9 = isSet(HWCap2, hwcap2_ARCH_3_00)
PPC64.HasDARN = isSet(HWCap2, hwcap2_DARN)
PPC64.HasSCV = isSet(HWCap2, hwcap2_SCV)
}

View File

@ -0,0 +1,10 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
const CacheLinePadSize = 32
func doinit() {
}

View File

@ -0,0 +1,205 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
const CacheLinePadSize = 256
var HWCap uint
// bitIsSet reports whether the bit at index is set. The bit index
// is in big endian order, so bit index 0 is the leftmost bit.
func bitIsSet(bits []uint64, index uint) bool {
return bits[index/64]&((1<<63)>>(index%64)) != 0
}
// function is the function code for the named function.
type function uint8
const (
// KM{,A,C,CTR} function codes
aes128 function = 18 // AES-128
aes192 function = 19 // AES-192
aes256 function = 20 // AES-256
// K{I,L}MD function codes
sha1 function = 1 // SHA-1
sha256 function = 2 // SHA-256
sha512 function = 3 // SHA-512
sha3_224 function = 32 // SHA3-224
sha3_256 function = 33 // SHA3-256
sha3_384 function = 34 // SHA3-384
sha3_512 function = 35 // SHA3-512
shake128 function = 36 // SHAKE-128
shake256 function = 37 // SHAKE-256
// KLMD function codes
ghash function = 65 // GHASH
)
const (
// KDSA function codes
ecdsaVerifyP256 function = 1 // NIST P256
ecdsaVerifyP384 function = 2 // NIST P384
ecdsaVerifyP521 function = 3 // NIST P521
ecdsaSignP256 function = 9 // NIST P256
ecdsaSignP384 function = 10 // NIST P384
ecdsaSignP521 function = 11 // NIST P521
eddsaVerifyEd25519 function = 32 // Curve25519
eddsaVerifyEd448 function = 36 // Curve448
eddsaSignEd25519 function = 40 // Curve25519
eddsaSignEd448 function = 44 // Curve448
)
// queryResult contains the result of a Query function
// call. Bits are numbered in big endian order so the
// leftmost bit (the MSB) is at index 0.
type queryResult struct {
bits [2]uint64
}
// Has reports whether the given functions are present.
func (q *queryResult) Has(fns ...function) bool {
if len(fns) == 0 {
panic("no function codes provided")
}
for _, f := range fns {
if !bitIsSet(q.bits[:], uint(f)) {
return false
}
}
return true
}
// facility is a bit index for the named facility.
type facility uint8
const (
// mandatory facilities
zarch facility = 1 // z architecture mode is active
stflef facility = 7 // store-facility-list-extended
ldisp facility = 18 // long-displacement
eimm facility = 21 // extended-immediate
// miscellaneous facilities
dfp facility = 42 // decimal-floating-point
etf3eh facility = 30 // extended-translation 3 enhancement
// cryptography facilities
msa facility = 17 // message-security-assist
msa3 facility = 76 // message-security-assist extension 3
msa4 facility = 77 // message-security-assist extension 4
msa5 facility = 57 // message-security-assist extension 5
msa8 facility = 146 // message-security-assist extension 8
msa9 facility = 155 // message-security-assist extension 9
// vector facilities
vxe facility = 135 // vector-enhancements 1
// Note: vx requires kernel support
// and so must be fetched from HWCAP.
hwcap_VX = 1 << 11 // vector facility
)
// facilityList contains the result of an STFLE call.
// Bits are numbered in big endian order so the
// leftmost bit (the MSB) is at index 0.
type facilityList struct {
bits [4]uint64
}
// Has reports whether the given facilities are present.
func (s *facilityList) Has(fs ...facility) bool {
if len(fs) == 0 {
panic("no facility bits provided")
}
for _, f := range fs {
if !bitIsSet(s.bits[:], uint(f)) {
return false
}
}
return true
}
// The following feature detection functions are defined in cpu_s390x.s.
// They are likely to be expensive to call so the results should be cached.
func stfle() facilityList
func kmQuery() queryResult
func kmcQuery() queryResult
func kmctrQuery() queryResult
func kmaQuery() queryResult
func kimdQuery() queryResult
func klmdQuery() queryResult
func kdsaQuery() queryResult
func doinit() {
options = []option{
{Name: "zarch", Feature: &S390X.HasZARCH},
{Name: "stfle", Feature: &S390X.HasSTFLE},
{Name: "ldisp", Feature: &S390X.HasLDISP},
{Name: "msa", Feature: &S390X.HasMSA},
{Name: "eimm", Feature: &S390X.HasEIMM},
{Name: "dfp", Feature: &S390X.HasDFP},
{Name: "etf3eh", Feature: &S390X.HasETF3EH},
{Name: "vx", Feature: &S390X.HasVX},
{Name: "vxe", Feature: &S390X.HasVXE},
{Name: "kdsa", Feature: &S390X.HasKDSA},
}
aes := []function{aes128, aes192, aes256}
facilities := stfle()
S390X.HasZARCH = facilities.Has(zarch)
S390X.HasSTFLE = facilities.Has(stflef)
S390X.HasLDISP = facilities.Has(ldisp)
S390X.HasEIMM = facilities.Has(eimm)
S390X.HasDFP = facilities.Has(dfp)
S390X.HasETF3EH = facilities.Has(etf3eh)
S390X.HasMSA = facilities.Has(msa)
if S390X.HasMSA {
// cipher message
km, kmc := kmQuery(), kmcQuery()
S390X.HasAES = km.Has(aes...)
S390X.HasAESCBC = kmc.Has(aes...)
if facilities.Has(msa4) {
kmctr := kmctrQuery()
S390X.HasAESCTR = kmctr.Has(aes...)
}
if facilities.Has(msa8) {
kma := kmaQuery()
S390X.HasAESGCM = kma.Has(aes...)
}
// compute message digest
kimd := kimdQuery() // intermediate (no padding)
klmd := klmdQuery() // last (padding)
S390X.HasSHA1 = kimd.Has(sha1) && klmd.Has(sha1)
S390X.HasSHA256 = kimd.Has(sha256) && klmd.Has(sha256)
S390X.HasSHA512 = kimd.Has(sha512) && klmd.Has(sha512)
S390X.HasGHASH = kimd.Has(ghash) // KLMD-GHASH does not exist
sha3 := []function{
sha3_224, sha3_256, sha3_384, sha3_512,
shake128, shake256,
}
S390X.HasSHA3 = kimd.Has(sha3...) && klmd.Has(sha3...)
S390X.HasKDSA = facilities.Has(msa9) // elliptic curves
if S390X.HasKDSA {
kdsa := kdsaQuery()
S390X.HasECDSA = kdsa.Has(ecdsaVerifyP256, ecdsaSignP256, ecdsaVerifyP384, ecdsaSignP384, ecdsaVerifyP521, ecdsaSignP521)
S390X.HasEDDSA = kdsa.Has(eddsaVerifyEd25519, eddsaSignEd25519, eddsaVerifyEd448, eddsaSignEd448)
}
}
S390X.HasVX = isSet(HWCap, hwcap_VX)
if S390X.HasVX {
S390X.HasVXE = facilities.Has(vxe)
}
}
func isSet(hwc uint, value uint) bool {
return hwc&value != 0
}

View File

@ -0,0 +1,63 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "textflag.h"
// func stfle() facilityList
TEXT ·stfle(SB), NOSPLIT|NOFRAME, $0-32
MOVD $ret+0(FP), R1
MOVD $3, R0 // last doubleword index to store
XC $32, (R1), (R1) // clear 4 doublewords (32 bytes)
WORD $0xb2b01000 // store facility list extended (STFLE)
RET
// func kmQuery() queryResult
TEXT ·kmQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KM-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xB92E0024 // cipher message (KM)
RET
// func kmcQuery() queryResult
TEXT ·kmcQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KMC-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xB92F0024 // cipher message with chaining (KMC)
RET
// func kmctrQuery() queryResult
TEXT ·kmctrQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KMCTR-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xB92D4024 // cipher message with counter (KMCTR)
RET
// func kmaQuery() queryResult
TEXT ·kmaQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KMA-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xb9296024 // cipher message with authentication (KMA)
RET
// func kimdQuery() queryResult
TEXT ·kimdQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KIMD-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xB93E0024 // compute intermediate message digest (KIMD)
RET
// func klmdQuery() queryResult
TEXT ·klmdQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KLMD-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xB93F0024 // compute last message digest (KLMD)
RET
// func kdsaQuery() queryResult
TEXT ·kdsaQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KLMD-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xB93A0008 // compute digital signature authentication
RET

View File

@ -0,0 +1,63 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu_test
import (
"errors"
. "internal/cpu"
"os"
"regexp"
"testing"
)
func getFeatureList() ([]string, error) {
cpuinfo, err := os.ReadFile("/proc/cpuinfo")
if err != nil {
return nil, err
}
r := regexp.MustCompile("features\\s*:\\s*(.*)")
b := r.FindSubmatch(cpuinfo)
if len(b) < 2 {
return nil, errors.New("no feature list in /proc/cpuinfo")
}
return regexp.MustCompile("\\s+").Split(string(b[1]), -1), nil
}
func TestS390XAgainstCPUInfo(t *testing.T) {
// mapping of linux feature strings to S390X fields
mapping := make(map[string]*bool)
for _, option := range Options {
mapping[option.Name] = option.Feature
}
// these must be true on the machines Go supports
mandatory := make(map[string]bool)
mandatory["zarch"] = false
mandatory["eimm"] = false
mandatory["ldisp"] = false
mandatory["stfle"] = false
features, err := getFeatureList()
if err != nil {
t.Error(err)
}
for _, feature := range features {
if _, ok := mandatory[feature]; ok {
mandatory[feature] = true
}
if flag, ok := mapping[feature]; ok {
if !*flag {
t.Errorf("feature '%v' not detected", feature)
}
} else {
t.Logf("no entry for '%v'", feature)
}
}
for k, v := range mandatory {
if !v {
t.Errorf("mandatory feature '%v' not detected", k)
}
}
}

View File

@ -0,0 +1,83 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu_test
import (
. "internal/cpu"
"internal/testenv"
"os"
"os/exec"
"runtime"
"strings"
"testing"
)
func TestMinimalFeatures(t *testing.T) {
// TODO: maybe do MustSupportFeatureDectection(t) ?
if runtime.GOARCH == "arm64" {
switch runtime.GOOS {
case "linux", "android", "darwin":
default:
t.Skipf("%s/%s is not supported", runtime.GOOS, runtime.GOARCH)
}
}
for _, o := range Options {
if o.Required && !*o.Feature {
t.Errorf("%v expected true, got false", o.Name)
}
}
}
func MustHaveDebugOptionsSupport(t *testing.T) {
if !DebugOptions {
t.Skipf("skipping test: cpu feature options not supported by OS")
}
}
func MustSupportFeatureDectection(t *testing.T) {
// TODO: add platforms that do not have CPU feature detection support.
}
func runDebugOptionsTest(t *testing.T, test string, options string) {
MustHaveDebugOptionsSupport(t)
testenv.MustHaveExec(t)
env := "GODEBUG=" + options
cmd := exec.Command(os.Args[0], "-test.run="+test)
cmd.Env = append(cmd.Env, env)
output, err := cmd.CombinedOutput()
lines := strings.Fields(string(output))
lastline := lines[len(lines)-1]
got := strings.TrimSpace(lastline)
want := "PASS"
if err != nil || got != want {
t.Fatalf("%s with %s: want %s, got %v", test, env, want, got)
}
}
func TestDisableAllCapabilities(t *testing.T) {
MustSupportFeatureDectection(t)
runDebugOptionsTest(t, "TestAllCapabilitiesDisabled", "cpu.all=off")
}
func TestAllCapabilitiesDisabled(t *testing.T) {
MustHaveDebugOptionsSupport(t)
if os.Getenv("GODEBUG") != "cpu.all=off" {
t.Skipf("skipping test: GODEBUG=cpu.all=off not set")
}
for _, o := range Options {
want := o.Required
if got := *o.Feature; got != want {
t.Errorf("%v: expected %v, got %v", o.Name, want, got)
}
}
}

View File

@ -0,0 +1,10 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
const CacheLinePadSize = 64
func doinit() {
}

View File

@ -0,0 +1,163 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build 386 amd64
package cpu
const CacheLinePadSize = 64
// cpuid is implemented in cpu_x86.s.
func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
// xgetbv with ecx = 0 is implemented in cpu_x86.s.
func xgetbv() (eax, edx uint32)
const (
// edx bits
cpuid_SSE2 = 1 << 26
// ecx bits
cpuid_SSE3 = 1 << 0
cpuid_PCLMULQDQ = 1 << 1
cpuid_SSSE3 = 1 << 9
cpuid_FMA = 1 << 12
cpuid_SSE41 = 1 << 19
cpuid_SSE42 = 1 << 20
cpuid_POPCNT = 1 << 23
cpuid_AES = 1 << 25
cpuid_OSXSAVE = 1 << 27
cpuid_AVX = 1 << 28
// ebx bits
cpuid_BMI1 = 1 << 3
cpuid_AVX2 = 1 << 5
cpuid_BMI2 = 1 << 8
cpuid_ERMS = 1 << 9
cpuid_ADX = 1 << 19
)
var maxExtendedFunctionInformation uint32
func doinit() {
options = []option{
{Name: "adx", Feature: &X86.HasADX},
{Name: "aes", Feature: &X86.HasAES},
{Name: "avx", Feature: &X86.HasAVX},
{Name: "avx2", Feature: &X86.HasAVX2},
{Name: "bmi1", Feature: &X86.HasBMI1},
{Name: "bmi2", Feature: &X86.HasBMI2},
{Name: "erms", Feature: &X86.HasERMS},
{Name: "fma", Feature: &X86.HasFMA},
{Name: "pclmulqdq", Feature: &X86.HasPCLMULQDQ},
{Name: "popcnt", Feature: &X86.HasPOPCNT},
{Name: "sse3", Feature: &X86.HasSSE3},
{Name: "sse41", Feature: &X86.HasSSE41},
{Name: "sse42", Feature: &X86.HasSSE42},
{Name: "ssse3", Feature: &X86.HasSSSE3},
// These capabilities should always be enabled on amd64:
{Name: "sse2", Feature: &X86.HasSSE2, Required: GOARCH == "amd64"},
}
maxID, _, _, _ := cpuid(0, 0)
if maxID < 1 {
return
}
maxExtendedFunctionInformation, _, _, _ = cpuid(0x80000000, 0)
_, _, ecx1, edx1 := cpuid(1, 0)
X86.HasSSE2 = isSet(edx1, cpuid_SSE2)
X86.HasSSE3 = isSet(ecx1, cpuid_SSE3)
X86.HasPCLMULQDQ = isSet(ecx1, cpuid_PCLMULQDQ)
X86.HasSSSE3 = isSet(ecx1, cpuid_SSSE3)
X86.HasSSE41 = isSet(ecx1, cpuid_SSE41)
X86.HasSSE42 = isSet(ecx1, cpuid_SSE42)
X86.HasPOPCNT = isSet(ecx1, cpuid_POPCNT)
X86.HasAES = isSet(ecx1, cpuid_AES)
// OSXSAVE can be false when using older Operating Systems
// or when explicitly disabled on newer Operating Systems by
// e.g. setting the xsavedisable boot option on Windows 10.
X86.HasOSXSAVE = isSet(ecx1, cpuid_OSXSAVE)
// The FMA instruction set extension only has VEX prefixed instructions.
// VEX prefixed instructions require OSXSAVE to be enabled.
// See Intel 64 and IA-32 Architecture Software Developers Manual Volume 2
// Section 2.4 "AVX and SSE Instruction Exception Specification"
X86.HasFMA = isSet(ecx1, cpuid_FMA) && X86.HasOSXSAVE
osSupportsAVX := false
// For XGETBV, OSXSAVE bit is required and sufficient.
if X86.HasOSXSAVE {
eax, _ := xgetbv()
// Check if XMM and YMM registers have OS support.
osSupportsAVX = isSet(eax, 1<<1) && isSet(eax, 1<<2)
}
X86.HasAVX = isSet(ecx1, cpuid_AVX) && osSupportsAVX
if maxID < 7 {
return
}
_, ebx7, _, _ := cpuid(7, 0)
X86.HasBMI1 = isSet(ebx7, cpuid_BMI1)
X86.HasAVX2 = isSet(ebx7, cpuid_AVX2) && osSupportsAVX
X86.HasBMI2 = isSet(ebx7, cpuid_BMI2)
X86.HasERMS = isSet(ebx7, cpuid_ERMS)
X86.HasADX = isSet(ebx7, cpuid_ADX)
}
func isSet(hwc uint32, value uint32) bool {
return hwc&value != 0
}
// Name returns the CPU name given by the vendor.
// If the CPU name can not be determined an
// empty string is returned.
func Name() string {
if maxExtendedFunctionInformation < 0x80000004 {
return ""
}
data := make([]byte, 0, 3*4*4)
var eax, ebx, ecx, edx uint32
eax, ebx, ecx, edx = cpuid(0x80000002, 0)
data = appendBytes(data, eax, ebx, ecx, edx)
eax, ebx, ecx, edx = cpuid(0x80000003, 0)
data = appendBytes(data, eax, ebx, ecx, edx)
eax, ebx, ecx, edx = cpuid(0x80000004, 0)
data = appendBytes(data, eax, ebx, ecx, edx)
// Trim leading spaces.
for len(data) > 0 && data[0] == ' ' {
data = data[1:]
}
// Trim tail after and including the first null byte.
for i, c := range data {
if c == '\x00' {
data = data[:i]
break
}
}
return string(data)
}
func appendBytes(b []byte, args ...uint32) []byte {
for _, arg := range args {
b = append(b,
byte((arg >> 0)),
byte((arg >> 8)),
byte((arg >> 16)),
byte((arg >> 24)))
}
return b
}

View File

@ -0,0 +1,26 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build 386 amd64
#include "textflag.h"
// func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
TEXT ·cpuid(SB), NOSPLIT, $0-24
MOVL eaxArg+0(FP), AX
MOVL ecxArg+4(FP), CX
CPUID
MOVL AX, eax+8(FP)
MOVL BX, ebx+12(FP)
MOVL CX, ecx+16(FP)
MOVL DX, edx+20(FP)
RET
// func xgetbv() (eax, edx uint32)
TEXT ·xgetbv(SB),NOSPLIT,$0-8
MOVL $0, CX
XGETBV
MOVL AX, eax+0(FP)
MOVL DX, edx+4(FP)
RET

View File

@ -0,0 +1,54 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build 386 amd64
package cpu_test
import (
. "internal/cpu"
"os"
"runtime"
"testing"
)
func TestX86ifAVX2hasAVX(t *testing.T) {
if X86.HasAVX2 && !X86.HasAVX {
t.Fatalf("HasAVX expected true when HasAVX2 is true, got false")
}
}
func TestDisableSSE2(t *testing.T) {
runDebugOptionsTest(t, "TestSSE2DebugOption", "cpu.sse2=off")
}
func TestSSE2DebugOption(t *testing.T) {
MustHaveDebugOptionsSupport(t)
if os.Getenv("GODEBUG") != "cpu.sse2=off" {
t.Skipf("skipping test: GODEBUG=cpu.sse2=off not set")
}
want := runtime.GOARCH != "386" // SSE2 can only be disabled on 386.
if got := X86.HasSSE2; got != want {
t.Errorf("X86.HasSSE2 on %s expected %v, got %v", runtime.GOARCH, want, got)
}
}
func TestDisableSSE3(t *testing.T) {
runDebugOptionsTest(t, "TestSSE3DebugOption", "cpu.sse3=off")
}
func TestSSE3DebugOption(t *testing.T) {
MustHaveDebugOptionsSupport(t)
if os.Getenv("GODEBUG") != "cpu.sse3=off" {
t.Skipf("skipping test: GODEBUG=cpu.sse3=off not set")
}
want := false
if got := X86.HasSSE3; got != want {
t.Errorf("X86.HasSSE3 expected %v, got %v", want, got)
}
}

View File

@ -0,0 +1,9 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
var (
Options = options
)

View File

@ -0,0 +1,45 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package nettrace contains internal hooks for tracing activity in
// the net package. This package is purely internal for use by the
// net/http/httptrace package and has no stable API exposed to end
// users.
package nettrace
// TraceKey is a context.Context Value key. Its associated value should
// be a *Trace struct.
type TraceKey struct{}
// LookupIPAltResolverKey is a context.Context Value key used by tests to
// specify an alternate resolver func.
// It is not exposed to outsider users. (But see issue 12503)
// The value should be the same type as lookupIP:
// func lookupIP(ctx context.Context, host string) ([]IPAddr, error)
type LookupIPAltResolverKey struct{}
// Trace contains a set of hooks for tracing events within
// the net package. Any specific hook may be nil.
type Trace struct {
// DNSStart is called with the hostname of a DNS lookup
// before it begins.
DNSStart func(name string)
// DNSDone is called after a DNS lookup completes (or fails).
// The coalesced parameter is whether singleflight de-dupped
// the call. The addrs are of type net.IPAddr but can't
// actually be for circular dependency reasons.
DNSDone func(netIPs []interface{}, coalesced bool, err error)
// ConnectStart is called before a Dial, excluding Dials made
// during DNS lookups. In the case of DualStack (Happy Eyeballs)
// dialing, this may be called multiple times, from multiple
// goroutines.
ConnectStart func(network, addr string)
// ConnectStart is called after a Dial with the results, excluding
// Dials made during DNS lookups. It may also be called multiple
// times, like ConnectStart.
ConnectDone func(network, addr string, err error)
}

View File

@ -0,0 +1,308 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package testenv provides information about what functionality
// is available in different testing environments run by the Go team.
//
// It is an internal package because these details are specific
// to the Go team's test setup (on build.golang.org) and not
// fundamental to tests in general.
package testenv
import (
"errors"
"flag"
"internal/cfg"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"testing"
)
// Builder reports the name of the builder running this test
// (for example, "linux-amd64" or "windows-386-gce").
// If the test is not running on the build infrastructure,
// Builder returns the empty string.
func Builder() string {
return os.Getenv("GO_BUILDER_NAME")
}
// HasGoBuild reports whether the current system can build programs with ``go build''
// and then run them with os.StartProcess or exec.Command.
func HasGoBuild() bool {
if os.Getenv("GO_GCFLAGS") != "" {
// It's too much work to require every caller of the go command
// to pass along "-gcflags="+os.Getenv("GO_GCFLAGS").
// For now, if $GO_GCFLAGS is set, report that we simply can't
// run go build.
return false
}
switch runtime.GOOS {
case "android", "js", "ios":
return false
}
return true
}
// MustHaveGoBuild checks that the current system can build programs with ``go build''
// and then run them with os.StartProcess or exec.Command.
// If not, MustHaveGoBuild calls t.Skip with an explanation.
func MustHaveGoBuild(t testing.TB) {
if os.Getenv("GO_GCFLAGS") != "" {
t.Skipf("skipping test: 'go build' not compatible with setting $GO_GCFLAGS")
}
if !HasGoBuild() {
t.Skipf("skipping test: 'go build' not available on %s/%s", runtime.GOOS, runtime.GOARCH)
}
}
// HasGoRun reports whether the current system can run programs with ``go run.''
func HasGoRun() bool {
// For now, having go run and having go build are the same.
return HasGoBuild()
}
// MustHaveGoRun checks that the current system can run programs with ``go run.''
// If not, MustHaveGoRun calls t.Skip with an explanation.
func MustHaveGoRun(t testing.TB) {
if !HasGoRun() {
t.Skipf("skipping test: 'go run' not available on %s/%s", runtime.GOOS, runtime.GOARCH)
}
}
// GoToolPath reports the path to the Go tool.
// It is a convenience wrapper around GoTool.
// If the tool is unavailable GoToolPath calls t.Skip.
// If the tool should be available and isn't, GoToolPath calls t.Fatal.
func GoToolPath(t testing.TB) string {
MustHaveGoBuild(t)
path, err := GoTool()
if err != nil {
t.Fatal(err)
}
// Add all environment variables that affect the Go command to test metadata.
// Cached test results will be invalidate when these variables change.
// See golang.org/issue/32285.
for _, envVar := range strings.Fields(cfg.KnownEnv) {
os.Getenv(envVar)
}
return path
}
// GoTool reports the path to the Go tool.
func GoTool() (string, error) {
if !HasGoBuild() {
return "", errors.New("platform cannot run go tool")
}
var exeSuffix string
if runtime.GOOS == "windows" {
exeSuffix = ".exe"
}
path := filepath.Join(runtime.GOROOT(), "bin", "go"+exeSuffix)
if _, err := os.Stat(path); err == nil {
return path, nil
}
goBin, err := exec.LookPath("go" + exeSuffix)
if err != nil {
return "", errors.New("cannot find go tool: " + err.Error())
}
return goBin, nil
}
// HasExec reports whether the current system can start new processes
// using os.StartProcess or (more commonly) exec.Command.
func HasExec() bool {
switch runtime.GOOS {
case "js", "ios":
return false
}
return true
}
// HasSrc reports whether the entire source tree is available under GOROOT.
func HasSrc() bool {
switch runtime.GOOS {
case "ios":
return false
}
return true
}
// MustHaveExec checks that the current system can start new processes
// using os.StartProcess or (more commonly) exec.Command.
// If not, MustHaveExec calls t.Skip with an explanation.
func MustHaveExec(t testing.TB) {
if !HasExec() {
t.Skipf("skipping test: cannot exec subprocess on %s/%s", runtime.GOOS, runtime.GOARCH)
}
}
var execPaths sync.Map // path -> error
// MustHaveExecPath checks that the current system can start the named executable
// using os.StartProcess or (more commonly) exec.Command.
// If not, MustHaveExecPath calls t.Skip with an explanation.
func MustHaveExecPath(t testing.TB, path string) {
MustHaveExec(t)
err, found := execPaths.Load(path)
if !found {
_, err = exec.LookPath(path)
err, _ = execPaths.LoadOrStore(path, err)
}
if err != nil {
t.Skipf("skipping test: %s: %s", path, err)
}
}
// HasExternalNetwork reports whether the current system can use
// external (non-localhost) networks.
func HasExternalNetwork() bool {
return !testing.Short() && runtime.GOOS != "js"
}
// MustHaveExternalNetwork checks that the current system can use
// external (non-localhost) networks.
// If not, MustHaveExternalNetwork calls t.Skip with an explanation.
func MustHaveExternalNetwork(t testing.TB) {
if runtime.GOOS == "js" {
t.Skipf("skipping test: no external network on %s", runtime.GOOS)
}
if testing.Short() {
t.Skipf("skipping test: no external network in -short mode")
}
}
var haveCGO bool
// HasCGO reports whether the current system can use cgo.
func HasCGO() bool {
return haveCGO
}
// MustHaveCGO calls t.Skip if cgo is not available.
func MustHaveCGO(t testing.TB) {
if !haveCGO {
t.Skipf("skipping test: no cgo")
}
}
// CanInternalLink reports whether the current system can link programs with
// internal linking.
// (This is the opposite of cmd/internal/sys.MustLinkExternal. Keep them in sync.)
func CanInternalLink() bool {
switch runtime.GOOS {
case "android":
if runtime.GOARCH != "arm64" {
return false
}
case "ios":
if runtime.GOARCH == "arm64" {
return false
}
}
return true
}
// MustInternalLink checks that the current system can link programs with internal
// linking.
// If not, MustInternalLink calls t.Skip with an explanation.
func MustInternalLink(t testing.TB) {
if !CanInternalLink() {
t.Skipf("skipping test: internal linking on %s/%s is not supported", runtime.GOOS, runtime.GOARCH)
}
}
// HasSymlink reports whether the current system can use os.Symlink.
func HasSymlink() bool {
ok, _ := hasSymlink()
return ok
}
// MustHaveSymlink reports whether the current system can use os.Symlink.
// If not, MustHaveSymlink calls t.Skip with an explanation.
func MustHaveSymlink(t testing.TB) {
ok, reason := hasSymlink()
if !ok {
t.Skipf("skipping test: cannot make symlinks on %s/%s%s", runtime.GOOS, runtime.GOARCH, reason)
}
}
// HasLink reports whether the current system can use os.Link.
func HasLink() bool {
// From Android release M (Marshmallow), hard linking files is blocked
// and an attempt to call link() on a file will return EACCES.
// - https://code.google.com/p/android-developer-preview/issues/detail?id=3150
return runtime.GOOS != "plan9" && runtime.GOOS != "android"
}
// MustHaveLink reports whether the current system can use os.Link.
// If not, MustHaveLink calls t.Skip with an explanation.
func MustHaveLink(t testing.TB) {
if !HasLink() {
t.Skipf("skipping test: hardlinks are not supported on %s/%s", runtime.GOOS, runtime.GOARCH)
}
}
var flaky = flag.Bool("flaky", false, "run known-flaky tests too")
func SkipFlaky(t testing.TB, issue int) {
t.Helper()
if !*flaky {
t.Skipf("skipping known flaky test without the -flaky flag; see golang.org/issue/%d", issue)
}
}
func SkipFlakyNet(t testing.TB) {
t.Helper()
if v, _ := strconv.ParseBool(os.Getenv("GO_BUILDER_FLAKY_NET")); v {
t.Skip("skipping test on builder known to have frequent network failures")
}
}
// CleanCmdEnv will fill cmd.Env with the environment, excluding certain
// variables that could modify the behavior of the Go tools such as
// GODEBUG and GOTRACEBACK.
func CleanCmdEnv(cmd *exec.Cmd) *exec.Cmd {
if cmd.Env != nil {
panic("environment already set")
}
for _, env := range os.Environ() {
// Exclude GODEBUG from the environment to prevent its output
// from breaking tests that are trying to parse other command output.
if strings.HasPrefix(env, "GODEBUG=") {
continue
}
// Exclude GOTRACEBACK for the same reason.
if strings.HasPrefix(env, "GOTRACEBACK=") {
continue
}
cmd.Env = append(cmd.Env, env)
}
return cmd
}
// CPUIsSlow reports whether the CPU running the test is suspected to be slow.
func CPUIsSlow() bool {
switch runtime.GOARCH {
case "arm", "mips", "mipsle", "mips64", "mips64le":
return true
}
return false
}
// SkipIfShortAndSlow skips t if -short is set and the CPU running the test is
// suspected to be slow.
//
// (This is useful for CPU-intensive tests that otherwise complete quickly.)
func SkipIfShortAndSlow(t testing.TB) {
if testing.Short() && CPUIsSlow() {
t.Helper()
t.Skipf("skipping test in -short mode on %s", runtime.GOARCH)
}
}

View File

@ -0,0 +1,11 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build cgo
package testenv
func init() {
haveCGO = true
}

View File

@ -0,0 +1,20 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !windows
package testenv
import (
"runtime"
)
func hasSymlink() (ok bool, reason string) {
switch runtime.GOOS {
case "android", "plan9":
return false, ""
}
return true, ""
}

View File

@ -0,0 +1,47 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testenv
import (
"os"
"path/filepath"
"sync"
"syscall"
)
var symlinkOnce sync.Once
var winSymlinkErr error
func initWinHasSymlink() {
tmpdir, err := os.MkdirTemp("", "symtest")
if err != nil {
panic("failed to create temp directory: " + err.Error())
}
defer os.RemoveAll(tmpdir)
err = os.Symlink("target", filepath.Join(tmpdir, "symlink"))
if err != nil {
err = err.(*os.LinkError).Err
switch err {
case syscall.EWINDOWS, syscall.ERROR_PRIVILEGE_NOT_HELD:
winSymlinkErr = err
}
}
}
func hasSymlink() (ok bool, reason string) {
symlinkOnce.Do(initWinHasSymlink)
switch winSymlinkErr {
case nil:
return true, ""
case syscall.EWINDOWS:
return false, ": symlinks are not supported on your version of Windows"
case syscall.ERROR_PRIVILEGE_NOT_HELD:
return false, ": you don't have enough privileges to create symlinks"
}
return false, ""
}

View File

@ -0,0 +1,132 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http_test
import (
"bufio"
"bytes"
"crypto/tls"
"crypto/x509"
"fmt"
"github.com/lesismal/llib/std/net/http/httptest"
"io"
. "net/http"
"strings"
"testing"
)
func TestNextProtoUpgrade(t *testing.T) {
setParallel(t)
defer afterTest(t)
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "path=%s,proto=", r.URL.Path)
if r.TLS != nil {
w.Write([]byte(r.TLS.NegotiatedProtocol))
}
if r.RemoteAddr == "" {
t.Error("request with no RemoteAddr")
}
if r.Body == nil {
t.Errorf("request with nil Body")
}
}))
ts.TLS = &tls.Config{
NextProtos: []string{"unhandled-proto", "tls-0.9"},
}
ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){
"tls-0.9": handleTLSProtocol09,
}
ts.StartTLS()
defer ts.Close()
// Normal request, without NPN.
{
c := ts.Client()
res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if want := "path=/,proto="; string(body) != want {
t.Errorf("plain request = %q; want %q", body, want)
}
}
// Request to an advertised but unhandled NPN protocol.
// Server will hang up.
{
certPool := x509.NewCertPool()
certPool.AddCert(ts.Certificate())
tr := &Transport{
TLSClientConfig: &tls.Config{
RootCAs: certPool,
NextProtos: []string{"unhandled-proto"},
},
}
defer tr.CloseIdleConnections()
c := &Client{
Transport: tr,
}
res, err := c.Get(ts.URL)
if err == nil {
defer res.Body.Close()
var buf bytes.Buffer
res.Write(&buf)
t.Errorf("expected error on unhandled-proto request; got: %s", buf.Bytes())
}
}
// Request using the "tls-0.9" protocol, which we register here.
// It is HTTP/0.9 over TLS.
{
c := ts.Client()
tlsConfig := c.Transport.(*Transport).TLSClientConfig
tlsConfig.NextProtos = []string{"tls-0.9"}
conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
if err != nil {
t.Fatal(err)
}
conn.Write([]byte("GET /foo\n"))
body, err := io.ReadAll(conn)
if err != nil {
t.Fatal(err)
}
if want := "path=/foo,proto=tls-0.9"; string(body) != want {
t.Errorf("plain request = %q; want %q", body, want)
}
}
}
// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the
// TestNextProtoUpgrade test.
func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) {
br := bufio.NewReader(conn)
line, err := br.ReadString('\n')
if err != nil {
return
}
line = strings.TrimSpace(line)
path := strings.TrimPrefix(line, "GET ")
if path == line {
return
}
req, _ := NewRequest("GET", path, nil)
req.Proto = "HTTP/0.9"
req.ProtoMajor = 0
req.ProtoMinor = 9
rw := &http09Writer{conn, make(Header)}
h.ServeHTTP(rw, req)
}
type http09Writer struct {
io.Writer
h Header
}
func (w http09Writer) Header() Header { return w.h }
func (w http09Writer) WriteHeader(int) {} // no headers

View File

@ -0,0 +1,220 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements CGI from the perspective of a child
// process.
package cgi
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
)
// Request returns the HTTP request as represented in the current
// environment. This assumes the current program is being run
// by a web server in a CGI environment.
// The returned Request's Body is populated, if applicable.
func Request() (*http.Request, error) {
r, err := RequestFromMap(envMap(os.Environ()))
if err != nil {
return nil, err
}
if r.ContentLength > 0 {
r.Body = io.NopCloser(io.LimitReader(os.Stdin, r.ContentLength))
}
return r, nil
}
func envMap(env []string) map[string]string {
m := make(map[string]string)
for _, kv := range env {
if idx := strings.Index(kv, "="); idx != -1 {
m[kv[:idx]] = kv[idx+1:]
}
}
return m
}
// RequestFromMap creates an http.Request from CGI variables.
// The returned Request's Body field is not populated.
func RequestFromMap(params map[string]string) (*http.Request, error) {
r := new(http.Request)
r.Method = params["REQUEST_METHOD"]
if r.Method == "" {
return nil, errors.New("cgi: no REQUEST_METHOD in environment")
}
r.Proto = params["SERVER_PROTOCOL"]
var ok bool
r.ProtoMajor, r.ProtoMinor, ok = http.ParseHTTPVersion(r.Proto)
if !ok {
return nil, errors.New("cgi: invalid SERVER_PROTOCOL version")
}
r.Close = true
r.Trailer = http.Header{}
r.Header = http.Header{}
r.Host = params["HTTP_HOST"]
if lenstr := params["CONTENT_LENGTH"]; lenstr != "" {
clen, err := strconv.ParseInt(lenstr, 10, 64)
if err != nil {
return nil, errors.New("cgi: bad CONTENT_LENGTH in environment: " + lenstr)
}
r.ContentLength = clen
}
if ct := params["CONTENT_TYPE"]; ct != "" {
r.Header.Set("Content-Type", ct)
}
// Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers
for k, v := range params {
if !strings.HasPrefix(k, "HTTP_") || k == "HTTP_HOST" {
continue
}
r.Header.Add(strings.ReplaceAll(k[5:], "_", "-"), v)
}
uriStr := params["REQUEST_URI"]
if uriStr == "" {
// Fallback to SCRIPT_NAME, PATH_INFO and QUERY_STRING.
uriStr = params["SCRIPT_NAME"] + params["PATH_INFO"]
s := params["QUERY_STRING"]
if s != "" {
uriStr += "?" + s
}
}
// There's apparently a de-facto standard for this.
// https://web.archive.org/web/20170105004655/http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636
if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" {
r.TLS = &tls.ConnectionState{HandshakeComplete: true}
}
if r.Host != "" {
// Hostname is provided, so we can reasonably construct a URL.
rawurl := r.Host + uriStr
if r.TLS == nil {
rawurl = "http://" + rawurl
} else {
rawurl = "https://" + rawurl
}
url, err := url.Parse(rawurl)
if err != nil {
return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl)
}
r.URL = url
}
// Fallback logic if we don't have a Host header or the URL
// failed to parse
if r.URL == nil {
url, err := url.Parse(uriStr)
if err != nil {
return nil, errors.New("cgi: failed to parse REQUEST_URI into a URL: " + uriStr)
}
r.URL = url
}
// Request.RemoteAddr has its port set by Go's standard http
// server, so we do here too.
remotePort, _ := strconv.Atoi(params["REMOTE_PORT"]) // zero if unset or invalid
r.RemoteAddr = net.JoinHostPort(params["REMOTE_ADDR"], strconv.Itoa(remotePort))
return r, nil
}
// Serve executes the provided Handler on the currently active CGI
// request, if any. If there's no current CGI environment
// an error is returned. The provided handler may be nil to use
// http.DefaultServeMux.
func Serve(handler http.Handler) error {
req, err := Request()
if err != nil {
return err
}
if req.Body == nil {
req.Body = http.NoBody
}
if handler == nil {
handler = http.DefaultServeMux
}
rw := &response{
req: req,
header: make(http.Header),
bufw: bufio.NewWriter(os.Stdout),
}
handler.ServeHTTP(rw, req)
rw.Write(nil) // make sure a response is sent
if err = rw.bufw.Flush(); err != nil {
return err
}
return nil
}
type response struct {
req *http.Request
header http.Header
code int
wroteHeader bool
wroteCGIHeader bool
bufw *bufio.Writer
}
func (r *response) Flush() {
r.bufw.Flush()
}
func (r *response) Header() http.Header {
return r.header
}
func (r *response) Write(p []byte) (n int, err error) {
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
if !r.wroteCGIHeader {
r.writeCGIHeader(p)
}
return r.bufw.Write(p)
}
func (r *response) WriteHeader(code int) {
if r.wroteHeader {
// Note: explicitly using Stderr, as Stdout is our HTTP output.
fmt.Fprintf(os.Stderr, "CGI attempted to write header twice on request for %s", r.req.URL)
return
}
r.wroteHeader = true
r.code = code
}
// writeCGIHeader finalizes the header sent to the client and writes it to the output.
// p is not written by writeHeader, but is the first chunk of the body
// that will be written. It is sniffed for a Content-Type if none is
// set explicitly.
func (r *response) writeCGIHeader(p []byte) {
if r.wroteCGIHeader {
return
}
r.wroteCGIHeader = true
fmt.Fprintf(r.bufw, "Status: %d %s\r\n", r.code, http.StatusText(r.code))
if _, hasType := r.header["Content-Type"]; !hasType {
r.header.Set("Content-Type", http.DetectContentType(p))
}
r.header.Write(r.bufw)
r.bufw.WriteString("\r\n")
r.bufw.Flush()
}

View File

@ -0,0 +1,208 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Tests for CGI (the child process perspective)
package cgi
import (
"bufio"
"bytes"
"github.com/lesismal/llib/std/net/http/httptest"
"net/http"
"strings"
"testing"
)
func TestRequest(t *testing.T) {
env := map[string]string{
"SERVER_PROTOCOL": "HTTP/1.1",
"REQUEST_METHOD": "GET",
"HTTP_HOST": "example.com",
"HTTP_REFERER": "elsewhere",
"HTTP_USER_AGENT": "goclient",
"HTTP_FOO_BAR": "baz",
"REQUEST_URI": "/path?a=b",
"CONTENT_LENGTH": "123",
"CONTENT_TYPE": "text/xml",
"REMOTE_ADDR": "5.6.7.8",
"REMOTE_PORT": "54321",
}
req, err := RequestFromMap(env)
if err != nil {
t.Fatalf("RequestFromMap: %v", err)
}
if g, e := req.UserAgent(), "goclient"; e != g {
t.Errorf("expected UserAgent %q; got %q", e, g)
}
if g, e := req.Method, "GET"; e != g {
t.Errorf("expected Method %q; got %q", e, g)
}
if g, e := req.Header.Get("Content-Type"), "text/xml"; e != g {
t.Errorf("expected Content-Type %q; got %q", e, g)
}
if g, e := req.ContentLength, int64(123); e != g {
t.Errorf("expected ContentLength %d; got %d", e, g)
}
if g, e := req.Referer(), "elsewhere"; e != g {
t.Errorf("expected Referer %q; got %q", e, g)
}
if req.Header == nil {
t.Fatalf("unexpected nil Header")
}
if g, e := req.Header.Get("Foo-Bar"), "baz"; e != g {
t.Errorf("expected Foo-Bar %q; got %q", e, g)
}
if g, e := req.URL.String(), "http://example.com/path?a=b"; e != g {
t.Errorf("expected URL %q; got %q", e, g)
}
if g, e := req.FormValue("a"), "b"; e != g {
t.Errorf("expected FormValue(a) %q; got %q", e, g)
}
if req.Trailer == nil {
t.Errorf("unexpected nil Trailer")
}
if req.TLS != nil {
t.Errorf("expected nil TLS")
}
if e, g := "5.6.7.8:54321", req.RemoteAddr; e != g {
t.Errorf("RemoteAddr: got %q; want %q", g, e)
}
}
func TestRequestWithTLS(t *testing.T) {
env := map[string]string{
"SERVER_PROTOCOL": "HTTP/1.1",
"REQUEST_METHOD": "GET",
"HTTP_HOST": "example.com",
"HTTP_REFERER": "elsewhere",
"REQUEST_URI": "/path?a=b",
"CONTENT_TYPE": "text/xml",
"HTTPS": "1",
"REMOTE_ADDR": "5.6.7.8",
}
req, err := RequestFromMap(env)
if err != nil {
t.Fatalf("RequestFromMap: %v", err)
}
if g, e := req.URL.String(), "https://example.com/path?a=b"; e != g {
t.Errorf("expected URL %q; got %q", e, g)
}
if req.TLS == nil {
t.Errorf("expected non-nil TLS")
}
}
func TestRequestWithoutHost(t *testing.T) {
env := map[string]string{
"SERVER_PROTOCOL": "HTTP/1.1",
"HTTP_HOST": "",
"REQUEST_METHOD": "GET",
"REQUEST_URI": "/path?a=b",
"CONTENT_LENGTH": "123",
}
req, err := RequestFromMap(env)
if err != nil {
t.Fatalf("RequestFromMap: %v", err)
}
if req.URL == nil {
t.Fatalf("unexpected nil URL")
}
if g, e := req.URL.String(), "/path?a=b"; e != g {
t.Errorf("URL = %q; want %q", g, e)
}
}
func TestRequestWithoutRequestURI(t *testing.T) {
env := map[string]string{
"SERVER_PROTOCOL": "HTTP/1.1",
"HTTP_HOST": "example.com",
"REQUEST_METHOD": "GET",
"SCRIPT_NAME": "/dir/scriptname",
"PATH_INFO": "/p1/p2",
"QUERY_STRING": "a=1&b=2",
"CONTENT_LENGTH": "123",
}
req, err := RequestFromMap(env)
if err != nil {
t.Fatalf("RequestFromMap: %v", err)
}
if req.URL == nil {
t.Fatalf("unexpected nil URL")
}
if g, e := req.URL.String(), "http://example.com/dir/scriptname/p1/p2?a=1&b=2"; e != g {
t.Errorf("URL = %q; want %q", g, e)
}
}
func TestRequestWithoutRemotePort(t *testing.T) {
env := map[string]string{
"SERVER_PROTOCOL": "HTTP/1.1",
"HTTP_HOST": "example.com",
"REQUEST_METHOD": "GET",
"REQUEST_URI": "/path?a=b",
"CONTENT_LENGTH": "123",
"REMOTE_ADDR": "5.6.7.8",
}
req, err := RequestFromMap(env)
if err != nil {
t.Fatalf("RequestFromMap: %v", err)
}
if e, g := "5.6.7.8:0", req.RemoteAddr; e != g {
t.Errorf("RemoteAddr: got %q; want %q", g, e)
}
}
func TestResponse(t *testing.T) {
var tests = []struct {
name string
body string
wantCT string
}{
{
name: "no body",
wantCT: "text/plain; charset=utf-8",
},
{
name: "html",
body: "<html><head><title>test page</title></head><body>This is a body</body></html>",
wantCT: "text/html; charset=utf-8",
},
{
name: "text",
body: strings.Repeat("gopher", 86),
wantCT: "text/plain; charset=utf-8",
},
{
name: "jpg",
body: "\xFF\xD8\xFF" + strings.Repeat("B", 1024),
wantCT: "image/jpeg",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
resp := response{
req: httptest.NewRequest("GET", "/", nil),
header: http.Header{},
bufw: bufio.NewWriter(&buf),
}
n, err := resp.Write([]byte(tt.body))
if err != nil {
t.Errorf("Write: unexpected %v", err)
}
if want := len(tt.body); n != want {
t.Errorf("reported short Write: got %v want %v", n, want)
}
resp.writeCGIHeader(nil)
resp.Flush()
if got := resp.Header().Get("Content-Type"); got != tt.wantCT {
t.Errorf("wrong content-type: got %q, want %q", got, tt.wantCT)
}
if !bytes.HasSuffix(buf.Bytes(), []byte(tt.body)) {
t.Errorf("body was not correctly written")
}
})
}
}

View File

@ -0,0 +1,408 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements the host side of CGI (being the webserver
// parent process).
// Package cgi implements CGI (Common Gateway Interface) as specified
// in RFC 3875.
//
// Note that using CGI means starting a new process to handle each
// request, which is typically less efficient than using a
// long-running server. This package is intended primarily for
// compatibility with existing systems.
package cgi
import (
"bufio"
"fmt"
"io"
"log"
"net"
"net/http"
"net/textproto"
"os"
"os/exec"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"golang.org/x/net/http/httpguts"
)
var trailingPort = regexp.MustCompile(`:([0-9]+)$`)
var osDefaultInheritEnv = func() []string {
switch runtime.GOOS {
case "darwin", "ios":
return []string{"DYLD_LIBRARY_PATH"}
case "linux", "freebsd", "netbsd", "openbsd":
return []string{"LD_LIBRARY_PATH"}
case "hpux":
return []string{"LD_LIBRARY_PATH", "SHLIB_PATH"}
case "irix":
return []string{"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"}
case "illumos", "solaris":
return []string{"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"}
case "windows":
return []string{"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"}
}
return nil
}()
// Handler runs an executable in a subprocess with a CGI environment.
type Handler struct {
Path string // path to the CGI executable
Root string // root URI prefix of handler or empty for "/"
// Dir specifies the CGI executable's working directory.
// If Dir is empty, the base directory of Path is used.
// If Path has no base directory, the current working
// directory is used.
Dir string
Env []string // extra environment variables to set, if any, as "key=value"
InheritEnv []string // environment variables to inherit from host, as "key"
Logger *log.Logger // optional log for errors or nil to use log.Print
Args []string // optional arguments to pass to child process
Stderr io.Writer // optional stderr for the child process; nil means os.Stderr
// PathLocationHandler specifies the root http Handler that
// should handle internal redirects when the CGI process
// returns a Location header value starting with a "/", as
// specified in RFC 3875 § 6.3.2. This will likely be
// http.DefaultServeMux.
//
// If nil, a CGI response with a local URI path is instead sent
// back to the client and not redirected internally.
PathLocationHandler http.Handler
}
func (h *Handler) stderr() io.Writer {
if h.Stderr != nil {
return h.Stderr
}
return os.Stderr
}
// removeLeadingDuplicates remove leading duplicate in environments.
// It's possible to override environment like following.
// cgi.Handler{
// ...
// Env: []string{"SCRIPT_FILENAME=foo.php"},
// }
func removeLeadingDuplicates(env []string) (ret []string) {
for i, e := range env {
found := false
if eq := strings.IndexByte(e, '='); eq != -1 {
keq := e[:eq+1] // "key="
for _, e2 := range env[i+1:] {
if strings.HasPrefix(e2, keq) {
found = true
break
}
}
}
if !found {
ret = append(ret, e)
}
}
return
}
func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
root := h.Root
if root == "" {
root = "/"
}
if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" {
rw.WriteHeader(http.StatusBadRequest)
rw.Write([]byte("Chunked request bodies are not supported by CGI."))
return
}
pathInfo := req.URL.Path
if root != "/" && strings.HasPrefix(pathInfo, root) {
pathInfo = pathInfo[len(root):]
}
port := "80"
if matches := trailingPort.FindStringSubmatch(req.Host); len(matches) != 0 {
port = matches[1]
}
env := []string{
"SERVER_SOFTWARE=go",
"SERVER_NAME=" + req.Host,
"SERVER_PROTOCOL=HTTP/1.1",
"HTTP_HOST=" + req.Host,
"GATEWAY_INTERFACE=CGI/1.1",
"REQUEST_METHOD=" + req.Method,
"QUERY_STRING=" + req.URL.RawQuery,
"REQUEST_URI=" + req.URL.RequestURI(),
"PATH_INFO=" + pathInfo,
"SCRIPT_NAME=" + root,
"SCRIPT_FILENAME=" + h.Path,
"SERVER_PORT=" + port,
}
if remoteIP, remotePort, err := net.SplitHostPort(req.RemoteAddr); err == nil {
env = append(env, "REMOTE_ADDR="+remoteIP, "REMOTE_HOST="+remoteIP, "REMOTE_PORT="+remotePort)
} else {
// could not parse ip:port, let's use whole RemoteAddr and leave REMOTE_PORT undefined
env = append(env, "REMOTE_ADDR="+req.RemoteAddr, "REMOTE_HOST="+req.RemoteAddr)
}
if req.TLS != nil {
env = append(env, "HTTPS=on")
}
for k, v := range req.Header {
k = strings.Map(upperCaseAndUnderscore, k)
if k == "PROXY" {
// See Issue 16405
continue
}
joinStr := ", "
if k == "COOKIE" {
joinStr = "; "
}
env = append(env, "HTTP_"+k+"="+strings.Join(v, joinStr))
}
if req.ContentLength > 0 {
env = append(env, fmt.Sprintf("CONTENT_LENGTH=%d", req.ContentLength))
}
if ctype := req.Header.Get("Content-Type"); ctype != "" {
env = append(env, "CONTENT_TYPE="+ctype)
}
envPath := os.Getenv("PATH")
if envPath == "" {
envPath = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin"
}
env = append(env, "PATH="+envPath)
for _, e := range h.InheritEnv {
if v := os.Getenv(e); v != "" {
env = append(env, e+"="+v)
}
}
for _, e := range osDefaultInheritEnv {
if v := os.Getenv(e); v != "" {
env = append(env, e+"="+v)
}
}
if h.Env != nil {
env = append(env, h.Env...)
}
env = removeLeadingDuplicates(env)
var cwd, path string
if h.Dir != "" {
path = h.Path
cwd = h.Dir
} else {
cwd, path = filepath.Split(h.Path)
}
if cwd == "" {
cwd = "."
}
internalError := func(err error) {
rw.WriteHeader(http.StatusInternalServerError)
h.printf("CGI error: %v", err)
}
cmd := &exec.Cmd{
Path: path,
Args: append([]string{h.Path}, h.Args...),
Dir: cwd,
Env: env,
Stderr: h.stderr(),
}
if req.ContentLength != 0 {
cmd.Stdin = req.Body
}
stdoutRead, err := cmd.StdoutPipe()
if err != nil {
internalError(err)
return
}
err = cmd.Start()
if err != nil {
internalError(err)
return
}
if hook := testHookStartProcess; hook != nil {
hook(cmd.Process)
}
defer cmd.Wait()
defer stdoutRead.Close()
linebody := bufio.NewReaderSize(stdoutRead, 1024)
headers := make(http.Header)
statusCode := 0
headerLines := 0
sawBlankLine := false
for {
line, isPrefix, err := linebody.ReadLine()
if isPrefix {
rw.WriteHeader(http.StatusInternalServerError)
h.printf("cgi: long header line from subprocess.")
return
}
if err == io.EOF {
break
}
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
h.printf("cgi: error reading headers: %v", err)
return
}
if len(line) == 0 {
sawBlankLine = true
break
}
headerLines++
parts := strings.SplitN(string(line), ":", 2)
if len(parts) < 2 {
h.printf("cgi: bogus header line: %s", string(line))
continue
}
header, val := parts[0], parts[1]
if !httpguts.ValidHeaderFieldName(header) {
h.printf("cgi: invalid header name: %q", header)
continue
}
val = textproto.TrimString(val)
switch {
case header == "Status":
if len(val) < 3 {
h.printf("cgi: bogus status (short): %q", val)
return
}
code, err := strconv.Atoi(val[0:3])
if err != nil {
h.printf("cgi: bogus status: %q", val)
h.printf("cgi: line was %q", line)
return
}
statusCode = code
default:
headers.Add(header, val)
}
}
if headerLines == 0 || !sawBlankLine {
rw.WriteHeader(http.StatusInternalServerError)
h.printf("cgi: no headers")
return
}
if loc := headers.Get("Location"); loc != "" {
if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil {
h.handleInternalRedirect(rw, req, loc)
return
}
if statusCode == 0 {
statusCode = http.StatusFound
}
}
if statusCode == 0 && headers.Get("Content-Type") == "" {
rw.WriteHeader(http.StatusInternalServerError)
h.printf("cgi: missing required Content-Type in headers")
return
}
if statusCode == 0 {
statusCode = http.StatusOK
}
// Copy headers to rw's headers, after we've decided not to
// go into handleInternalRedirect, which won't want its rw
// headers to have been touched.
for k, vv := range headers {
for _, v := range vv {
rw.Header().Add(k, v)
}
}
rw.WriteHeader(statusCode)
_, err = io.Copy(rw, linebody)
if err != nil {
h.printf("cgi: copy error: %v", err)
// And kill the child CGI process so we don't hang on
// the deferred cmd.Wait above if the error was just
// the client (rw) going away. If it was a read error
// (because the child died itself), then the extra
// kill of an already-dead process is harmless (the PID
// won't be reused until the Wait above).
cmd.Process.Kill()
}
}
func (h *Handler) printf(format string, v ...interface{}) {
if h.Logger != nil {
h.Logger.Printf(format, v...)
} else {
log.Printf(format, v...)
}
}
func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) {
url, err := req.URL.Parse(path)
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
h.printf("cgi: error resolving local URI path %q: %v", path, err)
return
}
// TODO: RFC 3875 isn't clear if only GET is supported, but it
// suggests so: "Note that any message-body attached to the
// request (such as for a POST request) may not be available
// to the resource that is the target of the redirect." We
// should do some tests against Apache to see how it handles
// POST, HEAD, etc. Does the internal redirect get the same
// method or just GET? What about incoming headers?
// (e.g. Cookies) Which headers, if any, are copied into the
// second request?
newReq := &http.Request{
Method: "GET",
URL: url,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Host: url.Host,
RemoteAddr: req.RemoteAddr,
TLS: req.TLS,
}
h.PathLocationHandler.ServeHTTP(rw, newReq)
}
func upperCaseAndUnderscore(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r - ('a' - 'A')
case r == '-':
return '_'
case r == '=':
// Maybe not part of the CGI 'spec' but would mess up
// the environment in any case, as Go represents the
// environment as a slice of "key=value" strings.
return '_'
}
// TODO: other transformations in spec or practice?
return r
}
var testHookStartProcess func(*os.Process) // nil except for some tests

View File

@ -0,0 +1,578 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Tests for package cgi
package cgi
import (
"bufio"
"bytes"
"fmt"
"github.com/lesismal/llib/std/net/http/httptest"
"io"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"testing"
"time"
)
func newRequest(httpreq string) *http.Request {
buf := bufio.NewReader(strings.NewReader(httpreq))
req, err := http.ReadRequest(buf)
if err != nil {
panic("cgi: bogus http request in test: " + httpreq)
}
req.RemoteAddr = "1.2.3.4:1234"
return req
}
func runCgiTest(t *testing.T, h *Handler,
httpreq string,
expectedMap map[string]string, checks ...func(reqInfo map[string]string)) *httptest.ResponseRecorder {
rw := httptest.NewRecorder()
req := newRequest(httpreq)
h.ServeHTTP(rw, req)
runResponseChecks(t, rw, expectedMap, checks...)
return rw
}
func runResponseChecks(t *testing.T, rw *httptest.ResponseRecorder,
expectedMap map[string]string, checks ...func(reqInfo map[string]string)) {
// Make a map to hold the test map that the CGI returns.
m := make(map[string]string)
m["_body"] = rw.Body.String()
linesRead := 0
readlines:
for {
line, err := rw.Body.ReadString('\n')
switch {
case err == io.EOF:
break readlines
case err != nil:
t.Fatalf("unexpected error reading from CGI: %v", err)
}
linesRead++
trimmedLine := strings.TrimRight(line, "\r\n")
split := strings.SplitN(trimmedLine, "=", 2)
if len(split) != 2 {
t.Fatalf("Unexpected %d parts from invalid line number %v: %q; existing map=%v",
len(split), linesRead, line, m)
}
m[split[0]] = split[1]
}
for key, expected := range expectedMap {
got := m[key]
if key == "cwd" {
// For Windows. golang.org/issue/4645.
fi1, _ := os.Stat(got)
fi2, _ := os.Stat(expected)
if os.SameFile(fi1, fi2) {
got = expected
}
}
if got != expected {
t.Errorf("for key %q got %q; expected %q", key, got, expected)
}
}
for _, check := range checks {
check(m)
}
}
var cgiTested, cgiWorks bool
func check(t *testing.T) {
if !cgiTested {
cgiTested = true
cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil
}
if !cgiWorks {
// No Perl on Windows, needed by test.cgi
// TODO: make the child process be Go, not Perl.
t.Skip("Skipping test: test.cgi failed.")
}
}
func TestCGIBasicGet(t *testing.T) {
check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
}
expectedMap := map[string]string{
"test": "Hello CGI",
"param-a": "b",
"param-foo": "bar",
"env-GATEWAY_INTERFACE": "CGI/1.1",
"env-HTTP_HOST": "example.com",
"env-PATH_INFO": "",
"env-QUERY_STRING": "foo=bar&a=b",
"env-REMOTE_ADDR": "1.2.3.4",
"env-REMOTE_HOST": "1.2.3.4",
"env-REMOTE_PORT": "1234",
"env-REQUEST_METHOD": "GET",
"env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
"env-SCRIPT_FILENAME": "testdata/test.cgi",
"env-SCRIPT_NAME": "/test.cgi",
"env-SERVER_NAME": "example.com",
"env-SERVER_PORT": "80",
"env-SERVER_SOFTWARE": "go",
}
replay := runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
if expected, got := "text/html", replay.Header().Get("Content-Type"); got != expected {
t.Errorf("got a Content-Type of %q; expected %q", got, expected)
}
if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
}
}
func TestCGIEnvIPv6(t *testing.T) {
check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
}
expectedMap := map[string]string{
"test": "Hello CGI",
"param-a": "b",
"param-foo": "bar",
"env-GATEWAY_INTERFACE": "CGI/1.1",
"env-HTTP_HOST": "example.com",
"env-PATH_INFO": "",
"env-QUERY_STRING": "foo=bar&a=b",
"env-REMOTE_ADDR": "2000::3000",
"env-REMOTE_HOST": "2000::3000",
"env-REMOTE_PORT": "12345",
"env-REQUEST_METHOD": "GET",
"env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
"env-SCRIPT_FILENAME": "testdata/test.cgi",
"env-SCRIPT_NAME": "/test.cgi",
"env-SERVER_NAME": "example.com",
"env-SERVER_PORT": "80",
"env-SERVER_SOFTWARE": "go",
}
rw := httptest.NewRecorder()
req := newRequest("GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n")
req.RemoteAddr = "[2000::3000]:12345"
h.ServeHTTP(rw, req)
runResponseChecks(t, rw, expectedMap)
}
func TestCGIBasicGetAbsPath(t *testing.T) {
check(t)
pwd, err := os.Getwd()
if err != nil {
t.Fatalf("getwd error: %v", err)
}
h := &Handler{
Path: pwd + "/testdata/test.cgi",
Root: "/test.cgi",
}
expectedMap := map[string]string{
"env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
"env-SCRIPT_FILENAME": pwd + "/testdata/test.cgi",
"env-SCRIPT_NAME": "/test.cgi",
}
runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
func TestPathInfo(t *testing.T) {
check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
}
expectedMap := map[string]string{
"param-a": "b",
"env-PATH_INFO": "/extrapath",
"env-QUERY_STRING": "a=b",
"env-REQUEST_URI": "/test.cgi/extrapath?a=b",
"env-SCRIPT_FILENAME": "testdata/test.cgi",
"env-SCRIPT_NAME": "/test.cgi",
}
runCgiTest(t, h, "GET /test.cgi/extrapath?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
func TestPathInfoDirRoot(t *testing.T) {
check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/myscript/",
}
expectedMap := map[string]string{
"env-PATH_INFO": "bar",
"env-QUERY_STRING": "a=b",
"env-REQUEST_URI": "/myscript/bar?a=b",
"env-SCRIPT_FILENAME": "testdata/test.cgi",
"env-SCRIPT_NAME": "/myscript/",
}
runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
func TestDupHeaders(t *testing.T) {
check(t)
h := &Handler{
Path: "testdata/test.cgi",
}
expectedMap := map[string]string{
"env-REQUEST_URI": "/myscript/bar?a=b",
"env-SCRIPT_FILENAME": "testdata/test.cgi",
"env-HTTP_COOKIE": "nom=NOM; yum=YUM",
"env-HTTP_X_FOO": "val1, val2",
}
runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+
"Cookie: nom=NOM\n"+
"Cookie: yum=YUM\n"+
"X-Foo: val1\n"+
"X-Foo: val2\n"+
"Host: example.com\n\n",
expectedMap)
}
// Issue 16405: CGI+http.Transport differing uses of HTTP_PROXY.
// Verify we don't set the HTTP_PROXY environment variable.
// Hope nobody was depending on it. It's not a known header, though.
func TestDropProxyHeader(t *testing.T) {
check(t)
h := &Handler{
Path: "testdata/test.cgi",
}
expectedMap := map[string]string{
"env-REQUEST_URI": "/myscript/bar?a=b",
"env-SCRIPT_FILENAME": "testdata/test.cgi",
"env-HTTP_X_FOO": "a",
}
runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+
"X-Foo: a\n"+
"Proxy: should_be_stripped\n"+
"Host: example.com\n\n",
expectedMap,
func(reqInfo map[string]string) {
if v, ok := reqInfo["env-HTTP_PROXY"]; ok {
t.Errorf("HTTP_PROXY = %q; should be absent", v)
}
})
}
func TestPathInfoNoRoot(t *testing.T) {
check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "",
}
expectedMap := map[string]string{
"env-PATH_INFO": "/bar",
"env-QUERY_STRING": "a=b",
"env-REQUEST_URI": "/bar?a=b",
"env-SCRIPT_FILENAME": "testdata/test.cgi",
"env-SCRIPT_NAME": "/",
}
runCgiTest(t, h, "GET /bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
func TestCGIBasicPost(t *testing.T) {
check(t)
postReq := `POST /test.cgi?a=b HTTP/1.0
Host: example.com
Content-Type: application/x-www-form-urlencoded
Content-Length: 15
postfoo=postbar`
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
}
expectedMap := map[string]string{
"test": "Hello CGI",
"param-postfoo": "postbar",
"env-REQUEST_METHOD": "POST",
"env-CONTENT_LENGTH": "15",
"env-REQUEST_URI": "/test.cgi?a=b",
}
runCgiTest(t, h, postReq, expectedMap)
}
func chunk(s string) string {
return fmt.Sprintf("%x\r\n%s\r\n", len(s), s)
}
// The CGI spec doesn't allow chunked requests.
func TestCGIPostChunked(t *testing.T) {
check(t)
postReq := `POST /test.cgi?a=b HTTP/1.1
Host: example.com
Content-Type: application/x-www-form-urlencoded
Transfer-Encoding: chunked
` + chunk("postfoo") + chunk("=") + chunk("postbar") + chunk("")
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
}
expectedMap := map[string]string{}
resp := runCgiTest(t, h, postReq, expectedMap)
if got, expected := resp.Code, http.StatusBadRequest; got != expected {
t.Fatalf("Expected %v response code from chunked request body; got %d",
expected, got)
}
}
func TestRedirect(t *testing.T) {
check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
}
rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil)
if e, g := 302, rec.Code; e != g {
t.Errorf("expected status code %d; got %d", e, g)
}
if e, g := "http://foo.com/", rec.Header().Get("Location"); e != g {
t.Errorf("expected Location header of %q; got %q", e, g)
}
}
func TestInternalRedirect(t *testing.T) {
check(t)
baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path)
fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr)
})
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
PathLocationHandler: baseHandler,
}
expectedMap := map[string]string{
"basepath": "/foo",
"remoteaddr": "1.2.3.4:1234",
}
runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
// TestCopyError tests that we kill the process if there's an error copying
// its output. (for example, from the client having gone away)
func TestCopyError(t *testing.T) {
check(t)
if runtime.GOOS == "windows" {
t.Skipf("skipping test on %q", runtime.GOOS)
}
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
}
ts := httptest.NewServer(h)
defer ts.Close()
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
req, _ := http.NewRequest("GET", "http://example.com/test.cgi?bigresponse=1", nil)
err = req.Write(conn)
if err != nil {
t.Fatalf("Write: %v", err)
}
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("ReadResponse: %v", err)
}
pidstr := res.Header.Get("X-CGI-Pid")
if pidstr == "" {
t.Fatalf("expected an X-CGI-Pid header in response")
}
pid, err := strconv.Atoi(pidstr)
if err != nil {
t.Fatalf("invalid X-CGI-Pid value")
}
var buf [5000]byte
n, err := io.ReadFull(res.Body, buf[:])
if err != nil {
t.Fatalf("ReadFull: %d bytes, %v", n, err)
}
childRunning := func() bool {
return isProcessRunning(pid)
}
if !childRunning() {
t.Fatalf("pre-conn.Close, expected child to be running")
}
conn.Close()
tries := 0
for tries < 25 && childRunning() {
time.Sleep(50 * time.Millisecond * time.Duration(tries))
tries++
}
if childRunning() {
t.Fatalf("post-conn.Close, expected child to be gone")
}
}
func TestDirUnix(t *testing.T) {
check(t)
if runtime.GOOS == "windows" {
t.Skipf("skipping test on %q", runtime.GOOS)
}
cwd, _ := os.Getwd()
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
Dir: cwd,
}
expectedMap := map[string]string{
"cwd": cwd,
}
runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
cwd, _ = os.Getwd()
cwd = filepath.Join(cwd, "testdata")
h = &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
}
expectedMap = map[string]string{
"cwd": cwd,
}
runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
func findPerl(t *testing.T) string {
t.Helper()
perl, err := exec.LookPath("perl")
if err != nil {
t.Skip("Skipping test: perl not found.")
}
perl, _ = filepath.Abs(perl)
cmd := exec.Command(perl, "-e", "print 123")
cmd.Env = []string{"PATH=/garbage"}
out, err := cmd.Output()
if err != nil || string(out) != "123" {
t.Skipf("Skipping test: %s is not functional", perl)
}
return perl
}
func TestDirWindows(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip("Skipping windows specific test.")
}
cgifile, _ := filepath.Abs("testdata/test.cgi")
perl := findPerl(t)
cwd, _ := os.Getwd()
h := &Handler{
Path: perl,
Root: "/test.cgi",
Dir: cwd,
Args: []string{cgifile},
Env: []string{"SCRIPT_FILENAME=" + cgifile},
}
expectedMap := map[string]string{
"cwd": cwd,
}
runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
// If not specify Dir on windows, working directory should be
// base directory of perl.
cwd, _ = filepath.Split(perl)
if cwd != "" && cwd[len(cwd)-1] == filepath.Separator {
cwd = cwd[:len(cwd)-1]
}
h = &Handler{
Path: perl,
Root: "/test.cgi",
Args: []string{cgifile},
Env: []string{"SCRIPT_FILENAME=" + cgifile},
}
expectedMap = map[string]string{
"cwd": cwd,
}
runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
func TestEnvOverride(t *testing.T) {
check(t)
cgifile, _ := filepath.Abs("testdata/test.cgi")
perl := findPerl(t)
cwd, _ := os.Getwd()
h := &Handler{
Path: perl,
Root: "/test.cgi",
Dir: cwd,
Args: []string{cgifile},
Env: []string{
"SCRIPT_FILENAME=" + cgifile,
"REQUEST_URI=/foo/bar",
"PATH=/wibble"},
}
expectedMap := map[string]string{
"cwd": cwd,
"env-SCRIPT_FILENAME": cgifile,
"env-REQUEST_URI": "/foo/bar",
"env-PATH": "/wibble",
}
runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
func TestHandlerStderr(t *testing.T) {
check(t)
var stderr bytes.Buffer
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
Stderr: &stderr,
}
rw := httptest.NewRecorder()
req := newRequest("GET /test.cgi?writestderr=1 HTTP/1.0\nHost: example.com\n\n")
h.ServeHTTP(rw, req)
if got, want := stderr.String(), "Hello, stderr!\n"; got != want {
t.Errorf("Stderr = %q; want %q", got, want)
}
}
func TestRemoveLeadingDuplicates(t *testing.T) {
tests := []struct {
env []string
want []string
}{
{
env: []string{"a=b", "b=c", "a=b2"},
want: []string{"b=c", "a=b2"},
},
{
env: []string{"a=b", "b=c", "d", "e=f"},
want: []string{"a=b", "b=c", "d", "e=f"},
},
}
for _, tt := range tests {
got := removeLeadingDuplicates(tt.env)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("removeLeadingDuplicates(%q) = %q; want %q", tt.env, got, tt.want)
}
}
}

View File

@ -0,0 +1,295 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Tests a Go CGI program running under a Go CGI host process.
// Further, the two programs are the same binary, just checking
// their environment to figure out what mode to run in.
package cgi
import (
"bytes"
"errors"
"fmt"
"github.com/lesismal/llib/std/net/http/httptest"
"internal/testenv"
"io"
"net/http"
"net/url"
"os"
"strings"
"testing"
"time"
)
// This test is a CGI host (testing host.go) that runs its own binary
// as a child process testing the other half of CGI (child.go).
func TestHostingOurselves(t *testing.T) {
testenv.MustHaveExec(t)
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
Args: []string{"-test.run=TestBeChildCGIProcess"},
}
expectedMap := map[string]string{
"test": "Hello CGI-in-CGI",
"param-a": "b",
"param-foo": "bar",
"env-GATEWAY_INTERFACE": "CGI/1.1",
"env-HTTP_HOST": "example.com",
"env-PATH_INFO": "",
"env-QUERY_STRING": "foo=bar&a=b",
"env-REMOTE_ADDR": "1.2.3.4",
"env-REMOTE_HOST": "1.2.3.4",
"env-REMOTE_PORT": "1234",
"env-REQUEST_METHOD": "GET",
"env-REQUEST_URI": "/test.go?foo=bar&a=b",
"env-SCRIPT_FILENAME": os.Args[0],
"env-SCRIPT_NAME": "/test.go",
"env-SERVER_NAME": "example.com",
"env-SERVER_PORT": "80",
"env-SERVER_SOFTWARE": "go",
}
replay := runCgiTest(t, h, "GET /test.go?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
if expected, got := "text/plain; charset=utf-8", replay.Header().Get("Content-Type"); got != expected {
t.Errorf("got a Content-Type of %q; expected %q", got, expected)
}
if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
}
}
type customWriterRecorder struct {
w io.Writer
*httptest.ResponseRecorder
}
func (r *customWriterRecorder) Write(p []byte) (n int, err error) {
return r.w.Write(p)
}
type limitWriter struct {
w io.Writer
n int
}
func (w *limitWriter) Write(p []byte) (n int, err error) {
if len(p) > w.n {
p = p[:w.n]
}
if len(p) > 0 {
n, err = w.w.Write(p)
w.n -= n
}
if w.n == 0 {
err = errors.New("past write limit")
}
return
}
// If there's an error copying the child's output to the parent, test
// that we kill the child.
func TestKillChildAfterCopyError(t *testing.T) {
testenv.MustHaveExec(t)
defer func() { testHookStartProcess = nil }()
proc := make(chan *os.Process, 1)
testHookStartProcess = func(p *os.Process) {
proc <- p
}
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
Args: []string{"-test.run=TestBeChildCGIProcess"},
}
req, _ := http.NewRequest("GET", "http://example.com/test.cgi?write-forever=1", nil)
rec := httptest.NewRecorder()
var out bytes.Buffer
const writeLen = 50 << 10
rw := &customWriterRecorder{&limitWriter{&out, writeLen}, rec}
donec := make(chan bool, 1)
go func() {
h.ServeHTTP(rw, req)
donec <- true
}()
select {
case <-donec:
if out.Len() != writeLen || out.Bytes()[0] != 'a' {
t.Errorf("unexpected output: %q", out.Bytes())
}
case <-time.After(5 * time.Second):
t.Errorf("timeout. ServeHTTP hung and didn't kill the child process?")
select {
case p := <-proc:
p.Kill()
t.Logf("killed process")
default:
t.Logf("didn't kill process")
}
}
}
// Test that a child handler writing only headers works.
// golang.org/issue/7196
func TestChildOnlyHeaders(t *testing.T) {
testenv.MustHaveExec(t)
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
Args: []string{"-test.run=TestBeChildCGIProcess"},
}
expectedMap := map[string]string{
"_body": "",
}
replay := runCgiTest(t, h, "GET /test.go?no-body=1 HTTP/1.0\nHost: example.com\n\n", expectedMap)
if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
}
}
// Test that a child handler does not receive a nil Request Body.
// golang.org/issue/39190
func TestNilRequestBody(t *testing.T) {
testenv.MustHaveExec(t)
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
Args: []string{"-test.run=TestBeChildCGIProcess"},
}
expectedMap := map[string]string{
"nil-request-body": "false",
}
_ = runCgiTest(t, h, "POST /test.go?nil-request-body=1 HTTP/1.0\nHost: example.com\n\n", expectedMap)
_ = runCgiTest(t, h, "POST /test.go?nil-request-body=1 HTTP/1.0\nHost: example.com\nContent-Length: 0\n\n", expectedMap)
}
func TestChildContentType(t *testing.T) {
testenv.MustHaveExec(t)
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
Args: []string{"-test.run=TestBeChildCGIProcess"},
}
var tests = []struct {
name string
body string
wantCT string
}{
{
name: "no body",
wantCT: "text/plain; charset=utf-8",
},
{
name: "html",
body: "<html><head><title>test page</title></head><body>This is a body</body></html>",
wantCT: "text/html; charset=utf-8",
},
{
name: "text",
body: strings.Repeat("gopher", 86),
wantCT: "text/plain; charset=utf-8",
},
{
name: "jpg",
body: "\xFF\xD8\xFF" + strings.Repeat("B", 1024),
wantCT: "image/jpeg",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
expectedMap := map[string]string{"_body": tt.body}
req := fmt.Sprintf("GET /test.go?exact-body=%s HTTP/1.0\nHost: example.com\n\n", url.QueryEscape(tt.body))
replay := runCgiTest(t, h, req, expectedMap)
if got := replay.Header().Get("Content-Type"); got != tt.wantCT {
t.Errorf("got a Content-Type of %q; expected it to start with %q", got, tt.wantCT)
}
})
}
}
// golang.org/issue/7198
func Test500WithNoHeaders(t *testing.T) { want500Test(t, "/immediate-disconnect") }
func Test500WithNoContentType(t *testing.T) { want500Test(t, "/no-content-type") }
func Test500WithEmptyHeaders(t *testing.T) { want500Test(t, "/empty-headers") }
func want500Test(t *testing.T, path string) {
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
Args: []string{"-test.run=TestBeChildCGIProcess"},
}
expectedMap := map[string]string{
"_body": "",
}
replay := runCgiTest(t, h, "GET "+path+" HTTP/1.0\nHost: example.com\n\n", expectedMap)
if replay.Code != 500 {
t.Errorf("Got code %d; want 500", replay.Code)
}
}
type neverEnding byte
func (b neverEnding) Read(p []byte) (n int, err error) {
for i := range p {
p[i] = byte(b)
}
return len(p), nil
}
// Note: not actually a test.
func TestBeChildCGIProcess(t *testing.T) {
if os.Getenv("REQUEST_METHOD") == "" {
// Not in a CGI environment; skipping test.
return
}
switch os.Getenv("REQUEST_URI") {
case "/immediate-disconnect":
os.Exit(0)
case "/no-content-type":
fmt.Printf("Content-Length: 6\n\nHello\n")
os.Exit(0)
case "/empty-headers":
fmt.Printf("\nHello")
os.Exit(0)
}
Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.FormValue("nil-request-body") == "1" {
fmt.Fprintf(rw, "nil-request-body=%v\n", req.Body == nil)
return
}
rw.Header().Set("X-Test-Header", "X-Test-Value")
req.ParseForm()
if req.FormValue("no-body") == "1" {
return
}
if eb, ok := req.Form["exact-body"]; ok {
io.WriteString(rw, eb[0])
return
}
if req.FormValue("write-forever") == "1" {
io.Copy(rw, neverEnding('a'))
for {
time.Sleep(5 * time.Second) // hang forever, until killed
}
}
fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n")
for k, vv := range req.Form {
for _, v := range vv {
fmt.Fprintf(rw, "param-%s=%s\n", k, v)
}
}
for _, kv := range os.Environ() {
fmt.Fprintf(rw, "env-%s\n", kv)
}
}))
os.Exit(0)
}

View File

@ -0,0 +1,17 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build plan9
package cgi
import (
"os"
"strconv"
)
func isProcessRunning(pid int) bool {
_, err := os.Stat("/proc/" + strconv.Itoa(pid))
return err == nil
}

View File

@ -0,0 +1,20 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !plan9
package cgi
import (
"os"
"syscall"
)
func isProcessRunning(pid int) bool {
p, err := os.FindProcess(pid)
if err != nil {
return false
}
return p.Signal(syscall.Signal(0)) == nil
}

Some files were not shown because too many files have changed in this diff Show More