diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index e99eaa3..80e54f2 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -277,6 +277,9 @@ func Blockchain_Start(params map[string]interface{}) (*Blockchain, error) { func (chain *Blockchain) IntegratorAddress() rpc.Address { return chain.integrator_address } +func (chain *Blockchain) SetIntegratorAddress(addr rpc.Address) { + chain.integrator_address = addr +} // this function is called to read blockchain state from DB // It is callable at any point in time @@ -615,6 +618,11 @@ func (chain *Blockchain) Add_Complete_Block(cbl *block.Complete_Block) (err erro block_logger.Error(fmt.Errorf("Double Registration TX"), "duplicate registration", "txid", cbl.Txs[i].GetHash()) return errormsg.ErrTXDoubleSpend, false } + + tx_hash := cbl.Txs[i].GetHash() + if chain.simulator == false && tx_hash[0] != 0 && tx_hash[1] != 0 { + return fmt.Errorf("Registration TX has not solved PoW"), false + } reg_map[string(cbl.Txs[i].MinerAddress[:])] = true } } @@ -1112,6 +1120,12 @@ func (chain *Blockchain) Add_TX_To_Pool(tx *transaction.Transaction) error { return fmt.Errorf("premine tx not mineable") } if tx.IsRegistration() { // registration tx will not go any forward + + tx_hash := tx.GetHash() + if chain.simulator == false && tx_hash[0] != 0 && tx_hash[1] != 0 { + return fmt.Errorf("TX doesn't solve Pow") + } + // ggive regpool a chance to register if ss, err := chain.Store.Balance_store.LoadSnapshot(0); err == nil { if balance_tree, err := ss.GetTree(config.BALANCE_TREE); err == nil { diff --git a/blockchain/difficulty.go b/blockchain/difficulty.go index 63fcb85..b48e608 100644 --- a/blockchain/difficulty.go +++ b/blockchain/difficulty.go @@ -191,17 +191,19 @@ type DiffProvider interface { func Get_Difficulty_At_Tips(source DiffProvider, tips []crypto.Hash) *big.Int { var MinimumDifficulty *big.Int + GenesisDifficulty := new(big.Int).SetUint64(1) if globals.IsMainnet() { MinimumDifficulty = new(big.Int).SetUint64(config.Settings.MAINNET_MINIMUM_DIFFICULTY) // this must be controllable parameter + GenesisDifficulty = new(big.Int).SetUint64(config.Settings.MAINNET_BOOTSTRAP_DIFFICULTY) } else { MinimumDifficulty = new(big.Int).SetUint64(config.Settings.TESTNET_MINIMUM_DIFFICULTY) // this must be controllable parameter + GenesisDifficulty = new(big.Int).SetUint64(config.Settings.TESTNET_BOOTSTRAP_DIFFICULTY) } - GenesisDifficulty := new(big.Int).SetUint64(1) if chain, ok := source.(*Blockchain); ok { if chain.simulator == true { - return GenesisDifficulty + return new(big.Int).SetUint64(1) } } @@ -225,7 +227,7 @@ func Get_Difficulty_At_Tips(source DiffProvider, tips []crypto.Hash) *big.Int { // until we have atleast 2 blocks, we cannot run the algo if height < 3 { - return MinimumDifficulty + return GenesisDifficulty } tip_difficulty := source.Load_Block_Difficulty(tips[0]) diff --git a/blockchain/hardcoded_contracts.go b/blockchain/hardcoded_contracts.go index e0d4e75..a15b5b7 100644 --- a/blockchain/hardcoded_contracts.go +++ b/blockchain/hardcoded_contracts.go @@ -58,12 +58,14 @@ func (chain *Blockchain) install_hardcoded_contracts(cache map[crypto.Hash]*grav if _, _, err = dvm.ParseSmartContract(source_nameservice); err != nil { logger.Error(err, "error Parsing hard coded sc") + panic(err) return } var name crypto.Hash name[31] = 1 if err = chain.install_hardcoded_sc(cache, ss, balance_tree, sc_tree, source_nameservice, name); err != nil { + panic(err) return } diff --git a/blockchain/storefs.go b/blockchain/storefs.go index 57ea3a5..3b437e8 100644 --- a/blockchain/storefs.go +++ b/blockchain/storefs.go @@ -30,6 +30,14 @@ type storefs struct { basedir string } +// TODO we need to enable big support or shift to object store at some point in time +func (s *storefs) getpath(h [32]byte) string { + // if you wish to use 3 level indirection, it will cause 16 million inodes to be used, but system will be faster + //return filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) + // currently we are settling on 65536 inodes + return filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1])) +} + // the filename stores the following information // hex block id (64 chars).block._ rewards (decimal) _ difficulty _ cumulative difficulty @@ -40,7 +48,7 @@ func (s *storefs) ReadBlock(h [32]byte) ([]byte, error) { return nil, fmt.Errorf("empty block") } - dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) + dir := s.getpath(h) files, err := os.ReadDir(dir) if err != nil { @@ -51,7 +59,7 @@ func (s *storefs) ReadBlock(h [32]byte) ([]byte, error) { for _, file := range files { if strings.HasPrefix(file.Name(), filename_start) { //fmt.Printf("Reading block with filename %s\n", file.Name()) - file := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2]), file.Name()) + file := filepath.Join(dir, file.Name()) return os.ReadFile(file) } } @@ -61,7 +69,7 @@ func (s *storefs) ReadBlock(h [32]byte) ([]byte, error) { // on windows, we see an odd behaviour where some files could not be deleted, since they may exist only in cache func (s *storefs) DeleteBlock(h [32]byte) error { - dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) + dir := s.getpath(h) files, err := os.ReadDir(dir) if err != nil { @@ -72,7 +80,7 @@ func (s *storefs) DeleteBlock(h [32]byte) error { var found bool for _, file := range files { if strings.HasPrefix(file.Name(), filename_start) { - file := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2]), file.Name()) + file := filepath.Join(dir, file.Name()) err = os.Remove(file) if err != nil { //return err @@ -88,7 +96,7 @@ func (s *storefs) DeleteBlock(h [32]byte) error { } func (s *storefs) ReadBlockDifficulty(h [32]byte) (*big.Int, error) { - dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) + dir := s.getpath(h) files, err := os.ReadDir(dir) if err != nil { @@ -122,7 +130,7 @@ func (chain *Blockchain) ReadBlockSnapshotVersion(h [32]byte) (uint64, error) { return chain.Store.Block_tx_store.ReadBlockSnapshotVersion(h) } func (s *storefs) ReadBlockSnapshotVersion(h [32]byte) (uint64, error) { - dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) + dir := s.getpath(h) files, err := os.ReadDir(dir) // this always returns the sorted list if err != nil { @@ -167,7 +175,7 @@ func (chain *Blockchain) ReadBlockHeight(h [32]byte) (uint64, error) { } func (s *storefs) ReadBlockHeight(h [32]byte) (uint64, error) { - dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) + dir := s.getpath(h) files, err := os.ReadDir(dir) if err != nil { @@ -194,7 +202,7 @@ func (s *storefs) ReadBlockHeight(h [32]byte) (uint64, error) { } func (s *storefs) WriteBlock(h [32]byte, data []byte, difficulty *big.Int, ss_version uint64, height uint64) (err error) { - dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) + dir := s.getpath(h) file := filepath.Join(dir, fmt.Sprintf("%x.block_%s_%d_%d", h[:], difficulty.String(), ss_version, height)) if err = os.MkdirAll(dir, 0700); err != nil { return err @@ -203,12 +211,13 @@ func (s *storefs) WriteBlock(h [32]byte, data []byte, difficulty *big.Int, ss_ve } func (s *storefs) ReadTX(h [32]byte) ([]byte, error) { - file := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2]), fmt.Sprintf("%x.tx", h[:])) + dir := s.getpath(h) + file := filepath.Join(dir, fmt.Sprintf("%x.tx", h[:])) return ioutil.ReadFile(file) } func (s *storefs) WriteTX(h [32]byte, data []byte) (err error) { - dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) + dir := s.getpath(h) file := filepath.Join(dir, fmt.Sprintf("%x.tx", h[:])) if err = os.MkdirAll(dir, 0700); err != nil { @@ -219,7 +228,7 @@ func (s *storefs) WriteTX(h [32]byte, data []byte) (err error) { } func (s *storefs) DeleteTX(h [32]byte) (err error) { - dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) + dir := s.getpath(h) file := filepath.Join(dir, fmt.Sprintf("%x.tx", h[:])) return os.Remove(file) } diff --git a/blockchain/transaction_execute.go b/blockchain/transaction_execute.go index abaf6c6..d7812ae 100644 --- a/blockchain/transaction_execute.go +++ b/blockchain/transaction_execute.go @@ -70,10 +70,15 @@ func (chain *Blockchain) process_miner_transaction(bl *block.Block, genesis bool if genesis == true { // process premine ,register genesis block, dev key balance := crypto.ConstructElGamal(acckey.G1(), crypto.ElGamal_BASE_G) // init zero balance - balance = balance.Plus(new(big.Int).SetUint64(tx.Value << 1)) // add premine to users balance homomorphically + balance = balance.Plus(new(big.Int).SetUint64(tx.Value)) // add premine to users balance homomorphically nb := crypto.NonceBalance{NonceHeight: 0, Balance: balance} balance_tree.Put(tx.MinerAddress[:], nb.Serialize()) // reserialize and store + if globals.IsMainnet() { + return + } + + // only testnet/simulator will have dummy accounts to test // we must process premine list and register and give them balance, premine_count := 0 scanner := bufio.NewScanner(strings.NewReader(premine.List)) @@ -119,12 +124,15 @@ func (chain *Blockchain) process_miner_transaction(bl *block.Block, genesis bool base_reward := CalcBlockReward(uint64(height)) full_reward := base_reward + fees - //full_reward is divided into equal parts for all miner blocks + miner address + integrator_reward := full_reward * 167 / 10000 + + //full_reward is divided into equal parts for all miner blocks + // integrator only gets 1.67 % of block reward // since perfect division is not possible, ( see money handling) // any left over change is delivered to main miner who integrated the full block - share := full_reward / uint64(len(bl.MiniBlocks)) // one block integrator, this is integer division - leftover := full_reward - (share * uint64(len(bl.MiniBlocks))) // only integrator will get this + share := (full_reward - integrator_reward) / uint64(len(bl.MiniBlocks)) // one block integrator, this is integer division + leftover := full_reward - integrator_reward - (share * uint64(len(bl.MiniBlocks))) // only integrator will get this { // giver integrator his reward balance_serialized, err := balance_tree.Get(tx.MinerAddress[:]) @@ -132,8 +140,8 @@ func (chain *Blockchain) process_miner_transaction(bl *block.Block, genesis bool panic(err) } nb := new(crypto.NonceBalance).Deserialize(balance_serialized) - nb.Balance = nb.Balance.Plus(new(big.Int).SetUint64(share + leftover)) // add miners reward to miners balance homomorphically - balance_tree.Put(tx.MinerAddress[:], nb.Serialize()) // reserialize and store + nb.Balance = nb.Balance.Plus(new(big.Int).SetUint64(integrator_reward + leftover)) // add miners reward to miners balance homomorphically + balance_tree.Put(tx.MinerAddress[:], nb.Serialize()) // reserialize and store } // all the other miniblocks will get their share @@ -230,7 +238,6 @@ func (chain *Blockchain) process_transaction(changed map[crypto.Hash]*graviton.T nb.NonceHeight = height } tree.Put(key_compressed, nb.Serialize()) // reserialize and store - } } diff --git a/cmd/dero-wallet-cli/easymenu_post_open.go b/cmd/dero-wallet-cli/easymenu_post_open.go index 89d672e..a3519a1 100644 --- a/cmd/dero-wallet-cli/easymenu_post_open.go +++ b/cmd/dero-wallet-cli/easymenu_post_open.go @@ -17,15 +17,15 @@ package main import "io" - +import "os" import "time" import "fmt" +import "errors" -//import "io/ioutil" import "strings" -//import "path/filepath" -//import "encoding/hex" +import "path/filepath" +import "encoding/json" import "github.com/chzyer/readline" @@ -64,6 +64,7 @@ func display_easymenu_post_open_command(l *readline.Instance) { io.WriteString(w, "\t\033[1m12\033[0m\tTransfer all balance (send DERO) To Another Wallet\n") io.WriteString(w, "\t\033[1m13\033[0m\tShow transaction history\n") io.WriteString(w, "\t\033[1m14\033[0m\tRescan transaction history\n") + io.WriteString(w, "\t\033[1m15\033[0m\tExport all transaction history in json format\n") } io.WriteString(w, "\n\t\033[1m9\033[0m\tExit menu and start prompt\n") @@ -127,10 +128,21 @@ func handle_easymenu_post_open_command(l *readline.Instance, line string) (proce fmt.Fprintf(l.Stderr(), "Wallet address : "+color_green+"%s"+color_white+" is going to be registered.This is a pre-condition for using the online chain.It will take few seconds to register.\n", wallet.GetAddress()) - reg_tx := wallet.GetRegistrationTX() - // at this point we must send the registration transaction fmt.Fprintf(l.Stderr(), "Wallet address : "+color_green+"%s"+color_white+" is going to be registered.Pls wait till the account is registered.\n", wallet.GetAddress()) + + fmt.Fprintf(l.Stderr(), "This will take a couple of minutes.Please wait....\n") + + var reg_tx *transaction.Transaction + for { + + reg_tx = wallet.GetRegistrationTX() + hash := reg_tx.GetHash() + + if hash[0] == 0 && hash[1] == 0 { + break + } + } fmt.Fprintf(l.Stderr(), "Registration TXID %s\n", reg_tx.GetHash()) err := wallet.SendTransaction(reg_tx) if err != nil { @@ -243,6 +255,7 @@ func handle_easymenu_post_open_command(l *readline.Instance, line string) (proce if a.Arguments.Has(rpc.RPC_COMMENT, rpc.DataString) { // but only it is present logger.Info("Integrated Message", "comment", a.Arguments.Value(rpc.RPC_COMMENT, rpc.DataString)) + arguments = append(arguments, rpc.Argument{rpc.RPC_COMMENT, rpc.DataString, a.Arguments.Value(rpc.RPC_COMMENT, rpc.DataString)}) } } @@ -291,10 +304,12 @@ func handle_easymenu_post_open_command(l *readline.Instance, line string) (proce amount_to_transfer = a.Arguments.Value(rpc.RPC_VALUE_TRANSFER, rpc.DataUint64).(uint64) } else { - amount_str := read_line_with_prompt(l, fmt.Sprintf("Enter amount to transfer in DERO (max TODO): ")) + mbal, _ := wallet.Get_Balance() + amount_str := read_line_with_prompt(l, fmt.Sprintf("Enter amount to transfer in DERO (current balance %s): ", globals.FormatMoney(mbal))) if amount_str == "" { - amount_str = ".00001" + logger.Error(nil, "Cannot transfer 0") + break // invalid amount provided, bail out } amount_to_transfer, err = globals.ParseAmount(amount_str) if err != nil { @@ -315,8 +330,15 @@ func handle_easymenu_post_open_command(l *readline.Instance, line string) (proce // if no arguments, use space by embedding a small comment if len(arguments) == 0 { // allow user to enter Comment + if v, err := ReadUint64(l, "Please enter payment id (or destination port number)", uint64(0)); err == nil { + arguments = append(arguments, rpc.Argument{Name: rpc.RPC_DESTINATION_PORT, DataType: rpc.DataUint64, Value: v}) + } else { + logger.Error(err, fmt.Sprintf("%s could not be parsed (type %s),", "Number", rpc.DataUint64)) + return + } + if v, err := ReadString(l, "Comment", ""); err == nil { - arguments = append(arguments, rpc.Argument{Name: "Comment", DataType: rpc.DataString, Value: v}) + arguments = append(arguments, rpc.Argument{Name: rpc.RPC_COMMENT, DataType: rpc.DataString, Value: v}) } else { logger.Error(fmt.Errorf("%s could not be parsed (type %s),", "Comment", rpc.DataString), "") return @@ -429,6 +451,34 @@ func handle_easymenu_post_open_command(l *readline.Instance, line string) (proce case "14": logger.Info("Rescanning wallet history") rescan_bc(wallet) + case "15": + if !ValidateCurrentPassword(l, wallet) { + logger.Error(fmt.Errorf("Invalid password"), "") + break + } + + if _, err := os.Stat("./history"); errors.Is(err, os.ErrNotExist) { + if err := os.Mkdir("./history", 0700); err != nil { + logger.Error(err, "Error creating directory") + break + } + } + + var zeroscid crypto.Hash + account := wallet.GetAccount() + for k, v := range account.EntriesNative { + filename := filepath.Join("./history", k.String()+".json") + if k == zeroscid { + filename = filepath.Join("./history", "dero.json") + } + if data, err := json.Marshal(v); err != nil { + logger.Error(err, "Error exporting data") + } else if err = os.WriteFile(filename, data, 0600); err != nil { + logger.Error(err, "Error exporting data") + } else { + logger.Info("successfully exported history", "file", filename) + } + } default: processed = false // just loop diff --git a/cmd/dero-wallet-cli/main.go b/cmd/dero-wallet-cli/main.go index 3ff71d5..18524a7 100644 --- a/cmd/dero-wallet-cli/main.go +++ b/cmd/dero-wallet-cli/main.go @@ -80,6 +80,7 @@ Usage: --rpc-server Run rpc server, so wallet is accessible using api --rpc-bind=<127.0.0.1:20209> Wallet binds on this ip address and port --rpc-login= RPC server will grant access based on these credentials + --allow-rpc-password-change RPC server will change password if you send "Pass" header with new password ` var menu_mode bool = true // default display menu mode //var account_valid bool = false // if an account has been opened, do not allow to create new account in this session @@ -119,7 +120,7 @@ func main() { } // init the lookup table one, anyone importing walletapi should init this first, this will take around 1 sec on any recent system - walletapi.Initialize_LookupTable(1, 1<<17) + walletapi.Initialize_LookupTable(1, 1<<21) // We need to initialize readline first, so it changes stderr to ansi processor on windows l, err := readline.NewEx(&readline.Config{ diff --git a/cmd/dero-wallet-cli/prompt.go b/cmd/dero-wallet-cli/prompt.go index 6a22fd8..80de635 100644 --- a/cmd/dero-wallet-cli/prompt.go +++ b/cmd/dero-wallet-cli/prompt.go @@ -25,6 +25,7 @@ import "time" //import "io/ioutil" //import "path/filepath" import "strings" +import "unicode" import "strconv" import "encoding/hex" @@ -39,6 +40,15 @@ import "github.com/deroproject/derohe/cryptography/crypto" var account walletapi.Account +func isASCII(s string) bool { + for _, c := range s { + if c > unicode.MaxASCII { + return false + } + } + return true +} + // handle all commands while in prompt mode func handle_prompt_command(l *readline.Instance, line string) { @@ -126,6 +136,54 @@ func handle_prompt_command(l *readline.Instance, line string) { case "spendkey": // give user his spend key display_spend_key(l, wallet) + case "filesign": // sign a file contents + if !ValidateCurrentPassword(l, wallet) { + logger.Error(err, "Invalid password") + PressAnyKey(l, wallet) + break + } + + filename, err := ReadString(l, "Enter file to sign", "") + if err != nil { + logger.Error(err, "Cannot read input file name") + } + + outputfile := filename + ".sign" + + if filedata, err := os.ReadFile(filename); err != nil { + logger.Error(err, "Cannot read input file") + } else if err := os.WriteFile(outputfile, wallet.SignData(filedata), 0600); err != nil { + logger.Error(err, "Cannot write output file", "file", outputfile) + } else { + logger.Info("successfully signed file. please check", "file", outputfile) + } + + case "fileverify": // verify a file contents + filename, err := ReadString(l, "Enter file to verify signature", "") + if err != nil { + logger.Error(err, "Cannot read input file name") + } + + outputfile := strings.TrimSuffix(filename, ".sign") + + if filedata, err := os.ReadFile(filename); err != nil { + logger.Error(err, "Cannot read input file") + } else if signer, message, err := wallet.CheckSignature(filedata); err != nil { + logger.Error(err, "Signature verify failed", "file", filename) + } else { + logger.Info("Signed by", "address", signer.String()) + + if isASCII(string(message)) { // do not spew garbage + logger.Info("", "message", string(message)) + } + + if os.WriteFile(outputfile, message, 0600); err != nil { + logger.Error(err, "Cannot write output file", "file", outputfile) + } + logger.Info("successfully wrote message to file. please check", "file", outputfile) + + } + case "password": // change wallet password if ConfirmYesNoDefaultNo(l, "Change wallet password (y/N)") && ValidateCurrentPassword(l, wallet) { @@ -525,6 +583,10 @@ func ReadUint64(l *readline.Instance, cprompt string, default_value uint64) (a u error_message := "" color := color_green + if len(line) == 0 { + line = []rune(fmt.Sprintf("%d", default_value)) + } + if len(line) >= 1 { _, err := strconv.ParseUint(string(line), 0, 64) if err != nil { @@ -548,6 +610,9 @@ func ReadUint64(l *readline.Instance, cprompt string, default_value uint64) (a u if err != nil { return } + if len(line) == 0 { + line = []byte(fmt.Sprintf("%d", default_value)) + } a, err = strconv.ParseUint(string(line), 0, 64) l.SetPrompt(cprompt) l.Refresh() @@ -800,6 +865,8 @@ var completer = readline.NewPrefixCompleter( readline.PcItem("balance"), readline.PcItem("integrated_address"), readline.PcItem("get_tx_key"), + readline.PcItem("filesign"), + readline.PcItem("fileverify"), readline.PcItem("menu"), readline.PcItem("rescan_bc"), readline.PcItem("payment_id"), @@ -817,7 +884,6 @@ var completer = readline.NewPrefixCompleter( readline.PcItem("version"), readline.PcItem("transfer"), readline.PcItem("transfer_all"), - readline.PcItem("walletviewkey"), readline.PcItem("bye"), readline.PcItem("exit"), readline.PcItem("quit"), diff --git a/cmd/derod/main.go b/cmd/derod/main.go index 9d79d0f..c9adbb1 100644 --- a/cmd/derod/main.go +++ b/cmd/derod/main.go @@ -281,7 +281,9 @@ func main() { } testnet_string := "" - if !globals.IsMainnet() { + if globals.IsMainnet() { + testnet_string = "\033[31m MAINNET" + } else { testnet_string = "\033[31m TESTNET" } @@ -406,6 +408,19 @@ restart_loop: memoryfile.Close() */ + case command == "setintegratoraddress": + if len(line_parts) != 2 { + logger.Error(fmt.Errorf("This function requires 1 parameters, dero address"), "") + continue + } + if addr, err := rpc.NewAddress(line_parts[1]); err != nil { + logger.Error(err, "invalid address") + continue + } else { + chain.SetIntegratorAddress(*addr) + logger.Info("will use", "integrator_address", chain.IntegratorAddress().String()) + } + case command == "print_bc": logger.Info("printing block chain") @@ -1023,6 +1038,8 @@ func usage(w io.Writer) { io.WriteString(w, "\t\033[1mregpool_print\033[0m\t\tprint regpool contents\n") io.WriteString(w, "\t\033[1mregpool_delete_tx\033[0m\t\tDelete specific tx from regpool\n") io.WriteString(w, "\t\033[1mregpool_flush\033[0m\t\tFlush mempool\n") + io.WriteString(w, "\t\033[1msetintegratoraddress\033[0m\t\tChange current integrated address\n") + io.WriteString(w, "\t\033[1mversion\033[0m\t\tShow version\n") io.WriteString(w, "\t\033[1mexit\033[0m\t\tQuit the daemon\n") io.WriteString(w, "\t\033[1mquit\033[0m\t\tQuit the daemon\n") @@ -1046,6 +1063,7 @@ var completer = readline.NewPrefixCompleter( readline.PcItem("block_export"), readline.PcItem("block_import"), // readline.PcItem("print_tx"), + readline.PcItem("setintegratoraddress"), readline.PcItem("status"), readline.PcItem("sync_info"), readline.PcItem("version"), diff --git a/cmd/derod/rpc/rpc_dero_getencryptedbalance.go b/cmd/derod/rpc/rpc_dero_getencryptedbalance.go index a2a7fa9..bf125ae 100644 --- a/cmd/derod/rpc/rpc_dero_getencryptedbalance.go +++ b/cmd/derod/rpc/rpc_dero_getencryptedbalance.go @@ -35,7 +35,6 @@ func GetEncryptedBalance(ctx context.Context, p rpc.GetEncryptedBalance_Params) defer func() { // safety so if anything wrong happens, we return error if r := recover(); r != nil { err = fmt.Errorf("panic occured. stack trace %s", debug.Stack()) - fmt.Printf("panic stack trace %s params %+v\n", debug.Stack(), p) } }() diff --git a/cmd/simulator/blockchain_sim_test.go b/cmd/simulator/blockchain_sim_test.go index 9f7b8d3..d39bb23 100644 --- a/cmd/simulator/blockchain_sim_test.go +++ b/cmd/simulator/blockchain_sim_test.go @@ -77,7 +77,7 @@ func simulator_chain_start() (*blockchain.Blockchain, *derodrpc.RPCServer, map[s OptionsFirst: true, } - globals.Arguments, err = parser.ParseArgs(command_line_test, []string{"--data-dir", tmpdirectory, "--rpc-bind", rpcport_test}, config.Version.String()) + globals.Arguments, err = parser.ParseArgs(command_line_test, []string{"--data-dir", tmpdirectory, "--rpc-bind", rpcport_test, "--testnet"}, config.Version.String()) if err != nil { //log.Fatalf("Error while parsing options err: %s\n", err) return nil, nil, nil diff --git a/config/config.go b/config/config.go index 211f7d5..16c0ef8 100644 --- a/config/config.go +++ b/config/config.go @@ -43,7 +43,7 @@ const SC_META = "M" // keeps all SCs balance, their state, their OWNER, the const MAX_STORAGE_GAS_ATOMIC_UNITS = 20000 // Minimum FEE calculation constants are here -const FEE_PER_KB = uint64(100) // .00100 dero per kb +const FEE_PER_KB = uint64(20) // .00020 dero per kb // we can easily improve TPS by changing few parameters in this file // the resources compute/network may not be easy for the developing countries @@ -56,8 +56,8 @@ const MIN_RINGSIZE = 2 // >= 2 , ringsize will be accepted const MAX_RINGSIZE = 128 // <= 128, ringsize will be accepted type SettingsStruct struct { - MAINNET_BOOTSTRAP_DIFFICULTY uint64 `env:"MAINNET_BOOTSTRAP_DIFFICULTY" envDefault:"80000000"` - MAINNET_MINIMUM_DIFFICULTY uint64 `env:"MAINNET_MINIMUM_DIFFICULTY" envDefault:"80000000"` + MAINNET_BOOTSTRAP_DIFFICULTY uint64 `env:"MAINNET_BOOTSTRAP_DIFFICULTY" envDefault:"10000000"` // mainnet bootstrap is 10 MH/s + MAINNET_MINIMUM_DIFFICULTY uint64 `env:"MAINNET_MINIMUM_DIFFICULTY" envDefault:"100000"` // mainnet minimum is 100 KH/s TESTNET_BOOTSTRAP_DIFFICULTY uint64 `env:"TESTNET_BOOTSTRAP_DIFFICULTY" envDefault:"10000"` TESTNET_MINIMUM_DIFFICULTY uint64 `env:"TESTNET_MINIMUM_DIFFICULTY" envDefault:"10000"` @@ -89,7 +89,7 @@ type CHAIN_CONFIG struct { } var Mainnet = CHAIN_CONFIG{Name: "mainnet", - Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x9a, 0x44, 0x45, 0x0}), + Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x9a, 0x44, 0x40, 0x0}), GETWORK_Default_Port: 10100, P2P_Default_Port: 10101, RPC_Default_Port: 10102, @@ -101,13 +101,13 @@ var Mainnet = CHAIN_CONFIG{Name: "mainnet", "00" + // Source is DERO network "00" + // Dest is DERO network "00" + // PREMINE_FLAG - "8fff7f" + // PREMINE_VALUE + "80a8b9ceb024" + // PREMINE_VALUE "1f9bcc1208dee302769931ad378a4c0c4b2c21b0cfb3e752607e12d2b6fa642500", // miners public key } var Testnet = CHAIN_CONFIG{Name: "testnet", // testnet will always have last 3 bytes 0 - Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x80, 0x00, 0x00, 0x00}), + Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x83, 0x00, 0x00, 0x00}), GETWORK_Default_Port: 10100, P2P_Default_Port: 40401, RPC_Default_Port: 40402, @@ -120,7 +120,7 @@ var Testnet = CHAIN_CONFIG{Name: "testnet", // testnet will always have last 3 b "00" + // Source is DERO network "00" + // Dest is DERO network "00" + // PREMINE_FLAG - "8fff7f" + // PREMINE_VALUE + "80a8b9ceb024" + // PREMINE_VALUE "1f9bcc1208dee302769931ad378a4c0c4b2c21b0cfb3e752607e12d2b6fa642500", // miners public key } diff --git a/config/version.go b/config/version.go index 1010cf6..5b5d12e 100644 --- a/config/version.go +++ b/config/version.go @@ -20,4 +20,4 @@ import "github.com/blang/semver/v4" // right now it has to be manually changed // do we need to include git commitsha?? -var Version = semver.MustParse("3.4.106-0.DEROHE.STARGATE+18012022") +var Version = semver.MustParse("3.4.109-0.DEROHE.STARGATE+18012022") diff --git a/dvm/dvm.go b/dvm/dvm.go index 318b82f..22d92ed 100644 --- a/dvm/dvm.go +++ b/dvm/dvm.go @@ -576,7 +576,7 @@ func (i *DVM_Interpreter) interpret_SmartContract() (err error) { } if i.State.Trace { - fmt.Printf("interpreting line %+v err:'%s'\n", line, err) + fmt.Printf("interpreting line %+v err:'%v'\n", line, err) } if err != nil { err = fmt.Errorf("err while interpreting line %+v err %s\n", line, err) diff --git a/globals/globals.go b/globals/globals.go index 5cef491..28ea5c2 100644 --- a/globals/globals.go +++ b/globals/globals.go @@ -187,8 +187,6 @@ func InitializeLog(console, logfile io.Writer) { func Initialize() { var err error - Arguments["--testnet"] = true // force testnet every where - InitNetwork() // choose socks based proxy if user requested so diff --git a/vendor/github.com/creachadair/jrpc2/.github/workflows/go-presubmit.yml b/vendor/github.com/creachadair/jrpc2/.github/workflows/go-presubmit.yml index 4d9d2c7..4142da2 100644 --- a/vendor/github.com/creachadair/jrpc2/.github/workflows/go-presubmit.yml +++ b/vendor/github.com/creachadair/jrpc2/.github/workflows/go-presubmit.yml @@ -18,7 +18,7 @@ jobs: os: ['ubuntu-latest'] steps: - name: Install Go ${{ matrix.go-version }} - uses: actions/setup-go@v1 + uses: actions/setup-go@v2 with: go-version: ${{ matrix.go-version }} - uses: actions/checkout@v2 diff --git a/vendor/github.com/creachadair/jrpc2/README.md b/vendor/github.com/creachadair/jrpc2/README.md index 757b6c4..4e05523 100644 --- a/vendor/github.com/creachadair/jrpc2/README.md +++ b/vendor/github.com/creachadair/jrpc2/README.md @@ -1,10 +1,9 @@ # jrpc2 [![GoDoc](https://img.shields.io/static/v1?label=godoc&message=reference&color=yellow)](https://pkg.go.dev/github.com/creachadair/jrpc2) -[![Go Report Card](https://goreportcard.com/badge/github.com/creachadair/jrpc2)](https://goreportcard.com/report/github.com/creachadair/jrpc2) -This repository provides Go package that implements a [JSON-RPC 2.0][spec] client and server. -There is also a working [example in the Go playground](https://play.golang.org/p/MSClCk55UzF). +This repository provides a Go module that implements a [JSON-RPC 2.0][spec] client and server. +There is also a working [example in the Go playground](https://go.dev/play/p/fY-Pnvf03Hr). ## Packages diff --git a/vendor/github.com/creachadair/jrpc2/base.go b/vendor/github.com/creachadair/jrpc2/base.go index c595a9c..d99d280 100644 --- a/vendor/github.com/creachadair/jrpc2/base.go +++ b/vendor/github.com/creachadair/jrpc2/base.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2 import ( @@ -17,9 +19,12 @@ type Assigner interface { // The implementation can obtain the complete request from ctx using the // jrpc2.InboundRequest function. Assign(ctx context.Context, method string) Handler +} - // Names returns a slice of all known method names for the assigner. The - // resulting slice is ordered lexicographically and contains no duplicates. +// Namer is an optional interface that an Assigner may implement to expose the +// names of its methods to the ServerInfo method. +type Namer interface { + // Names returns all known method names in lexicographic order. Names() []string } @@ -91,11 +96,14 @@ func (r *Request) UnmarshalParams(v interface{}) error { dec := json.NewDecoder(bytes.NewReader(r.params)) dec.DisallowUnknownFields() if err := dec.Decode(v); err != nil { - return Errorf(code.InvalidParams, "invalid parameters: %v", err.Error()) + return errInvalidParams.WithData(err.Error()) } return nil } - return json.Unmarshal(r.params, v) + if err := json.Unmarshal(r.params, v); err != nil { + return errInvalidParams.WithData(err.Error()) + } + return nil } // ParamString returns the encoded request parameters of r as a string. diff --git a/vendor/github.com/creachadair/jrpc2/bench_test.go b/vendor/github.com/creachadair/jrpc2/bench_test.go index 8be5c91..c1e4c1b 100644 --- a/vendor/github.com/creachadair/jrpc2/bench_test.go +++ b/vendor/github.com/creachadair/jrpc2/bench_test.go @@ -1,13 +1,18 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2_test import ( "context" + "strconv" + "sync" "testing" "github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/jctx" "github.com/creachadair/jrpc2/server" + "github.com/fortytw2/leaktest" ) func BenchmarkRoundTrip(b *testing.B) { @@ -69,6 +74,75 @@ func BenchmarkRoundTrip(b *testing.B) { } } +func BenchmarkLoad(b *testing.B) { + defer leaktest.Check(b)() + + // The load testing service has a no-op method to exercise server overhead. + loc := server.NewLocal(handler.Map{ + "void": handler.Func(func(context.Context, *jrpc2.Request) (interface{}, error) { + return nil, nil + }), + }, nil) + defer loc.Close() + + // Exercise concurrent calls. + ctx := context.Background() + b.Run("Call", func(b *testing.B) { + var wg sync.WaitGroup + for i := 0; i < b.N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := loc.Client.Call(ctx, "void", nil) + if err != nil { + b.Errorf("Call failed: %v", err) + } + }() + } + wg.Wait() + }) + + // Exercise concurrent notifications. + b.Run("Notify", func(b *testing.B) { + var wg sync.WaitGroup + for i := 0; i < b.N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := loc.Client.Notify(ctx, "void", nil) + if err != nil { + b.Errorf("Notify failed: %v", err) + } + }() + } + wg.Wait() + }) + + // Exercise concurrent batches of various sizes. + for _, bs := range []int{1, 2, 4, 8, 12, 16, 20, 50} { + batch := make([]jrpc2.Spec, bs) + for j := 0; j < len(batch); j++ { + batch[j].Method = "void" + } + + name := "Batch-" + strconv.Itoa(bs) + b.Run(name, func(b *testing.B) { + var wg sync.WaitGroup + for i := 0; i < b.N; i += bs { + wg.Add(1) + go func() { + defer wg.Done() + _, err := loc.Client.Batch(ctx, batch) + if err != nil { + b.Errorf("Batch failed: %v", err) + } + }() + } + wg.Wait() + }) + } +} + func BenchmarkParseRequests(b *testing.B) { reqs := []struct { desc, input string diff --git a/vendor/github.com/creachadair/jrpc2/channel/bench_test.go b/vendor/github.com/creachadair/jrpc2/channel/bench_test.go index 5233d9b..d76004f 100644 --- a/vendor/github.com/creachadair/jrpc2/channel/bench_test.go +++ b/vendor/github.com/creachadair/jrpc2/channel/bench_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package channel_test import ( diff --git a/vendor/github.com/creachadair/jrpc2/channel/channel.go b/vendor/github.com/creachadair/jrpc2/channel/channel.go index 6bae36f..af535b1 100644 --- a/vendor/github.com/creachadair/jrpc2/channel/channel.go +++ b/vendor/github.com/creachadair/jrpc2/channel/channel.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Package channel defines a basic communications channel. // // A Channel encodes/transmits and decodes/receives data records over an @@ -61,10 +63,15 @@ type Channel interface { Close() error } -// IsErrClosing reports whether err is the internal error returned by a read -// from a pipe or socket that is closed. This is false for err == nil. +// ErrClosed is a sentinel error that can be returned to indicate an operation +// failed because the channel was closed. +var ErrClosed = errors.New("channel is closed") + +// IsErrClosing reports whether err is a channel-closed error. This is true +// for the internal error returned by a read from a pipe or socket that is +// closed, or an error that wraps ErrClosed. It is false if err == nil. func IsErrClosing(err error) bool { - return err != nil && errors.Is(err, net.ErrClosed) + return err != nil && (errors.Is(err, ErrClosed) || errors.Is(err, net.ErrClosed)) } // A Framing converts a reader and a writer into a Channel with a particular @@ -77,14 +84,12 @@ type direct struct { } func (d direct) Send(msg []byte) (err error) { - cp := make([]byte, len(msg)) - copy(cp, msg) defer func() { if p := recover(); p != nil { err = errors.New("send on closed channel") } }() - d.send <- cp + d.send <- msg return nil } @@ -101,6 +106,10 @@ func (d direct) Close() error { close(d.send); return nil } // Direct returns a pair of synchronous connected channels that pass message // buffers directly in memory without framing or encoding. Sends to client will // be received by server, and vice versa. +// +// Note that buffers passed to direct channels are not copied. If the caller +// needs to use the buffer after sending it on a direct channel, the caller is +// responsible for making a copy. func Direct() (client, server Channel) { c2s := make(chan []byte) s2c := make(chan []byte) diff --git a/vendor/github.com/creachadair/jrpc2/channel/channel_test.go b/vendor/github.com/creachadair/jrpc2/channel/channel_test.go index fed0af3..d540a7a 100644 --- a/vendor/github.com/creachadair/jrpc2/channel/channel_test.go +++ b/vendor/github.com/creachadair/jrpc2/channel/channel_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package channel import ( @@ -6,6 +8,8 @@ import ( "strings" "sync" "testing" + + "github.com/fortytw2/leaktest" ) // newPipe creates a pair of connected in-memory channels using the specified @@ -20,6 +24,8 @@ func newPipe(framing Framing) (client, server Channel) { } func testSendRecv(t *testing.T, s, r Channel, msg string) { + defer leaktest.Check(t)() + var wg sync.WaitGroup var sendErr, recvErr error var data []byte diff --git a/vendor/github.com/creachadair/jrpc2/channel/hdr.go b/vendor/github.com/creachadair/jrpc2/channel/hdr.go index 97b3579..5b1328b 100644 --- a/vendor/github.com/creachadair/jrpc2/channel/hdr.go +++ b/vendor/github.com/creachadair/jrpc2/channel/hdr.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package channel import ( diff --git a/vendor/github.com/creachadair/jrpc2/channel/json.go b/vendor/github.com/creachadair/jrpc2/channel/json.go index 13639f0..9257088 100644 --- a/vendor/github.com/creachadair/jrpc2/channel/json.go +++ b/vendor/github.com/creachadair/jrpc2/channel/json.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package channel import ( diff --git a/vendor/github.com/creachadair/jrpc2/channel/split.go b/vendor/github.com/creachadair/jrpc2/channel/split.go index 8217cd4..6f4a874 100644 --- a/vendor/github.com/creachadair/jrpc2/channel/split.go +++ b/vendor/github.com/creachadair/jrpc2/channel/split.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package channel import ( diff --git a/vendor/github.com/creachadair/jrpc2/client.go b/vendor/github.com/creachadair/jrpc2/client.go index fd1fe74..339c9ff 100644 --- a/vendor/github.com/creachadair/jrpc2/client.go +++ b/vendor/github.com/creachadair/jrpc2/client.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2 import ( @@ -21,10 +23,11 @@ type Client struct { log func(string, ...interface{}) // write debug logs here enctx encoder snote func(*jmessage) - scall func(*jmessage) []byte + scall func(context.Context, *jmessage) []byte chook func(*Client, *Response) - allow1 bool // tolerate v1 replies with no version marker + cbctx context.Context // terminates when the client is closed + cbcancel func() // cancels cbctx mu sync.Mutex // protects the fields below ch channel.Channel // channel to the server @@ -35,14 +38,17 @@ type Client struct { // NewClient returns a new client that communicates with the server via ch. func NewClient(ch channel.Channel, opts *ClientOptions) *Client { + cbctx, cbcancel := context.WithCancel(context.Background()) c := &Client{ - done: new(sync.WaitGroup), - log: opts.logFunc(), - allow1: opts.allowV1(), - enctx: opts.encodeContext(), - snote: opts.handleNotification(), - scall: opts.handleCallback(), - chook: opts.handleCancel(), + done: new(sync.WaitGroup), + log: opts.logFunc(), + enctx: opts.encodeContext(), + snote: opts.handleNotification(), + scall: opts.handleCallback(), + chook: opts.handleCancel(), + + cbctx: cbctx, + cbcancel: cbcancel, // Lock-protected fields ch: ch, @@ -99,7 +105,7 @@ func (c *Client) accept(ch receiver) error { } // handleRequest handles a callback or notification from the server. The -// caller must hold c.mu, and this blocks until the handler completes. +// caller must hold c.mu. This function does not block for the handler. // Precondition: msg is a request or notification, not a response or error. func (c *Client) handleRequest(msg *jmessage) { if msg.isNotification() { @@ -113,10 +119,22 @@ func (c *Client) handleRequest(msg *jmessage) { } else if c.ch == nil { c.log("Client channel is closed; discarding callback: %v", msg) } else { - bits := c.scall(msg) - if err := c.ch.Send(bits); err != nil { - c.log("Sending reply for callback %v failed: %v", msg, err) - } + // Run the callback handler in its own goroutine. The context will be + // cancelled automatically when the client is closed. + ctx := context.WithValue(c.cbctx, clientKey{}, c) + c.done.Add(1) + go func() { + defer c.done.Done() + bits := c.scall(ctx, msg) + + c.mu.Lock() + defer c.mu.Unlock() + if c.err != nil { + c.log("Discarding callback response: %v", c.err) + } else if err := c.ch.Send(bits); err != nil { + c.log("Sending reply for callback %v failed: %v", msg, err) + } + }() } } @@ -365,7 +383,7 @@ func (c *Client) Notify(ctx context.Context, method string, params interface{}) return err } -// Close shuts down the client, abandoning any pending in-flight requests. +// Close shuts down the client, terminating any pending in-flight requests. func (c *Client) Close() error { c.mu.Lock() c.stop(errClientStopped) @@ -392,20 +410,19 @@ func (c *Client) stop(err error) { } c.ch.Close() + // Unblock and fail any pending callbacks. + c.cbcancel() + // Unblock and fail any pending requests. for _, p := range c.pending { p.cancel() } + c.err = err c.ch = nil } -func (c *Client) versionOK(v string) bool { - if v == "" { - return c.allow1 - } - return v == Version -} +func (c *Client) versionOK(v string) bool { return v == Version } // marshalParams validates and marshals params to JSON for a request. The // value of params must be either nil or encodable as a JSON object or array. @@ -417,7 +434,7 @@ func (c *Client) marshalParams(ctx context.Context, method string, params interf if err != nil { return nil, err } - if len(pbits) == 0 || (pbits[0] != '[' && pbits[0] != '{' && !isNull(pbits)) { + if fb := firstByte(pbits); fb != '[' && fb != '{' && !isNull(pbits) { // JSON-RPC requires that if parameters are provided at all, they are // an array or an object. return nil, &Error{Code: code.InvalidRequest, Message: "invalid parameters: array or object required"} diff --git a/vendor/github.com/creachadair/jrpc2/code/code.go b/vendor/github.com/creachadair/jrpc2/code/code.go index ac3e09f..0973186 100644 --- a/vendor/github.com/creachadair/jrpc2/code/code.go +++ b/vendor/github.com/creachadair/jrpc2/code/code.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Package code defines error code values used by the jrpc2 package. package code diff --git a/vendor/github.com/creachadair/jrpc2/code/code_test.go b/vendor/github.com/creachadair/jrpc2/code/code_test.go index 3ac8714..cf0c6ed 100644 --- a/vendor/github.com/creachadair/jrpc2/code/code_test.go +++ b/vendor/github.com/creachadair/jrpc2/code/code_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package code_test import ( diff --git a/vendor/github.com/creachadair/jrpc2/ctx.go b/vendor/github.com/creachadair/jrpc2/ctx.go index f172a58..c55ab49 100644 --- a/vendor/github.com/creachadair/jrpc2/ctx.go +++ b/vendor/github.com/creachadair/jrpc2/ctx.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2 import ( @@ -32,3 +34,12 @@ type inboundRequestKey struct{} func ServerFromContext(ctx context.Context) *Server { return ctx.Value(serverKey{}).(*Server) } type serverKey struct{} + +// ClientFromContext returns the client associated with the given context. +// This will be populated on the context passed to callback handlers. +// +// A callback handler must not close the client, as the close will deadlock +// waiting for the callback to return. +func ClientFromContext(ctx context.Context) *Client { return ctx.Value(clientKey{}).(*Client) } + +type clientKey struct{} diff --git a/vendor/github.com/creachadair/jrpc2/doc.go b/vendor/github.com/creachadair/jrpc2/doc.go index 013f994..756bd71 100644 --- a/vendor/github.com/creachadair/jrpc2/doc.go +++ b/vendor/github.com/creachadair/jrpc2/doc.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + /* Package jrpc2 implements a server and a client for the JSON-RPC 2.0 protocol defined by http://www.jsonrpc.org/specification. @@ -12,9 +14,20 @@ Handle method with this signature: Handle(ctx Context.Context, req *jrpc2.Request) (interface{}, error) A server finds the handler for a request by looking up its method name in a -jrpc2.Assigner provided when the server is set up. +jrpc2.Assigner provided when the server is set up. A Handler can decode the +request parameters using the UnmarshalParams method on the request: -For example, suppose we want to export this Add function via JSON-RPC: + func (H) Handle(ctx context.Context, req *jrpc2.Request) (interface{}, error) { + var args ArgType + if err := req.UnmarshalParams(&args); err != nil { + return nil, err + } + return usefulStuffWith(args) + } + +The handler package makes it easier to use functions that do not have this +exact type signature as handlers, by using reflection to lift functions into +the Handler interface. For example, suppose we want to export this Add function: // Add returns the sum of a slice of integers. func Add(ctx context.Context, values []int) int { @@ -25,9 +38,8 @@ For example, suppose we want to export this Add function via JSON-RPC: return sum } -The handler package helps adapt existing functions to the Handler interface. -To convert Add to a jrpc2.Handler, call handler.New, which uses reflection to -lift its argument into the jrpc2.Handler interface: +To convert Add to a jrpc2.Handler, call handler.New, which wraps its argument +into the jrpc2.Handler interface via the handler.Func type: h := handler.New(Add) // h is now a jrpc2.Handler that calls Add diff --git a/vendor/github.com/creachadair/jrpc2/error.go b/vendor/github.com/creachadair/jrpc2/error.go index 161429d..c7dc9ed 100644 --- a/vendor/github.com/creachadair/jrpc2/error.go +++ b/vendor/github.com/creachadair/jrpc2/error.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2 import ( @@ -44,12 +46,21 @@ var errClientStopped = errors.New("the client has been stopped") // errEmptyMethod is the error reported for an empty request method name. var errEmptyMethod = &Error{Code: code.InvalidRequest, Message: "empty method name"} +// errNoSuchMethod is the error reported for an unknown method name. +var errNoSuchMethod = &Error{Code: code.MethodNotFound, Message: "no such method"} + +// errDuplicateID is the error reported for a duplicated request ID. +var errDuplicateID = &Error{Code: code.InvalidRequest, Message: "duplicate request ID"} + // errInvalidRequest is the error reported for an invalid request object or batch. var errInvalidRequest = &Error{Code: code.ParseError, Message: "invalid request value"} // errEmptyBatch is the error reported for an empty request batch. var errEmptyBatch = &Error{Code: code.InvalidRequest, Message: "empty request batch"} +// errInvalidParams is the error reported for invalid request parameters. +var errInvalidParams = &Error{Code: code.InvalidParams, Message: "invalid parameters"} + // ErrConnClosed is returned by a server's push-to-client methods if they are // called after the client connection is closed. var ErrConnClosed = errors.New("client connection is closed") diff --git a/vendor/github.com/creachadair/jrpc2/examples_test.go b/vendor/github.com/creachadair/jrpc2/examples_test.go index 61c46ed..05552e2 100644 --- a/vendor/github.com/creachadair/jrpc2/examples_test.go +++ b/vendor/github.com/creachadair/jrpc2/examples_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2_test import ( @@ -22,29 +24,23 @@ type Msg struct { Text string `json:"msg"` } -func startServer() server.Local { - return server.NewLocal(handler.Map{ - "Hello": handler.New(func(ctx context.Context) string { - return "Hello, world!" - }), - "Echo": handler.New(func(_ context.Context, args []json.RawMessage) []json.RawMessage { - return args - }), - "Log": handler.New(func(ctx context.Context, msg Msg) (bool, error) { - fmt.Println("Log:", msg.Text) - return true, nil - }), - }, nil) -} +var local = server.NewLocal(handler.Map{ + "Hello": handler.New(func(ctx context.Context) string { + return "Hello, world!" + }), + "Echo": handler.New(func(_ context.Context, args []json.RawMessage) []json.RawMessage { + return args + }), + "Log": handler.New(func(ctx context.Context, msg Msg) (bool, error) { + fmt.Println("Log:", msg.Text) + return true, nil + }), +}, nil) func ExampleNewServer() { - // Construct a new server with methods "Hello" and "Log". - loc := startServer() - defer loc.Close() - // We can query the server for its current status information, including a // list of its methods. - si := loc.Server.ServerInfo() + si := local.Server.ServerInfo() fmt.Println(strings.Join(si.Methods, "\n")) // Output: @@ -54,10 +50,7 @@ func ExampleNewServer() { } func ExampleClient_Call() { - loc := startServer() - defer loc.Close() - - rsp, err := loc.Client.Call(ctx, "Hello", nil) + rsp, err := local.Client.Call(ctx, "Hello", nil) if err != nil { log.Fatalf("Call: %v", err) } @@ -71,11 +64,8 @@ func ExampleClient_Call() { } func ExampleClient_CallResult() { - loc := startServer() - defer loc.Close() - var msg string - if err := loc.Client.CallResult(ctx, "Hello", nil, &msg); err != nil { + if err := local.Client.CallResult(ctx, "Hello", nil, &msg); err != nil { log.Fatalf("CallResult: %v", err) } fmt.Println(msg) @@ -84,10 +74,7 @@ func ExampleClient_CallResult() { } func ExampleClient_Batch() { - loc := startServer() - defer loc.Close() - - rsps, err := loc.Client.Batch(ctx, []jrpc2.Spec{ + rsps, err := local.Client.Batch(ctx, []jrpc2.Spec{ {Method: "Hello"}, {Method: "Log", Params: Msg{"Sing it!"}, Notify: true}, }) @@ -164,10 +151,7 @@ type strictParams struct { func (strictParams) DisallowUnknownFields() {} func ExampleResponse_UnmarshalResult() { - loc := startServer() - defer loc.Close() - - rsp, err := loc.Client.Call(ctx, "Echo", []string{"alpha", "oscar", "kilo"}) + rsp, err := local.Client.Call(ctx, "Echo", []string{"alpha", "oscar", "kilo"}) if err != nil { log.Fatalf("Call: %v", err) } diff --git a/vendor/github.com/creachadair/jrpc2/go.mod b/vendor/github.com/creachadair/jrpc2/go.mod index a999c4c..e9518bb 100644 --- a/vendor/github.com/creachadair/jrpc2/go.mod +++ b/vendor/github.com/creachadair/jrpc2/go.mod @@ -1,7 +1,8 @@ module github.com/creachadair/jrpc2 require ( - github.com/google/go-cmp v0.5.6 + github.com/fortytw2/leaktest v1.3.0 + github.com/google/go-cmp v0.5.7 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect ) diff --git a/vendor/github.com/creachadair/jrpc2/go.sum b/vendor/github.com/creachadair/jrpc2/go.sum index 632f99d..171fd00 100644 --- a/vendor/github.com/creachadair/jrpc2/go.sum +++ b/vendor/github.com/creachadair/jrpc2/go.sum @@ -1,5 +1,7 @@ -github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= -github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= +github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/vendor/github.com/creachadair/jrpc2/handler/example_test.go b/vendor/github.com/creachadair/jrpc2/handler/example_test.go index b477c18..898b23a 100644 --- a/vendor/github.com/creachadair/jrpc2/handler/example_test.go +++ b/vendor/github.com/creachadair/jrpc2/handler/example_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package handler_test import ( @@ -67,30 +69,53 @@ func ExampleObj_unmarshal() { // uid=501, name="P. T. Barnum" } -func ExamplePositional() { - fn := func(ctx context.Context, name string, age int, isOld bool) error { - fmt.Printf("%s is %d (is old: %v)\n", name, age, isOld) - return nil - } - call := handler.NewPos(fn, "name", "age", "isOld") +func describe(_ context.Context, name string, age int, isOld bool) error { + fmt.Printf("%s is %d (old: %v)\n", name, age, isOld) + return nil +} - req, err := jrpc2.ParseRequests([]byte(` -{ - "jsonrpc": "2.0", - "id": 1, - "method": "foo", - "params": { - "name": "Dennis", - "age": 37, - "isOld": false - } -}`)) - if err != nil { - log.Fatalf("Parse: %v", err) - } - if _, err := call(context.Background(), req[0]); err != nil { +func ExamplePositional_object() { + call := handler.NewPos(describe, "name", "age", "isOld") + + req := mustParseReq(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "foo", + "params": { + "name": "Dennis", + "age": 37, + "isOld": false + } + }`) + if _, err := call(context.Background(), req); err != nil { log.Fatalf("Call: %v", err) } // Output: - // Dennis is 37 (is old: false) + // Dennis is 37 (old: false) +} + +func ExamplePositional_array() { + call := handler.NewPos(describe, "name", "age", "isOld") + + req := mustParseReq(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "foo", + "params": ["Marvin", 973000, true] + }`) + if _, err := call(context.Background(), req); err != nil { + log.Fatalf("Call: %v", err) + } + // Output: + // Marvin is 973000 (old: true) +} + +func mustParseReq(s string) *jrpc2.Request { + reqs, err := jrpc2.ParseRequests([]byte(s)) + if err != nil { + log.Fatalf("ParseRequests: %v", err) + } else if len(reqs) == 0 { + log.Fatal("ParseRequests: empty result") + } + return reqs[0] } diff --git a/vendor/github.com/creachadair/jrpc2/handler/handler.go b/vendor/github.com/creachadair/jrpc2/handler/handler.go index 6478fe7..4f67188 100644 --- a/vendor/github.com/creachadair/jrpc2/handler/handler.go +++ b/vendor/github.com/creachadair/jrpc2/handler/handler.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Package handler provides implementations of the jrpc2.Assigner interface, // and support for adapting functions to the jrpc2.Handler interface. package handler @@ -7,7 +9,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "reflect" "sort" "strings" @@ -31,7 +32,7 @@ type Map map[string]jrpc2.Handler // Assign implements part of the jrpc2.Assigner interface. func (m Map) Assign(_ context.Context, method string) jrpc2.Handler { return m[method] } -// Names implements part of the jrpc2.Assigner interface. +// Names implements the optional jrpc2.Namer extension interface. func (m Map) Names() []string { var names []string for name := range m { @@ -64,7 +65,12 @@ func (m ServiceMap) Assign(ctx context.Context, method string) jrpc2.Handler { func (m ServiceMap) Names() []string { var all []string for svc, assigner := range m { - for _, name := range assigner.Names() { + namer, ok := assigner.(jrpc2.Namer) + if !ok { + all = append(all, svc+".*") + continue + } + for _, name := range namer.Names() { all = append(all, svc+"."+name) } } @@ -116,6 +122,7 @@ type FuncInfo struct { Result reflect.Type // the non-error result type, or nil ReportsError bool // true if the function reports an error strictFields bool // enforce strict field checking + posNames []string // positional field names (requires strictFields) fn interface{} // the original function value } @@ -151,16 +158,20 @@ func (fi *FuncInfo) Wrap() Func { } // If strict field checking is desired, ensure arguments are wrapped. + arg := fi.Argument wrapArg := func(v reflect.Value) interface{} { return v.Interface() } - if fi.strictFields && !fi.Argument.Implements(strictType) { - wrapArg = func(v reflect.Value) interface{} { return &strict{v.Interface()} } + if fi.strictFields && arg != nil && !arg.Implements(strictType) { + names := fi.posNames + wrapArg = func(v reflect.Value) interface{} { + return &strict{v: v.Interface(), posNames: names} + } } // Construct a function to unpack the parameters from the request message, // based on the signature of the user's callback. var newInput func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) - if fi.Argument == nil { + if arg == nil { // Case 1: The function does not want any request parameters. // Nothing needs to be decoded, but verify no parameters were passed. newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) { @@ -170,16 +181,16 @@ func (fi *FuncInfo) Wrap() Func { return []reflect.Value{ctx}, nil } - } else if fi.Argument == reqType { + } else if arg == reqType { // Case 2: The function wants the underlying *jrpc2.Request value. newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) { return []reflect.Value{ctx, reflect.ValueOf(req)}, nil } - } else if fi.Argument.Kind() == reflect.Ptr { + } else if arg.Kind() == reflect.Ptr { // Case 3a: The function wants a pointer to its argument value. newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) { - in := reflect.New(fi.Argument.Elem()) + in := reflect.New(arg.Elem()) if err := req.UnmarshalParams(wrapArg(in)); err != nil { return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err) } @@ -188,7 +199,7 @@ func (fi *FuncInfo) Wrap() Func { } else { // Case 3b: The function wants a bare argument value. newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) { - in := reflect.New(fi.Argument) // we still need a pointer to unmarshal + in := reflect.New(arg) // we still need a pointer to unmarshal if err := req.UnmarshalParams(wrapArg(in)); err != nil { return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err) } @@ -254,16 +265,17 @@ func (fi *FuncInfo) Wrap() Func { // Note that the JSON-RPC standard restricts encoded parameter values to arrays // and objects. Check will accept argument types that do not encode to arrays // or objects, but the wrapper will report an error when decoding the request. -// // The recommended solution is to define a struct type for your parameters. -// For arbitrary single value types, however, another approach is to wrap it in -// a 1-element array, for example: +// +// For a single arbitrary type, another approach is to use a 1-element array: // // func(ctx context.Context, sp [1]string) error { // s := sp[0] // pull the actual argument out of the array // // ... // } // +// For more complex positional signatures, see also handler.Positional. +// func Check(fn interface{}) (*FuncInfo, error) { if fn == nil { return nil, errors.New("nil function") @@ -299,97 +311,48 @@ func Check(fn interface{}) (*FuncInfo, error) { return info, nil } -// Args is a wrapper that decodes an array of positional parameters into -// concrete locations. -// -// Unmarshaling a JSON value into an Args value v succeeds if the JSON encodes -// an array with length len(v), and unmarshaling each subvalue i into the -// corresponding v[i] succeeds. As a special case, if v[i] == nil the -// corresponding value is discarded. -// -// Marshaling an Args value v into JSON succeeds if each element of the slice -// is JSON marshalable, and yields a JSON array of length len(v) containing the -// JSON values corresponding to the elements of v. -// -// Usage example: -// -// func Handler(ctx context.Context, req *jrpc2.Request) (interface{}, error) { -// var x, y int -// var s string -// -// if err := req.UnmarshalParams(&handler.Args{&x, &y, &s}); err != nil { -// return nil, err -// } -// // do useful work with x, y, and s -// } -// -type Args []interface{} - -// UnmarshalJSON supports JSON unmarshaling for a. -func (a Args) UnmarshalJSON(data []byte) error { - var elts []json.RawMessage - if err := json.Unmarshal(data, &elts); err != nil { - return filterJSONError("args", "array", err) - } else if len(elts) != len(a) { - return fmt.Errorf("wrong number of args (got %d, want %d)", len(elts), len(a)) - } - for i, elt := range elts { - if a[i] == nil { - continue - } else if err := json.Unmarshal(elt, a[i]); err != nil { - return fmt.Errorf("decoding argument %d: %w", i+1, err) - } - } - return nil -} - -// MarshalJSON supports JSON marshaling for a. -func (a Args) MarshalJSON() ([]byte, error) { - if len(a) == 0 { - return []byte(`[]`), nil - } - return json.Marshal([]interface{}(a)) -} - -// Obj is a wrapper that maps object fields into concrete locations. -// -// Unmarshaling a JSON text into an Obj value v succeeds if the JSON encodes an -// object, and unmarshaling the value for each key k of the object into v[k] -// succeeds. If k does not exist in v, it is ignored. -// -// Marshaling an Obj into JSON works as for an ordinary map. -type Obj map[string]interface{} - -// UnmarshalJSON supports JSON unmarshaling into o. -func (o Obj) UnmarshalJSON(data []byte) error { - var base map[string]json.RawMessage - if err := json.Unmarshal(data, &base); err != nil { - return filterJSONError("decoding", "object", err) - } - for key, arg := range o { - val, ok := base[key] - if !ok { - continue - } else if err := json.Unmarshal(val, arg); err != nil { - return fmt.Errorf("decoding %q: %v", key, err) - } - } - return nil -} - -func filterJSONError(tag, want string, err error) error { - if t, ok := err.(*json.UnmarshalTypeError); ok { - return fmt.Errorf("%s: cannot decode %s as %s", tag, t.Value, want) - } - return err -} - // strict is a wrapper for an arbitrary value that enforces strict field -// checking when unmarshaling from JSON. -type strict struct{ v interface{} } +// checking when unmarshaling from JSON, and handles translation of array +// format into object format. +type strict struct { + v interface{} + posNames []string +} + +// translate translates the raw JSON data into the correct format for +// unmarshaling into s.v. +// +// If s.posNames is set and data encodes an array, the array is rewritten to an +// equivalent object with field names assigned by the positional names. +// Otherwise, data is returned as-is without error. +func (s *strict) translate(data []byte) ([]byte, error) { + if len(s.posNames) == 0 || firstByte(data) != '[' { + return data, nil // no names, or not an array + } + + // Decode the array wrapper and verify it has the correct length. + var arr []json.RawMessage + if err := json.Unmarshal(data, &arr); err != nil { + return nil, err + } else if len(arr) != len(s.posNames) { + return nil, jrpc2.Errorf(code.InvalidParams, "got %d parameters, want %d", + len(arr), len(s.posNames)) + } + + // Rewrite the array into an object. + obj := make(map[string]json.RawMessage, len(s.posNames)) + for i, name := range s.posNames { + obj[name] = arr[i] + } + return json.Marshal(obj) +} func (s *strict) UnmarshalJSON(data []byte) error { - dec := json.NewDecoder(bytes.NewReader(data)) + actual, err := s.translate(data) + if err != nil { + return err + } + dec := json.NewDecoder(bytes.NewReader(actual)) dec.DisallowUnknownFields() return dec.Decode(s.v) } diff --git a/vendor/github.com/creachadair/jrpc2/handler/handler_test.go b/vendor/github.com/creachadair/jrpc2/handler/handler_test.go index 47fe0cb..6f057e8 100644 --- a/vendor/github.com/creachadair/jrpc2/handler/handler_test.go +++ b/vendor/github.com/creachadair/jrpc2/handler/handler_test.go @@ -1,9 +1,13 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package handler_test import ( "context" "encoding/json" "errors" + "fmt" + "strconv" "testing" "github.com/creachadair/jrpc2" @@ -69,6 +73,49 @@ func TestCheck(t *testing.T) { } } +// Verify that the wrappers constructed by FuncInfo.Wrap can properly decode +// their arguments of different types and structure. +func TestFuncInfo_wrapDecode(t *testing.T) { + tests := []struct { + fn handler.Func + p string + want interface{} + }{ + // A positional handler should decode its argument from an array or an object. + {handler.NewPos(func(_ context.Context, z int) int { return z }, "arg"), + `[25]`, 25}, + {handler.NewPos(func(_ context.Context, z int) int { return z }, "arg"), + `{"arg":109}`, 109}, + + // A type with custom marshaling should be properly handled. + {handler.NewPos(func(_ context.Context, b stringByte) byte { return byte(b) }, "arg"), + `["00111010"]`, byte(0x3a)}, + {handler.NewPos(func(_ context.Context, b stringByte) byte { return byte(b) }, "arg"), + `{"arg":"10011100"}`, byte(0x9c)}, + {handler.New(func(_ context.Context, v fauxStruct) int { return int(v) }), + `{"type":"thing","value":99}`, 99}, + + // Plain JSON should get its argument unmodified. + {handler.New(func(_ context.Context, v json.RawMessage) string { return string(v) }), + `{"x": true, "y": null}`, `{"x": true, "y": null}`}, + + // Npn-positional slice argument. + {handler.New(func(_ context.Context, ss []string) int { return len(ss) }), + `["a", "b", "c"]`, 3}, + } + ctx := context.Background() + for _, test := range tests { + req := mustParseRequest(t, + fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"x","params":%s}`, test.p)) + got, err := test.fn(ctx, req) + if err != nil { + t.Errorf("Call %v failed: %v", test.fn, err) + } else if diff := cmp.Diff(test.want, got); diff != "" { + t.Errorf("Call %v: wrong result (-want, +got)\n%s", test.fn, diff) + } + } +} + // Verify that the Positional function correctly handles its cases. func TestPositional(t *testing.T) { tests := []struct { @@ -132,6 +179,17 @@ func TestNewStrict(t *testing.T) { } } +// Verify that a handler with no argument type does not panic attempting to +// enforce strict field checking. +func TestNewStrict_argumentRegression(t *testing.T) { + defer func() { + if x := recover(); x != nil { + t.Fatalf("NewStrict panic: %v", x) + } + }() + handler.NewStrict(func(context.Context) error { return nil }) +} + // Verify that the handling of pointer-typed arguments does not incorrectly // introduce another pointer indirection. func TestNew_pointerRegression(t *testing.T) { @@ -173,14 +231,18 @@ func TestPositional_decode(t *testing.T) { bad bool }{ {`{"jsonrpc":"2.0","id":1,"method":"add","params":{"first":5,"second":3}}`, 8, false}, - {`{"jsonrpc":"2.0","id":2,"method":"add","params":{"first":5}}`, 5, false}, - {`{"jsonrpc":"2.0","id":3,"method":"add","params":{"second":3}}`, 3, false}, - {`{"jsonrpc":"2.0","id":4,"method":"add","params":{}}`, 0, false}, - {`{"jsonrpc":"2.0","id":5,"method":"add","params":null}`, 0, false}, - {`{"jsonrpc":"2.0","id":6,"method":"add"}`, 0, false}, + {`{"jsonrpc":"2.0","id":2,"method":"add","params":[5,3]}`, 8, false}, + {`{"jsonrpc":"2.0","id":3,"method":"add","params":{"first":5}}`, 5, false}, + {`{"jsonrpc":"2.0","id":4,"method":"add","params":{"second":3}}`, 3, false}, + {`{"jsonrpc":"2.0","id":5,"method":"add","params":{}}`, 0, false}, + {`{"jsonrpc":"2.0","id":6,"method":"add","params":null}`, 0, false}, + {`{"jsonrpc":"2.0","id":7,"method":"add"}`, 0, false}, - {`{"jsonrpc":"2.0","id":6,"method":"add","params":["wrong", "type"]}`, 0, true}, - {`{"jsonrpc":"2.0","id":6,"method":"add","params":{"unknown":"field"}}`, 0, true}, + {`{"jsonrpc":"2.0","id":10,"method":"add","params":["wrong", "type"]}`, 0, true}, + {`{"jsonrpc":"2.0","id":12,"method":"add","params":[15, "wrong-type"]}`, 0, true}, + {`{"jsonrpc":"2.0","id":13,"method":"add","params":{"unknown":"field"}}`, 0, true}, + {`{"jsonrpc":"2.0","id":14,"method":"add","params":[1]}`, 0, true}, // too few + {`{"jsonrpc":"2.0","id":15,"method":"add","params":[1,2,3]}`, 0, true}, // too many } for _, test := range tests { req := mustParseRequest(t, test.input) @@ -386,3 +448,38 @@ func mustParseRequest(t *testing.T, text string) *jrpc2.Request { } return req[0] } + +// stringByte is a byte with a custom JSON encoding. It expects a string of +// decimal digits 1 and 0, e.g., "10011000" == 0x98. +type stringByte byte + +func (s *stringByte) UnmarshalText(text []byte) error { + v, err := strconv.ParseUint(string(text), 2, 8) + if err != nil { + return err + } + *s = stringByte(v) + return nil +} + +// fauxStruct is an integer with a custom JSON encoding. It expects an object: +// +// {"type":"thing","value":} +// +type fauxStruct int + +func (s *fauxStruct) UnmarshalJSON(data []byte) error { + var tmp struct { + T string `json:"type"` + V *int `json:"value"` + } + if err := json.Unmarshal(data, &tmp); err != nil { + return err + } else if tmp.T != "thing" { + return fmt.Errorf("unknown type %q", tmp.T) + } else if tmp.V == nil { + return errors.New("missing value") + } + *s = fauxStruct(*tmp.V) + return nil +} diff --git a/vendor/github.com/creachadair/jrpc2/handler/helpers.go b/vendor/github.com/creachadair/jrpc2/handler/helpers.go new file mode 100644 index 0000000..63da2b2 --- /dev/null +++ b/vendor/github.com/creachadair/jrpc2/handler/helpers.go @@ -0,0 +1,103 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + +package handler + +import ( + "bytes" + "encoding/json" + "fmt" +) + +// Args is a wrapper that decodes an array of positional parameters into +// concrete locations. +// +// Unmarshaling a JSON value into an Args value v succeeds if the JSON encodes +// an array with length len(v), and unmarshaling each subvalue i into the +// corresponding v[i] succeeds. As a special case, if v[i] == nil the +// corresponding value is discarded. +// +// Marshaling an Args value v into JSON succeeds if each element of the slice +// is JSON marshalable, and yields a JSON array of length len(v) containing the +// JSON values corresponding to the elements of v. +// +// Usage example: +// +// func Handler(ctx context.Context, req *jrpc2.Request) (interface{}, error) { +// var x, y int +// var s string +// +// if err := req.UnmarshalParams(&handler.Args{&x, &y, &s}); err != nil { +// return nil, err +// } +// // do useful work with x, y, and s +// } +// +type Args []interface{} + +// UnmarshalJSON supports JSON unmarshaling for a. +func (a Args) UnmarshalJSON(data []byte) error { + var elts []json.RawMessage + if err := json.Unmarshal(data, &elts); err != nil { + return filterJSONError("args", "array", err) + } else if len(elts) != len(a) { + return fmt.Errorf("wrong number of args (got %d, want %d)", len(elts), len(a)) + } + for i, elt := range elts { + if a[i] == nil { + continue + } else if err := json.Unmarshal(elt, a[i]); err != nil { + return fmt.Errorf("decoding argument %d: %w", i+1, err) + } + } + return nil +} + +// MarshalJSON supports JSON marshaling for a. +func (a Args) MarshalJSON() ([]byte, error) { + if len(a) == 0 { + return []byte(`[]`), nil + } + return json.Marshal([]interface{}(a)) +} + +// Obj is a wrapper that maps object fields into concrete locations. +// +// Unmarshaling a JSON text into an Obj value v succeeds if the JSON encodes an +// object, and unmarshaling the value for each key k of the object into v[k] +// succeeds. If k does not exist in v, it is ignored. +// +// Marshaling an Obj into JSON works as for an ordinary map. +type Obj map[string]interface{} + +// UnmarshalJSON supports JSON unmarshaling into o. +func (o Obj) UnmarshalJSON(data []byte) error { + var base map[string]json.RawMessage + if err := json.Unmarshal(data, &base); err != nil { + return filterJSONError("decoding", "object", err) + } + for key, arg := range o { + val, ok := base[key] + if !ok { + continue + } else if err := json.Unmarshal(val, arg); err != nil { + return fmt.Errorf("decoding %q: %v", key, err) + } + } + return nil +} + +func filterJSONError(tag, want string, err error) error { + if t, ok := err.(*json.UnmarshalTypeError); ok { + return fmt.Errorf("%s: cannot decode %s as %s", tag, t.Value, want) + } + return err +} + +// firstByte returns the first non-whitespace byte of data, or 0 if there is none. +func firstByte(data []byte) byte { + clean := bytes.TrimSpace(data) + if len(clean) == 0 { + return 0 + } + return clean[0] +} diff --git a/vendor/github.com/creachadair/jrpc2/handler/positional.go b/vendor/github.com/creachadair/jrpc2/handler/positional.go index d5ee624..1ebf2e5 100644 --- a/vendor/github.com/creachadair/jrpc2/handler/positional.go +++ b/vendor/github.com/creachadair/jrpc2/handler/positional.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package handler import ( @@ -26,9 +28,9 @@ func NewPos(fn interface{}, names ...string) Func { // value of fn must be a function with one of the following type signature // schemes: // -// func(context.Context, X1, x2, ..., Xn) (Y, error) -// func(context.Context, X1, x2, ..., Xn) Y -// func(context.Context, X1, x2, ..., Xn) error +// func(context.Context, X1, X2, ..., Xn) (Y, error) +// func(context.Context, X1, X2, ..., Xn) Y +// func(context.Context, X1, X2, ..., Xn) error // // For JSON-marshalable types X_i and Y. If fn does not have one of these // forms, Positional reports an error. The given names must match the number of @@ -56,6 +58,13 @@ func NewPos(fn interface{}, names ...string) Func { // field keys generate an error. The field names are not required to match the // parameter names declared by the function; it is the names assigned here that // determine which object keys are accepted. +// +// The wrapped function will also accept a JSON array with with (exactly) the +// same number of elements as the positional parameters: +// +// [17, 23] +// +// Unlike the object format, no arguments can be omitted in this format. func Positional(fn interface{}, names ...string) (*FuncInfo, error) { if fn == nil { return nil, errors.New("nil function") @@ -85,6 +94,7 @@ func Positional(fn interface{}, names ...string) (*FuncInfo, error) { fi, err := Check(makeCaller(ft, fv, atype)) if err == nil { fi.strictFields = true + fi.posNames = names } return fi, err } diff --git a/vendor/github.com/creachadair/jrpc2/internal_test.go b/vendor/github.com/creachadair/jrpc2/internal_test.go index 58b0b69..8daa00d 100644 --- a/vendor/github.com/creachadair/jrpc2/internal_test.go +++ b/vendor/github.com/creachadair/jrpc2/internal_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2 // This file contains tests that need to inspect the internal details of the @@ -12,6 +14,7 @@ import ( "github.com/creachadair/jrpc2/channel" "github.com/creachadair/jrpc2/code" + "github.com/fortytw2/leaktest" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" ) @@ -150,6 +153,8 @@ func (h hmap) Names() []string { return nil } // Verify that if the client context terminates during a request, the client // will terminate and report failure. func TestClient_contextCancellation(t *testing.T) { + defer leaktest.Check(t)() + started := make(chan struct{}) stopped := make(chan struct{}) cpipe, spipe := channel.Direct() @@ -203,6 +208,8 @@ func TestClient_contextCancellation(t *testing.T) { } func TestServer_specialMethods(t *testing.T) { + defer leaktest.Check(t)() + s := NewServer(hmap{ "rpc.nonesuch": methodFunc(func(context.Context, *Request) (interface{}, error) { return "OK", nil @@ -225,6 +232,8 @@ func TestServer_specialMethods(t *testing.T) { // Verify that the option to remove the special behaviour of rpc.* methods can // be correctly disabled by the server options. func TestServer_disableBuiltinHook(t *testing.T) { + defer leaktest.Check(t)() + s := NewServer(hmap{ "rpc.nonesuch": methodFunc(func(context.Context, *Request) (interface{}, error) { return "OK", nil @@ -249,6 +258,8 @@ func TestServer_disableBuiltinHook(t *testing.T) { // request. The Client never sends requests like that, but the server needs to // cope with it correctly. func TestBatchReply(t *testing.T) { + defer leaktest.Check(t)() + cpipe, spipe := channel.Direct() srv := NewServer(hmap{ "test": methodFunc(func(_ context.Context, req *Request) (interface{}, error) { diff --git a/vendor/github.com/creachadair/jrpc2/jctx/example_test.go b/vendor/github.com/creachadair/jrpc2/jctx/example_test.go index a890a1a..25ff5f6 100644 --- a/vendor/github.com/creachadair/jrpc2/jctx/example_test.go +++ b/vendor/github.com/creachadair/jrpc2/jctx/example_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jctx_test import ( diff --git a/vendor/github.com/creachadair/jrpc2/jctx/jctx.go b/vendor/github.com/creachadair/jrpc2/jctx/jctx.go index ff21f24..e1edb22 100644 --- a/vendor/github.com/creachadair/jrpc2/jctx/jctx.go +++ b/vendor/github.com/creachadair/jrpc2/jctx/jctx.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Package jctx implements an encoder and decoder for request context values, // allowing context metadata to be propagated through JSON-RPC. // @@ -66,7 +68,7 @@ func Encode(ctx context.Context, method string, params json.RawMessage) (json.Ra v := wireVersion c := wireContext{V: &v, Payload: params} if dl, ok := ctx.Deadline(); ok { - utcdl := dl.In(time.UTC) + utcdl := dl.UTC() c.Deadline = &utcdl } @@ -102,7 +104,7 @@ func Decode(ctx context.Context, method string, req json.RawMessage) (context.Co } if c.Deadline != nil && !c.Deadline.IsZero() { var ignored context.CancelFunc - ctx, ignored = context.WithDeadline(ctx, (*c.Deadline).In(time.UTC)) + ctx, ignored = context.WithDeadline(ctx, (*c.Deadline).UTC()) _ = ignored // the caller cannot use this value } diff --git a/vendor/github.com/creachadair/jrpc2/jctx/jctx_test.go b/vendor/github.com/creachadair/jrpc2/jctx/jctx_test.go index a6c386a..5f6161a 100644 --- a/vendor/github.com/creachadair/jrpc2/jctx/jctx_test.go +++ b/vendor/github.com/creachadair/jrpc2/jctx/jctx_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jctx import ( diff --git a/vendor/github.com/creachadair/jrpc2/jhttp/bridge.go b/vendor/github.com/creachadair/jrpc2/jhttp/bridge.go index 578d50e..622f650 100644 --- a/vendor/github.com/creachadair/jrpc2/jhttp/bridge.go +++ b/vendor/github.com/creachadair/jrpc2/jhttp/bridge.go @@ -1,15 +1,15 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Package jhttp implements a bridge from HTTP to JSON-RPC. This permits // requests to be submitted to a JSON-RPC server using HTTP as a transport. package jhttp import ( - "bytes" "context" "encoding/json" "fmt" "io" "net/http" - "strconv" "github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2/server" @@ -17,35 +17,52 @@ import ( // A Bridge is a http.Handler that bridges requests to a JSON-RPC server. // -// The body of the HTTP POST request must contain the complete JSON-RPC request -// message, encoded with Content-Type: application/json. Either a single -// request object or a list of request objects is supported. -// -// If the request completes, whether or not there is an error, the HTTP -// response is 200 (OK) for ordinary requests or 204 (No Response) for -// notifications, and the response body contains the JSON-RPC response. +// By default, the bridge accepts only HTTP POST requests with the complete +// JSON-RPC request message in the body, with Content-Type application/json. +// Either a single request object or a list of request objects is supported. // // If the HTTP request method is not "POST", the bridge reports 405 (Method Not // Allowed). If the Content-Type is not application/json, the bridge reports // 415 (Unsupported Media Type). // +// If a ParseRequest hook is set, these requirements are disabled, and the hook +// is entirely responsible for checking request structure. +// +// If a ParseGETRequest hook is set, HTTP "GET" requests are handled by a +// Getter using that hook; otherwise "GET" requests are handled as above. +// +// If the request completes, whether or not there is an error, the HTTP +// response is 200 (OK) for ordinary requests or 204 (No Response) for +// notifications, and the response body contains the JSON-RPC response. +// // The bridge attaches the inbound HTTP request to the context passed to the // client, allowing an EncodeContext callback to retrieve state from the HTTP // headers. Use jhttp.HTTPRequest to retrieve the request from the context. type Bridge struct { - local server.Local - checkType func(string) bool + local server.Local + parseReq func(*http.Request) ([]*jrpc2.Request, error) + getter *Getter } // ServeHTTP implements the required method of http.Handler. func (b Bridge) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - w.WriteHeader(http.StatusMethodNotAllowed) + // If a GET hook is defined, allow GET requests. + if req.Method == "GET" && b.getter != nil { + b.getter.ServeHTTP(w, req) return } - if !b.checkType(req.Header.Get("Content-Type")) { - w.WriteHeader(http.StatusUnsupportedMediaType) - return + + // If no parse hook is defined, insist that the method is POST and the + // content-type is application/json. Setting a hook disables these checks. + if b.parseReq == nil { + if req.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if req.Header.Get("Content-Type") != "application/json" { + w.WriteHeader(http.StatusUnsupportedMediaType) + return + } } if err := b.serveInternal(w, req); err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -54,11 +71,6 @@ func (b Bridge) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error { - body, err := io.ReadAll(req.Body) - if err != nil { - return err - } - // The HTTP request requires a response, but the server will not reply if // all the requests are notifications. Check whether we have any calls // needing a response, and choose whether to wait for a reply based on that. @@ -66,7 +78,7 @@ func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error { // Note that we are forgiving about a missing version marker in a request, // since we can't tell at this point whether the server is willing to accept // messages like that. - jreq, err := jrpc2.ParseRequests(body) + jreq, err := b.parseHTTPRequest(req) if err != nil && err != jrpc2.ErrInvalidVersion { return err } @@ -118,26 +130,44 @@ func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error { rsp.SetID(inboundID[i]) } - // If the original request was a single message, make sure we encode the - // response the same way. - var reply []byte - if len(rsps) == 1 && !bytes.HasPrefix(bytes.TrimSpace(body), []byte("[")) { - reply, err = json.Marshal(rsps[0]) - } else { - reply, err = json.Marshal(rsps) + return b.encodeResponses(rsps, w) +} + +func (b Bridge) parseHTTPRequest(req *http.Request) ([]*jrpc2.Request, error) { + if b.parseReq != nil { + return b.parseReq(req) } + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + return jrpc2.ParseRequests(body) +} + +func (b Bridge) encodeResponses(rsps []*jrpc2.Response, w http.ResponseWriter) error { + // If there is only a single reply, send it alone; otherwise encode a batch. + // Per the spec (https://www.jsonrpc.org/specification#batch), this is OK; + // we are not required to respond to a batch with an array: + // + // The Server SHOULD respond with an Array containing the corresponding + // Response objects + // + data, err := marshalResponses(rsps) if err != nil { return err } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Length", strconv.Itoa(len(reply))) - w.Write(reply) + writeJSON(w, http.StatusOK, json.RawMessage(data)) return nil } // Close closes the channel to the server, waits for the server to exit, and // reports its exit status. -func (b Bridge) Close() error { return b.local.Close() } +func (b Bridge) Close() error { + if b.getter != nil { + b.getter.Close() + } + return b.local.Close() +} // NewBridge constructs a new Bridge that starts a server on mux and dispatches // HTTP requests to it. The server will run until the bridge is closed. @@ -149,13 +179,22 @@ func (b Bridge) Close() error { return b.local.Close() } // hooks on the bridge client as usual, but the remote client will not see push // messages from the server. func NewBridge(mux jrpc2.Assigner, opts *BridgeOptions) Bridge { - return Bridge{ + b := Bridge{ local: server.NewLocal(mux, &server.LocalOptions{ Client: opts.clientOptions(), Server: opts.serverOptions(), }), - checkType: opts.checkContentType(), + parseReq: opts.parseRequest(), } + if pget := opts.parseGETRequest(); pget != nil { + g := NewGetter(mux, &GetterOptions{ + Client: opts.clientOptions(), + Server: opts.serverOptions(), + ParseRequest: pget, + }) + b.getter = &g + } + return b } // BridgeOptions are optional settings for a Bridge. A nil pointer is ready for @@ -167,11 +206,22 @@ type BridgeOptions struct { // Options for the bridge server (default nil). Server *jrpc2.ServerOptions - // If non-nil, this function is called to check whether the HTTP request's - // declared content-type is valid. If this function returns false, the - // request is rejected. If nil, the default check requires a content type of - // "application/json". - CheckContentType func(contentType string) bool + // If non-nil, this function is called to parse JSON-RPC requests from the + // HTTP request body. If this function reports an error, the request fails. + // By default, the bridge uses jrpc2.ParseRequests on the HTTP request body. + // + // Setting this hook disables the default requirement that the request + // method be POST and the content-type be application/json. + ParseRequest func(*http.Request) ([]*jrpc2.Request, error) + + // If non-nil, this function is used to parse a JSON-RPC method name and + // parameters from the URL of an HTTP GET request. If this function reports + // an error, the request fails. + // + // If this hook is set, all GET requests are handled by a Getter using this + // parse function, and are not passed to a ParseRequest hook even if one is + // defined. + ParseGETRequest func(*http.Request) (string, interface{}, error) } func (o *BridgeOptions) clientOptions() *jrpc2.ClientOptions { @@ -188,11 +238,18 @@ func (o *BridgeOptions) serverOptions() *jrpc2.ServerOptions { return o.Server } -func (o *BridgeOptions) checkContentType() func(string) bool { - if o == nil || o.CheckContentType == nil { - return func(ctype string) bool { return ctype == "application/json" } +func (o *BridgeOptions) parseRequest() func(*http.Request) ([]*jrpc2.Request, error) { + if o == nil { + return nil } - return o.CheckContentType + return o.ParseRequest +} + +func (o *BridgeOptions) parseGETRequest() func(*http.Request) (string, interface{}, error) { + if o == nil { + return nil + } + return o.ParseGETRequest } type httpReqKey struct{} @@ -206,3 +263,11 @@ func HTTPRequest(ctx context.Context) *http.Request { } return nil } + +// marshalResponses encodes a batch of JSON-RPC responses into JSON. +func marshalResponses(rsps []*jrpc2.Response) ([]byte, error) { + if len(rsps) == 1 { + return json.Marshal(rsps[0]) + } + return json.Marshal(rsps) +} diff --git a/vendor/github.com/creachadair/jrpc2/jhttp/channel.go b/vendor/github.com/creachadair/jrpc2/jhttp/channel.go index 858c595..79a1926 100644 --- a/vendor/github.com/creachadair/jrpc2/jhttp/channel.go +++ b/vendor/github.com/creachadair/jrpc2/jhttp/channel.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jhttp import ( diff --git a/vendor/github.com/creachadair/jrpc2/jhttp/example_test.go b/vendor/github.com/creachadair/jrpc2/jhttp/example_test.go index da0ccca..7471a67 100644 --- a/vendor/github.com/creachadair/jrpc2/jhttp/example_test.go +++ b/vendor/github.com/creachadair/jrpc2/jhttp/example_test.go @@ -1,46 +1,45 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jhttp_test import ( "context" "fmt" - "io" "log" - "net/http" "net/http/httptest" "strings" + "github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/jhttp" ) func Example() { - // Set up a bridge to demonstrate the API. + // Set up a bridge exporting a simple service. b := jhttp.NewBridge(handler.Map{ - "Test": handler.New(func(ctx context.Context, ss []string) (string, error) { - return strings.Join(ss, " "), nil + "Test": handler.New(func(ctx context.Context, ss []string) string { + return strings.Join(ss, " ") }), }, nil) defer b.Close() + // The bridge can be used as the handler for an HTTP server. hsrv := httptest.NewServer(b) defer hsrv.Close() - rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{ - "jsonrpc": "2.0", - "id": 10235, - "method": "Test", - "params": ["full", "plate", "and", "packing", "steel"] -}`)) - if err != nil { - log.Fatalf("POST request failed: %v", err) - } - body, err := io.ReadAll(rsp.Body) - rsp.Body.Close() - if err != nil { - log.Fatalf("Reading response body: %v", err) + // Set up a client using an HTTP channel, and use it to call the test + // service exported by the bridge. + ch := jhttp.NewChannel(hsrv.URL, nil) + cli := jrpc2.NewClient(ch, nil) + + var result string + if err := cli.CallResult(context.Background(), "Test", []string{ + "full", "plate", "and", "packing", "steel", + }, &result); err != nil { + log.Fatalf("Call failed: %v", err) } - fmt.Println(string(body)) + fmt.Println("Result:", result) // Output: - // {"jsonrpc":"2.0","id":10235,"result":"full plate and packing steel"} + // Result: full plate and packing steel } diff --git a/vendor/github.com/creachadair/jrpc2/jhttp/getter.go b/vendor/github.com/creachadair/jrpc2/jhttp/getter.go new file mode 100644 index 0000000..b14e533 --- /dev/null +++ b/vendor/github.com/creachadair/jrpc2/jhttp/getter.go @@ -0,0 +1,288 @@ +// Copyright (C) 2021 Michael J. Fromberger. All Rights Reserved. + +package jhttp + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/creachadair/jrpc2" + "github.com/creachadair/jrpc2/code" + "github.com/creachadair/jrpc2/server" +) + +// A Getter is a http.Handler that bridges GET requests to a JSON-RPC server. +// +// The JSON-RPC method name and parameters are decoded from the request URL. +// The results from a successful call are encoded as JSON in the response body +// with status 200 (OK). In case of error, the response body is a JSON-RPC +// error object, and the HTTP status is one of the following: +// +// Condition HTTP Status +// ----------------------- ----------------------------------- +// Parsing request 400 (Bad request) +// Method not found 404 (Not found) +// (other errors) 500 (Internal server error) +// +// By default, the URL path identifies the JSON-RPC method, and the URL query +// parameters are converted into a JSON object for the parameters. Leading and +// trailing slashes are stripped from the path, and query values are sent as +// JSON strings. +// +// For example, this URL: +// +// http://site.org:2112/some/method?param1=xyzzy¶m2=apple +// +// would produce the method name "some/method" and this parameter object: +// +// {"param1":"xyzzy", "param2":"apple"} +// +// To override the default behaviour, set a ParseRequest hook in GetterOptions. +// See also the jhttp.ParseQuery function for a more expressive translation. +type Getter struct { + local server.Local + parseReq func(*http.Request) (string, interface{}, error) +} + +// NewGetter constructs a new Getter that starts a server on mux and dispatches +// HTTP requests to it. The server will run until the getter is closed. +// +// Note that a getter is not able to push calls or notifications from the +// server back to the remote client even if enabled. +func NewGetter(mux jrpc2.Assigner, opts *GetterOptions) Getter { + return Getter{ + local: server.NewLocal(mux, &server.LocalOptions{ + Client: opts.clientOptions(), + Server: opts.serverOptions(), + }), + parseReq: opts.parseRequest(), + } +} + +// ServeHTTP implements the required method of http.Handler. +func (g Getter) ServeHTTP(w http.ResponseWriter, req *http.Request) { + method, params, err := g.parseHTTPRequest(req) + if err != nil { + writeJSON(w, http.StatusBadRequest, &jrpc2.Error{ + Code: code.ParseError, + Message: err.Error(), + }) + return + } + + ctx := context.WithValue(req.Context(), httpReqKey{}, req) + var result json.RawMessage + if err := g.local.Client.CallResult(ctx, method, params, &result); err != nil { + var status int + switch code.FromError(err) { + case code.MethodNotFound: + status = http.StatusNotFound + default: + status = http.StatusInternalServerError + } + writeJSON(w, status, err) + return + } + writeJSON(w, http.StatusOK, result) +} + +// Close closes the channel to the server, waits for the server to exit, and +// reports its exit status. +func (g Getter) Close() error { return g.local.Close() } + +func (g Getter) parseHTTPRequest(req *http.Request) (string, interface{}, error) { + if g.parseReq != nil { + return g.parseReq(req) + } + if err := req.ParseForm(); err != nil { + return "", nil, err + } + method := strings.Trim(req.URL.Path, "/") + if method == "" { + return "", nil, errors.New("empty method name") + } + params := make(map[string]string) + for key := range req.Form { + params[key] = req.Form.Get(key) + } + return method, params, nil +} + +// GetterOptions are optional settings for a Getter. A nil pointer is ready for +// use and provides default values as described. +type GetterOptions struct { + // Options for the getter client (default nil). + Client *jrpc2.ClientOptions + + // Options for the getter server (default nil). + Server *jrpc2.ServerOptions + + // If set, this function is called to parse a method name and request + // parameters from an HTTP request. If this is not set, the default handler + // uses the URL path as the method name and the URL query as the method + // parameters. + ParseRequest func(*http.Request) (string, interface{}, error) +} + +func (o *GetterOptions) clientOptions() *jrpc2.ClientOptions { + if o == nil { + return nil + } + return o.Client +} + +func (o *GetterOptions) serverOptions() *jrpc2.ServerOptions { + if o == nil { + return nil + } + return o.Server +} + +func (o *GetterOptions) parseRequest() func(*http.Request) (string, interface{}, error) { + if o == nil { + return nil + } + return o.ParseRequest +} + +func writeJSON(w http.ResponseWriter, code int, obj interface{}) { + bits, err := json.Marshal(obj) + if err != nil { + // Fallback in case of marshaling error. This should not happen, but + // ensures the client gets a loggable reply from a broken server. + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, err.Error()) + return + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Length", strconv.Itoa(len(bits))) + w.WriteHeader(code) + w.Write(bits) +} + +// ParseQuery parses a request URL and constructs a parameter map from the +// query values encoded in the URL and/or request body. +// +// The method name is the URL path, with leading and trailing slashes trimmed. +// Query values are converted into argument values by these rules: +// +// Double-quoted values are interpreted as JSON string values, with the same +// encoding and escaping rules (UTF-8 with backslash escapes). Examples: +// +// "" +// "foo\nbar" +// "a \"string\" of text" +// +// Values that consist of decimal digits and an optional leading sign are +// treated as either int64 (if there is no decimal point) or float64 values. +// Examples: +// +// 25 +// -16 +// 3.259 +// +// The unquoted strings "true" and "false" are converted to the corresponding +// Boolean values. The unquoted string "null" is converted to nil. +// +// To express arbitrary bytes, use a singly-quoted string encoded in base64. +// For example: +// +// 'aGVsbG8sIHdvcmxk' -- represents "hello, world" +// +// All values not matching any of the above are treated as literal strings. +// +// On success, the result has concrete type map[string]interface{} and the +// method name is not empty. +func ParseQuery(req *http.Request) (string, interface{}, error) { + if err := req.ParseForm(); err != nil { + return "", nil, err + } + method := strings.Trim(req.URL.Path, "/") + if method == "" { + return "", nil, errors.New("empty URL path") + } + if len(req.Form) == 0 { + return method, nil, nil + } + + params := make(map[string]interface{}) + for key := range req.Form { + val := req.Form.Get(key) + if v, ok, err := parseJSONString(val); err != nil { + return "", nil, fmt.Errorf("decoding string %q: %w", key, err) + } else if ok { + params[key] = v + } else if n, ok := parseNumber(val); ok { + params[key] = n + } else if b, ok := parseConstant(val); ok { + params[key] = b + } else if d, ok, err := parseQuoted64(val); err != nil { + return "", nil, fmt.Errorf("decoding bytes %q: %w", key, err) + } else if ok { + params[key] = d + } else { + params[key] = val + } + } + return method, params, nil +} + +func parseJSONString(s string) (string, bool, error) { + if len(s) >= 2 { + if s[0] == '"' && s[len(s)-1] == '"' { + var dec string + err := json.Unmarshal([]byte(s), &dec) + if err != nil { + return "", false, err + } + return dec, true, nil + } else if s[0] == '"' || s[len(s)-1] == '"' { + return "", false, errors.New("missing string quote") + } + } + return "", false, nil +} + +func parseNumber(s string) (interface{}, bool) { + z, err := strconv.ParseInt(s, 10, 64) + if err == nil { + return z, true + } + v, err := strconv.ParseFloat(s, 64) + if err == nil { + return v, true + } + return nil, false +} + +func parseConstant(s string) (interface{}, bool) { + switch s { + case "true": + return true, true + case "false": + return false, true + case "null": + return nil, true + default: + return nil, false + } +} + +func parseQuoted64(s string) ([]byte, bool, error) { + if len(s) >= 2 { + if s[0] == '\'' && s[len(s)-1] == '\'' { + trim := strings.TrimRight(s[1:len(s)-1], "=") // discard base64 padding + dec, err := base64.RawStdEncoding.DecodeString(trim) + return dec, err == nil, err + } else if s[0] == '\'' || s[len(s)-1] == '\'' { + return nil, false, errors.New("missing bytes quote") + } + } + return nil, false, nil +} diff --git a/vendor/github.com/creachadair/jrpc2/jhttp/getter_test.go b/vendor/github.com/creachadair/jrpc2/jhttp/getter_test.go new file mode 100644 index 0000000..98ceb7d --- /dev/null +++ b/vendor/github.com/creachadair/jrpc2/jhttp/getter_test.go @@ -0,0 +1,236 @@ +// Copyright (C) 2021 Michael J. Fromberger. All Rights Reserved. + +package jhttp_test + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/creachadair/jrpc2" + "github.com/creachadair/jrpc2/handler" + "github.com/creachadair/jrpc2/jhttp" + "github.com/fortytw2/leaktest" + "github.com/google/go-cmp/cmp" +) + +func TestGetter(t *testing.T) { + defer leaktest.Check(t)() + + mux := handler.Map{ + "concat": handler.NewPos(func(ctx context.Context, a, b string) string { + return a + b + }, "first", "second"), + } + + g := jhttp.NewGetter(mux, &jhttp.GetterOptions{ + Client: &jrpc2.ClientOptions{EncodeContext: checkContext}, + }) + defer checkClose(t, g) + + hsrv := httptest.NewServer(g) + defer hsrv.Close() + url := func(pathQuery string) string { + return hsrv.URL + "/" + pathQuery + } + + t.Run("OK", func(t *testing.T) { + got := mustGet(t, url("concat?second=world&first=hello"), http.StatusOK) + const want = `"helloworld"` + if got != want { + t.Errorf("Response body: got %#q, want %#q", got, want) + } + }) + t.Run("NotFound", func(t *testing.T) { + got := mustGet(t, url("nonesuch"), http.StatusNotFound) + const want = `"code":-32601` // MethodNotFound + if !strings.Contains(got, want) { + t.Errorf("Response body: got %#q, want %#q", got, want) + } + }) + t.Run("BadRequest", func(t *testing.T) { + // N.B. invalid query string + got := mustGet(t, url("concat?x%2"), http.StatusBadRequest) + const want = `"code":-32700` // ParseError + if !strings.Contains(got, want) { + t.Errorf("Response body: got %#q, want %#q", got, want) + } + }) + t.Run("InternalError", func(t *testing.T) { + got := mustGet(t, url("concat?third=c"), http.StatusInternalServerError) + const want = `"code":-32602` // InvalidParams + if !strings.Contains(got, want) { + t.Errorf("Response body: got %#q, want %#q", got, want) + } + }) +} + +func TestParseQuery(t *testing.T) { + tests := []struct { + url string + body string + method string + want interface{} + errText string + }{ + // Error: Missing method name. + {"http://localhost:2112/", "", "", nil, "empty URL path"}, + + // No parameters. + {"http://localhost/foo", "", "foo", nil, ""}, + + // Unbalanced double-quoted string. + {"https://fuzz.ball/foo?bad=%22xyz", "", "foo", nil, "missing string quote"}, + {"https://fuzz.ball/bar?bad=xyz%22", "", "bar", nil, "missing string quote"}, + + // Unbalanced single-quoted string. + {`http://stripe/sister?bad='invalid`, "", "sister", nil, "missing bytes quote"}, + {`http://stripe/sister?bad=invalid'`, "", "sister", nil, "missing bytes quote"}, + + // Invalid byte string. + {`http://green.as/balls?bad='NOT%20VALID'`, "", "balls", nil, "decoding bytes"}, + + // Invalid double-quoted string. + {`http://black.as/sin?bad=%22a%5Cx25%22`, "", "sin", nil, "invalid character"}, + + // Valid: Single-quoted byte string (base64). + {`http://fast.as.hell/and?twice='YXMgcHJldHR5IGFzIHlvdQ=='`, + "", "and", map[string]interface{}{ + "twice": []byte("as pretty as you"), + }, ""}, + + // Valid: Unquoted strings and null. + {`http://head.like/a-hole?black=as&your=null&soul`, + "", "a-hole", map[string]interface{}{ + "black": "as", + "your": nil, + "soul": "", + }, ""}, + + // Valid: Quoted strings, numbers, Booleans. + {`http://foo.com:1999/go/west/?alpha=%22xyz%22&bravo=3&charlie=true&delta=false&echo=3.2`, + "", "go/west", map[string]interface{}{ + "alpha": "xyz", + "bravo": int64(3), + "charlie": true, + "delta": false, + "echo": 3.2, + }, ""}, + + // Valid: Form-encoded query in the request body. + {`http://buz.org:2013/bodyblow`, + "alpha=%22pdq%22&bravo=-19.4&charlie=false", "bodyblow", map[string]interface{}{ + "alpha": "pdq", + "bravo": float64(-19.4), + "charlie": false, + }, ""}, + } + for _, test := range tests { + t.Run("ParseQuery", func(t *testing.T) { + req, err := http.NewRequest("PUT", test.url, strings.NewReader(test.body)) + if err != nil { + t.Fatalf("New request for %q failed: %v", test.url, err) + } + if test.body != "" { + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + + method, params, err := jhttp.ParseQuery(req) + if err != nil { + if test.errText == "" { + t.Fatalf("ParseQuery failed: %v", err) + } else if !strings.Contains(err.Error(), test.errText) { + t.Fatalf("ParseQuery: got error %v, want %q", err, test.errText) + } + } else if test.errText != "" { + t.Fatalf("ParseQuery: got method %q, params %+v, wanted error %q", + method, params, test.errText) + } else { + if method != test.method { + t.Errorf("ParseQuery method: got %q, want %q", method, test.method) + } + if diff := cmp.Diff(test.want, params); diff != "" { + t.Errorf("Wrong params: (-want, +got)\n%s", diff) + } + } + }) + } +} + +func TestGetter_parseRequest(t *testing.T) { + defer leaktest.Check(t)() + + mux := handler.Map{ + "format": handler.NewPos(func(ctx context.Context, a string, b int) string { + return fmt.Sprintf("%s-%d", a, b) + }, "tag", "value"), + } + + g := jhttp.NewGetter(mux, &jhttp.GetterOptions{ + ParseRequest: func(req *http.Request) (string, interface{}, error) { + if err := req.ParseForm(); err != nil { + return "", nil, err + } + tag := req.Form.Get("tag") + val, err := strconv.ParseInt(req.Form.Get("value"), 10, 64) + if err != nil && req.Form.Get("value") != "" { + return "", nil, fmt.Errorf("invalid number: %w", err) + } + return strings.TrimPrefix(req.URL.Path, "/x/"), map[string]interface{}{ + "tag": tag, + "value": val, + }, nil + }, + }) + defer checkClose(t, g) + + hsrv := httptest.NewServer(g) + defer hsrv.Close() + url := func(pathQuery string) string { + return hsrv.URL + "/" + pathQuery + } + + t.Run("OK", func(t *testing.T) { + got := mustGet(t, url("x/format?tag=foo&value=25"), http.StatusOK) + const want = `"foo-25"` + if got != want { + t.Errorf("Response body: got %#q, want %#q", got, want) + } + }) + t.Run("NotFound", func(t *testing.T) { + // N.B. Missing path prefix. + got := mustGet(t, url("format"), http.StatusNotFound) + const want = `"code":-32601` // MethodNotFound + if !strings.Contains(got, want) { + t.Errorf("Response body: got %#q, want %#q", got, want) + } + }) + t.Run("InternalError", func(t *testing.T) { + // N.B. Parameter type does not match on the server side. + got := mustGet(t, url("x/format?tag=foo&value=bar"), http.StatusBadRequest) + const want = `"code":-32700` // ParseError + if !strings.Contains(got, want) { + t.Errorf("Response body: got %#q, want %#q", got, want) + } + }) +} + +func mustGet(t *testing.T, url string, code int) string { + t.Helper() + rsp, err := http.Get(url) + if err != nil { + t.Fatalf("GET request failed: %v", err) + } else if got := rsp.StatusCode; got != code { + t.Errorf("GET response code: got %v, want %v", got, code) + } + body, err := io.ReadAll(rsp.Body) + if err != nil { + t.Errorf("Reading GET body: %v", err) + } + return string(body) +} diff --git a/vendor/github.com/creachadair/jrpc2/jhttp/jhttp_test.go b/vendor/github.com/creachadair/jrpc2/jhttp/jhttp_test.go index 3e28d87..fd77672 100644 --- a/vendor/github.com/creachadair/jrpc2/jhttp/jhttp_test.go +++ b/vendor/github.com/creachadair/jrpc2/jhttp/jhttp_test.go @@ -1,10 +1,11 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jhttp_test import ( "context" "encoding/json" "errors" - "fmt" "io" "net/http" "net/http/httptest" @@ -14,6 +15,7 @@ import ( "github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/jhttp" + "github.com/fortytw2/leaktest" ) var testService = handler.Map{ @@ -33,6 +35,8 @@ func checkContext(ctx context.Context, _ string, p json.RawMessage) (json.RawMes } func TestBridge(t *testing.T) { + defer leaktest.Check(t)() + // Set up a bridge with the test configuration. b := jhttp.NewBridge(testService, &jhttp.BridgeOptions{ Client: &jrpc2.ClientOptions{EncodeContext: checkContext}, @@ -45,49 +49,29 @@ func TestBridge(t *testing.T) { // Verify that a valid POST request succeeds. t.Run("PostOK", func(t *testing.T) { - rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{ - "jsonrpc": "2.0", - "id": 1, - "method": "Test1", - "params": ["a", "foolish", "consistency", "is", "the", "hobgoblin"] -} -`)) - if err != nil { - t.Fatalf("POST request failed: %v", err) - } else if got, want := rsp.StatusCode, http.StatusOK; got != want { - t.Errorf("POST response code: got %v, want %v", got, want) - } - body, err := io.ReadAll(rsp.Body) - if err != nil { - t.Errorf("Reading POST body: %v", err) - } + got := mustPost(t, hsrv.URL, `{ + "jsonrpc": "2.0", + "id": 1, + "method": "Test1", + "params": ["a", "foolish", "consistency", "is", "the", "hobgoblin"] + }`, http.StatusOK) const want = `{"jsonrpc":"2.0","id":1,"result":6}` - if got := string(body); got != want { + if got != want { t.Errorf("POST body: got %#q, want %#q", got, want) } }) // Verify that the bridge will accept a batch. t.Run("PostBatchOK", func(t *testing.T) { - rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`[ - {"jsonrpc":"2.0", "id": 3, "method": "Test1", "params": ["first"]}, - {"jsonrpc":"2.0", "id": 7, "method": "Test1", "params": ["among", "equals"]} -] -`)) - if err != nil { - t.Fatalf("POST request failed: %v", err) - } else if got, want := rsp.StatusCode, http.StatusOK; got != want { - t.Errorf("POST response code: got %v, want %v", got, want) - } - body, err := io.ReadAll(rsp.Body) - if err != nil { - t.Errorf("Reading POST body: %v", err) - } + got := mustPost(t, hsrv.URL, `[ + {"jsonrpc":"2.0", "id": 3, "method": "Test1", "params": ["first"]}, + {"jsonrpc":"2.0", "id": 7, "method": "Test1", "params": ["among", "equals"]} + ]`, http.StatusOK) const want = `[{"jsonrpc":"2.0","id":3,"result":1},` + `{"jsonrpc":"2.0","id":7,"result":2}]` - if got := string(body); got != want { + if got != want { t.Errorf("POST body: got %#q, want %#q", got, want) } }) @@ -108,62 +92,51 @@ func TestBridge(t *testing.T) { rsp, err := http.Post(hsrv.URL, "text/plain", strings.NewReader(`{}`)) if err != nil { t.Fatalf("POST request failed: %v", err) - } - if got, want := rsp.StatusCode, http.StatusUnsupportedMediaType; got != want { - t.Errorf("POST status: got %v, want %v", got, want) + } else if got, want := rsp.StatusCode, http.StatusUnsupportedMediaType; got != want { + t.Errorf("POST response code: got %v, want %v", got, want) } }) // Verify that a POST that generates a JSON-RPC error succeeds. t.Run("PostErrorReply", func(t *testing.T) { - rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{ - "id": 1, - "jsonrpc": "2.0" -} -`)) - if err != nil { - t.Fatalf("POST request failed: %v", err) - } else if got, want := rsp.StatusCode, http.StatusOK; got != want { - t.Errorf("POST status: got %v, want %v", got, want) - } - body, err := io.ReadAll(rsp.Body) - if err != nil { - t.Errorf("Reading POST body: %v", err) - } + got := mustPost(t, hsrv.URL, `{ + "id": 1, + "jsonrpc": "2.0" + }`, http.StatusOK) const exp = `{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"empty method name"}}` - if got := string(body); got != exp { + if got != exp { t.Errorf("POST body: got %#q, want %#q", got, exp) } }) // Verify that a notification returns an empty success. t.Run("PostNotification", func(t *testing.T) { - rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{ - "jsonrpc": "2.0", - "method": "TakeNotice", - "params": [] -}`)) - if err != nil { - t.Fatalf("POST request failed: %v", err) - } else if got, want := rsp.StatusCode, http.StatusNoContent; got != want { - t.Errorf("POST status: got %v, want %v", got, want) - } - body, err := io.ReadAll(rsp.Body) - if err != nil { - t.Errorf("Reading POST body: %v", err) - } - if got := string(body); got != "" { + got := mustPost(t, hsrv.URL, `{ + "jsonrpc": "2.0", + "method": "TakeNotice", + "params": [] + }`, http.StatusNoContent) + if got != "" { t.Errorf("POST body: got %q, want empty", got) } }) } -// Verify that the content-type check hook works. -func TestBridge_contentTypeCheck(t *testing.T) { +// Verify that the request-parsing hook works. +func TestBridge_parseRequest(t *testing.T) { + defer leaktest.Check(t)() + + const reqMessage = `{"jsonrpc":"2.0", "method": "Test2", "id": 100, "params":null}` + const wantReply = `{"jsonrpc":"2.0","id":100,"result":0}` + b := jhttp.NewBridge(testService, &jhttp.BridgeOptions{ - CheckContentType: func(ctype string) bool { - return ctype == "application/octet-stream" + ParseRequest: func(req *http.Request) ([]*jrpc2.Request, error) { + action := req.Header.Get("x-test-header") + if action == "fail" { + return nil, errors.New("parse hook reporting failure") + } + return jrpc2.ParseRequests([]byte(reqMessage)) }, }) defer checkClose(t, b) @@ -171,29 +144,93 @@ func TestBridge_contentTypeCheck(t *testing.T) { hsrv := httptest.NewServer(b) defer hsrv.Close() - const reqTemplate = `{"jsonrpc":"2.0","id":%q,"method":"Test1","params":["a","b","c"]}` - t.Run("ContentTypeOK", func(t *testing.T) { - rsp, err := http.Post(hsrv.URL, "application/octet-stream", - strings.NewReader(fmt.Sprintf(reqTemplate, "ok"))) + t.Run("Succeed", func(t *testing.T) { + // Since a parse hook is set, the method and content-type checks should not occur. + // We send an empty body to be sure the request comes from the hook. + req, err := http.NewRequest("GET", hsrv.URL, strings.NewReader("")) if err != nil { - t.Fatalf("POST request failed: %v", err) + t.Fatalf("NewRequest: %v", err) + } + + rsp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("GET request failed: %v", err) } else if got, want := rsp.StatusCode, http.StatusOK; got != want { - t.Errorf("POST response code: got %v, want %v", got, want) + t.Errorf("GET response code: got %v, want %v", got, want) + } + body, _ := io.ReadAll(rsp.Body) + rsp.Body.Close() + if got := string(body); got != wantReply { + t.Errorf("Response: got %#q, want %#q", got, wantReply) } }) - t.Run("ContentTypeBad", func(t *testing.T) { - rsp, err := http.Post(hsrv.URL, "text/plain", - strings.NewReader(fmt.Sprintf(reqTemplate, "bad"))) + t.Run("Fail", func(t *testing.T) { + req, err := http.NewRequest("POST", hsrv.URL, strings.NewReader("")) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + req.Header.Set("X-Test-Header", "fail") + + rsp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("POST request failed: %v", err) - } else if got, want := rsp.StatusCode, http.StatusUnsupportedMediaType; got != want { + } else if got, want := rsp.StatusCode, http.StatusInternalServerError; got != want { t.Errorf("POST response code: got %v, want %v", got, want) } }) } +func TestBridge_parseGETRequest(t *testing.T) { + defer leaktest.Check(t)() + + mux := handler.Map{ + "str/eq": handler.NewPos(func(ctx context.Context, a, b string) bool { + return a == b + }, "lhs", "rhs"), + } + b := jhttp.NewBridge(mux, &jhttp.BridgeOptions{ + ParseGETRequest: func(req *http.Request) (string, interface{}, error) { + if err := req.ParseForm(); err != nil { + return "", nil, err + } + method := strings.Trim(req.URL.Path, "/") + params := make(map[string]string) + for key := range req.Form { + params[key] = req.Form.Get(key) + } + return method, params, nil + }, + }) + defer checkClose(t, b) + + hsrv := httptest.NewServer(b) + defer hsrv.Close() + url := func(pathQuery string) string { + return hsrv.URL + "/" + pathQuery + } + + t.Run("GET", func(t *testing.T) { + got := mustGet(t, url("str/eq?rhs=fizz&lhs=buzz"), http.StatusOK) + const want = `false` + if got != want { + t.Errorf("Response body: got %#q, want %#q", got, want) + } + }) + t.Run("POST", func(t *testing.T) { + const req = `{"jsonrpc":"2.0", "id":1, "method":"str/eq", "params":["foo","foo"]}` + got := mustPost(t, hsrv.URL, req, http.StatusOK) + + const want = `{"jsonrpc":"2.0","id":1,"result":true}` + if got != want { + t.Errorf("Response body: got %#q, want %#q", got, want) + } + }) +} + func TestChannel(t *testing.T) { + defer leaktest.Check(t)() + b := jhttp.NewBridge(testService, nil) defer checkClose(t, b) hsrv := httptest.NewServer(b) @@ -259,3 +296,18 @@ func checkClose(t *testing.T, c io.Closer) { t.Errorf("Error in Close: %v", err) } } + +func mustPost(t *testing.T, url, req string, code int) string { + t.Helper() + rsp, err := http.Post(url, "application/json", strings.NewReader(req)) + if err != nil { + t.Fatalf("POST request failed: %v", err) + } else if got := rsp.StatusCode; got != code { + t.Errorf("POST response code: got %v, want %v", got, code) + } + body, err := io.ReadAll(rsp.Body) + if err != nil { + t.Errorf("Reading POST body: %v", err) + } + return string(body) +} diff --git a/vendor/github.com/creachadair/jrpc2/jrpc2_test.go b/vendor/github.com/creachadair/jrpc2/jrpc2_test.go index c7a7b93..b2deaa8 100644 --- a/vendor/github.com/creachadair/jrpc2/jrpc2_test.go +++ b/vendor/github.com/creachadair/jrpc2/jrpc2_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2_test import ( @@ -6,6 +8,7 @@ import ( "errors" "fmt" "sort" + "sync" "sync/atomic" "testing" "time" @@ -16,6 +19,7 @@ import ( "github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/jctx" "github.com/creachadair/jrpc2/server" + "github.com/fortytw2/leaktest" "github.com/google/go-cmp/cmp" ) @@ -24,8 +28,6 @@ var ( _ code.ErrCoder = (*jrpc2.Error)(nil) ) -var notAuthorized = code.Register(-32095, "request not authorized") - var testOK = handler.New(func(ctx context.Context) (string, error) { return "OK", nil }) @@ -114,6 +116,8 @@ var callTests = []struct { } func TestServerInfo_methodNames(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(handler.ServiceMap{ "Test": testService, }, nil) @@ -130,13 +134,12 @@ func TestServerInfo_methodNames(t *testing.T) { } func TestClient_Call(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(handler.ServiceMap{ "Test": testService, }, &server.LocalOptions{ - Server: &jrpc2.ServerOptions{ - AllowV1: true, - Concurrency: 16, - }, + Server: &jrpc2.ServerOptions{Concurrency: 16}, }) defer loc.Close() c := loc.Client @@ -164,6 +167,8 @@ func TestClient_Call(t *testing.T) { } func TestClient_CallResult(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(handler.ServiceMap{ "Test": testService, }, &server.LocalOptions{ @@ -187,13 +192,12 @@ func TestClient_CallResult(t *testing.T) { } func TestClient_Batch(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(handler.ServiceMap{ "Test": testService, }, &server.LocalOptions{ - Server: &jrpc2.ServerOptions{ - AllowV1: true, - Concurrency: 16, - }, + Server: &jrpc2.ServerOptions{Concurrency: 16}, }) defer loc.Close() c := loc.Client @@ -238,6 +242,8 @@ func TestClient_Batch(t *testing.T) { // Verify that notifications respect order of arrival. func TestServer_notificationOrder(t *testing.T) { + defer leaktest.Check(t)() + var last int32 loc := server.NewLocal(handler.Map{ @@ -269,6 +275,8 @@ func TestServer_notificationOrder(t *testing.T) { // Verify that a method that returns only an error (no result payload) is set // up and handled correctly. func TestHandler_errorOnly(t *testing.T) { + defer leaktest.Check(t)() + const errMessage = "not enough strings" loc := server.NewLocal(handler.Map{ "ErrorOnly": handler.New(func(_ context.Context, ss []string) error { @@ -309,9 +317,11 @@ func TestHandler_errorOnly(t *testing.T) { }) } -// Verify that a timeout set on the context is respected by the server and -// propagates back to the client as an error. -func TestServer_contextTimeout(t *testing.T) { +// Verify that a timeout set on the client context is respected and reports +// back to the caller as an error. +func TestClient_contextTimeout(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(handler.Map{ "Stall": handler.New(func(ctx context.Context) (bool, error) { t.Log("Stalling...") @@ -325,13 +335,12 @@ func TestServer_contextTimeout(t *testing.T) { }), }, nil) defer loc.Close() - c := loc.Client ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() start := time.Now() - got, err := c.Call(ctx, "Stall", nil) + got, err := loc.Client.Call(ctx, "Stall", nil) if err == nil { t.Errorf("Stall: got %+v, wanted error", got) } else if err != context.DeadlineExceeded { @@ -343,6 +352,8 @@ func TestServer_contextTimeout(t *testing.T) { // Verify that stopping the server terminates in-flight requests. func TestServer_stopCancelsHandlers(t *testing.T) { + defer leaktest.Check(t)() + started := make(chan struct{}) stopped := make(chan error, 1) loc := server.NewLocal(handler.Map{ @@ -379,6 +390,8 @@ func TestServer_stopCancelsHandlers(t *testing.T) { // Test that a handler can cancel an in-flight request. func TestServer_CancelRequest(t *testing.T) { + defer leaktest.Check(t)() + ready := make(chan struct{}) loc := server.NewLocal(handler.Map{ "Stall": handler.New(func(ctx context.Context) error { @@ -430,6 +443,8 @@ func TestServer_CancelRequest(t *testing.T) { // Test that an error with data attached to it is correctly propagated back // from the server to the client, in a value of concrete type *Error. func TestError_withData(t *testing.T) { + defer leaktest.Check(t)() + const errCode = -32000 const errData = `{"caroline":452}` const errMessage = "error thingy" @@ -482,6 +497,8 @@ func TestError_withData(t *testing.T) { // Test that a client correctly reports bad parameters. func TestClient_badCallParams(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(handler.Map{ "Test": handler.New(func(_ context.Context, v interface{}) error { return jrpc2.Errorf(129, "this should not be reached") @@ -499,6 +516,8 @@ func TestClient_badCallParams(t *testing.T) { // Verify that metrics are correctly propagated to server info. func TestServer_serverInfoMetrics(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(handler.Map{ "Metricize": handler.New(func(ctx context.Context) (bool, error) { m := jrpc2.ServerFromContext(ctx).Metrics() @@ -529,6 +548,9 @@ func TestServer_serverInfoMetrics(t *testing.T) { if _, err := c.Call(ctx, "Metricize", nil); err != nil { t.Fatalf("Call(Metricize) failed: %v", err) } + if got := s.ServerInfo().Counter["rpc.serversActive"]; got != 1 { + t.Errorf("Metric rpc.serversActive: got %d, want 1", got) + } loc.Close() info := s.ServerInfo() @@ -542,6 +564,7 @@ func TestServer_serverInfoMetrics(t *testing.T) { {info.Counter, "zero-sum", 0}, {info.Counter, "rpc.bytesRead", -1}, {info.Counter, "rpc.bytesWritten", -1}, + {info.Counter, "rpc.serversActive", 0}, {info.MaxValue, "max-metric-value", 5}, {info.MaxValue, "rpc.bytesRead", -1}, {info.MaxValue, "rpc.bytesWritten", -1}, @@ -562,6 +585,8 @@ func TestServer_serverInfoMetrics(t *testing.T) { // elicit a correct response from the server. Here we simulate a "different" // client by writing requests directly into the channel. func TestServer_nonLibraryClient(t *testing.T) { + defer leaktest.Check(t)() + srv, cli := channel.Direct() s := jrpc2.NewServer(handler.Map{ "X": testOK, @@ -594,7 +619,7 @@ func TestServer_nonLibraryClient(t *testing.T) { // The method specified doesn't exist. {`{"jsonrpc":"2.0", "id": 3, "method": "NoneSuch"}`, - `{"jsonrpc":"2.0","id":3,"error":{"code":-32601,"message":"no such method \"NoneSuch\""}}`}, + `{"jsonrpc":"2.0","id":3,"error":{"code":-32601,"message":"no such method","data":"NoneSuch"}}`}, // The parameters are of the wrong form. {`{"jsonrpc":"2.0", "id": 4, "method": "X", "params": "bogus"}`, @@ -604,9 +629,13 @@ func TestServer_nonLibraryClient(t *testing.T) { {`{"jsonrpc": "2.0", "id": 6, "method": "X", "params": null}`, `{"jsonrpc":"2.0","id":6,"result":"OK"}`}, - // Correct requests, one with a non-null response, one with a null response. + // Correct requests. {`{"jsonrpc":"2.0","id": 5, "method": "X"}`, `{"jsonrpc":"2.0","id":5,"result":"OK"}`}, {`{"jsonrpc":"2.0","id":21,"method":"Y"}`, `{"jsonrpc":"2.0","id":21,"result":null}`}, + {`{"jsonrpc":"2.0","id":0,"method":"X"}`, `{"jsonrpc":"2.0","id":0,"result":"OK"}`}, + {`{"jsonrpc":"2.0","id":-0,"method":"X"}`, `{"jsonrpc":"2.0","id":-0,"result":"OK"}`}, + {`{"jsonrpc":"2.0","id":-1,"method":"X"}`, `{"jsonrpc":"2.0","id":-1,"result":"OK"}`}, + {`{"jsonrpc":"2.0","id":-600,"method":"Y"}`, `{"jsonrpc":"2.0","id":-600,"result":null}`}, // A batch of correct requests. {`[{"jsonrpc":"2.0", "id":"a1", "method":"X"}, {"jsonrpc":"2.0", "id":"a2", "method": "X"}]`, @@ -625,7 +654,7 @@ func TestServer_nonLibraryClient(t *testing.T) { // A batch of invalid requests returns a batch of errors. {`[{"jsonrpc": "2.0", "id": 6, "method":"bogus"}]`, - `[{"jsonrpc":"2.0","id":6,"error":{"code":-32601,"message":"no such method \"bogus\""}}]`}, + `[{"jsonrpc":"2.0","id":6,"error":{"code":-32601,"message":"no such method","data":"bogus"}}]`}, // Batch requests return batch responses, even for a singleton. {`[{"jsonrpc": "2.0", "id": 7, "method": "X"}]`, `[{"jsonrpc":"2.0","id":7,"result":"OK"}]`}, @@ -677,6 +706,8 @@ func TestServer_nonLibraryClient(t *testing.T) { // Verify that server-side push notifications work. func TestServer_Notify(t *testing.T) { + defer leaktest.Check(t)() + // Set up a server and client with server-side notification support. Here // we're just capturing the name of the notification method, as a sign we // got the right thing. @@ -728,6 +759,8 @@ func TestServer_Notify(t *testing.T) { // Verify that server-side callbacks can time out. func TestServer_callbackTimeout(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(handler.Map{ "Test": handler.New(func(ctx context.Context) error { tctx, cancel := context.WithTimeout(ctx, 5*time.Millisecond) @@ -757,6 +790,8 @@ func TestServer_callbackTimeout(t *testing.T) { // Verify that server-side callbacks work. func TestServer_Callback(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(handler.Map{ "CallMeMaybe": handler.New(func(ctx context.Context) error { if _, err := jrpc2.ServerFromContext(ctx).Callback(ctx, "succeed", nil); err != nil { @@ -799,6 +834,8 @@ func TestServer_Callback(t *testing.T) { // Verify that a server push after the client closes does not trigger a panic. func TestServer_pushAfterClose(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(make(handler.Map), &server.LocalOptions{ Server: &jrpc2.ServerOptions{AllowPush: true}, }) @@ -814,6 +851,8 @@ func TestServer_pushAfterClose(t *testing.T) { // Verify that an OnCancel hook is called when expected. func TestClient_onCancelHook(t *testing.T) { + defer leaktest.Check(t)() + hooked := make(chan struct{}) // closed when hook notification is finished loc := server.NewLocal(handler.Map{ @@ -870,8 +909,162 @@ func TestClient_onCancelHook(t *testing.T) { } } +// Verify that client callback handlers are cancelled when the client stops. +func TestClient_closeEndsCallbacks(t *testing.T) { + defer leaktest.Check(t)() + + ready := make(chan struct{}) + loc := server.NewLocal(handler.Map{ + "Test": handler.New(func(ctx context.Context) error { + // Call back to the client and block indefinitely until it returns. + srv := jrpc2.ServerFromContext(ctx) + _, err := srv.Callback(ctx, "whatever", nil) + return err + }), + }, &server.LocalOptions{ + Server: &jrpc2.ServerOptions{AllowPush: true}, + Client: &jrpc2.ClientOptions{ + OnCallback: handler.New(func(ctx context.Context) error { + // Signal the test that the callback handler is running. When the + // client is closed, it should terminate ctx and allow this to + // return. If that doesn't happen, time out and fail. + close(ready) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(10 * time.Second): + return errors.New("context not cancelled before timeout") + } + }), + }, + }) + go func() { + rsp, err := loc.Client.Call(context.Background(), "Test", nil) + if err == nil { + t.Errorf("Client call: got %+v, wanted error", rsp) + } + }() + <-ready + loc.Client.Close() + loc.Server.Wait() +} + +// Verify that it is possible for multiple callback handlers to execute +// concurrently. +func TestClient_concurrentCallbacks(t *testing.T) { + defer leaktest.Check(t)() + + ready1 := make(chan struct{}) + ready2 := make(chan struct{}) + release := make(chan struct{}) + + loc := server.NewLocal(handler.Map{ + "Test": handler.New(func(ctx context.Context) []string { + srv := jrpc2.ServerFromContext(ctx) + + // Call two callbacks concurrently, wait until they are both running, + // then ungate them and wait for them both to reply. Return their + // responses back to the test for validation. + ss := make([]string, 2) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + rsp, err := srv.Callback(ctx, "C1", nil) + if err != nil { + t.Errorf("Callback C1 failed: %v", err) + } else { + rsp.UnmarshalResult(&ss[0]) + } + }() + go func() { + defer wg.Done() + rsp, err := srv.Callback(ctx, "C2", nil) + if err != nil { + t.Errorf("Callback C2 failed: %v", err) + } else { + rsp.UnmarshalResult(&ss[1]) + } + }() + <-ready1 // C1 is ready + <-ready2 // C2 is ready + close(release) // allow all callbacks to proceed + wg.Wait() // wait for all callbacks to be done + return ss + }), + }, &server.LocalOptions{ + Server: &jrpc2.ServerOptions{AllowPush: true}, + Client: &jrpc2.ClientOptions{ + OnCallback: handler.Func(func(ctx context.Context, req *jrpc2.Request) (interface{}, error) { + // A trivial callback that reports its method name. + // The name is used to select which invocation we are serving. + switch req.Method() { + case "C1": + close(ready1) + case "C2": + close(ready2) + default: + return nil, fmt.Errorf("unexpected method %q", req.Method()) + } + <-release + return req.Method(), nil + }), + }, + }) + defer loc.Close() + + var got []string + if err := loc.Client.CallResult(context.Background(), "Test", nil, &got); err != nil { + t.Errorf("Call Test failed: %v", err) + } + want := []string{"C1", "C2"} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("Wrong callback results: (-want, +got)\n%s", diff) + } +} + +// Verify that a callback can successfully call "up" into the server. +func TestClient_callbackUpCall(t *testing.T) { + defer leaktest.Check(t)() + + const pingMessage = "kittens!" + + var probe string + loc := server.NewLocal(handler.Map{ + "Test": handler.New(func(ctx context.Context) error { + // Call back to the client, and propagate its response. + srv := jrpc2.ServerFromContext(ctx) + _, err := srv.Callback(ctx, "whatever", nil) + return err + }), + "Ping": handler.New(func(context.Context) string { + // This method is called by the client-side callback. + return pingMessage + }), + }, &server.LocalOptions{ + Server: &jrpc2.ServerOptions{AllowPush: true}, + Client: &jrpc2.ClientOptions{ + OnCallback: handler.New(func(ctx context.Context) error { + // Call back up into the server. + cli := jrpc2.ClientFromContext(ctx) + return cli.CallResult(ctx, "Ping", nil, &probe) + }), + }, + }) + + if _, err := loc.Client.Call(context.Background(), "Test", nil); err != nil { + t.Errorf("Call Test failed: %v", err) + } + loc.Close() + if probe != pingMessage { + t.Errorf("Probe response: got %q, want %q", probe, pingMessage) + } +} + // Verify that the context encoding/decoding hooks work. func TestContextPlumbing(t *testing.T) { + defer leaktest.Check(t)() + want := time.Now().Add(10 * time.Second) ctx, cancel := context.WithDeadline(context.Background(), want) defer cancel() @@ -898,88 +1091,11 @@ func TestContextPlumbing(t *testing.T) { } } -// Verify that the request-checking hook works. -func TestServer_checkRequestHook(t *testing.T) { - const wantResponse = "Hey girl" - const wantToken = "OK" - - loc := server.NewLocal(handler.Map{ - "Test": handler.New(func(ctx context.Context) (string, error) { - return wantResponse, nil - }), - }, &server.LocalOptions{ - // Enable auth checking and context decoding for the server. - Server: &jrpc2.ServerOptions{ - DecodeContext: jctx.Decode, - CheckRequest: func(ctx context.Context, req *jrpc2.Request) error { - var token []byte - switch err := jctx.UnmarshalMetadata(ctx, &token); err { - case nil: - t.Logf("Metadata present: value=%q", string(token)) - case jctx.ErrNoMetadata: - t.Log("Metadata not set") - default: - return err - } - if s := string(token); s != wantToken { - return jrpc2.Errorf(notAuthorized, "not authorized") - } - return nil - }, - }, - - // Enable context encoding for the client. - Client: &jrpc2.ClientOptions{ - EncodeContext: jctx.Encode, - }, - }) - defer loc.Close() - c := loc.Client - - // Call without a token and verify that we get an error. - t.Run("NoToken", func(t *testing.T) { - var rsp string - err := c.CallResult(context.Background(), "Test", nil, &rsp) - if err == nil { - t.Errorf("Call(Test): got %q, wanted error", rsp) - } else if ec := code.FromError(err); ec != notAuthorized { - t.Errorf("Call(Test): got code %v, want %v", ec, notAuthorized) - } - }) - - // Call with a valid token and verify that we get a response. - t.Run("GoodToken", func(t *testing.T) { - ctx, err := jctx.WithMetadata(context.Background(), []byte(wantToken)) - if err != nil { - t.Fatalf("Call(Test): attaching metadata: %v", err) - } - var rsp string - if err := c.CallResult(ctx, "Test", nil, &rsp); err != nil { - t.Errorf("Call(Test): unexpected error: %v", err) - } - if rsp != wantResponse { - t.Errorf("Call(Test): got %q, want %q", rsp, wantResponse) - } - }) - - // Call with an invalid token and verify that we get an error. - t.Run("BadToken", func(t *testing.T) { - ctx, err := jctx.WithMetadata(context.Background(), []byte("BAD")) - if err != nil { - t.Fatalf("Call(Test): attaching metadata: %v", err) - } - var rsp string - if err := c.CallResult(ctx, "Test", nil, &rsp); err == nil { - t.Errorf("Call(Test): got %q, wanted error", rsp) - } else if ec := code.FromError(err); ec != notAuthorized { - t.Errorf("Call(Test): got code %v, want %v", ec, notAuthorized) - } - }) -} - // Verify that calling a wrapped method which takes no parameters, but in which // the caller provided parameters, will correctly report an error. func TestHandler_noParams(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(handler.Map{"Test": testOK}, nil) defer loc.Close() @@ -993,6 +1109,8 @@ func TestHandler_noParams(t *testing.T) { // Verify that the rpc.serverInfo handler and client wrapper work together. func TestRPCServerInfo(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(handler.Map{"Test": testOK}, nil) defer loc.Close() @@ -1040,6 +1158,8 @@ func TestNetwork(t *testing.T) { // Verify that the context passed to an assigner has the correct structure. func TestHandler_assignContext(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(assignFunc(func(ctx context.Context, method string) jrpc2.Handler { req := jrpc2.InboundRequest(ctx) if req == nil { @@ -1065,9 +1185,10 @@ func TestHandler_assignContext(t *testing.T) { type assignFunc func(context.Context, string) jrpc2.Handler func (a assignFunc) Assign(ctx context.Context, m string) jrpc2.Handler { return a(ctx, m) } -func (assignFunc) Names() []string { return nil } func TestServer_WaitStatus(t *testing.T) { + defer leaktest.Check(t)() + check := func(t *testing.T, stat jrpc2.ServerStatus, closed, stopped bool, wantErr error) { t.Helper() if got, want := stat.Success(), wantErr == nil; got != want { @@ -1113,6 +1234,8 @@ func (b buggyChannel) Recv() ([]byte, error) { return []byte(b.data), b.err } func (buggyChannel) Close() error { return nil } func TestRequest_strictFields(t *testing.T) { + defer leaktest.Check(t)() + type other struct { C bool `json:"charlie"` } @@ -1121,56 +1244,97 @@ func TestRequest_strictFields(t *testing.T) { B int `json:"bravo"` other } - type result struct { - X string `json:"xray"` - } loc := server.NewLocal(handler.Map{ - "Test": handler.New(func(ctx context.Context, req *jrpc2.Request) (interface{}, error) { - var ps, qs params - - if err := req.UnmarshalParams(jrpc2.StrictFields(&ps)); err == nil { - t.Errorf("Unmarshal strict: got %+v, want error", ps) + "Strict": handler.New(func(ctx context.Context, req *jrpc2.Request) (string, error) { + var ps params + if err := req.UnmarshalParams(jrpc2.StrictFields(&ps)); err != nil { + return "", err } - - if err := req.UnmarshalParams(&qs); err != nil { - t.Errorf("Unmarshal non-strict (default): unexpected error: %v", err) + return ps.A, nil + }), + "Normal": handler.New(func(ctx context.Context, req *jrpc2.Request) (string, error) { + var ps params + if err := req.UnmarshalParams(&ps); err != nil { + return "", err } - - return map[string]string{ - "xray": "ok", - "gamma": "not ok", - }, nil + return ps.A, nil }), }, nil) defer loc.Close() - ctx := context.Background() - rsp, err := loc.Client.Call(ctx, "Test", handler.Obj{ - "alpha": "foo", - "bravo": 25, - "charlie": true, // exercise embedding - "delta": 31.5, // unknown field - }) + + tests := []struct { + method string + params interface{} + code code.Code + want string + }{ + {"Strict", handler.Obj{"alpha": "aiuto"}, code.NoError, "aiuto"}, + {"Strict", handler.Obj{"alpha": "selva me", "charlie": true}, code.NoError, "selva me"}, + {"Strict", handler.Obj{"alpha": "OK", "nonesuch": true}, code.InvalidParams, ""}, + {"Normal", handler.Obj{"alpha": "OK", "nonesuch": true}, code.NoError, "OK"}, + } + for _, test := range tests { + name := test.method + "/" + if test.code == code.NoError { + name += "OK" + } else { + name += test.code.String() + } + t.Run(name, func(t *testing.T) { + var res string + err := loc.Client.CallResult(ctx, test.method, test.params, &res) + if err == nil && test.code != code.NoError { + t.Errorf("CallResult: got %+v, want error code %v", res, test.code) + } else if err != nil { + if c := code.FromError(err); c != test.code { + t.Errorf("CallResult: got error %v, wanted code %v", err, test.code) + } + } else if res != test.want { + t.Errorf("CallResult: got %#q, want %#q", res, test.want) + } + }) + } +} + +func TestResponse_strictFields(t *testing.T) { + defer leaktest.Check(t)() + + type result struct { + A string `json:"alpha"` + } + loc := server.NewLocal(handler.Map{ + "Test": handler.New(func(ctx context.Context, req *jrpc2.Request) handler.Obj { + return handler.Obj{"alpha": "OK", "bravo": "not OK"} + }), + }, nil) + defer loc.Close() + ctx := context.Background() + + res, err := loc.Client.Call(ctx, "Test", nil) if err != nil { t.Fatalf("Call failed: %v", err) } - t.Run("NonStrictResult", func(t *testing.T) { - var res result - if err := rsp.UnmarshalResult(&res); err != nil { - t.Errorf("UnmarshalResult: %v", err) + t.Run("Normal", func(t *testing.T) { + var got result + if err := res.UnmarshalResult(&got); err != nil { + t.Errorf("UnmarshalResult failed: %v", err) + } else if got.A != "OK" { + t.Errorf("Result: got %#q, want OK", got.A) } }) - - t.Run("StrictResult", func(t *testing.T) { - var res result - if err := rsp.UnmarshalResult(jrpc2.StrictFields(&res)); err == nil { - t.Errorf("UnmarshalResult: got %+v, want error", res) + t.Run("Strict", func(t *testing.T) { + var got result + if err := res.UnmarshalResult(jrpc2.StrictFields(&got)); err == nil { + t.Errorf("UnmarshalResult: got %#v, wanted error", got) } }) } func TestServerFromContext(t *testing.T) { + defer leaktest.Check(t)() + var got *jrpc2.Server loc := server.NewLocal(handler.Map{ "Test": handler.New(func(ctx context.Context) error { @@ -1190,6 +1354,8 @@ func TestServerFromContext(t *testing.T) { } func TestServer_newContext(t *testing.T) { + defer leaktest.Check(t)() + // Prepare a context with a test value attached to it, that the handler can // extract to verify that the base context was plumbed in correctly. type ctxKey string diff --git a/vendor/github.com/creachadair/jrpc2/json.go b/vendor/github.com/creachadair/jrpc2/json.go index bbe34b4..141de86 100644 --- a/vendor/github.com/creachadair/jrpc2/json.go +++ b/vendor/github.com/creachadair/jrpc2/json.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2 import ( @@ -70,7 +72,7 @@ func (j *jmessages) parseJSON(data []byte) error { // or array. var msgs []json.RawMessage var batch bool - if len(data) == 0 || data[0] != '[' { + if firstByte(data) != '[' { msgs = append(msgs, nil) if err := json.Unmarshal(data, &msgs[0]); err != nil { return errInvalidRequest @@ -128,7 +130,7 @@ func isValidID(v json.RawMessage) bool { } func (j *jmessage) fail(code code.Code, msg string) { - j.err = Errorf(code, msg) + j.err = &Error{Code: code, Message: msg} } func (j *jmessage) toJSON() ([]byte, error) { @@ -205,7 +207,7 @@ func (j *jmessage) parseJSON(data []byte) error { if !isNull(val) { j.P = val } - if len(j.P) != 0 && j.P[0] != '[' && j.P[0] != '{' { + if fb := firstByte(j.P); fb != 0 && fb != '[' && fb != '{' { j.fail(code.InvalidRequest, "parameters must be array or object") } case "error": @@ -267,6 +269,15 @@ func isNull(msg json.RawMessage) bool { return len(msg) == 4 && msg[0] == 'n' && msg[1] == 'u' && msg[2] == 'l' && msg[3] == 'l' } +// firstByte returns the first non-whitespace byte of data, or 0 if there is none. +func firstByte(data []byte) byte { + clean := bytes.TrimSpace(data) + if len(clean) == 0 { + return 0 + } + return clean[0] +} + // strictFielder is an optional interface that can be implemented by a type to // reject unknown fields when unmarshaling from JSON. If a type does not // implement this interface, unknown fields are ignored. @@ -274,8 +285,8 @@ type strictFielder interface { DisallowUnknownFields() } -// StrictFields wraps a value v to implement the DisallowUnknownFields method, -// requiring unknown fields to be rejected when unmarshaling from JSON. +// StrictFields wraps a value v to require unknown fields to be rejected when +// unmarshaling from JSON. // // For example: // @@ -286,4 +297,8 @@ func StrictFields(v interface{}) interface{} { return &strict{v: v} } type strict struct{ v interface{} } -func (strict) DisallowUnknownFields() {} +func (s *strict) UnmarshalJSON(data []byte) error { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.DisallowUnknownFields() + return dec.Decode(s.v) +} diff --git a/vendor/github.com/creachadair/jrpc2/json_test.go b/vendor/github.com/creachadair/jrpc2/json_test.go index 7ae3f14..b3ffe4f 100644 --- a/vendor/github.com/creachadair/jrpc2/json_test.go +++ b/vendor/github.com/creachadair/jrpc2/json_test.go @@ -1,5 +1,4 @@ -//go:build oldbench -// +build oldbench +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. package jrpc2 diff --git a/vendor/github.com/creachadair/jrpc2/metrics/metrics.go b/vendor/github.com/creachadair/jrpc2/metrics/metrics.go index 6171b69..59573d1 100644 --- a/vendor/github.com/creachadair/jrpc2/metrics/metrics.go +++ b/vendor/github.com/creachadair/jrpc2/metrics/metrics.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Package metrics defines a concurrently-accessible metrics collector. // // A *metrics.M value exports methods to track integer counters and maximum diff --git a/vendor/github.com/creachadair/jrpc2/metrics/metrics_test.go b/vendor/github.com/creachadair/jrpc2/metrics/metrics_test.go index cfebd0e..9f61d09 100644 --- a/vendor/github.com/creachadair/jrpc2/metrics/metrics_test.go +++ b/vendor/github.com/creachadair/jrpc2/metrics/metrics_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package metrics_test import ( diff --git a/vendor/github.com/creachadair/jrpc2/opts.go b/vendor/github.com/creachadair/jrpc2/opts.go index 33ef88b..0793f8a 100644 --- a/vendor/github.com/creachadair/jrpc2/opts.go +++ b/vendor/github.com/creachadair/jrpc2/opts.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2 import ( @@ -23,10 +25,6 @@ type ServerOptions struct { // received and each response or error returned. RPCLog RPCLogger - // Instructs the server to tolerate requests that do not include the - // required "jsonrpc" version marker. - AllowV1 bool - // Instructs the server to allow server callbacks and notifications, a // non-standard extension to the JSON-RPC protocol. If AllowPush is false, // the Notify and Callback methods of the server report errors if called. @@ -55,12 +53,6 @@ type ServerOptions struct { // If unset, context and parameters are used as given. DecodeContext func(context.Context, string, json.RawMessage) (context.Context, json.RawMessage, error) - // If set, this function is called with the context and client request - // (after decoding, if DecodeContext is set) that are to be delivered to the - // handler. If CheckRequest reports a non-nil error, the request fails with - // that error without invoking the handler. - CheckRequest func(ctx context.Context, req *Request) error - // If set, use this value to record server metrics. All servers created // from the same options will share the same metrics collector. If none is // set, an empty collector will be created for each new server. @@ -79,7 +71,6 @@ func (s *ServerOptions) logFunc() func(string, ...interface{}) { return s.Logger.Printf } -func (s *ServerOptions) allowV1() bool { return s != nil && s.AllowV1 } func (s *ServerOptions) allowPush() bool { return s != nil && s.AllowPush } func (s *ServerOptions) allowBuiltin() bool { return s == nil || !s.DisableBuiltin } @@ -106,22 +97,13 @@ func (o *ServerOptions) newContext() func() context.Context { type decoder = func(context.Context, string, json.RawMessage) (context.Context, json.RawMessage, error) -func (s *ServerOptions) decodeContext() (decoder, bool) { +func (s *ServerOptions) decodeContext() decoder { if s == nil || s.DecodeContext == nil { return func(ctx context.Context, method string, params json.RawMessage) (context.Context, json.RawMessage, error) { return ctx, params, nil - }, false + } } - return s.DecodeContext, true -} - -type verifier = func(context.Context, *Request) error - -func (s *ServerOptions) checkRequest() verifier { - if s == nil || s.CheckRequest == nil { - return func(context.Context, *Request) error { return nil } - } - return s.CheckRequest + return s.DecodeContext } func (s *ServerOptions) metrics() *metrics.M { @@ -144,10 +126,6 @@ type ClientOptions struct { // If not nil, send debug text logs here. Logger Logger - // Instructs the client to tolerate responses that do not include the - // required "jsonrpc" version marker. - AllowV1 bool - // If set, this function is called with the context, method name, and // encoded request parameters before the request is sent to the server. // Its return value replaces the request parameters. This allows the client @@ -162,12 +140,17 @@ type ClientOptions struct { OnNotify func(*Request) // If set, this function is called if a request is received from the server. - // If unset, server requests are logged and discarded. At most one - // invocation of this callback will be active at a time. - // Server callbacks are a non-standard extension of JSON-RPC. + // If unset, server requests are logged and discarded. Multiple invocations + // of the callback handler may be active concurrently. + // + // The callback handler can retrieve the client from its context using the + // jrpc2.ClientFromContext function. The context terminates when the client + // is closed. // // If a callback handler panics, the client will recover the panic and // report a system error back to the server describing the error. + // + // Server callbacks are a non-standard extension of JSON-RPC. OnCallback func(context.Context, *Request) (interface{}, error) // If set, this function is called when the context for a request terminates. @@ -186,8 +169,6 @@ func (c *ClientOptions) logFunc() func(string, ...interface{}) { return c.Logger.Printf } -func (c *ClientOptions) allowV1() bool { return c != nil && c.AllowV1 } - type encoder = func(context.Context, string, json.RawMessage) (json.RawMessage, error) func (c *ClientOptions) encodeContext() encoder { @@ -214,15 +195,12 @@ func (c *ClientOptions) handleCancel() func(*Client, *Response) { return c.OnCancel } -func (c *ClientOptions) handleCallback() func(*jmessage) []byte { +func (c *ClientOptions) handleCallback() func(context.Context, *jmessage) []byte { if c == nil || c.OnCallback == nil { return nil } cb := c.OnCallback - return func(req *jmessage) []byte { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - + return func(ctx context.Context, req *jmessage) []byte { // Recover panics from the callback handler to ensure the server gets a // response even if the callback fails without a result. // diff --git a/vendor/github.com/creachadair/jrpc2/queue.go b/vendor/github.com/creachadair/jrpc2/queue.go new file mode 100644 index 0000000..f8a7bfb --- /dev/null +++ b/vendor/github.com/creachadair/jrpc2/queue.go @@ -0,0 +1,64 @@ +// Copyright (C) 2021 Michael J. Fromberger. All Rights Reserved. + +package jrpc2 + +type queue struct { + front, back *entry + free *entry + nelts int +} + +func newQueue() *queue { + sentinel := new(entry) + return &queue{front: sentinel, back: sentinel} +} + +func (q *queue) isEmpty() bool { return q.front.link == nil } +func (q *queue) size() int { return q.nelts } +func (q *queue) reset() { q.front.link = nil; q.back = q.front; q.nelts = 0 } + +func (q *queue) alloc(data jmessages) *entry { + if q.free == nil { + return &entry{data: data} + } + out := q.free + q.free = out.link + out.data = data + out.link = nil + return out +} + +func (q *queue) release(e *entry) { + e.link, q.free = q.free, e + e.data = nil +} + +func (q *queue) each(f func(jmessages)) { + for cur := q.front.link; cur != nil; cur = cur.link { + f(cur.data) + } +} + +func (q *queue) push(m jmessages) { + e := q.alloc(m) + q.back.link = e + q.back = e + q.nelts++ +} + +func (q *queue) pop() jmessages { + out := q.front.link + q.front.link = out.link + if out == q.back { + q.back = q.front + } + q.nelts-- + data := out.data + q.release(out) + return data +} + +type entry struct { + data jmessages + link *entry +} diff --git a/vendor/github.com/creachadair/jrpc2/regression_test.go b/vendor/github.com/creachadair/jrpc2/regression_test.go index 8e23301..5f917f6 100644 --- a/vendor/github.com/creachadair/jrpc2/regression_test.go +++ b/vendor/github.com/creachadair/jrpc2/regression_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2_test import ( @@ -10,11 +12,15 @@ import ( "github.com/creachadair/jrpc2/channel" "github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/server" + "github.com/fortytw2/leaktest" + "github.com/google/go-cmp/cmp" ) // Verify that a notification handler will not deadlock with the dispatcher on // holding the server lock. See: https://github.com/creachadair/jrpc2/issues/27 func TestLockRaceRegression(t *testing.T) { + defer leaktest.Check(t)() + hdone := make(chan struct{}) local := server.NewLocal(handler.Map{ // Do some busy-work and then try to get the server lock, in this case @@ -54,6 +60,8 @@ func TestLockRaceRegression(t *testing.T) { // Verify that if a callback handler panics, the client will report an error // back to the server. See https://github.com/creachadair/jrpc2/issues/41. func TestOnCallbackPanicRegression(t *testing.T) { + defer leaktest.Check(t)() + const panicString = "the devil you say" loc := server.NewLocal(handler.Map{ @@ -89,6 +97,8 @@ func TestOnCallbackPanicRegression(t *testing.T) { // Verify that a duplicate request ID that arrives while a task is in flight // does not cause the existing task to be cancelled. func TestDuplicateIDCancellation(t *testing.T) { + defer leaktest.Check(t)() + tctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -132,7 +142,7 @@ func TestDuplicateIDCancellation(t *testing.T) { // Send the duplicate, which should report an error. send(duplicateReq) - expect(`{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"duplicate request id \"1\""}}`) + expect(`{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"duplicate request ID","data":"1"}}`) // Unblock the handler, which should now complete. If the duplicate request // caused the handler to cancel, it will have logged an error to fail the test. @@ -142,3 +152,43 @@ func TestDuplicateIDCancellation(t *testing.T) { cch.Close() srv.Wait() } + +func TestCheckBatchDuplicateID(t *testing.T) { + defer leaktest.Check(t)() + + srv, cli := channel.Direct() + s := jrpc2.NewServer(handler.Map{ + "Test": testOK, + }, nil).Start(srv) + defer func() { + cli.Close() + if err := s.Wait(); err != nil { + t.Errorf("Server wait: unexpected error: %v", err) + } + }() + + // A batch of requests containing two calls with the same ID. + const input = `[ + {"jsonrpc": "2.0", "id": 1, "method": "Test"}, + {"jsonrpc": "2.0", "id": 1, "method": "Test"}, + {"jsonrpc": "2.0", "id": 2, "method": "Test"} +] +` + const errorReply = `{` + + `"jsonrpc":"2.0",` + + `"id":1,` + + `"error":{"code":-32600,"message":"duplicate request ID","data":"1"}` + + `}` + const want = `[` + errorReply + `,` + errorReply + `,` + `{"jsonrpc":"2.0","id":2,"result":"OK"}]` + + if err := cli.Send([]byte(input)); err != nil { + t.Fatalf("Send %d bytes failed: %v", len(input), err) + } + rsp, err := cli.Recv() + if err != nil { + t.Fatalf("Recv failed: %v", err) + } + if diff := cmp.Diff(want, string(rsp)); diff != "" { + t.Errorf("Server response: (-want, +got)\n%s", diff) + } +} diff --git a/vendor/github.com/creachadair/jrpc2/server.go b/vendor/github.com/creachadair/jrpc2/server.go index eaac58f..735557c 100644 --- a/vendor/github.com/creachadair/jrpc2/server.go +++ b/vendor/github.com/creachadair/jrpc2/server.go @@ -1,7 +1,8 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2 import ( - "container/list" "context" "encoding/json" "errors" @@ -26,14 +27,11 @@ type Server struct { sem *semaphore.Weighted // bounds concurrent execution (default 1) // Configurable settings - allow1 bool // allow v1 requests with no version marker allowP bool // allow server notifications to the client log func(string, ...interface{}) // write debug logs here rpcLog RPCLogger // log RPC requests and responses here newctx func() context.Context // create a new base request context dectx decoder // decode context from request - ckreq verifier // request checking hook - expctx bool // whether to expect request context metrics *metrics.M // metrics collected during execution start time.Time // when Start was called builtin bool // whether built-in rpc.* methods are enabled @@ -42,8 +40,8 @@ type Server struct { nbar sync.WaitGroup // notification barrier (see the dispatch method) err error // error from a previous operation - work *sync.Cond // for signaling message availability - inq *list.List // inbound requests awaiting processing + work chan struct{} // for signaling message availability + inq *queue // inbound requests awaiting processing ch channel.Channel // the channel to the client // For each request ID currently in-flight, this map carries a cancel @@ -67,28 +65,23 @@ func NewServer(mux Assigner, opts *ServerOptions) *Server { if mux == nil { panic("nil assigner") } - dc, exp := opts.decodeContext() s := &Server{ mux: mux, sem: semaphore.NewWeighted(opts.concurrency()), - allow1: opts.allowV1(), allowP: opts.allowPush(), log: opts.logFunc(), rpcLog: opts.rpcLog(), newctx: opts.newContext(), - dectx: dc, - ckreq: opts.checkRequest(), - expctx: exp, + dectx: opts.decodeContext(), mu: new(sync.Mutex), metrics: opts.metrics(), start: opts.startTime(), builtin: opts.allowBuiltin(), - inq: list.New(), + inq: newQueue(), used: make(map[string]context.CancelFunc), call: make(map[string]*Response), callID: 1, } - s.work = sync.NewCond(s.mu) return s } @@ -107,10 +100,14 @@ func (s *Server) Start(c channel.Channel) *Server { if s.start.IsZero() { s.start = time.Now().In(time.UTC) } + s.metrics.Count("rpc.serversActive", 1) // Reset all the I/O structures and start up the workers. s.err = nil + // Reset the signal channel. + s.work = make(chan struct{}, 1) + // s.wg waits for the maintenance goroutines for receiving input and // processing the request queue. In addition, each request in flight adds a // goroutine to s.wg. At server shutdown, s.wg completes when the @@ -157,6 +154,13 @@ func (s *Server) serve() { } } +func (s *Server) signal() { + select { + case s.work <- struct{}{}: + default: + } +} + // nextRequest blocks until a request batch is available and returns a function // that dispatches it to the appropriate handlers. The result is only an error // if the connection failed; errors reported by the handler are reported to the @@ -166,16 +170,18 @@ func (s *Server) serve() { func (s *Server) nextRequest() (func() error, error) { s.mu.Lock() defer s.mu.Unlock() - for s.ch != nil && s.inq.Len() == 0 { - s.work.Wait() + for s.ch != nil && s.inq.isEmpty() { + s.mu.Unlock() + <-s.work + s.mu.Lock() } - if s.ch == nil && s.inq.Len() == 0 { + if s.ch == nil && s.inq.isEmpty() { return nil, s.err } ch := s.ch // capture - next := s.inq.Remove(s.inq.Front()).(jmessages) - s.log("Dequeued request batch of length %d (qlen=%d)", len(next), s.inq.Len()) + next := s.inq.pop() + s.log("Dequeued request batch of length %d (qlen=%d)", len(next), s.inq.size()) // Construct a dispatcher to run the handlers outside the lock. return s.dispatch(next, ch), nil @@ -207,32 +213,35 @@ func (s *Server) dispatch(next jmessages, ch sender) func() error { // Resolve all the task handlers or record errors. start := time.Now() tasks := s.checkAndAssign(next) - last := len(tasks) - 1 // Ensure all notifications already issued have completed; see #24. - s.waitForBarrier(tasks.numValidNotifications()) + todo, notes := tasks.numToDo() + s.waitForBarrier(notes) return func() error { var wg sync.WaitGroup - for i, t := range tasks { + for _, t := range tasks { if t.err != nil { continue // nothing to do here; this task has already failed } - t := t - wg.Add(1) - run := func() { - defer wg.Done() - if t.hreq.IsNotification() { - defer s.nbar.Done() - } + todo-- + if todo == 0 { t.val, t.err = s.invoke(t.ctx, t.m, t.hreq) + if t.hreq.IsNotification() { + s.nbar.Done() + } + break } - if i < last { - go run() - } else { - run() - } + t := t + wg.Add(1) + go func() { + defer wg.Done() + t.val, t.err = s.invoke(t.ctx, t.m, t.hreq) + if t.hreq.IsNotification() { + s.nbar.Done() + } + }() } // Wait for all the handlers to return, then deliver any responses. @@ -269,16 +278,15 @@ func (s *Server) deliver(rsps jmessages, ch sender, elapsed time.Duration) error // records errors for them as appropriate. The caller must hold s.mu. func (s *Server) checkAndAssign(next jmessages) tasks { var ts tasks + var ids []string + dup := make(map[string]*task) // :: id ⇒ first task in batch with id + + // Phase 1: Filter out responses from push calls and check for duplicate + // request ID.s for _, req := range next { fid := fixID(req.ID) - t := &task{ - hreq: &Request{id: fid, method: req.M, params: req.P}, - batch: req.batch, - } id := string(fid) - if req.err != nil { - t.err = req.err // deferred validation error - } else if !req.isRequestOrNotification() && s.call[id] != nil { + if !req.isRequestOrNotification() && s.call[id] != nil { // This is a result or error for a pending push-call. // // N.B. It is important to check for this before checking for @@ -287,24 +295,51 @@ func (s *Server) checkAndAssign(next jmessages) tasks { delete(s.call, id) rsp.ch <- req continue // don't send a reply for this - } else if id != "" && s.used[id] != nil { - t.err = Errorf(code.InvalidRequest, "duplicate request id %q", id) + } else if req.err != nil { + // keep the existing error } else if !s.versionOK(req.V) { - t.err = ErrInvalidVersion - } else if req.M == "" { + req.err = ErrInvalidVersion + } + + t := &task{ + hreq: &Request{id: fid, method: req.M, params: req.P}, + batch: req.batch, + err: req.err, + } + if old := dup[id]; old != nil { + // A previous task already used this ID, fail both. + old.err = errDuplicateID.WithData(id) + t.err = old.err + } else if id != "" && s.used[id] != nil { + // A task from a previous batch already used this ID, fail this one. + t.err = errDuplicateID.WithData(id) + } else if id != "" { + // This is the first task with this ID in the batch. + dup[id] = t + } + ts = append(ts, t) + ids = append(ids, id) + } + + // Phase 2: Assign method handlers and set up contexts. + for i, t := range ts { + id := ids[i] + if t.err != nil { + // deferred validation error + } else if t.hreq.method == "" { t.err = errEmptyMethod } else if s.setContext(t, id) { - t.m = s.assign(t.ctx, req.M) + t.m = s.assign(t.ctx, t.hreq.method) if t.m == nil { - t.err = Errorf(code.MethodNotFound, "no such method %q", req.M) + t.err = errNoSuchMethod.WithData(t.hreq.method) } } if t.err != nil { - s.log("Request check error for %q (params %q): %v", req.M, string(req.P), t.err) + s.log("Request check error for %q (params %q): %v", + t.hreq.method, string(t.hreq.params), t.err) s.metrics.Count("rpc.errors", 1) } - ts = append(ts, t) } return ts } @@ -319,12 +354,6 @@ func (s *Server) setContext(t *task, id string) bool { return false } - // Check request. - if err := s.ckreq(base, t.hreq); err != nil { - t.err = err - return false - } - t.ctx = context.WithValue(base, inboundRequestKey{}, t.hreq) // Store the cancellation for a request that needs a reply, so that we can @@ -361,12 +390,14 @@ func (s *Server) invoke(base context.Context, h Handler, req *Request) (json.Raw // ServerInfo returns an atomic snapshot of the current server info for s. func (s *Server) ServerInfo() *ServerInfo { info := &ServerInfo{ - Methods: s.mux.Names(), - UsesContext: s.expctx, - StartTime: s.start, - Counter: make(map[string]int64), - MaxValue: make(map[string]int64), - Label: make(map[string]interface{}), + Methods: []string{"*"}, + StartTime: s.start, + Counter: make(map[string]int64), + MaxValue: make(map[string]int64), + Label: make(map[string]interface{}), + } + if n, ok := s.mux.(Namer); ok { + info.Methods = n.Names() } s.metrics.Snapshot(metrics.Snapshot{ Counter: info.Counter, @@ -524,7 +555,7 @@ func (s ServerStatus) Success() bool { return s.Err == nil } func (s *Server) WaitStatus() ServerStatus { s.wg.Wait() // Postcondition check. - if s.inq.Len() != 0 { + if !s.inq.isEmpty() { panic("s.inq is not empty at shutdown") } stat := ServerStatus{Err: s.err} @@ -557,8 +588,8 @@ func (s *Server) stop(err error) { // // TODO(@creachadair): We need better tests for this behaviour. var keep jmessages - for cur := s.inq.Front(); cur != nil; cur = s.inq.Front() { - for _, req := range cur.Value.(jmessages) { + s.inq.each(func(cur jmessages) { + for _, req := range cur { if req.isNotification() { keep = append(keep, req) s.log("Retaining notification %p", req) @@ -566,18 +597,17 @@ func (s *Server) stop(err error) { s.cancel(string(req.ID)) } } - s.inq.Remove(cur) - } + }) + s.inq.reset() for _, elt := range keep { - s.inq.PushBack(jmessages{elt}) + s.inq.push(jmessages{elt}) } - s.work.Broadcast() + close(s.work) // Cancel any in-flight requests that made it out of the queue, and // terminate any pending callback invocations. - for id, rsp := range s.call { - delete(s.call, id) - rsp.cancel() + for _, rsp := range s.call { + rsp.cancel() // the waiter will clean up the map } for id, cancel := range s.used { cancel() @@ -591,6 +621,7 @@ func (s *Server) stop(err error) { s.err = err s.ch = nil + s.metrics.Count("rpc.serversActive", -1) } // read is the main receiver loop, decoding requests from the client and adding @@ -620,9 +651,11 @@ func (s *Server) read(ch receiver) { } else if len(in) == 0 { s.pushError(errEmptyBatch) } else { - s.log("Received request batch of size %d (qlen=%d)", len(in), s.inq.Len()) - s.inq.PushBack(in) - s.work.Broadcast() + s.log("Received request batch of size %d (qlen=%d)", len(in), s.inq.size()) + s.inq.push(in) + if s.inq.size() == 1 { // the queue was empty + s.signal() + } } s.mu.Unlock() } @@ -633,9 +666,6 @@ type ServerInfo struct { // The list of method names exported by this server. Methods []string `json:"methods,omitempty"` - // Whether this server understands context wrappers. - UsesContext bool `json:"usesContext"` - // Metric values defined by the evaluation of methods. Counter map[string]int64 `json:"counters,omitempty"` MaxValue map[string]int64 `json:"maxValue,omitempty"` @@ -694,12 +724,7 @@ func (s *Server) cancel(id string) bool { return ok } -func (s *Server) versionOK(v string) bool { - if v == "" { - return s.allow1 // an empty version is OK if the server allows it - } - return v == Version // ... otherwise it must match the spec -} +func (s *Server) versionOK(v string) bool { return v == Version } // A task represents a pending method invocation received by the server. type task struct { @@ -758,12 +783,15 @@ func (ts tasks) responses(rpcLog RPCLogger) jmessages { return rsps } -// numValidNotifications reports the number of elements in ts that are -// syntactically valid notifications. -func (ts tasks) numValidNotifications() (n int) { +// numToDo reports the number of tasks in ts that need to be executed, and the +// number of those that are notifications. +func (ts tasks) numToDo() (todo, notes int) { for _, t := range ts { - if t.err == nil && t.hreq.IsNotification() { - n++ + if t.err == nil { + todo++ + if t.hreq.IsNotification() { + notes++ + } } } return diff --git a/vendor/github.com/creachadair/jrpc2/server/example_test.go b/vendor/github.com/creachadair/jrpc2/server/example_test.go index 03d9164..4357384 100644 --- a/vendor/github.com/creachadair/jrpc2/server/example_test.go +++ b/vendor/github.com/creachadair/jrpc2/server/example_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package server_test import ( diff --git a/vendor/github.com/creachadair/jrpc2/server/local.go b/vendor/github.com/creachadair/jrpc2/server/local.go index d94c5b2..bc73695 100644 --- a/vendor/github.com/creachadair/jrpc2/server/local.go +++ b/vendor/github.com/creachadair/jrpc2/server/local.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package server import ( diff --git a/vendor/github.com/creachadair/jrpc2/server/local_test.go b/vendor/github.com/creachadair/jrpc2/server/local_test.go index f368bd8..c34788a 100644 --- a/vendor/github.com/creachadair/jrpc2/server/local_test.go +++ b/vendor/github.com/creachadair/jrpc2/server/local_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package server_test import ( @@ -9,6 +11,7 @@ import ( "github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/server" + "github.com/fortytw2/leaktest" ) var doDebug = flag.Bool("debug", false, "Enable server and client debugging logs") @@ -24,6 +27,8 @@ func testOpts(t *testing.T) *server.LocalOptions { } func TestLocal(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(make(handler.Map), testOpts(t)) ctx := context.Background() si, err := jrpc2.RPCServerInfo(ctx, loc.Client) @@ -47,6 +52,8 @@ func TestLocal(t *testing.T) { // Test that concurrent callers to a local service do not deadlock. func TestLocalConcurrent(t *testing.T) { + defer leaktest.Check(t)() + loc := server.NewLocal(handler.Map{ "Test": handler.New(func(context.Context) error { return nil }), }, testOpts(t)) diff --git a/vendor/github.com/creachadair/jrpc2/server/loop.go b/vendor/github.com/creachadair/jrpc2/server/loop.go index b76c1f4..1eb4b50 100644 --- a/vendor/github.com/creachadair/jrpc2/server/loop.go +++ b/vendor/github.com/creachadair/jrpc2/server/loop.go @@ -1,7 +1,10 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Package server provides support routines for running jrpc2 servers. package server import ( + "context" "net" "sync" @@ -37,8 +40,9 @@ func (static) Finish(jrpc2.Assigner, jrpc2.ServerStatus) {} // An Accepter obtains client connections from an external source and // constructs channels from them. type Accepter interface { - // Accept accepts a connection and returns a new channel for it. - Accept() (channel.Channel, error) + // Accept blocks until a connection is available, or until ctx ends. + // If a connection is found, Accept returns a new channel for it. + Accept(ctx context.Context) (channel.Channel, error) } // NetAccepter adapts a net.Listener to the Accepter interface, using f as the @@ -52,7 +56,20 @@ type netAccepter struct { newChannel channel.Framing } -func (n netAccepter) Accept() (channel.Channel, error) { +func (n netAccepter) Accept(ctx context.Context) (channel.Channel, error) { + // A net.Listener does not obey a context, so simulate it by closing the + // listener if ctx ends. + ok := make(chan struct{}) + defer close(ok) + go func() { + select { + case <-ctx.Done(): + n.Listener.Close() + case <-ok: + return + } + }() + conn, err := n.Listener.Accept() if err != nil { return nil, err @@ -64,10 +81,10 @@ func (n netAccepter) Accept() (channel.Channel, error) { // service instance returned by newService and the given options. Each server // runs in a new goroutine. // -// If the listener reports an error, the loop will terminate and that error -// will be reported to the caller of Loop once any active servers have -// returned. -func Loop(lst Accepter, newService func() Service, opts *LoopOptions) error { +// If lst is closed or otherwise reports an error, the loop will terminate. +// The error will be reported to the caller of Loop once any active servers +// have returned. In addition, if ctx ends, any active servers will be stopped. +func Loop(ctx context.Context, lst Accepter, newService func() Service, opts *LoopOptions) error { serverOpts := opts.serverOpts() log := func(string, ...interface{}) {} if serverOpts != nil && serverOpts.Logger != nil { @@ -76,7 +93,7 @@ func Loop(lst Accepter, newService func() Service, opts *LoopOptions) error { var wg sync.WaitGroup for { - ch, err := lst.Accept() + ch, err := lst.Accept(ctx) if err != nil { if channel.IsErrClosing(err) { err = nil @@ -89,13 +106,20 @@ func Loop(lst Accepter, newService func() Service, opts *LoopOptions) error { wg.Add(1) go func() { defer wg.Done() + svc := newService() assigner, err := svc.Assigner() if err != nil { log("Service initialization failed: %v", err) return } + + sctx, cancel := context.WithCancel(ctx) + defer cancel() + srv := jrpc2.NewServer(assigner, serverOpts).Start(ch) + go func() { <-sctx.Done(); srv.Stop() }() + stat := srv.WaitStatus() svc.Finish(assigner, stat) if stat.Err != nil { diff --git a/vendor/github.com/creachadair/jrpc2/server/loop_test.go b/vendor/github.com/creachadair/jrpc2/server/loop_test.go index cf2361b..6434309 100644 --- a/vendor/github.com/creachadair/jrpc2/server/loop_test.go +++ b/vendor/github.com/creachadair/jrpc2/server/loop_test.go @@ -1,4 +1,6 @@ -package server +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + +package server_test import ( "context" @@ -11,12 +13,14 @@ import ( "github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2/channel" "github.com/creachadair/jrpc2/handler" + "github.com/creachadair/jrpc2/server" + "github.com/fortytw2/leaktest" ) var newChan = channel.Line // A static test service that returns the same thing each time. -var testService = Static(handler.Map{ +var testStatic = server.Static(handler.Map{ "Test": handler.New(func(context.Context) (string, error) { return "OK", nil }), @@ -29,8 +33,8 @@ type testSession struct { nCall int } -func newTestSession(t *testing.T) func() Service { - return func() Service { t.Helper(); return &testSession{t: t} } +func newTestSession(t *testing.T) func() server.Service { + return func() server.Service { t.Helper(); return &testSession{t: t} } } func (t *testSession) Assigner() (jrpc2.Assigner, error) { @@ -81,27 +85,27 @@ func mustDial(t *testing.T, addr string) *jrpc2.Client { return jrpc2.NewClient(newChan(conn, conn), nil) } -func mustServe(t *testing.T, lst net.Listener, newService func() Service) <-chan struct{} { +func mustServe(t *testing.T, ctx context.Context, lst net.Listener, newService func() server.Service) <-chan error { t.Helper() - sc := make(chan struct{}) + acc := server.NetAccepter(lst, newChan) + errc := make(chan error, 1) go func() { - defer close(sc) + defer close(errc) // Start a server loop to accept connections from the clients. This should // exit cleanly once all the clients have finished and the listener closes. - lst := NetAccepter(lst, newChan) - if err := Loop(lst, newService, nil); err != nil { - t.Errorf("Loop: unexpected failure: %v", err) - } + errc <- server.Loop(ctx, acc, newService, nil) }() - return sc + return errc } // Test that sequential clients against the same server work sanely. func TestSeq(t *testing.T) { + defer leaktest.Check(t)() + lst := mustListen(t) addr := lst.Addr().String() - sc := mustServe(t, lst, testService) + errc := mustServe(t, context.Background(), lst, testStatic) for i := 0; i < 5; i++ { cli := mustDial(t, addr) @@ -114,16 +118,82 @@ func TestSeq(t *testing.T) { cli.Close() } lst.Close() - <-sc + if err := <-errc; err != nil { + t.Errorf("Server exit failed: %v", err) + } +} + +// Test that context plumbing works properly. +func TestLoop_cancelContext(t *testing.T) { + defer leaktest.Check(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + lst := mustListen(t) + defer lst.Close() + errc := mustServe(t, ctx, lst, testStatic) + + time.AfterFunc(50*time.Millisecond, cancel) + select { + case err := <-errc: + if err != nil { + t.Errorf("Loop exit reported error: %v", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Loop did not exit in a timely manner after cancellation") + } +} + +// Test that cancelling a loop stops its servers. +func TestLoop_cancelServers(t *testing.T) { + defer leaktest.Check(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ready := make(chan struct{}) + + lst := mustListen(t) + errc := mustServe(t, ctx, lst, server.Static(handler.Map{ + "Test": handler.New(func(ctx context.Context) error { + // Signal readiness then block until cancelled. + // The server will cancel this method when stopped. + close(ready) + <-ctx.Done() + return ctx.Err() + }), + })) + cli := mustDial(t, lst.Addr().String()) + defer cli.Close() + + // Issue a call to the server that will block until the server cancels the + // handler at shutdown. If the server blocks after cancellation, it means it + // is not correctly stopping its active servers. + go cli.Call(context.Background(), "Test", nil) + + <-ready + cancel() // this should stop the loop and the server + + select { + case err := <-errc: + if err != nil { + t.Errorf("Loop result: %v", err) + } + case <-time.After(1 * time.Second): + t.Error("Loop did not exit in a timely manner after cancellation") + } } // Test that concurrent clients against the same server work sanely. func TestLoop(t *testing.T) { + defer leaktest.Check(t)() + tests := []struct { desc string - cons func() Service + cons func() server.Service }{ - {"StaticService", testService}, + {"StaticService", testStatic}, {"SessionStateService", newTestSession(t)}, } const numClients = 5 @@ -133,7 +203,7 @@ func TestLoop(t *testing.T) { t.Run(test.desc, func(t *testing.T) { lst := mustListen(t) addr := lst.Addr().String() - sc := mustServe(t, lst, test.cons) + errc := mustServe(t, context.Background(), lst, test.cons) // Start a bunch of clients, each of which will dial the server and make // some calls at random intervals to tickle the race detector. @@ -162,7 +232,9 @@ func TestLoop(t *testing.T) { // the service loop will stop. wg.Wait() lst.Close() - <-sc + if err := <-errc; err != nil { + t.Errorf("Server exit failed: %v", err) + } }) } } diff --git a/vendor/github.com/creachadair/jrpc2/server/run.go b/vendor/github.com/creachadair/jrpc2/server/run.go index 6f20fbe..90a603e 100644 --- a/vendor/github.com/creachadair/jrpc2/server/run.go +++ b/vendor/github.com/creachadair/jrpc2/server/run.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package server import ( diff --git a/vendor/github.com/creachadair/jrpc2/server/run_test.go b/vendor/github.com/creachadair/jrpc2/server/run_test.go index b8060fe..822feff 100644 --- a/vendor/github.com/creachadair/jrpc2/server/run_test.go +++ b/vendor/github.com/creachadair/jrpc2/server/run_test.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package server_test import ( diff --git a/vendor/github.com/creachadair/jrpc2/special.go b/vendor/github.com/creachadair/jrpc2/special.go index 03e29e3..6ae0f7b 100644 --- a/vendor/github.com/creachadair/jrpc2/special.go +++ b/vendor/github.com/creachadair/jrpc2/special.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + package jrpc2 import ( diff --git a/vendor/github.com/creachadair/jrpc2/tools/examples/adder/adder.go b/vendor/github.com/creachadair/jrpc2/tools/examples/adder/adder.go index d3ce66f..00a1386 100644 --- a/vendor/github.com/creachadair/jrpc2/tools/examples/adder/adder.go +++ b/vendor/github.com/creachadair/jrpc2/tools/examples/adder/adder.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Program adder demonstrates a trivial JSON-RPC server that communicates over // the process's stdin and stdout. // @@ -22,7 +24,7 @@ import ( ) // Add will be exported as a method named "Add". -func Add(ctx context.Context, vs ...int) int { +func Add(ctx context.Context, vs []int) int { sum := 0 for _, v := range vs { sum += v diff --git a/vendor/github.com/creachadair/jrpc2/tools/examples/client/client.go b/vendor/github.com/creachadair/jrpc2/tools/examples/client/client.go index da56f35..6a7ec49 100644 --- a/vendor/github.com/creachadair/jrpc2/tools/examples/client/client.go +++ b/vendor/github.com/creachadair/jrpc2/tools/examples/client/client.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Program client demonstrates how to set up a JSON-RPC 2.0 client using the // github.com/creachadair/jrpc2 package. // diff --git a/vendor/github.com/creachadair/jrpc2/tools/examples/http/server.go b/vendor/github.com/creachadair/jrpc2/tools/examples/http/server.go index 56e107d..a67bd9a 100644 --- a/vendor/github.com/creachadair/jrpc2/tools/examples/http/server.go +++ b/vendor/github.com/creachadair/jrpc2/tools/examples/http/server.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Program http demonstrates how to set up a JSON-RPC 2.0 server using the // github.com/creachadair/jrpc2 package with an HTTP transport. // @@ -45,6 +47,6 @@ func main() { log.Fatal(http.ListenAndServe(*listenAddr, nil)) } -func ping(ctx context.Context, msg ...string) string { +func ping(ctx context.Context, msg []string) string { return "OK: " + strings.Join(msg, "|") } diff --git a/vendor/github.com/creachadair/jrpc2/tools/examples/server/server.go b/vendor/github.com/creachadair/jrpc2/tools/examples/server/server.go index e5db2e0..05cdb73 100644 --- a/vendor/github.com/creachadair/jrpc2/tools/examples/server/server.go +++ b/vendor/github.com/creachadair/jrpc2/tools/examples/server/server.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Program server demonstrates how to set up a JSON-RPC 2.0 server using the // github.com/creachadair/jrpc2 package. // @@ -97,7 +99,8 @@ func main() { } log.Printf("Listening at %v...", lst.Addr()) acc := server.NetAccepter(lst, channel.Line) - server.Loop(acc, server.Static(mux), &server.LoopOptions{ + ctx := context.Background() + server.Loop(ctx, acc, server.Static(mux), &server.LoopOptions{ ServerOptions: &jrpc2.ServerOptions{ Logger: jrpc2.StdLogger(nil), Concurrency: *maxTasks, diff --git a/vendor/github.com/creachadair/jrpc2/tools/examples/wshttp/server.go b/vendor/github.com/creachadair/jrpc2/tools/examples/wshttp/server.go index 9b009ef..5102a1c 100644 --- a/vendor/github.com/creachadair/jrpc2/tools/examples/wshttp/server.go +++ b/vendor/github.com/creachadair/jrpc2/tools/examples/wshttp/server.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Program wshttp demonstrates how to set up a JSON-RPC 20 server using // the github.com/creachadair/jrpc2 package with a Websocket transport. // @@ -34,19 +36,18 @@ func main() { http.Handle("/rpc", lst) go hs.ListenAndServe() - acc := accepter{ - Listener: lst, - ctx: context.Background(), - } - svc := handler.Map{"Reverse": handler.New(reverse)} + svc := server.Static(handler.Map{ + "Reverse": handler.New(reverse), + }) - log.Printf("Listing at ws://%s/rpc", *listenAddr) - err := server.Loop(acc, server.Static(svc), &server.LoopOptions{ + log.Printf("Listening at ws://%s/rpc", *listenAddr) + ctx := context.Background() + err := server.Loop(ctx, accepter{lst}, svc, &server.LoopOptions{ ServerOptions: &jrpc2.ServerOptions{ Logger: jrpc2.StdLogger(nil), }, }) - hs.Shutdown(acc.ctx) + hs.Shutdown(ctx) if err != nil { log.Fatalf("Loop exited: %v", err) } @@ -60,11 +61,8 @@ func reverse(_ context.Context, ss []string) []string { return ss } -type accepter struct { - *wschannel.Listener - ctx context.Context -} +type accepter struct{ *wschannel.Listener } -func (a accepter) Accept() (channel.Channel, error) { - return a.Listener.Accept(a.ctx) +func (a accepter) Accept(ctx context.Context) (channel.Channel, error) { + return a.Listener.Accept(ctx) } diff --git a/vendor/github.com/creachadair/jrpc2/tools/go.mod b/vendor/github.com/creachadair/jrpc2/tools/go.mod index 465ade4..84f7ba6 100644 --- a/vendor/github.com/creachadair/jrpc2/tools/go.mod +++ b/vendor/github.com/creachadair/jrpc2/tools/go.mod @@ -3,8 +3,8 @@ module github.com/creachadair/jrpc2/tools go 1.17 require ( - github.com/creachadair/jrpc2 v0.30.3 - github.com/creachadair/wschannel v0.0.0-20211118152247-10d58f4f0def + github.com/creachadair/jrpc2 v0.35.4 + github.com/creachadair/wschannel v0.0.0-20220126134344-769774727b29 ) require ( diff --git a/vendor/github.com/creachadair/jrpc2/tools/go.sum b/vendor/github.com/creachadair/jrpc2/tools/go.sum index a9443f1..35f8f0a 100644 --- a/vendor/github.com/creachadair/jrpc2/tools/go.sum +++ b/vendor/github.com/creachadair/jrpc2/tools/go.sum @@ -1,9 +1,11 @@ -github.com/creachadair/jrpc2 v0.30.3 h1:fz8xYfTmIgxJXvr9HAoz0XBOpNklyixE7Hnh6iQP/4o= -github.com/creachadair/jrpc2 v0.30.3/go.mod h1:w+GXZGc+NwsH0xsUOgeLBIIRM0jBOSTXhv28KaWGRZU= -github.com/creachadair/wschannel v0.0.0-20211118152247-10d58f4f0def h1:FV0vHCqItsi0b3LwaEKyxj0su3VKdvbenCOkXnCAXnI= -github.com/creachadair/wschannel v0.0.0-20211118152247-10d58f4f0def/go.mod h1:/9Csuxj8r9h0YXexL0WmkahIhd85BleYWz7nt42ZgDc= -github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= -github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/creachadair/jrpc2 v0.35.4 h1:5ELLV7CMKLfALzkKNsQ//ngZLWDbEmAXgTgkL3JXAcU= +github.com/creachadair/jrpc2 v0.35.4/go.mod h1:a53Cer/NMD1y8P9UB2XbuOLRELKRLDf8u7bRi4v1qsE= +github.com/creachadair/wschannel v0.0.0-20220126134344-769774727b29 h1:EtcZoRTuhqCedRtvfUrzuyrsT53RWNN7xZOE9lljDw0= +github.com/creachadair/wschannel v0.0.0-20220126134344-769774727b29/go.mod h1:xFi56wWYs7X0OlNzbtz/yzLCuN3a8Hf36QALYnAsO0o= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= +github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= diff --git a/vendor/github.com/creachadair/jrpc2/tools/jcall/jcall.go b/vendor/github.com/creachadair/jrpc2/tools/jcall/jcall.go index 727db25..f4d64a7 100644 --- a/vendor/github.com/creachadair/jrpc2/tools/jcall/jcall.go +++ b/vendor/github.com/creachadair/jrpc2/tools/jcall/jcall.go @@ -1,3 +1,5 @@ +// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. + // Program jcall issues RPC calls to a JSON-RPC server. // // Usage: diff --git a/walletapi/cipher_test.go b/walletapi/cipher_test.go index 99fd997..1642173 100644 --- a/walletapi/cipher_test.go +++ b/walletapi/cipher_test.go @@ -42,3 +42,44 @@ func Test_AEAD_Cipher(t *testing.T) { t.Fatalf("AEAD cipher encryption/decryption failed") } } + +// functional test whether the wrappers are okay +func Test_Signed_Message_Cipher(t *testing.T) { + + wsrc, err := Create_Encrypted_Wallet_From_Recovery_Words_Memory("", "sequence atlas unveil summon pebbles tuesday beer rudely snake rockets different fuselage woven tagged bested dented vegan hover rapid fawns obvious muppet randomly seasons randomly") + if err != nil { + t.Fatalf("Cannot create encrypted wallet, err %s", err) + } + + result := wsrc.SignData([]byte("HELLO")) + + signer, message, err := wsrc.CheckSignature(result) + if err != nil { + t.Fatalf("Cannot check signature, err %s", err) + } + if string(message) != "HELLO" { + t.Fatalf("Message corruption") + } + if signer.String() != wsrc.GetAddress().String() { + t.Fatalf("Address corruption") + } + + // make sure other wallets can also verify the signatures + + w2, err := Create_Encrypted_Wallet_From_Recovery_Words_Memory("", "perfil lujo faja puma favor pedir detalle doble carbón neón paella cuarto ánimo cuento conga correr dental moneda león donar entero logro realidad acceso doble") + if err != nil { + t.Fatalf("Cannot create encrypted wallet, err %s", err) + } + + signer, message, err = w2.CheckSignature(result) + if err != nil { + t.Fatalf("Cannot check signature, err %s", err) + } + if string(message) != "HELLO" { + t.Fatalf("Message corruption") + } + if signer.String() != wsrc.GetAddress().String() { + t.Fatalf("Address corruption") + } + +} diff --git a/walletapi/daemon_communication.go b/walletapi/daemon_communication.go index 18b9f95..85930b7 100644 --- a/walletapi/daemon_communication.go +++ b/walletapi/daemon_communication.go @@ -165,6 +165,7 @@ func test_connectivity() (err error) { if info.Testnet != !globals.IsMainnet() { err = fmt.Errorf("Mainnet/TestNet is different between wallet/daemon.Please run daemon/wallet without --testnet") logger.Error(err, "Mainnet/Testnet mismatch") + fmt.Printf("Mainnet/Testnet mismatch\n") return } @@ -187,8 +188,23 @@ func (w *Wallet_Memory) sync_loop() { } - err := w.Sync_Wallet_Memory_With_Daemon() // sync with the daemon - logger.V(1).Error(err, "wallet syncing err", err) + if IsDaemonOnline() && test_connectivity() != nil { + time.Sleep(timeout) // wait 5 seconds + continue + } + + var zerohash crypto.Hash + if len(w.account.EntriesNative) == 0 { + err := w.Sync_Wallet_Memory_With_Daemon() + logger.V(1).Error(err, "wallet syncing err", err) + } else { + for k := range w.account.EntriesNative { + err := w.Sync_Wallet_Memory_With_Daemon_internal(k) + if k == zerohash && err != nil { + logger.V(1).Error(err, "wallet syncing err", err) + } + } + } time.Sleep(timeout) // wait 5 seconds } @@ -456,6 +472,12 @@ func (w *Wallet_Memory) GetDecryptedBalanceAtTopoHeight(scid crypto.Hash, topohe return 0, 0, err } + if w.account.EntriesNative != nil { + if _, ok := w.account.EntriesNative[scid]; !ok { //if we could obtain something, try tracking + w.account.EntriesNative[scid] = []rpc.Entry{} + } + } + return w.DecodeEncryptedBalance_Memory(encrypted_balance, 0), noncetopo, nil } diff --git a/walletapi/mnemonics/mnemonics.go b/walletapi/mnemonics/mnemonics.go index 91d709f..b4f5440 100644 --- a/walletapi/mnemonics/mnemonics.go +++ b/walletapi/mnemonics/mnemonics.go @@ -78,7 +78,7 @@ func Words_To_Key(words_line string) (language_name string, keybig *big.Int, err //rlog.Tracef(1, "len of words %d", words) // if seed size is not 24 or 25, return err - if len(words) != SEED_LENGTH && len(words) != (SEED_LENGTH+1) { + if /*len(words) != SEED_LENGTH &&*/ len(words) != (SEED_LENGTH + 1) { err = fmt.Errorf("Invalid Seed") return } diff --git a/walletapi/rpcserver/rpc_websocket_server.go b/walletapi/rpcserver/rpc_websocket_server.go index f5c6f79..c2e3d6a 100644 --- a/walletapi/rpcserver/rpc_websocket_server.go +++ b/walletapi/rpcserver/rpc_websocket_server.go @@ -56,6 +56,8 @@ type RPCServer struct { srv *http.Server mux *http.ServeMux logger logr.Logger + user string + password string Exit_Event chan bool // blockchain is shutting down and we must quit ASAP sync.RWMutex } @@ -69,6 +71,13 @@ func RPCServer_Start(wallet *walletapi.Wallet_Disk, title string) (*RPCServer, e r.Exit_Event = make(chan bool) + if globals.Arguments["--rpc-login"] != nil { // this was verified at startup + userpass := globals.Arguments["--rpc-login"].(string) + parts := strings.SplitN(userpass, ":", 2) + r.user = parts[0] + r.password = parts[1] + } + go r.Run(wallet) atomic.AddUint32(&globals.Subsystem_Active, 1) // increment subsystem @@ -91,6 +100,34 @@ func (r *RPCServer) RPCServer_Stop() { atomic.AddUint32(&globals.Subsystem_Active, ^uint32(0)) // this decrement 1 fom subsystem } +// check basic authrizaion +func hasbasicauthfailed(rpcserver *RPCServer, w http.ResponseWriter, r *http.Request) bool { + if rpcserver.user == "" { + return false + } + u, p, ok := r.BasicAuth() + if !ok { + w.WriteHeader(401) + io.WriteString(w, "Authorization Required") + return true + } + if u != rpcserver.user || p != rpcserver.password { + w.WriteHeader(401) + io.WriteString(w, "Authorization Required") + return true + } + + if globals.Arguments["--allow-rpc-password-change"] != nil && globals.Arguments["--allow-rpc-password-change"].(bool) == true { + + if r.Header.Get("Pass") != "" { + rpcserver.password = r.Header.Get("Pass") + } + } + + return false + +} + // setup handlers func (rpcserver *RPCServer) Run(wallet *walletapi.Wallet_Disk) { @@ -131,6 +168,10 @@ func (rpcserver *RPCServer) Run(wallet *walletapi.Wallet_Disk) { var bridge = jhttp.NewBridge(wallet_handler, &jhttp.BridgeOptions{Server: options}) translate_http_to_jsonrpc_and_vice_versa := func(w http.ResponseWriter, r *http.Request) { + + if hasbasicauthfailed(rpcserver, w, r) { + return + } bridge.ServeHTTP(w, r) } @@ -144,6 +185,9 @@ func (rpcserver *RPCServer) Run(wallet *walletapi.Wallet_Disk) { client_connections.Delete(ws_server) } }() + if hasbasicauthfailed(rpcserver, w, r) { + return + } c, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -167,6 +211,10 @@ func (rpcserver *RPCServer) Run(wallet *walletapi.Wallet_Disk) { rpcserver.mux.HandleFunc("/install_sc", func(w http.ResponseWriter, req *http.Request) { // translate call internally, how to do it using a single json request var p rpc.Transfer_Params + if hasbasicauthfailed(rpcserver, w, req) { + return + } + b, err := ioutil.ReadAll(req.Body) defer req.Body.Close() if err != nil { diff --git a/walletapi/transaction_build.go b/walletapi/transaction_build.go index 7743953..7ea6d5b 100644 --- a/walletapi/transaction_build.go +++ b/walletapi/transaction_build.go @@ -136,7 +136,7 @@ rebuild_tx: value := transfers[t].Amount burn_value := transfers[t].Burn if fees == 0 && asset.SCID.IsZero() && !fees_done { - fees = fees + uint64(len(transfers)+2)*uint64((float64(config.FEE_PER_KB)*float64(w.GetFeeMultiplier()))) + fees = fees + uint64(len(transfers)+2)*uint64((float64(config.FEE_PER_KB)*float64(float32(len(publickeylist)/16)+w.GetFeeMultiplier()))) if data, err := scdata.MarshalBinary(); err != nil { panic(err) } else { diff --git a/walletapi/tx_creation_test.go b/walletapi/tx_creation_test.go index ee6b0f4..d366569 100644 --- a/walletapi/tx_creation_test.go +++ b/walletapi/tx_creation_test.go @@ -80,7 +80,7 @@ func simulator_chain_start() (*blockchain.Blockchain, *derodrpc.RPCServer, map[s OptionsFirst: true, } - globals.Arguments, err = parser.ParseArgs(command_line, []string{"--data-dir", tmpdirectory, "--rpc-bind", rpcport}, config.Version.String()) + globals.Arguments, err = parser.ParseArgs(command_line, []string{"--data-dir", tmpdirectory, "--rpc-bind", rpcport, "--testnet"}, config.Version.String()) if err != nil { //log.Fatalf("Error while parsing options err: %s\n", err) return nil, nil, nil diff --git a/walletapi/wallet.go b/walletapi/wallet.go index 960552b..9b2954d 100644 --- a/walletapi/wallet.go +++ b/walletapi/wallet.go @@ -23,6 +23,7 @@ import "strings" import "math/big" import "crypto/rand" +import "encoding/pem" import "encoding/binary" import "github.com/go-logr/logr" @@ -122,7 +123,7 @@ func (w *Wallet_Memory) InsertReplace(scid crypto.Hash, e rpc.Entry) { // generate keys from using random numbers func Generate_Keys_From_Random() (user *Account, err error) { - user = &Account{Ringsize: 4, FeesMultiplier: 1.5} + user = &Account{Ringsize: 16, FeesMultiplier: 2.0} seed := crypto.RandomScalarBNRed() user.Keys = Generate_Keys_From_Seed(seed) @@ -142,7 +143,7 @@ func Generate_Keys_From_Seed(Seed *crypto.BNRed) (keys _Keys) { // generate user account using recovery seeds func Generate_Account_From_Recovery_Words(words string) (user *Account, err error) { - user = &Account{Ringsize: 4, FeesMultiplier: 1.5} + user = &Account{Ringsize: 16, FeesMultiplier: 2.0} language, seed, err := mnemonics.Words_To_Key(words) if err != nil { return @@ -155,7 +156,7 @@ func Generate_Account_From_Recovery_Words(words string) (user *Account, err erro } func Generate_Account_From_Seed(Seed *crypto.BNRed) (user *Account, err error) { - user = &Account{Ringsize: 4, FeesMultiplier: 1.5} + user = &Account{Ringsize: 16, FeesMultiplier: 2.0} // TODO check whether the seed is invalid user.Keys = Generate_Keys_From_Seed(Seed) @@ -303,6 +304,7 @@ func (w *Wallet_Memory) Get_Payments_TXID(txid string) (entry rpc.Entry) { } // delete most of the data and prepare for rescan +// TODO we must save tokens list and reuse, them, but will be created on-demand when using shows transfers/or rpc apis func (w *Wallet_Memory) Clean() { //w.account.Entries = w.account.Entries[:0] @@ -526,3 +528,70 @@ func FormatMoneyPrecision(amount uint64, precision int) string { result.Quo(float_amount, hard_coded_decimals) return result.Text('f', precision) // 5 is display precision after floating point } + +// this basically does a Schnorr Signature on random information for registration +func (w *Wallet_Memory) SignData(input []byte) []byte { + var tmppoint bn256.G1 + + tmpsecret := crypto.RandomScalar() + tmppoint.ScalarMult(crypto.G, tmpsecret) + + serialize := []byte(fmt.Sprintf("%s%s%x", w.account.Keys.Public.G1().String(), tmppoint.String(), input)) + + c := crypto.ReducedHash(serialize) + s := new(big.Int).Mul(c, w.account.Keys.Secret.BigInt()) // basicaly scalar mul add + s = s.Mod(s, bn256.Order) + s = s.Add(s, tmpsecret) + s = s.Mod(s, bn256.Order) + + p := &pem.Block{Type: "DERO SIGNED MESSAGE"} + p.Headers = map[string]string{} + p.Headers["Address"] = w.GetAddress().String() + p.Headers["C"] = fmt.Sprintf("%x", c) + p.Headers["S"] = fmt.Sprintf("%x", s) + p.Bytes = input + + return pem.EncodeToMemory(p) +} + +func (w *Wallet_Memory) CheckSignature(input []byte) (signer *rpc.Address, message []byte, err error) { + p, _ := pem.Decode(input) + if p == nil { + err = fmt.Errorf("Unknown format") + return + } + + astr := p.Headers["Address"] + cstr := p.Headers["C"] + sstr := p.Headers["S"] + + addr, err := rpc.NewAddress(astr) + if err != nil { + return + } + + c, ok := new(big.Int).SetString(cstr, 16) + if !ok { + err = fmt.Errorf("Unknown C format") + return + } + + s, ok := new(big.Int).SetString(sstr, 16) + if !ok { + err = fmt.Errorf("Unknown S format") + return + } + + tmppoint := new(bn256.G1).Add(new(bn256.G1).ScalarMult(crypto.G, s), new(bn256.G1).ScalarMult(addr.PublicKey.G1(), new(big.Int).Neg(c))) + serialize := []byte(fmt.Sprintf("%s%s%x", addr.PublicKey.G1().String(), tmppoint.String(), p.Bytes)) + + c_calculated := crypto.ReducedHash(serialize) + if c.String() != c_calculated.String() { + err = fmt.Errorf("signature mismatch") + return + } + + signer = addr + message = p.Bytes + return +} diff --git a/walletapi/wallet_memory.go b/walletapi/wallet_memory.go index 0b3582a..c1a299e 100644 --- a/walletapi/wallet_memory.go +++ b/walletapi/wallet_memory.go @@ -328,6 +328,13 @@ func Generate_Key(k KDF, password string) (key []byte) { } } +func (w *Wallet_Memory) GetAccount() *Account { + if w == nil { + return nil + } + return w.account +} + func (w *Wallet_Memory) save_if_disk() { if w == nil || w.wallet_disk == nil { return