DERO-HE STARGATE Testnet Release42

This commit is contained in:
Captain 2022-02-06 07:06:32 +00:00
parent fb76ca92d4
commit 5be577718b
No known key found for this signature in database
GPG Key ID: 18CDB3ED5E85D2D4
82 changed files with 2510 additions and 726 deletions

View File

@ -277,6 +277,9 @@ func Blockchain_Start(params map[string]interface{}) (*Blockchain, error) {
func (chain *Blockchain) IntegratorAddress() rpc.Address { func (chain *Blockchain) IntegratorAddress() rpc.Address {
return chain.integrator_address return chain.integrator_address
} }
func (chain *Blockchain) SetIntegratorAddress(addr rpc.Address) {
chain.integrator_address = addr
}
// this function is called to read blockchain state from DB // this function is called to read blockchain state from DB
// It is callable at any point in time // It is callable at any point in time
@ -615,6 +618,11 @@ func (chain *Blockchain) Add_Complete_Block(cbl *block.Complete_Block) (err erro
block_logger.Error(fmt.Errorf("Double Registration TX"), "duplicate registration", "txid", cbl.Txs[i].GetHash()) block_logger.Error(fmt.Errorf("Double Registration TX"), "duplicate registration", "txid", cbl.Txs[i].GetHash())
return errormsg.ErrTXDoubleSpend, false return errormsg.ErrTXDoubleSpend, false
} }
tx_hash := cbl.Txs[i].GetHash()
if chain.simulator == false && tx_hash[0] != 0 && tx_hash[1] != 0 {
return fmt.Errorf("Registration TX has not solved PoW"), false
}
reg_map[string(cbl.Txs[i].MinerAddress[:])] = true reg_map[string(cbl.Txs[i].MinerAddress[:])] = true
} }
} }
@ -1112,6 +1120,12 @@ func (chain *Blockchain) Add_TX_To_Pool(tx *transaction.Transaction) error {
return fmt.Errorf("premine tx not mineable") return fmt.Errorf("premine tx not mineable")
} }
if tx.IsRegistration() { // registration tx will not go any forward if tx.IsRegistration() { // registration tx will not go any forward
tx_hash := tx.GetHash()
if chain.simulator == false && tx_hash[0] != 0 && tx_hash[1] != 0 {
return fmt.Errorf("TX doesn't solve Pow")
}
// ggive regpool a chance to register // ggive regpool a chance to register
if ss, err := chain.Store.Balance_store.LoadSnapshot(0); err == nil { if ss, err := chain.Store.Balance_store.LoadSnapshot(0); err == nil {
if balance_tree, err := ss.GetTree(config.BALANCE_TREE); err == nil { if balance_tree, err := ss.GetTree(config.BALANCE_TREE); err == nil {

View File

@ -191,17 +191,19 @@ type DiffProvider interface {
func Get_Difficulty_At_Tips(source DiffProvider, tips []crypto.Hash) *big.Int { func Get_Difficulty_At_Tips(source DiffProvider, tips []crypto.Hash) *big.Int {
var MinimumDifficulty *big.Int var MinimumDifficulty *big.Int
GenesisDifficulty := new(big.Int).SetUint64(1)
if globals.IsMainnet() { if globals.IsMainnet() {
MinimumDifficulty = new(big.Int).SetUint64(config.Settings.MAINNET_MINIMUM_DIFFICULTY) // this must be controllable parameter MinimumDifficulty = new(big.Int).SetUint64(config.Settings.MAINNET_MINIMUM_DIFFICULTY) // this must be controllable parameter
GenesisDifficulty = new(big.Int).SetUint64(config.Settings.MAINNET_BOOTSTRAP_DIFFICULTY)
} else { } else {
MinimumDifficulty = new(big.Int).SetUint64(config.Settings.TESTNET_MINIMUM_DIFFICULTY) // this must be controllable parameter MinimumDifficulty = new(big.Int).SetUint64(config.Settings.TESTNET_MINIMUM_DIFFICULTY) // this must be controllable parameter
GenesisDifficulty = new(big.Int).SetUint64(config.Settings.TESTNET_BOOTSTRAP_DIFFICULTY)
} }
GenesisDifficulty := new(big.Int).SetUint64(1)
if chain, ok := source.(*Blockchain); ok { if chain, ok := source.(*Blockchain); ok {
if chain.simulator == true { if chain.simulator == true {
return GenesisDifficulty return new(big.Int).SetUint64(1)
} }
} }
@ -225,7 +227,7 @@ func Get_Difficulty_At_Tips(source DiffProvider, tips []crypto.Hash) *big.Int {
// until we have atleast 2 blocks, we cannot run the algo // until we have atleast 2 blocks, we cannot run the algo
if height < 3 { if height < 3 {
return MinimumDifficulty return GenesisDifficulty
} }
tip_difficulty := source.Load_Block_Difficulty(tips[0]) tip_difficulty := source.Load_Block_Difficulty(tips[0])

View File

@ -58,12 +58,14 @@ func (chain *Blockchain) install_hardcoded_contracts(cache map[crypto.Hash]*grav
if _, _, err = dvm.ParseSmartContract(source_nameservice); err != nil { if _, _, err = dvm.ParseSmartContract(source_nameservice); err != nil {
logger.Error(err, "error Parsing hard coded sc") logger.Error(err, "error Parsing hard coded sc")
panic(err)
return return
} }
var name crypto.Hash var name crypto.Hash
name[31] = 1 name[31] = 1
if err = chain.install_hardcoded_sc(cache, ss, balance_tree, sc_tree, source_nameservice, name); err != nil { if err = chain.install_hardcoded_sc(cache, ss, balance_tree, sc_tree, source_nameservice, name); err != nil {
panic(err)
return return
} }

View File

@ -30,6 +30,14 @@ type storefs struct {
basedir string basedir string
} }
// TODO we need to enable big support or shift to object store at some point in time
func (s *storefs) getpath(h [32]byte) string {
// if you wish to use 3 level indirection, it will cause 16 million inodes to be used, but system will be faster
//return filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2]))
// currently we are settling on 65536 inodes
return filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]))
}
// the filename stores the following information // the filename stores the following information
// hex block id (64 chars).block._ rewards (decimal) _ difficulty _ cumulative difficulty // hex block id (64 chars).block._ rewards (decimal) _ difficulty _ cumulative difficulty
@ -40,7 +48,7 @@ func (s *storefs) ReadBlock(h [32]byte) ([]byte, error) {
return nil, fmt.Errorf("empty block") return nil, fmt.Errorf("empty block")
} }
dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) dir := s.getpath(h)
files, err := os.ReadDir(dir) files, err := os.ReadDir(dir)
if err != nil { if err != nil {
@ -51,7 +59,7 @@ func (s *storefs) ReadBlock(h [32]byte) ([]byte, error) {
for _, file := range files { for _, file := range files {
if strings.HasPrefix(file.Name(), filename_start) { if strings.HasPrefix(file.Name(), filename_start) {
//fmt.Printf("Reading block with filename %s\n", file.Name()) //fmt.Printf("Reading block with filename %s\n", file.Name())
file := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2]), file.Name()) file := filepath.Join(dir, file.Name())
return os.ReadFile(file) return os.ReadFile(file)
} }
} }
@ -61,7 +69,7 @@ func (s *storefs) ReadBlock(h [32]byte) ([]byte, error) {
// on windows, we see an odd behaviour where some files could not be deleted, since they may exist only in cache // on windows, we see an odd behaviour where some files could not be deleted, since they may exist only in cache
func (s *storefs) DeleteBlock(h [32]byte) error { func (s *storefs) DeleteBlock(h [32]byte) error {
dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) dir := s.getpath(h)
files, err := os.ReadDir(dir) files, err := os.ReadDir(dir)
if err != nil { if err != nil {
@ -72,7 +80,7 @@ func (s *storefs) DeleteBlock(h [32]byte) error {
var found bool var found bool
for _, file := range files { for _, file := range files {
if strings.HasPrefix(file.Name(), filename_start) { if strings.HasPrefix(file.Name(), filename_start) {
file := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2]), file.Name()) file := filepath.Join(dir, file.Name())
err = os.Remove(file) err = os.Remove(file)
if err != nil { if err != nil {
//return err //return err
@ -88,7 +96,7 @@ func (s *storefs) DeleteBlock(h [32]byte) error {
} }
func (s *storefs) ReadBlockDifficulty(h [32]byte) (*big.Int, error) { func (s *storefs) ReadBlockDifficulty(h [32]byte) (*big.Int, error) {
dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) dir := s.getpath(h)
files, err := os.ReadDir(dir) files, err := os.ReadDir(dir)
if err != nil { if err != nil {
@ -122,7 +130,7 @@ func (chain *Blockchain) ReadBlockSnapshotVersion(h [32]byte) (uint64, error) {
return chain.Store.Block_tx_store.ReadBlockSnapshotVersion(h) return chain.Store.Block_tx_store.ReadBlockSnapshotVersion(h)
} }
func (s *storefs) ReadBlockSnapshotVersion(h [32]byte) (uint64, error) { func (s *storefs) ReadBlockSnapshotVersion(h [32]byte) (uint64, error) {
dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) dir := s.getpath(h)
files, err := os.ReadDir(dir) // this always returns the sorted list files, err := os.ReadDir(dir) // this always returns the sorted list
if err != nil { if err != nil {
@ -167,7 +175,7 @@ func (chain *Blockchain) ReadBlockHeight(h [32]byte) (uint64, error) {
} }
func (s *storefs) ReadBlockHeight(h [32]byte) (uint64, error) { func (s *storefs) ReadBlockHeight(h [32]byte) (uint64, error) {
dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) dir := s.getpath(h)
files, err := os.ReadDir(dir) files, err := os.ReadDir(dir)
if err != nil { if err != nil {
@ -194,7 +202,7 @@ func (s *storefs) ReadBlockHeight(h [32]byte) (uint64, error) {
} }
func (s *storefs) WriteBlock(h [32]byte, data []byte, difficulty *big.Int, ss_version uint64, height uint64) (err error) { func (s *storefs) WriteBlock(h [32]byte, data []byte, difficulty *big.Int, ss_version uint64, height uint64) (err error) {
dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) dir := s.getpath(h)
file := filepath.Join(dir, fmt.Sprintf("%x.block_%s_%d_%d", h[:], difficulty.String(), ss_version, height)) file := filepath.Join(dir, fmt.Sprintf("%x.block_%s_%d_%d", h[:], difficulty.String(), ss_version, height))
if err = os.MkdirAll(dir, 0700); err != nil { if err = os.MkdirAll(dir, 0700); err != nil {
return err return err
@ -203,12 +211,13 @@ func (s *storefs) WriteBlock(h [32]byte, data []byte, difficulty *big.Int, ss_ve
} }
func (s *storefs) ReadTX(h [32]byte) ([]byte, error) { func (s *storefs) ReadTX(h [32]byte) ([]byte, error) {
file := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2]), fmt.Sprintf("%x.tx", h[:])) dir := s.getpath(h)
file := filepath.Join(dir, fmt.Sprintf("%x.tx", h[:]))
return ioutil.ReadFile(file) return ioutil.ReadFile(file)
} }
func (s *storefs) WriteTX(h [32]byte, data []byte) (err error) { func (s *storefs) WriteTX(h [32]byte, data []byte) (err error) {
dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) dir := s.getpath(h)
file := filepath.Join(dir, fmt.Sprintf("%x.tx", h[:])) file := filepath.Join(dir, fmt.Sprintf("%x.tx", h[:]))
if err = os.MkdirAll(dir, 0700); err != nil { if err = os.MkdirAll(dir, 0700); err != nil {
@ -219,7 +228,7 @@ func (s *storefs) WriteTX(h [32]byte, data []byte) (err error) {
} }
func (s *storefs) DeleteTX(h [32]byte) (err error) { func (s *storefs) DeleteTX(h [32]byte) (err error) {
dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) dir := s.getpath(h)
file := filepath.Join(dir, fmt.Sprintf("%x.tx", h[:])) file := filepath.Join(dir, fmt.Sprintf("%x.tx", h[:]))
return os.Remove(file) return os.Remove(file)
} }

View File

@ -70,10 +70,15 @@ func (chain *Blockchain) process_miner_transaction(bl *block.Block, genesis bool
if genesis == true { // process premine ,register genesis block, dev key if genesis == true { // process premine ,register genesis block, dev key
balance := crypto.ConstructElGamal(acckey.G1(), crypto.ElGamal_BASE_G) // init zero balance balance := crypto.ConstructElGamal(acckey.G1(), crypto.ElGamal_BASE_G) // init zero balance
balance = balance.Plus(new(big.Int).SetUint64(tx.Value << 1)) // add premine to users balance homomorphically balance = balance.Plus(new(big.Int).SetUint64(tx.Value)) // add premine to users balance homomorphically
nb := crypto.NonceBalance{NonceHeight: 0, Balance: balance} nb := crypto.NonceBalance{NonceHeight: 0, Balance: balance}
balance_tree.Put(tx.MinerAddress[:], nb.Serialize()) // reserialize and store balance_tree.Put(tx.MinerAddress[:], nb.Serialize()) // reserialize and store
if globals.IsMainnet() {
return
}
// only testnet/simulator will have dummy accounts to test
// we must process premine list and register and give them balance, // we must process premine list and register and give them balance,
premine_count := 0 premine_count := 0
scanner := bufio.NewScanner(strings.NewReader(premine.List)) scanner := bufio.NewScanner(strings.NewReader(premine.List))
@ -119,12 +124,15 @@ func (chain *Blockchain) process_miner_transaction(bl *block.Block, genesis bool
base_reward := CalcBlockReward(uint64(height)) base_reward := CalcBlockReward(uint64(height))
full_reward := base_reward + fees full_reward := base_reward + fees
//full_reward is divided into equal parts for all miner blocks + miner address integrator_reward := full_reward * 167 / 10000
//full_reward is divided into equal parts for all miner blocks
// integrator only gets 1.67 % of block reward
// since perfect division is not possible, ( see money handling) // since perfect division is not possible, ( see money handling)
// any left over change is delivered to main miner who integrated the full block // any left over change is delivered to main miner who integrated the full block
share := full_reward / uint64(len(bl.MiniBlocks)) // one block integrator, this is integer division share := (full_reward - integrator_reward) / uint64(len(bl.MiniBlocks)) // one block integrator, this is integer division
leftover := full_reward - (share * uint64(len(bl.MiniBlocks))) // only integrator will get this leftover := full_reward - integrator_reward - (share * uint64(len(bl.MiniBlocks))) // only integrator will get this
{ // giver integrator his reward { // giver integrator his reward
balance_serialized, err := balance_tree.Get(tx.MinerAddress[:]) balance_serialized, err := balance_tree.Get(tx.MinerAddress[:])
@ -132,7 +140,7 @@ func (chain *Blockchain) process_miner_transaction(bl *block.Block, genesis bool
panic(err) panic(err)
} }
nb := new(crypto.NonceBalance).Deserialize(balance_serialized) nb := new(crypto.NonceBalance).Deserialize(balance_serialized)
nb.Balance = nb.Balance.Plus(new(big.Int).SetUint64(share + leftover)) // add miners reward to miners balance homomorphically nb.Balance = nb.Balance.Plus(new(big.Int).SetUint64(integrator_reward + leftover)) // add miners reward to miners balance homomorphically
balance_tree.Put(tx.MinerAddress[:], nb.Serialize()) // reserialize and store balance_tree.Put(tx.MinerAddress[:], nb.Serialize()) // reserialize and store
} }
@ -230,7 +238,6 @@ func (chain *Blockchain) process_transaction(changed map[crypto.Hash]*graviton.T
nb.NonceHeight = height nb.NonceHeight = height
} }
tree.Put(key_compressed, nb.Serialize()) // reserialize and store tree.Put(key_compressed, nb.Serialize()) // reserialize and store
} }
} }

View File

@ -17,15 +17,15 @@
package main package main
import "io" import "io"
import "os"
import "time" import "time"
import "fmt" import "fmt"
import "errors"
//import "io/ioutil"
import "strings" import "strings"
//import "path/filepath" import "path/filepath"
//import "encoding/hex" import "encoding/json"
import "github.com/chzyer/readline" import "github.com/chzyer/readline"
@ -64,6 +64,7 @@ func display_easymenu_post_open_command(l *readline.Instance) {
io.WriteString(w, "\t\033[1m12\033[0m\tTransfer all balance (send DERO) To Another Wallet\n") io.WriteString(w, "\t\033[1m12\033[0m\tTransfer all balance (send DERO) To Another Wallet\n")
io.WriteString(w, "\t\033[1m13\033[0m\tShow transaction history\n") io.WriteString(w, "\t\033[1m13\033[0m\tShow transaction history\n")
io.WriteString(w, "\t\033[1m14\033[0m\tRescan transaction history\n") io.WriteString(w, "\t\033[1m14\033[0m\tRescan transaction history\n")
io.WriteString(w, "\t\033[1m15\033[0m\tExport all transaction history in json format\n")
} }
io.WriteString(w, "\n\t\033[1m9\033[0m\tExit menu and start prompt\n") io.WriteString(w, "\n\t\033[1m9\033[0m\tExit menu and start prompt\n")
@ -127,10 +128,21 @@ func handle_easymenu_post_open_command(l *readline.Instance, line string) (proce
fmt.Fprintf(l.Stderr(), "Wallet address : "+color_green+"%s"+color_white+" is going to be registered.This is a pre-condition for using the online chain.It will take few seconds to register.\n", wallet.GetAddress()) fmt.Fprintf(l.Stderr(), "Wallet address : "+color_green+"%s"+color_white+" is going to be registered.This is a pre-condition for using the online chain.It will take few seconds to register.\n", wallet.GetAddress())
reg_tx := wallet.GetRegistrationTX()
// at this point we must send the registration transaction // at this point we must send the registration transaction
fmt.Fprintf(l.Stderr(), "Wallet address : "+color_green+"%s"+color_white+" is going to be registered.Pls wait till the account is registered.\n", wallet.GetAddress()) fmt.Fprintf(l.Stderr(), "Wallet address : "+color_green+"%s"+color_white+" is going to be registered.Pls wait till the account is registered.\n", wallet.GetAddress())
fmt.Fprintf(l.Stderr(), "This will take a couple of minutes.Please wait....\n")
var reg_tx *transaction.Transaction
for {
reg_tx = wallet.GetRegistrationTX()
hash := reg_tx.GetHash()
if hash[0] == 0 && hash[1] == 0 {
break
}
}
fmt.Fprintf(l.Stderr(), "Registration TXID %s\n", reg_tx.GetHash()) fmt.Fprintf(l.Stderr(), "Registration TXID %s\n", reg_tx.GetHash())
err := wallet.SendTransaction(reg_tx) err := wallet.SendTransaction(reg_tx)
if err != nil { if err != nil {
@ -243,6 +255,7 @@ func handle_easymenu_post_open_command(l *readline.Instance, line string) (proce
if a.Arguments.Has(rpc.RPC_COMMENT, rpc.DataString) { // but only it is present if a.Arguments.Has(rpc.RPC_COMMENT, rpc.DataString) { // but only it is present
logger.Info("Integrated Message", "comment", a.Arguments.Value(rpc.RPC_COMMENT, rpc.DataString)) logger.Info("Integrated Message", "comment", a.Arguments.Value(rpc.RPC_COMMENT, rpc.DataString))
arguments = append(arguments, rpc.Argument{rpc.RPC_COMMENT, rpc.DataString, a.Arguments.Value(rpc.RPC_COMMENT, rpc.DataString)})
} }
} }
@ -291,10 +304,12 @@ func handle_easymenu_post_open_command(l *readline.Instance, line string) (proce
amount_to_transfer = a.Arguments.Value(rpc.RPC_VALUE_TRANSFER, rpc.DataUint64).(uint64) amount_to_transfer = a.Arguments.Value(rpc.RPC_VALUE_TRANSFER, rpc.DataUint64).(uint64)
} else { } else {
amount_str := read_line_with_prompt(l, fmt.Sprintf("Enter amount to transfer in DERO (max TODO): ")) mbal, _ := wallet.Get_Balance()
amount_str := read_line_with_prompt(l, fmt.Sprintf("Enter amount to transfer in DERO (current balance %s): ", globals.FormatMoney(mbal)))
if amount_str == "" { if amount_str == "" {
amount_str = ".00001" logger.Error(nil, "Cannot transfer 0")
break // invalid amount provided, bail out
} }
amount_to_transfer, err = globals.ParseAmount(amount_str) amount_to_transfer, err = globals.ParseAmount(amount_str)
if err != nil { if err != nil {
@ -315,8 +330,15 @@ func handle_easymenu_post_open_command(l *readline.Instance, line string) (proce
// if no arguments, use space by embedding a small comment // if no arguments, use space by embedding a small comment
if len(arguments) == 0 { // allow user to enter Comment if len(arguments) == 0 { // allow user to enter Comment
if v, err := ReadUint64(l, "Please enter payment id (or destination port number)", uint64(0)); err == nil {
arguments = append(arguments, rpc.Argument{Name: rpc.RPC_DESTINATION_PORT, DataType: rpc.DataUint64, Value: v})
} else {
logger.Error(err, fmt.Sprintf("%s could not be parsed (type %s),", "Number", rpc.DataUint64))
return
}
if v, err := ReadString(l, "Comment", ""); err == nil { if v, err := ReadString(l, "Comment", ""); err == nil {
arguments = append(arguments, rpc.Argument{Name: "Comment", DataType: rpc.DataString, Value: v}) arguments = append(arguments, rpc.Argument{Name: rpc.RPC_COMMENT, DataType: rpc.DataString, Value: v})
} else { } else {
logger.Error(fmt.Errorf("%s could not be parsed (type %s),", "Comment", rpc.DataString), "") logger.Error(fmt.Errorf("%s could not be parsed (type %s),", "Comment", rpc.DataString), "")
return return
@ -429,6 +451,34 @@ func handle_easymenu_post_open_command(l *readline.Instance, line string) (proce
case "14": case "14":
logger.Info("Rescanning wallet history") logger.Info("Rescanning wallet history")
rescan_bc(wallet) rescan_bc(wallet)
case "15":
if !ValidateCurrentPassword(l, wallet) {
logger.Error(fmt.Errorf("Invalid password"), "")
break
}
if _, err := os.Stat("./history"); errors.Is(err, os.ErrNotExist) {
if err := os.Mkdir("./history", 0700); err != nil {
logger.Error(err, "Error creating directory")
break
}
}
var zeroscid crypto.Hash
account := wallet.GetAccount()
for k, v := range account.EntriesNative {
filename := filepath.Join("./history", k.String()+".json")
if k == zeroscid {
filename = filepath.Join("./history", "dero.json")
}
if data, err := json.Marshal(v); err != nil {
logger.Error(err, "Error exporting data")
} else if err = os.WriteFile(filename, data, 0600); err != nil {
logger.Error(err, "Error exporting data")
} else {
logger.Info("successfully exported history", "file", filename)
}
}
default: default:
processed = false // just loop processed = false // just loop

View File

@ -80,6 +80,7 @@ Usage:
--rpc-server Run rpc server, so wallet is accessible using api --rpc-server Run rpc server, so wallet is accessible using api
--rpc-bind=<127.0.0.1:20209> Wallet binds on this ip address and port --rpc-bind=<127.0.0.1:20209> Wallet binds on this ip address and port
--rpc-login=<username:password> RPC server will grant access based on these credentials --rpc-login=<username:password> RPC server will grant access based on these credentials
--allow-rpc-password-change RPC server will change password if you send "Pass" header with new password
` `
var menu_mode bool = true // default display menu mode var menu_mode bool = true // default display menu mode
//var account_valid bool = false // if an account has been opened, do not allow to create new account in this session //var account_valid bool = false // if an account has been opened, do not allow to create new account in this session
@ -119,7 +120,7 @@ func main() {
} }
// init the lookup table one, anyone importing walletapi should init this first, this will take around 1 sec on any recent system // init the lookup table one, anyone importing walletapi should init this first, this will take around 1 sec on any recent system
walletapi.Initialize_LookupTable(1, 1<<17) walletapi.Initialize_LookupTable(1, 1<<21)
// We need to initialize readline first, so it changes stderr to ansi processor on windows // We need to initialize readline first, so it changes stderr to ansi processor on windows
l, err := readline.NewEx(&readline.Config{ l, err := readline.NewEx(&readline.Config{

View File

@ -25,6 +25,7 @@ import "time"
//import "io/ioutil" //import "io/ioutil"
//import "path/filepath" //import "path/filepath"
import "strings" import "strings"
import "unicode"
import "strconv" import "strconv"
import "encoding/hex" import "encoding/hex"
@ -39,6 +40,15 @@ import "github.com/deroproject/derohe/cryptography/crypto"
var account walletapi.Account var account walletapi.Account
func isASCII(s string) bool {
for _, c := range s {
if c > unicode.MaxASCII {
return false
}
}
return true
}
// handle all commands while in prompt mode // handle all commands while in prompt mode
func handle_prompt_command(l *readline.Instance, line string) { func handle_prompt_command(l *readline.Instance, line string) {
@ -126,6 +136,54 @@ func handle_prompt_command(l *readline.Instance, line string) {
case "spendkey": // give user his spend key case "spendkey": // give user his spend key
display_spend_key(l, wallet) display_spend_key(l, wallet)
case "filesign": // sign a file contents
if !ValidateCurrentPassword(l, wallet) {
logger.Error(err, "Invalid password")
PressAnyKey(l, wallet)
break
}
filename, err := ReadString(l, "Enter file to sign", "")
if err != nil {
logger.Error(err, "Cannot read input file name")
}
outputfile := filename + ".sign"
if filedata, err := os.ReadFile(filename); err != nil {
logger.Error(err, "Cannot read input file")
} else if err := os.WriteFile(outputfile, wallet.SignData(filedata), 0600); err != nil {
logger.Error(err, "Cannot write output file", "file", outputfile)
} else {
logger.Info("successfully signed file. please check", "file", outputfile)
}
case "fileverify": // verify a file contents
filename, err := ReadString(l, "Enter file to verify signature", "")
if err != nil {
logger.Error(err, "Cannot read input file name")
}
outputfile := strings.TrimSuffix(filename, ".sign")
if filedata, err := os.ReadFile(filename); err != nil {
logger.Error(err, "Cannot read input file")
} else if signer, message, err := wallet.CheckSignature(filedata); err != nil {
logger.Error(err, "Signature verify failed", "file", filename)
} else {
logger.Info("Signed by", "address", signer.String())
if isASCII(string(message)) { // do not spew garbage
logger.Info("", "message", string(message))
}
if os.WriteFile(outputfile, message, 0600); err != nil {
logger.Error(err, "Cannot write output file", "file", outputfile)
}
logger.Info("successfully wrote message to file. please check", "file", outputfile)
}
case "password": // change wallet password case "password": // change wallet password
if ConfirmYesNoDefaultNo(l, "Change wallet password (y/N)") && if ConfirmYesNoDefaultNo(l, "Change wallet password (y/N)") &&
ValidateCurrentPassword(l, wallet) { ValidateCurrentPassword(l, wallet) {
@ -525,6 +583,10 @@ func ReadUint64(l *readline.Instance, cprompt string, default_value uint64) (a u
error_message := "" error_message := ""
color := color_green color := color_green
if len(line) == 0 {
line = []rune(fmt.Sprintf("%d", default_value))
}
if len(line) >= 1 { if len(line) >= 1 {
_, err := strconv.ParseUint(string(line), 0, 64) _, err := strconv.ParseUint(string(line), 0, 64)
if err != nil { if err != nil {
@ -548,6 +610,9 @@ func ReadUint64(l *readline.Instance, cprompt string, default_value uint64) (a u
if err != nil { if err != nil {
return return
} }
if len(line) == 0 {
line = []byte(fmt.Sprintf("%d", default_value))
}
a, err = strconv.ParseUint(string(line), 0, 64) a, err = strconv.ParseUint(string(line), 0, 64)
l.SetPrompt(cprompt) l.SetPrompt(cprompt)
l.Refresh() l.Refresh()
@ -800,6 +865,8 @@ var completer = readline.NewPrefixCompleter(
readline.PcItem("balance"), readline.PcItem("balance"),
readline.PcItem("integrated_address"), readline.PcItem("integrated_address"),
readline.PcItem("get_tx_key"), readline.PcItem("get_tx_key"),
readline.PcItem("filesign"),
readline.PcItem("fileverify"),
readline.PcItem("menu"), readline.PcItem("menu"),
readline.PcItem("rescan_bc"), readline.PcItem("rescan_bc"),
readline.PcItem("payment_id"), readline.PcItem("payment_id"),
@ -817,7 +884,6 @@ var completer = readline.NewPrefixCompleter(
readline.PcItem("version"), readline.PcItem("version"),
readline.PcItem("transfer"), readline.PcItem("transfer"),
readline.PcItem("transfer_all"), readline.PcItem("transfer_all"),
readline.PcItem("walletviewkey"),
readline.PcItem("bye"), readline.PcItem("bye"),
readline.PcItem("exit"), readline.PcItem("exit"),
readline.PcItem("quit"), readline.PcItem("quit"),

View File

@ -281,7 +281,9 @@ func main() {
} }
testnet_string := "" testnet_string := ""
if !globals.IsMainnet() { if globals.IsMainnet() {
testnet_string = "\033[31m MAINNET"
} else {
testnet_string = "\033[31m TESTNET" testnet_string = "\033[31m TESTNET"
} }
@ -406,6 +408,19 @@ restart_loop:
memoryfile.Close() memoryfile.Close()
*/ */
case command == "setintegratoraddress":
if len(line_parts) != 2 {
logger.Error(fmt.Errorf("This function requires 1 parameters, dero address"), "")
continue
}
if addr, err := rpc.NewAddress(line_parts[1]); err != nil {
logger.Error(err, "invalid address")
continue
} else {
chain.SetIntegratorAddress(*addr)
logger.Info("will use", "integrator_address", chain.IntegratorAddress().String())
}
case command == "print_bc": case command == "print_bc":
logger.Info("printing block chain") logger.Info("printing block chain")
@ -1023,6 +1038,8 @@ func usage(w io.Writer) {
io.WriteString(w, "\t\033[1mregpool_print\033[0m\t\tprint regpool contents\n") io.WriteString(w, "\t\033[1mregpool_print\033[0m\t\tprint regpool contents\n")
io.WriteString(w, "\t\033[1mregpool_delete_tx\033[0m\t\tDelete specific tx from regpool\n") io.WriteString(w, "\t\033[1mregpool_delete_tx\033[0m\t\tDelete specific tx from regpool\n")
io.WriteString(w, "\t\033[1mregpool_flush\033[0m\t\tFlush mempool\n") io.WriteString(w, "\t\033[1mregpool_flush\033[0m\t\tFlush mempool\n")
io.WriteString(w, "\t\033[1msetintegratoraddress\033[0m\t\tChange current integrated address\n")
io.WriteString(w, "\t\033[1mversion\033[0m\t\tShow version\n") io.WriteString(w, "\t\033[1mversion\033[0m\t\tShow version\n")
io.WriteString(w, "\t\033[1mexit\033[0m\t\tQuit the daemon\n") io.WriteString(w, "\t\033[1mexit\033[0m\t\tQuit the daemon\n")
io.WriteString(w, "\t\033[1mquit\033[0m\t\tQuit the daemon\n") io.WriteString(w, "\t\033[1mquit\033[0m\t\tQuit the daemon\n")
@ -1046,6 +1063,7 @@ var completer = readline.NewPrefixCompleter(
readline.PcItem("block_export"), readline.PcItem("block_export"),
readline.PcItem("block_import"), readline.PcItem("block_import"),
// readline.PcItem("print_tx"), // readline.PcItem("print_tx"),
readline.PcItem("setintegratoraddress"),
readline.PcItem("status"), readline.PcItem("status"),
readline.PcItem("sync_info"), readline.PcItem("sync_info"),
readline.PcItem("version"), readline.PcItem("version"),

View File

@ -35,7 +35,6 @@ func GetEncryptedBalance(ctx context.Context, p rpc.GetEncryptedBalance_Params)
defer func() { // safety so if anything wrong happens, we return error defer func() { // safety so if anything wrong happens, we return error
if r := recover(); r != nil { if r := recover(); r != nil {
err = fmt.Errorf("panic occured. stack trace %s", debug.Stack()) err = fmt.Errorf("panic occured. stack trace %s", debug.Stack())
fmt.Printf("panic stack trace %s params %+v\n", debug.Stack(), p)
} }
}() }()

View File

@ -77,7 +77,7 @@ func simulator_chain_start() (*blockchain.Blockchain, *derodrpc.RPCServer, map[s
OptionsFirst: true, OptionsFirst: true,
} }
globals.Arguments, err = parser.ParseArgs(command_line_test, []string{"--data-dir", tmpdirectory, "--rpc-bind", rpcport_test}, config.Version.String()) globals.Arguments, err = parser.ParseArgs(command_line_test, []string{"--data-dir", tmpdirectory, "--rpc-bind", rpcport_test, "--testnet"}, config.Version.String())
if err != nil { if err != nil {
//log.Fatalf("Error while parsing options err: %s\n", err) //log.Fatalf("Error while parsing options err: %s\n", err)
return nil, nil, nil return nil, nil, nil

View File

@ -43,7 +43,7 @@ const SC_META = "M" // keeps all SCs balance, their state, their OWNER, the
const MAX_STORAGE_GAS_ATOMIC_UNITS = 20000 const MAX_STORAGE_GAS_ATOMIC_UNITS = 20000
// Minimum FEE calculation constants are here // Minimum FEE calculation constants are here
const FEE_PER_KB = uint64(100) // .00100 dero per kb const FEE_PER_KB = uint64(20) // .00020 dero per kb
// we can easily improve TPS by changing few parameters in this file // we can easily improve TPS by changing few parameters in this file
// the resources compute/network may not be easy for the developing countries // the resources compute/network may not be easy for the developing countries
@ -56,8 +56,8 @@ const MIN_RINGSIZE = 2 // >= 2 , ringsize will be accepted
const MAX_RINGSIZE = 128 // <= 128, ringsize will be accepted const MAX_RINGSIZE = 128 // <= 128, ringsize will be accepted
type SettingsStruct struct { type SettingsStruct struct {
MAINNET_BOOTSTRAP_DIFFICULTY uint64 `env:"MAINNET_BOOTSTRAP_DIFFICULTY" envDefault:"80000000"` MAINNET_BOOTSTRAP_DIFFICULTY uint64 `env:"MAINNET_BOOTSTRAP_DIFFICULTY" envDefault:"10000000"` // mainnet bootstrap is 10 MH/s
MAINNET_MINIMUM_DIFFICULTY uint64 `env:"MAINNET_MINIMUM_DIFFICULTY" envDefault:"80000000"` MAINNET_MINIMUM_DIFFICULTY uint64 `env:"MAINNET_MINIMUM_DIFFICULTY" envDefault:"100000"` // mainnet minimum is 100 KH/s
TESTNET_BOOTSTRAP_DIFFICULTY uint64 `env:"TESTNET_BOOTSTRAP_DIFFICULTY" envDefault:"10000"` TESTNET_BOOTSTRAP_DIFFICULTY uint64 `env:"TESTNET_BOOTSTRAP_DIFFICULTY" envDefault:"10000"`
TESTNET_MINIMUM_DIFFICULTY uint64 `env:"TESTNET_MINIMUM_DIFFICULTY" envDefault:"10000"` TESTNET_MINIMUM_DIFFICULTY uint64 `env:"TESTNET_MINIMUM_DIFFICULTY" envDefault:"10000"`
@ -89,7 +89,7 @@ type CHAIN_CONFIG struct {
} }
var Mainnet = CHAIN_CONFIG{Name: "mainnet", var Mainnet = CHAIN_CONFIG{Name: "mainnet",
Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x9a, 0x44, 0x45, 0x0}), Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x9a, 0x44, 0x40, 0x0}),
GETWORK_Default_Port: 10100, GETWORK_Default_Port: 10100,
P2P_Default_Port: 10101, P2P_Default_Port: 10101,
RPC_Default_Port: 10102, RPC_Default_Port: 10102,
@ -101,13 +101,13 @@ var Mainnet = CHAIN_CONFIG{Name: "mainnet",
"00" + // Source is DERO network "00" + // Source is DERO network
"00" + // Dest is DERO network "00" + // Dest is DERO network
"00" + // PREMINE_FLAG "00" + // PREMINE_FLAG
"8fff7f" + // PREMINE_VALUE "80a8b9ceb024" + // PREMINE_VALUE
"1f9bcc1208dee302769931ad378a4c0c4b2c21b0cfb3e752607e12d2b6fa642500", // miners public key "1f9bcc1208dee302769931ad378a4c0c4b2c21b0cfb3e752607e12d2b6fa642500", // miners public key
} }
var Testnet = CHAIN_CONFIG{Name: "testnet", // testnet will always have last 3 bytes 0 var Testnet = CHAIN_CONFIG{Name: "testnet", // testnet will always have last 3 bytes 0
Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x80, 0x00, 0x00, 0x00}), Network_ID: uuid.FromBytesOrNil([]byte{0x59, 0xd7, 0xf7, 0xe9, 0xdd, 0x48, 0xd5, 0xfd, 0x13, 0x0a, 0xf6, 0xe0, 0x83, 0x00, 0x00, 0x00}),
GETWORK_Default_Port: 10100, GETWORK_Default_Port: 10100,
P2P_Default_Port: 40401, P2P_Default_Port: 40401,
RPC_Default_Port: 40402, RPC_Default_Port: 40402,
@ -120,7 +120,7 @@ var Testnet = CHAIN_CONFIG{Name: "testnet", // testnet will always have last 3 b
"00" + // Source is DERO network "00" + // Source is DERO network
"00" + // Dest is DERO network "00" + // Dest is DERO network
"00" + // PREMINE_FLAG "00" + // PREMINE_FLAG
"8fff7f" + // PREMINE_VALUE "80a8b9ceb024" + // PREMINE_VALUE
"1f9bcc1208dee302769931ad378a4c0c4b2c21b0cfb3e752607e12d2b6fa642500", // miners public key "1f9bcc1208dee302769931ad378a4c0c4b2c21b0cfb3e752607e12d2b6fa642500", // miners public key
} }

View File

@ -20,4 +20,4 @@ import "github.com/blang/semver/v4"
// right now it has to be manually changed // right now it has to be manually changed
// do we need to include git commitsha?? // do we need to include git commitsha??
var Version = semver.MustParse("3.4.106-0.DEROHE.STARGATE+18012022") var Version = semver.MustParse("3.4.109-0.DEROHE.STARGATE+18012022")

View File

@ -576,7 +576,7 @@ func (i *DVM_Interpreter) interpret_SmartContract() (err error) {
} }
if i.State.Trace { if i.State.Trace {
fmt.Printf("interpreting line %+v err:'%s'\n", line, err) fmt.Printf("interpreting line %+v err:'%v'\n", line, err)
} }
if err != nil { if err != nil {
err = fmt.Errorf("err while interpreting line %+v err %s\n", line, err) err = fmt.Errorf("err while interpreting line %+v err %s\n", line, err)

View File

@ -187,8 +187,6 @@ func InitializeLog(console, logfile io.Writer) {
func Initialize() { func Initialize() {
var err error var err error
Arguments["--testnet"] = true // force testnet every where
InitNetwork() InitNetwork()
// choose socks based proxy if user requested so // choose socks based proxy if user requested so

View File

@ -18,7 +18,7 @@ jobs:
os: ['ubuntu-latest'] os: ['ubuntu-latest']
steps: steps:
- name: Install Go ${{ matrix.go-version }} - name: Install Go ${{ matrix.go-version }}
uses: actions/setup-go@v1 uses: actions/setup-go@v2
with: with:
go-version: ${{ matrix.go-version }} go-version: ${{ matrix.go-version }}
- uses: actions/checkout@v2 - uses: actions/checkout@v2

View File

@ -1,10 +1,9 @@
# jrpc2 # jrpc2
[![GoDoc](https://img.shields.io/static/v1?label=godoc&message=reference&color=yellow)](https://pkg.go.dev/github.com/creachadair/jrpc2) [![GoDoc](https://img.shields.io/static/v1?label=godoc&message=reference&color=yellow)](https://pkg.go.dev/github.com/creachadair/jrpc2)
[![Go Report Card](https://goreportcard.com/badge/github.com/creachadair/jrpc2)](https://goreportcard.com/report/github.com/creachadair/jrpc2)
This repository provides Go package that implements a [JSON-RPC 2.0][spec] client and server. This repository provides a Go module that implements a [JSON-RPC 2.0][spec] client and server.
There is also a working [example in the Go playground](https://play.golang.org/p/MSClCk55UzF). There is also a working [example in the Go playground](https://go.dev/play/p/fY-Pnvf03Hr).
## Packages ## Packages

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2 package jrpc2
import ( import (
@ -17,9 +19,12 @@ type Assigner interface {
// The implementation can obtain the complete request from ctx using the // The implementation can obtain the complete request from ctx using the
// jrpc2.InboundRequest function. // jrpc2.InboundRequest function.
Assign(ctx context.Context, method string) Handler Assign(ctx context.Context, method string) Handler
}
// Names returns a slice of all known method names for the assigner. The // Namer is an optional interface that an Assigner may implement to expose the
// resulting slice is ordered lexicographically and contains no duplicates. // names of its methods to the ServerInfo method.
type Namer interface {
// Names returns all known method names in lexicographic order.
Names() []string Names() []string
} }
@ -91,11 +96,14 @@ func (r *Request) UnmarshalParams(v interface{}) error {
dec := json.NewDecoder(bytes.NewReader(r.params)) dec := json.NewDecoder(bytes.NewReader(r.params))
dec.DisallowUnknownFields() dec.DisallowUnknownFields()
if err := dec.Decode(v); err != nil { if err := dec.Decode(v); err != nil {
return Errorf(code.InvalidParams, "invalid parameters: %v", err.Error()) return errInvalidParams.WithData(err.Error())
} }
return nil return nil
} }
return json.Unmarshal(r.params, v) if err := json.Unmarshal(r.params, v); err != nil {
return errInvalidParams.WithData(err.Error())
}
return nil
} }
// ParamString returns the encoded request parameters of r as a string. // ParamString returns the encoded request parameters of r as a string.

View File

@ -1,13 +1,18 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2_test package jrpc2_test
import ( import (
"context" "context"
"strconv"
"sync"
"testing" "testing"
"github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/handler"
"github.com/creachadair/jrpc2/jctx" "github.com/creachadair/jrpc2/jctx"
"github.com/creachadair/jrpc2/server" "github.com/creachadair/jrpc2/server"
"github.com/fortytw2/leaktest"
) )
func BenchmarkRoundTrip(b *testing.B) { func BenchmarkRoundTrip(b *testing.B) {
@ -69,6 +74,75 @@ func BenchmarkRoundTrip(b *testing.B) {
} }
} }
func BenchmarkLoad(b *testing.B) {
defer leaktest.Check(b)()
// The load testing service has a no-op method to exercise server overhead.
loc := server.NewLocal(handler.Map{
"void": handler.Func(func(context.Context, *jrpc2.Request) (interface{}, error) {
return nil, nil
}),
}, nil)
defer loc.Close()
// Exercise concurrent calls.
ctx := context.Background()
b.Run("Call", func(b *testing.B) {
var wg sync.WaitGroup
for i := 0; i < b.N; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := loc.Client.Call(ctx, "void", nil)
if err != nil {
b.Errorf("Call failed: %v", err)
}
}()
}
wg.Wait()
})
// Exercise concurrent notifications.
b.Run("Notify", func(b *testing.B) {
var wg sync.WaitGroup
for i := 0; i < b.N; i++ {
wg.Add(1)
go func() {
defer wg.Done()
err := loc.Client.Notify(ctx, "void", nil)
if err != nil {
b.Errorf("Notify failed: %v", err)
}
}()
}
wg.Wait()
})
// Exercise concurrent batches of various sizes.
for _, bs := range []int{1, 2, 4, 8, 12, 16, 20, 50} {
batch := make([]jrpc2.Spec, bs)
for j := 0; j < len(batch); j++ {
batch[j].Method = "void"
}
name := "Batch-" + strconv.Itoa(bs)
b.Run(name, func(b *testing.B) {
var wg sync.WaitGroup
for i := 0; i < b.N; i += bs {
wg.Add(1)
go func() {
defer wg.Done()
_, err := loc.Client.Batch(ctx, batch)
if err != nil {
b.Errorf("Batch failed: %v", err)
}
}()
}
wg.Wait()
})
}
}
func BenchmarkParseRequests(b *testing.B) { func BenchmarkParseRequests(b *testing.B) {
reqs := []struct { reqs := []struct {
desc, input string desc, input string

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package channel_test package channel_test
import ( import (

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Package channel defines a basic communications channel. // Package channel defines a basic communications channel.
// //
// A Channel encodes/transmits and decodes/receives data records over an // A Channel encodes/transmits and decodes/receives data records over an
@ -61,10 +63,15 @@ type Channel interface {
Close() error Close() error
} }
// IsErrClosing reports whether err is the internal error returned by a read // ErrClosed is a sentinel error that can be returned to indicate an operation
// from a pipe or socket that is closed. This is false for err == nil. // failed because the channel was closed.
var ErrClosed = errors.New("channel is closed")
// IsErrClosing reports whether err is a channel-closed error. This is true
// for the internal error returned by a read from a pipe or socket that is
// closed, or an error that wraps ErrClosed. It is false if err == nil.
func IsErrClosing(err error) bool { func IsErrClosing(err error) bool {
return err != nil && errors.Is(err, net.ErrClosed) return err != nil && (errors.Is(err, ErrClosed) || errors.Is(err, net.ErrClosed))
} }
// A Framing converts a reader and a writer into a Channel with a particular // A Framing converts a reader and a writer into a Channel with a particular
@ -77,14 +84,12 @@ type direct struct {
} }
func (d direct) Send(msg []byte) (err error) { func (d direct) Send(msg []byte) (err error) {
cp := make([]byte, len(msg))
copy(cp, msg)
defer func() { defer func() {
if p := recover(); p != nil { if p := recover(); p != nil {
err = errors.New("send on closed channel") err = errors.New("send on closed channel")
} }
}() }()
d.send <- cp d.send <- msg
return nil return nil
} }
@ -101,6 +106,10 @@ func (d direct) Close() error { close(d.send); return nil }
// Direct returns a pair of synchronous connected channels that pass message // Direct returns a pair of synchronous connected channels that pass message
// buffers directly in memory without framing or encoding. Sends to client will // buffers directly in memory without framing or encoding. Sends to client will
// be received by server, and vice versa. // be received by server, and vice versa.
//
// Note that buffers passed to direct channels are not copied. If the caller
// needs to use the buffer after sending it on a direct channel, the caller is
// responsible for making a copy.
func Direct() (client, server Channel) { func Direct() (client, server Channel) {
c2s := make(chan []byte) c2s := make(chan []byte)
s2c := make(chan []byte) s2c := make(chan []byte)

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package channel package channel
import ( import (
@ -6,6 +8,8 @@ import (
"strings" "strings"
"sync" "sync"
"testing" "testing"
"github.com/fortytw2/leaktest"
) )
// newPipe creates a pair of connected in-memory channels using the specified // newPipe creates a pair of connected in-memory channels using the specified
@ -20,6 +24,8 @@ func newPipe(framing Framing) (client, server Channel) {
} }
func testSendRecv(t *testing.T, s, r Channel, msg string) { func testSendRecv(t *testing.T, s, r Channel, msg string) {
defer leaktest.Check(t)()
var wg sync.WaitGroup var wg sync.WaitGroup
var sendErr, recvErr error var sendErr, recvErr error
var data []byte var data []byte

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package channel package channel
import ( import (

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package channel package channel
import ( import (

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package channel package channel
import ( import (

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2 package jrpc2
import ( import (
@ -21,10 +23,11 @@ type Client struct {
log func(string, ...interface{}) // write debug logs here log func(string, ...interface{}) // write debug logs here
enctx encoder enctx encoder
snote func(*jmessage) snote func(*jmessage)
scall func(*jmessage) []byte scall func(context.Context, *jmessage) []byte
chook func(*Client, *Response) chook func(*Client, *Response)
allow1 bool // tolerate v1 replies with no version marker cbctx context.Context // terminates when the client is closed
cbcancel func() // cancels cbctx
mu sync.Mutex // protects the fields below mu sync.Mutex // protects the fields below
ch channel.Channel // channel to the server ch channel.Channel // channel to the server
@ -35,15 +38,18 @@ type Client struct {
// NewClient returns a new client that communicates with the server via ch. // NewClient returns a new client that communicates with the server via ch.
func NewClient(ch channel.Channel, opts *ClientOptions) *Client { func NewClient(ch channel.Channel, opts *ClientOptions) *Client {
cbctx, cbcancel := context.WithCancel(context.Background())
c := &Client{ c := &Client{
done: new(sync.WaitGroup), done: new(sync.WaitGroup),
log: opts.logFunc(), log: opts.logFunc(),
allow1: opts.allowV1(),
enctx: opts.encodeContext(), enctx: opts.encodeContext(),
snote: opts.handleNotification(), snote: opts.handleNotification(),
scall: opts.handleCallback(), scall: opts.handleCallback(),
chook: opts.handleCancel(), chook: opts.handleCancel(),
cbctx: cbctx,
cbcancel: cbcancel,
// Lock-protected fields // Lock-protected fields
ch: ch, ch: ch,
pending: make(map[string]*Response), pending: make(map[string]*Response),
@ -99,7 +105,7 @@ func (c *Client) accept(ch receiver) error {
} }
// handleRequest handles a callback or notification from the server. The // handleRequest handles a callback or notification from the server. The
// caller must hold c.mu, and this blocks until the handler completes. // caller must hold c.mu. This function does not block for the handler.
// Precondition: msg is a request or notification, not a response or error. // Precondition: msg is a request or notification, not a response or error.
func (c *Client) handleRequest(msg *jmessage) { func (c *Client) handleRequest(msg *jmessage) {
if msg.isNotification() { if msg.isNotification() {
@ -113,10 +119,22 @@ func (c *Client) handleRequest(msg *jmessage) {
} else if c.ch == nil { } else if c.ch == nil {
c.log("Client channel is closed; discarding callback: %v", msg) c.log("Client channel is closed; discarding callback: %v", msg)
} else { } else {
bits := c.scall(msg) // Run the callback handler in its own goroutine. The context will be
if err := c.ch.Send(bits); err != nil { // cancelled automatically when the client is closed.
ctx := context.WithValue(c.cbctx, clientKey{}, c)
c.done.Add(1)
go func() {
defer c.done.Done()
bits := c.scall(ctx, msg)
c.mu.Lock()
defer c.mu.Unlock()
if c.err != nil {
c.log("Discarding callback response: %v", c.err)
} else if err := c.ch.Send(bits); err != nil {
c.log("Sending reply for callback %v failed: %v", msg, err) c.log("Sending reply for callback %v failed: %v", msg, err)
} }
}()
} }
} }
@ -365,7 +383,7 @@ func (c *Client) Notify(ctx context.Context, method string, params interface{})
return err return err
} }
// Close shuts down the client, abandoning any pending in-flight requests. // Close shuts down the client, terminating any pending in-flight requests.
func (c *Client) Close() error { func (c *Client) Close() error {
c.mu.Lock() c.mu.Lock()
c.stop(errClientStopped) c.stop(errClientStopped)
@ -392,20 +410,19 @@ func (c *Client) stop(err error) {
} }
c.ch.Close() c.ch.Close()
// Unblock and fail any pending callbacks.
c.cbcancel()
// Unblock and fail any pending requests. // Unblock and fail any pending requests.
for _, p := range c.pending { for _, p := range c.pending {
p.cancel() p.cancel()
} }
c.err = err c.err = err
c.ch = nil c.ch = nil
} }
func (c *Client) versionOK(v string) bool { func (c *Client) versionOK(v string) bool { return v == Version }
if v == "" {
return c.allow1
}
return v == Version
}
// marshalParams validates and marshals params to JSON for a request. The // marshalParams validates and marshals params to JSON for a request. The
// value of params must be either nil or encodable as a JSON object or array. // value of params must be either nil or encodable as a JSON object or array.
@ -417,7 +434,7 @@ func (c *Client) marshalParams(ctx context.Context, method string, params interf
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(pbits) == 0 || (pbits[0] != '[' && pbits[0] != '{' && !isNull(pbits)) { if fb := firstByte(pbits); fb != '[' && fb != '{' && !isNull(pbits) {
// JSON-RPC requires that if parameters are provided at all, they are // JSON-RPC requires that if parameters are provided at all, they are
// an array or an object. // an array or an object.
return nil, &Error{Code: code.InvalidRequest, Message: "invalid parameters: array or object required"} return nil, &Error{Code: code.InvalidRequest, Message: "invalid parameters: array or object required"}

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Package code defines error code values used by the jrpc2 package. // Package code defines error code values used by the jrpc2 package.
package code package code

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package code_test package code_test
import ( import (

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2 package jrpc2
import ( import (
@ -32,3 +34,12 @@ type inboundRequestKey struct{}
func ServerFromContext(ctx context.Context) *Server { return ctx.Value(serverKey{}).(*Server) } func ServerFromContext(ctx context.Context) *Server { return ctx.Value(serverKey{}).(*Server) }
type serverKey struct{} type serverKey struct{}
// ClientFromContext returns the client associated with the given context.
// This will be populated on the context passed to callback handlers.
//
// A callback handler must not close the client, as the close will deadlock
// waiting for the callback to return.
func ClientFromContext(ctx context.Context) *Client { return ctx.Value(clientKey{}).(*Client) }
type clientKey struct{}

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
/* /*
Package jrpc2 implements a server and a client for the JSON-RPC 2.0 protocol Package jrpc2 implements a server and a client for the JSON-RPC 2.0 protocol
defined by http://www.jsonrpc.org/specification. defined by http://www.jsonrpc.org/specification.
@ -12,9 +14,20 @@ Handle method with this signature:
Handle(ctx Context.Context, req *jrpc2.Request) (interface{}, error) Handle(ctx Context.Context, req *jrpc2.Request) (interface{}, error)
A server finds the handler for a request by looking up its method name in a A server finds the handler for a request by looking up its method name in a
jrpc2.Assigner provided when the server is set up. jrpc2.Assigner provided when the server is set up. A Handler can decode the
request parameters using the UnmarshalParams method on the request:
For example, suppose we want to export this Add function via JSON-RPC: func (H) Handle(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
var args ArgType
if err := req.UnmarshalParams(&args); err != nil {
return nil, err
}
return usefulStuffWith(args)
}
The handler package makes it easier to use functions that do not have this
exact type signature as handlers, by using reflection to lift functions into
the Handler interface. For example, suppose we want to export this Add function:
// Add returns the sum of a slice of integers. // Add returns the sum of a slice of integers.
func Add(ctx context.Context, values []int) int { func Add(ctx context.Context, values []int) int {
@ -25,9 +38,8 @@ For example, suppose we want to export this Add function via JSON-RPC:
return sum return sum
} }
The handler package helps adapt existing functions to the Handler interface. To convert Add to a jrpc2.Handler, call handler.New, which wraps its argument
To convert Add to a jrpc2.Handler, call handler.New, which uses reflection to into the jrpc2.Handler interface via the handler.Func type:
lift its argument into the jrpc2.Handler interface:
h := handler.New(Add) // h is now a jrpc2.Handler that calls Add h := handler.New(Add) // h is now a jrpc2.Handler that calls Add

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2 package jrpc2
import ( import (
@ -44,12 +46,21 @@ var errClientStopped = errors.New("the client has been stopped")
// errEmptyMethod is the error reported for an empty request method name. // errEmptyMethod is the error reported for an empty request method name.
var errEmptyMethod = &Error{Code: code.InvalidRequest, Message: "empty method name"} var errEmptyMethod = &Error{Code: code.InvalidRequest, Message: "empty method name"}
// errNoSuchMethod is the error reported for an unknown method name.
var errNoSuchMethod = &Error{Code: code.MethodNotFound, Message: "no such method"}
// errDuplicateID is the error reported for a duplicated request ID.
var errDuplicateID = &Error{Code: code.InvalidRequest, Message: "duplicate request ID"}
// errInvalidRequest is the error reported for an invalid request object or batch. // errInvalidRequest is the error reported for an invalid request object or batch.
var errInvalidRequest = &Error{Code: code.ParseError, Message: "invalid request value"} var errInvalidRequest = &Error{Code: code.ParseError, Message: "invalid request value"}
// errEmptyBatch is the error reported for an empty request batch. // errEmptyBatch is the error reported for an empty request batch.
var errEmptyBatch = &Error{Code: code.InvalidRequest, Message: "empty request batch"} var errEmptyBatch = &Error{Code: code.InvalidRequest, Message: "empty request batch"}
// errInvalidParams is the error reported for invalid request parameters.
var errInvalidParams = &Error{Code: code.InvalidParams, Message: "invalid parameters"}
// ErrConnClosed is returned by a server's push-to-client methods if they are // ErrConnClosed is returned by a server's push-to-client methods if they are
// called after the client connection is closed. // called after the client connection is closed.
var ErrConnClosed = errors.New("client connection is closed") var ErrConnClosed = errors.New("client connection is closed")

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2_test package jrpc2_test
import ( import (
@ -22,8 +24,7 @@ type Msg struct {
Text string `json:"msg"` Text string `json:"msg"`
} }
func startServer() server.Local { var local = server.NewLocal(handler.Map{
return server.NewLocal(handler.Map{
"Hello": handler.New(func(ctx context.Context) string { "Hello": handler.New(func(ctx context.Context) string {
return "Hello, world!" return "Hello, world!"
}), }),
@ -34,17 +35,12 @@ func startServer() server.Local {
fmt.Println("Log:", msg.Text) fmt.Println("Log:", msg.Text)
return true, nil return true, nil
}), }),
}, nil) }, nil)
}
func ExampleNewServer() { func ExampleNewServer() {
// Construct a new server with methods "Hello" and "Log".
loc := startServer()
defer loc.Close()
// We can query the server for its current status information, including a // We can query the server for its current status information, including a
// list of its methods. // list of its methods.
si := loc.Server.ServerInfo() si := local.Server.ServerInfo()
fmt.Println(strings.Join(si.Methods, "\n")) fmt.Println(strings.Join(si.Methods, "\n"))
// Output: // Output:
@ -54,10 +50,7 @@ func ExampleNewServer() {
} }
func ExampleClient_Call() { func ExampleClient_Call() {
loc := startServer() rsp, err := local.Client.Call(ctx, "Hello", nil)
defer loc.Close()
rsp, err := loc.Client.Call(ctx, "Hello", nil)
if err != nil { if err != nil {
log.Fatalf("Call: %v", err) log.Fatalf("Call: %v", err)
} }
@ -71,11 +64,8 @@ func ExampleClient_Call() {
} }
func ExampleClient_CallResult() { func ExampleClient_CallResult() {
loc := startServer()
defer loc.Close()
var msg string var msg string
if err := loc.Client.CallResult(ctx, "Hello", nil, &msg); err != nil { if err := local.Client.CallResult(ctx, "Hello", nil, &msg); err != nil {
log.Fatalf("CallResult: %v", err) log.Fatalf("CallResult: %v", err)
} }
fmt.Println(msg) fmt.Println(msg)
@ -84,10 +74,7 @@ func ExampleClient_CallResult() {
} }
func ExampleClient_Batch() { func ExampleClient_Batch() {
loc := startServer() rsps, err := local.Client.Batch(ctx, []jrpc2.Spec{
defer loc.Close()
rsps, err := loc.Client.Batch(ctx, []jrpc2.Spec{
{Method: "Hello"}, {Method: "Hello"},
{Method: "Log", Params: Msg{"Sing it!"}, Notify: true}, {Method: "Log", Params: Msg{"Sing it!"}, Notify: true},
}) })
@ -164,10 +151,7 @@ type strictParams struct {
func (strictParams) DisallowUnknownFields() {} func (strictParams) DisallowUnknownFields() {}
func ExampleResponse_UnmarshalResult() { func ExampleResponse_UnmarshalResult() {
loc := startServer() rsp, err := local.Client.Call(ctx, "Echo", []string{"alpha", "oscar", "kilo"})
defer loc.Close()
rsp, err := loc.Client.Call(ctx, "Echo", []string{"alpha", "oscar", "kilo"})
if err != nil { if err != nil {
log.Fatalf("Call: %v", err) log.Fatalf("Call: %v", err)
} }

View File

@ -1,7 +1,8 @@
module github.com/creachadair/jrpc2 module github.com/creachadair/jrpc2
require ( require (
github.com/google/go-cmp v0.5.6 github.com/fortytw2/leaktest v1.3.0
github.com/google/go-cmp v0.5.7
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
) )

View File

@ -1,5 +1,7 @@
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package handler_test package handler_test
import ( import (
@ -67,15 +69,15 @@ func ExampleObj_unmarshal() {
// uid=501, name="P. T. Barnum" // uid=501, name="P. T. Barnum"
} }
func ExamplePositional() { func describe(_ context.Context, name string, age int, isOld bool) error {
fn := func(ctx context.Context, name string, age int, isOld bool) error { fmt.Printf("%s is %d (old: %v)\n", name, age, isOld)
fmt.Printf("%s is %d (is old: %v)\n", name, age, isOld)
return nil return nil
} }
call := handler.NewPos(fn, "name", "age", "isOld")
req, err := jrpc2.ParseRequests([]byte(` func ExamplePositional_object() {
{ call := handler.NewPos(describe, "name", "age", "isOld")
req := mustParseReq(`{
"jsonrpc": "2.0", "jsonrpc": "2.0",
"id": 1, "id": 1,
"method": "foo", "method": "foo",
@ -84,13 +86,36 @@ func ExamplePositional() {
"age": 37, "age": 37,
"isOld": false "isOld": false
} }
}`)) }`)
if err != nil { if _, err := call(context.Background(), req); err != nil {
log.Fatalf("Parse: %v", err)
}
if _, err := call(context.Background(), req[0]); err != nil {
log.Fatalf("Call: %v", err) log.Fatalf("Call: %v", err)
} }
// Output: // Output:
// Dennis is 37 (is old: false) // Dennis is 37 (old: false)
}
func ExamplePositional_array() {
call := handler.NewPos(describe, "name", "age", "isOld")
req := mustParseReq(`{
"jsonrpc": "2.0",
"id": 1,
"method": "foo",
"params": ["Marvin", 973000, true]
}`)
if _, err := call(context.Background(), req); err != nil {
log.Fatalf("Call: %v", err)
}
// Output:
// Marvin is 973000 (old: true)
}
func mustParseReq(s string) *jrpc2.Request {
reqs, err := jrpc2.ParseRequests([]byte(s))
if err != nil {
log.Fatalf("ParseRequests: %v", err)
} else if len(reqs) == 0 {
log.Fatal("ParseRequests: empty result")
}
return reqs[0]
} }

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Package handler provides implementations of the jrpc2.Assigner interface, // Package handler provides implementations of the jrpc2.Assigner interface,
// and support for adapting functions to the jrpc2.Handler interface. // and support for adapting functions to the jrpc2.Handler interface.
package handler package handler
@ -7,7 +9,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"reflect" "reflect"
"sort" "sort"
"strings" "strings"
@ -31,7 +32,7 @@ type Map map[string]jrpc2.Handler
// Assign implements part of the jrpc2.Assigner interface. // Assign implements part of the jrpc2.Assigner interface.
func (m Map) Assign(_ context.Context, method string) jrpc2.Handler { return m[method] } func (m Map) Assign(_ context.Context, method string) jrpc2.Handler { return m[method] }
// Names implements part of the jrpc2.Assigner interface. // Names implements the optional jrpc2.Namer extension interface.
func (m Map) Names() []string { func (m Map) Names() []string {
var names []string var names []string
for name := range m { for name := range m {
@ -64,7 +65,12 @@ func (m ServiceMap) Assign(ctx context.Context, method string) jrpc2.Handler {
func (m ServiceMap) Names() []string { func (m ServiceMap) Names() []string {
var all []string var all []string
for svc, assigner := range m { for svc, assigner := range m {
for _, name := range assigner.Names() { namer, ok := assigner.(jrpc2.Namer)
if !ok {
all = append(all, svc+".*")
continue
}
for _, name := range namer.Names() {
all = append(all, svc+"."+name) all = append(all, svc+"."+name)
} }
} }
@ -116,6 +122,7 @@ type FuncInfo struct {
Result reflect.Type // the non-error result type, or nil Result reflect.Type // the non-error result type, or nil
ReportsError bool // true if the function reports an error ReportsError bool // true if the function reports an error
strictFields bool // enforce strict field checking strictFields bool // enforce strict field checking
posNames []string // positional field names (requires strictFields)
fn interface{} // the original function value fn interface{} // the original function value
} }
@ -151,16 +158,20 @@ func (fi *FuncInfo) Wrap() Func {
} }
// If strict field checking is desired, ensure arguments are wrapped. // If strict field checking is desired, ensure arguments are wrapped.
arg := fi.Argument
wrapArg := func(v reflect.Value) interface{} { return v.Interface() } wrapArg := func(v reflect.Value) interface{} { return v.Interface() }
if fi.strictFields && !fi.Argument.Implements(strictType) { if fi.strictFields && arg != nil && !arg.Implements(strictType) {
wrapArg = func(v reflect.Value) interface{} { return &strict{v.Interface()} } names := fi.posNames
wrapArg = func(v reflect.Value) interface{} {
return &strict{v: v.Interface(), posNames: names}
}
} }
// Construct a function to unpack the parameters from the request message, // Construct a function to unpack the parameters from the request message,
// based on the signature of the user's callback. // based on the signature of the user's callback.
var newInput func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) var newInput func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error)
if fi.Argument == nil { if arg == nil {
// Case 1: The function does not want any request parameters. // Case 1: The function does not want any request parameters.
// Nothing needs to be decoded, but verify no parameters were passed. // Nothing needs to be decoded, but verify no parameters were passed.
newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) { newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) {
@ -170,16 +181,16 @@ func (fi *FuncInfo) Wrap() Func {
return []reflect.Value{ctx}, nil return []reflect.Value{ctx}, nil
} }
} else if fi.Argument == reqType { } else if arg == reqType {
// Case 2: The function wants the underlying *jrpc2.Request value. // Case 2: The function wants the underlying *jrpc2.Request value.
newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) { newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) {
return []reflect.Value{ctx, reflect.ValueOf(req)}, nil return []reflect.Value{ctx, reflect.ValueOf(req)}, nil
} }
} else if fi.Argument.Kind() == reflect.Ptr { } else if arg.Kind() == reflect.Ptr {
// Case 3a: The function wants a pointer to its argument value. // Case 3a: The function wants a pointer to its argument value.
newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) { newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) {
in := reflect.New(fi.Argument.Elem()) in := reflect.New(arg.Elem())
if err := req.UnmarshalParams(wrapArg(in)); err != nil { if err := req.UnmarshalParams(wrapArg(in)); err != nil {
return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err) return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err)
} }
@ -188,7 +199,7 @@ func (fi *FuncInfo) Wrap() Func {
} else { } else {
// Case 3b: The function wants a bare argument value. // Case 3b: The function wants a bare argument value.
newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) { newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) {
in := reflect.New(fi.Argument) // we still need a pointer to unmarshal in := reflect.New(arg) // we still need a pointer to unmarshal
if err := req.UnmarshalParams(wrapArg(in)); err != nil { if err := req.UnmarshalParams(wrapArg(in)); err != nil {
return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err) return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err)
} }
@ -254,16 +265,17 @@ func (fi *FuncInfo) Wrap() Func {
// Note that the JSON-RPC standard restricts encoded parameter values to arrays // Note that the JSON-RPC standard restricts encoded parameter values to arrays
// and objects. Check will accept argument types that do not encode to arrays // and objects. Check will accept argument types that do not encode to arrays
// or objects, but the wrapper will report an error when decoding the request. // or objects, but the wrapper will report an error when decoding the request.
//
// The recommended solution is to define a struct type for your parameters. // The recommended solution is to define a struct type for your parameters.
// For arbitrary single value types, however, another approach is to wrap it in //
// a 1-element array, for example: // For a single arbitrary type, another approach is to use a 1-element array:
// //
// func(ctx context.Context, sp [1]string) error { // func(ctx context.Context, sp [1]string) error {
// s := sp[0] // pull the actual argument out of the array // s := sp[0] // pull the actual argument out of the array
// // ... // // ...
// } // }
// //
// For more complex positional signatures, see also handler.Positional.
//
func Check(fn interface{}) (*FuncInfo, error) { func Check(fn interface{}) (*FuncInfo, error) {
if fn == nil { if fn == nil {
return nil, errors.New("nil function") return nil, errors.New("nil function")
@ -299,97 +311,48 @@ func Check(fn interface{}) (*FuncInfo, error) {
return info, nil return info, nil
} }
// Args is a wrapper that decodes an array of positional parameters into
// concrete locations.
//
// Unmarshaling a JSON value into an Args value v succeeds if the JSON encodes
// an array with length len(v), and unmarshaling each subvalue i into the
// corresponding v[i] succeeds. As a special case, if v[i] == nil the
// corresponding value is discarded.
//
// Marshaling an Args value v into JSON succeeds if each element of the slice
// is JSON marshalable, and yields a JSON array of length len(v) containing the
// JSON values corresponding to the elements of v.
//
// Usage example:
//
// func Handler(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
// var x, y int
// var s string
//
// if err := req.UnmarshalParams(&handler.Args{&x, &y, &s}); err != nil {
// return nil, err
// }
// // do useful work with x, y, and s
// }
//
type Args []interface{}
// UnmarshalJSON supports JSON unmarshaling for a.
func (a Args) UnmarshalJSON(data []byte) error {
var elts []json.RawMessage
if err := json.Unmarshal(data, &elts); err != nil {
return filterJSONError("args", "array", err)
} else if len(elts) != len(a) {
return fmt.Errorf("wrong number of args (got %d, want %d)", len(elts), len(a))
}
for i, elt := range elts {
if a[i] == nil {
continue
} else if err := json.Unmarshal(elt, a[i]); err != nil {
return fmt.Errorf("decoding argument %d: %w", i+1, err)
}
}
return nil
}
// MarshalJSON supports JSON marshaling for a.
func (a Args) MarshalJSON() ([]byte, error) {
if len(a) == 0 {
return []byte(`[]`), nil
}
return json.Marshal([]interface{}(a))
}
// Obj is a wrapper that maps object fields into concrete locations.
//
// Unmarshaling a JSON text into an Obj value v succeeds if the JSON encodes an
// object, and unmarshaling the value for each key k of the object into v[k]
// succeeds. If k does not exist in v, it is ignored.
//
// Marshaling an Obj into JSON works as for an ordinary map.
type Obj map[string]interface{}
// UnmarshalJSON supports JSON unmarshaling into o.
func (o Obj) UnmarshalJSON(data []byte) error {
var base map[string]json.RawMessage
if err := json.Unmarshal(data, &base); err != nil {
return filterJSONError("decoding", "object", err)
}
for key, arg := range o {
val, ok := base[key]
if !ok {
continue
} else if err := json.Unmarshal(val, arg); err != nil {
return fmt.Errorf("decoding %q: %v", key, err)
}
}
return nil
}
func filterJSONError(tag, want string, err error) error {
if t, ok := err.(*json.UnmarshalTypeError); ok {
return fmt.Errorf("%s: cannot decode %s as %s", tag, t.Value, want)
}
return err
}
// strict is a wrapper for an arbitrary value that enforces strict field // strict is a wrapper for an arbitrary value that enforces strict field
// checking when unmarshaling from JSON. // checking when unmarshaling from JSON, and handles translation of array
type strict struct{ v interface{} } // format into object format.
type strict struct {
v interface{}
posNames []string
}
// translate translates the raw JSON data into the correct format for
// unmarshaling into s.v.
//
// If s.posNames is set and data encodes an array, the array is rewritten to an
// equivalent object with field names assigned by the positional names.
// Otherwise, data is returned as-is without error.
func (s *strict) translate(data []byte) ([]byte, error) {
if len(s.posNames) == 0 || firstByte(data) != '[' {
return data, nil // no names, or not an array
}
// Decode the array wrapper and verify it has the correct length.
var arr []json.RawMessage
if err := json.Unmarshal(data, &arr); err != nil {
return nil, err
} else if len(arr) != len(s.posNames) {
return nil, jrpc2.Errorf(code.InvalidParams, "got %d parameters, want %d",
len(arr), len(s.posNames))
}
// Rewrite the array into an object.
obj := make(map[string]json.RawMessage, len(s.posNames))
for i, name := range s.posNames {
obj[name] = arr[i]
}
return json.Marshal(obj)
}
func (s *strict) UnmarshalJSON(data []byte) error { func (s *strict) UnmarshalJSON(data []byte) error {
dec := json.NewDecoder(bytes.NewReader(data)) actual, err := s.translate(data)
if err != nil {
return err
}
dec := json.NewDecoder(bytes.NewReader(actual))
dec.DisallowUnknownFields() dec.DisallowUnknownFields()
return dec.Decode(s.v) return dec.Decode(s.v)
} }

View File

@ -1,9 +1,13 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package handler_test package handler_test
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"strconv"
"testing" "testing"
"github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2"
@ -69,6 +73,49 @@ func TestCheck(t *testing.T) {
} }
} }
// Verify that the wrappers constructed by FuncInfo.Wrap can properly decode
// their arguments of different types and structure.
func TestFuncInfo_wrapDecode(t *testing.T) {
tests := []struct {
fn handler.Func
p string
want interface{}
}{
// A positional handler should decode its argument from an array or an object.
{handler.NewPos(func(_ context.Context, z int) int { return z }, "arg"),
`[25]`, 25},
{handler.NewPos(func(_ context.Context, z int) int { return z }, "arg"),
`{"arg":109}`, 109},
// A type with custom marshaling should be properly handled.
{handler.NewPos(func(_ context.Context, b stringByte) byte { return byte(b) }, "arg"),
`["00111010"]`, byte(0x3a)},
{handler.NewPos(func(_ context.Context, b stringByte) byte { return byte(b) }, "arg"),
`{"arg":"10011100"}`, byte(0x9c)},
{handler.New(func(_ context.Context, v fauxStruct) int { return int(v) }),
`{"type":"thing","value":99}`, 99},
// Plain JSON should get its argument unmodified.
{handler.New(func(_ context.Context, v json.RawMessage) string { return string(v) }),
`{"x": true, "y": null}`, `{"x": true, "y": null}`},
// Npn-positional slice argument.
{handler.New(func(_ context.Context, ss []string) int { return len(ss) }),
`["a", "b", "c"]`, 3},
}
ctx := context.Background()
for _, test := range tests {
req := mustParseRequest(t,
fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"x","params":%s}`, test.p))
got, err := test.fn(ctx, req)
if err != nil {
t.Errorf("Call %v failed: %v", test.fn, err)
} else if diff := cmp.Diff(test.want, got); diff != "" {
t.Errorf("Call %v: wrong result (-want, +got)\n%s", test.fn, diff)
}
}
}
// Verify that the Positional function correctly handles its cases. // Verify that the Positional function correctly handles its cases.
func TestPositional(t *testing.T) { func TestPositional(t *testing.T) {
tests := []struct { tests := []struct {
@ -132,6 +179,17 @@ func TestNewStrict(t *testing.T) {
} }
} }
// Verify that a handler with no argument type does not panic attempting to
// enforce strict field checking.
func TestNewStrict_argumentRegression(t *testing.T) {
defer func() {
if x := recover(); x != nil {
t.Fatalf("NewStrict panic: %v", x)
}
}()
handler.NewStrict(func(context.Context) error { return nil })
}
// Verify that the handling of pointer-typed arguments does not incorrectly // Verify that the handling of pointer-typed arguments does not incorrectly
// introduce another pointer indirection. // introduce another pointer indirection.
func TestNew_pointerRegression(t *testing.T) { func TestNew_pointerRegression(t *testing.T) {
@ -173,14 +231,18 @@ func TestPositional_decode(t *testing.T) {
bad bool bad bool
}{ }{
{`{"jsonrpc":"2.0","id":1,"method":"add","params":{"first":5,"second":3}}`, 8, false}, {`{"jsonrpc":"2.0","id":1,"method":"add","params":{"first":5,"second":3}}`, 8, false},
{`{"jsonrpc":"2.0","id":2,"method":"add","params":{"first":5}}`, 5, false}, {`{"jsonrpc":"2.0","id":2,"method":"add","params":[5,3]}`, 8, false},
{`{"jsonrpc":"2.0","id":3,"method":"add","params":{"second":3}}`, 3, false}, {`{"jsonrpc":"2.0","id":3,"method":"add","params":{"first":5}}`, 5, false},
{`{"jsonrpc":"2.0","id":4,"method":"add","params":{}}`, 0, false}, {`{"jsonrpc":"2.0","id":4,"method":"add","params":{"second":3}}`, 3, false},
{`{"jsonrpc":"2.0","id":5,"method":"add","params":null}`, 0, false}, {`{"jsonrpc":"2.0","id":5,"method":"add","params":{}}`, 0, false},
{`{"jsonrpc":"2.0","id":6,"method":"add"}`, 0, false}, {`{"jsonrpc":"2.0","id":6,"method":"add","params":null}`, 0, false},
{`{"jsonrpc":"2.0","id":7,"method":"add"}`, 0, false},
{`{"jsonrpc":"2.0","id":6,"method":"add","params":["wrong", "type"]}`, 0, true}, {`{"jsonrpc":"2.0","id":10,"method":"add","params":["wrong", "type"]}`, 0, true},
{`{"jsonrpc":"2.0","id":6,"method":"add","params":{"unknown":"field"}}`, 0, true}, {`{"jsonrpc":"2.0","id":12,"method":"add","params":[15, "wrong-type"]}`, 0, true},
{`{"jsonrpc":"2.0","id":13,"method":"add","params":{"unknown":"field"}}`, 0, true},
{`{"jsonrpc":"2.0","id":14,"method":"add","params":[1]}`, 0, true}, // too few
{`{"jsonrpc":"2.0","id":15,"method":"add","params":[1,2,3]}`, 0, true}, // too many
} }
for _, test := range tests { for _, test := range tests {
req := mustParseRequest(t, test.input) req := mustParseRequest(t, test.input)
@ -386,3 +448,38 @@ func mustParseRequest(t *testing.T, text string) *jrpc2.Request {
} }
return req[0] return req[0]
} }
// stringByte is a byte with a custom JSON encoding. It expects a string of
// decimal digits 1 and 0, e.g., "10011000" == 0x98.
type stringByte byte
func (s *stringByte) UnmarshalText(text []byte) error {
v, err := strconv.ParseUint(string(text), 2, 8)
if err != nil {
return err
}
*s = stringByte(v)
return nil
}
// fauxStruct is an integer with a custom JSON encoding. It expects an object:
//
// {"type":"thing","value":<integer>}
//
type fauxStruct int
func (s *fauxStruct) UnmarshalJSON(data []byte) error {
var tmp struct {
T string `json:"type"`
V *int `json:"value"`
}
if err := json.Unmarshal(data, &tmp); err != nil {
return err
} else if tmp.T != "thing" {
return fmt.Errorf("unknown type %q", tmp.T)
} else if tmp.V == nil {
return errors.New("missing value")
}
*s = fauxStruct(*tmp.V)
return nil
}

103
vendor/github.com/creachadair/jrpc2/handler/helpers.go generated vendored Normal file
View File

@ -0,0 +1,103 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package handler
import (
"bytes"
"encoding/json"
"fmt"
)
// Args is a wrapper that decodes an array of positional parameters into
// concrete locations.
//
// Unmarshaling a JSON value into an Args value v succeeds if the JSON encodes
// an array with length len(v), and unmarshaling each subvalue i into the
// corresponding v[i] succeeds. As a special case, if v[i] == nil the
// corresponding value is discarded.
//
// Marshaling an Args value v into JSON succeeds if each element of the slice
// is JSON marshalable, and yields a JSON array of length len(v) containing the
// JSON values corresponding to the elements of v.
//
// Usage example:
//
// func Handler(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
// var x, y int
// var s string
//
// if err := req.UnmarshalParams(&handler.Args{&x, &y, &s}); err != nil {
// return nil, err
// }
// // do useful work with x, y, and s
// }
//
type Args []interface{}
// UnmarshalJSON supports JSON unmarshaling for a.
func (a Args) UnmarshalJSON(data []byte) error {
var elts []json.RawMessage
if err := json.Unmarshal(data, &elts); err != nil {
return filterJSONError("args", "array", err)
} else if len(elts) != len(a) {
return fmt.Errorf("wrong number of args (got %d, want %d)", len(elts), len(a))
}
for i, elt := range elts {
if a[i] == nil {
continue
} else if err := json.Unmarshal(elt, a[i]); err != nil {
return fmt.Errorf("decoding argument %d: %w", i+1, err)
}
}
return nil
}
// MarshalJSON supports JSON marshaling for a.
func (a Args) MarshalJSON() ([]byte, error) {
if len(a) == 0 {
return []byte(`[]`), nil
}
return json.Marshal([]interface{}(a))
}
// Obj is a wrapper that maps object fields into concrete locations.
//
// Unmarshaling a JSON text into an Obj value v succeeds if the JSON encodes an
// object, and unmarshaling the value for each key k of the object into v[k]
// succeeds. If k does not exist in v, it is ignored.
//
// Marshaling an Obj into JSON works as for an ordinary map.
type Obj map[string]interface{}
// UnmarshalJSON supports JSON unmarshaling into o.
func (o Obj) UnmarshalJSON(data []byte) error {
var base map[string]json.RawMessage
if err := json.Unmarshal(data, &base); err != nil {
return filterJSONError("decoding", "object", err)
}
for key, arg := range o {
val, ok := base[key]
if !ok {
continue
} else if err := json.Unmarshal(val, arg); err != nil {
return fmt.Errorf("decoding %q: %v", key, err)
}
}
return nil
}
func filterJSONError(tag, want string, err error) error {
if t, ok := err.(*json.UnmarshalTypeError); ok {
return fmt.Errorf("%s: cannot decode %s as %s", tag, t.Value, want)
}
return err
}
// firstByte returns the first non-whitespace byte of data, or 0 if there is none.
func firstByte(data []byte) byte {
clean := bytes.TrimSpace(data)
if len(clean) == 0 {
return 0
}
return clean[0]
}

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package handler package handler
import ( import (
@ -26,9 +28,9 @@ func NewPos(fn interface{}, names ...string) Func {
// value of fn must be a function with one of the following type signature // value of fn must be a function with one of the following type signature
// schemes: // schemes:
// //
// func(context.Context, X1, x2, ..., Xn) (Y, error) // func(context.Context, X1, X2, ..., Xn) (Y, error)
// func(context.Context, X1, x2, ..., Xn) Y // func(context.Context, X1, X2, ..., Xn) Y
// func(context.Context, X1, x2, ..., Xn) error // func(context.Context, X1, X2, ..., Xn) error
// //
// For JSON-marshalable types X_i and Y. If fn does not have one of these // For JSON-marshalable types X_i and Y. If fn does not have one of these
// forms, Positional reports an error. The given names must match the number of // forms, Positional reports an error. The given names must match the number of
@ -56,6 +58,13 @@ func NewPos(fn interface{}, names ...string) Func {
// field keys generate an error. The field names are not required to match the // field keys generate an error. The field names are not required to match the
// parameter names declared by the function; it is the names assigned here that // parameter names declared by the function; it is the names assigned here that
// determine which object keys are accepted. // determine which object keys are accepted.
//
// The wrapped function will also accept a JSON array with with (exactly) the
// same number of elements as the positional parameters:
//
// [17, 23]
//
// Unlike the object format, no arguments can be omitted in this format.
func Positional(fn interface{}, names ...string) (*FuncInfo, error) { func Positional(fn interface{}, names ...string) (*FuncInfo, error) {
if fn == nil { if fn == nil {
return nil, errors.New("nil function") return nil, errors.New("nil function")
@ -85,6 +94,7 @@ func Positional(fn interface{}, names ...string) (*FuncInfo, error) {
fi, err := Check(makeCaller(ft, fv, atype)) fi, err := Check(makeCaller(ft, fv, atype))
if err == nil { if err == nil {
fi.strictFields = true fi.strictFields = true
fi.posNames = names
} }
return fi, err return fi, err
} }

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2 package jrpc2
// This file contains tests that need to inspect the internal details of the // This file contains tests that need to inspect the internal details of the
@ -12,6 +14,7 @@ import (
"github.com/creachadair/jrpc2/channel" "github.com/creachadair/jrpc2/channel"
"github.com/creachadair/jrpc2/code" "github.com/creachadair/jrpc2/code"
"github.com/fortytw2/leaktest"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
) )
@ -150,6 +153,8 @@ func (h hmap) Names() []string { return nil }
// Verify that if the client context terminates during a request, the client // Verify that if the client context terminates during a request, the client
// will terminate and report failure. // will terminate and report failure.
func TestClient_contextCancellation(t *testing.T) { func TestClient_contextCancellation(t *testing.T) {
defer leaktest.Check(t)()
started := make(chan struct{}) started := make(chan struct{})
stopped := make(chan struct{}) stopped := make(chan struct{})
cpipe, spipe := channel.Direct() cpipe, spipe := channel.Direct()
@ -203,6 +208,8 @@ func TestClient_contextCancellation(t *testing.T) {
} }
func TestServer_specialMethods(t *testing.T) { func TestServer_specialMethods(t *testing.T) {
defer leaktest.Check(t)()
s := NewServer(hmap{ s := NewServer(hmap{
"rpc.nonesuch": methodFunc(func(context.Context, *Request) (interface{}, error) { "rpc.nonesuch": methodFunc(func(context.Context, *Request) (interface{}, error) {
return "OK", nil return "OK", nil
@ -225,6 +232,8 @@ func TestServer_specialMethods(t *testing.T) {
// Verify that the option to remove the special behaviour of rpc.* methods can // Verify that the option to remove the special behaviour of rpc.* methods can
// be correctly disabled by the server options. // be correctly disabled by the server options.
func TestServer_disableBuiltinHook(t *testing.T) { func TestServer_disableBuiltinHook(t *testing.T) {
defer leaktest.Check(t)()
s := NewServer(hmap{ s := NewServer(hmap{
"rpc.nonesuch": methodFunc(func(context.Context, *Request) (interface{}, error) { "rpc.nonesuch": methodFunc(func(context.Context, *Request) (interface{}, error) {
return "OK", nil return "OK", nil
@ -249,6 +258,8 @@ func TestServer_disableBuiltinHook(t *testing.T) {
// request. The Client never sends requests like that, but the server needs to // request. The Client never sends requests like that, but the server needs to
// cope with it correctly. // cope with it correctly.
func TestBatchReply(t *testing.T) { func TestBatchReply(t *testing.T) {
defer leaktest.Check(t)()
cpipe, spipe := channel.Direct() cpipe, spipe := channel.Direct()
srv := NewServer(hmap{ srv := NewServer(hmap{
"test": methodFunc(func(_ context.Context, req *Request) (interface{}, error) { "test": methodFunc(func(_ context.Context, req *Request) (interface{}, error) {

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jctx_test package jctx_test
import ( import (

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Package jctx implements an encoder and decoder for request context values, // Package jctx implements an encoder and decoder for request context values,
// allowing context metadata to be propagated through JSON-RPC. // allowing context metadata to be propagated through JSON-RPC.
// //
@ -66,7 +68,7 @@ func Encode(ctx context.Context, method string, params json.RawMessage) (json.Ra
v := wireVersion v := wireVersion
c := wireContext{V: &v, Payload: params} c := wireContext{V: &v, Payload: params}
if dl, ok := ctx.Deadline(); ok { if dl, ok := ctx.Deadline(); ok {
utcdl := dl.In(time.UTC) utcdl := dl.UTC()
c.Deadline = &utcdl c.Deadline = &utcdl
} }
@ -102,7 +104,7 @@ func Decode(ctx context.Context, method string, req json.RawMessage) (context.Co
} }
if c.Deadline != nil && !c.Deadline.IsZero() { if c.Deadline != nil && !c.Deadline.IsZero() {
var ignored context.CancelFunc var ignored context.CancelFunc
ctx, ignored = context.WithDeadline(ctx, (*c.Deadline).In(time.UTC)) ctx, ignored = context.WithDeadline(ctx, (*c.Deadline).UTC())
_ = ignored // the caller cannot use this value _ = ignored // the caller cannot use this value
} }

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jctx package jctx
import ( import (

View File

@ -1,15 +1,15 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Package jhttp implements a bridge from HTTP to JSON-RPC. This permits // Package jhttp implements a bridge from HTTP to JSON-RPC. This permits
// requests to be submitted to a JSON-RPC server using HTTP as a transport. // requests to be submitted to a JSON-RPC server using HTTP as a transport.
package jhttp package jhttp
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strconv"
"github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/server" "github.com/creachadair/jrpc2/server"
@ -17,36 +17,53 @@ import (
// A Bridge is a http.Handler that bridges requests to a JSON-RPC server. // A Bridge is a http.Handler that bridges requests to a JSON-RPC server.
// //
// The body of the HTTP POST request must contain the complete JSON-RPC request // By default, the bridge accepts only HTTP POST requests with the complete
// message, encoded with Content-Type: application/json. Either a single // JSON-RPC request message in the body, with Content-Type application/json.
// request object or a list of request objects is supported. // Either a single request object or a list of request objects is supported.
//
// If the request completes, whether or not there is an error, the HTTP
// response is 200 (OK) for ordinary requests or 204 (No Response) for
// notifications, and the response body contains the JSON-RPC response.
// //
// If the HTTP request method is not "POST", the bridge reports 405 (Method Not // If the HTTP request method is not "POST", the bridge reports 405 (Method Not
// Allowed). If the Content-Type is not application/json, the bridge reports // Allowed). If the Content-Type is not application/json, the bridge reports
// 415 (Unsupported Media Type). // 415 (Unsupported Media Type).
// //
// If a ParseRequest hook is set, these requirements are disabled, and the hook
// is entirely responsible for checking request structure.
//
// If a ParseGETRequest hook is set, HTTP "GET" requests are handled by a
// Getter using that hook; otherwise "GET" requests are handled as above.
//
// If the request completes, whether or not there is an error, the HTTP
// response is 200 (OK) for ordinary requests or 204 (No Response) for
// notifications, and the response body contains the JSON-RPC response.
//
// The bridge attaches the inbound HTTP request to the context passed to the // The bridge attaches the inbound HTTP request to the context passed to the
// client, allowing an EncodeContext callback to retrieve state from the HTTP // client, allowing an EncodeContext callback to retrieve state from the HTTP
// headers. Use jhttp.HTTPRequest to retrieve the request from the context. // headers. Use jhttp.HTTPRequest to retrieve the request from the context.
type Bridge struct { type Bridge struct {
local server.Local local server.Local
checkType func(string) bool parseReq func(*http.Request) ([]*jrpc2.Request, error)
getter *Getter
} }
// ServeHTTP implements the required method of http.Handler. // ServeHTTP implements the required method of http.Handler.
func (b Bridge) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (b Bridge) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// If a GET hook is defined, allow GET requests.
if req.Method == "GET" && b.getter != nil {
b.getter.ServeHTTP(w, req)
return
}
// If no parse hook is defined, insist that the method is POST and the
// content-type is application/json. Setting a hook disables these checks.
if b.parseReq == nil {
if req.Method != "POST" { if req.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
return return
} }
if !b.checkType(req.Header.Get("Content-Type")) { if req.Header.Get("Content-Type") != "application/json" {
w.WriteHeader(http.StatusUnsupportedMediaType) w.WriteHeader(http.StatusUnsupportedMediaType)
return return
} }
}
if err := b.serveInternal(w, req); err != nil { if err := b.serveInternal(w, req); err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintln(w, err.Error()) fmt.Fprintln(w, err.Error())
@ -54,11 +71,6 @@ func (b Bridge) ServeHTTP(w http.ResponseWriter, req *http.Request) {
} }
func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error { func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error {
body, err := io.ReadAll(req.Body)
if err != nil {
return err
}
// The HTTP request requires a response, but the server will not reply if // The HTTP request requires a response, but the server will not reply if
// all the requests are notifications. Check whether we have any calls // all the requests are notifications. Check whether we have any calls
// needing a response, and choose whether to wait for a reply based on that. // needing a response, and choose whether to wait for a reply based on that.
@ -66,7 +78,7 @@ func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error {
// Note that we are forgiving about a missing version marker in a request, // Note that we are forgiving about a missing version marker in a request,
// since we can't tell at this point whether the server is willing to accept // since we can't tell at this point whether the server is willing to accept
// messages like that. // messages like that.
jreq, err := jrpc2.ParseRequests(body) jreq, err := b.parseHTTPRequest(req)
if err != nil && err != jrpc2.ErrInvalidVersion { if err != nil && err != jrpc2.ErrInvalidVersion {
return err return err
} }
@ -118,26 +130,44 @@ func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error {
rsp.SetID(inboundID[i]) rsp.SetID(inboundID[i])
} }
// If the original request was a single message, make sure we encode the return b.encodeResponses(rsps, w)
// response the same way. }
var reply []byte
if len(rsps) == 1 && !bytes.HasPrefix(bytes.TrimSpace(body), []byte("[")) { func (b Bridge) parseHTTPRequest(req *http.Request) ([]*jrpc2.Request, error) {
reply, err = json.Marshal(rsps[0]) if b.parseReq != nil {
} else { return b.parseReq(req)
reply, err = json.Marshal(rsps)
} }
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
return jrpc2.ParseRequests(body)
}
func (b Bridge) encodeResponses(rsps []*jrpc2.Response, w http.ResponseWriter) error {
// If there is only a single reply, send it alone; otherwise encode a batch.
// Per the spec (https://www.jsonrpc.org/specification#batch), this is OK;
// we are not required to respond to a batch with an array:
//
// The Server SHOULD respond with an Array containing the corresponding
// Response objects
//
data, err := marshalResponses(rsps)
if err != nil { if err != nil {
return err return err
} }
w.Header().Set("Content-Type", "application/json") writeJSON(w, http.StatusOK, json.RawMessage(data))
w.Header().Set("Content-Length", strconv.Itoa(len(reply)))
w.Write(reply)
return nil return nil
} }
// Close closes the channel to the server, waits for the server to exit, and // Close closes the channel to the server, waits for the server to exit, and
// reports its exit status. // reports its exit status.
func (b Bridge) Close() error { return b.local.Close() } func (b Bridge) Close() error {
if b.getter != nil {
b.getter.Close()
}
return b.local.Close()
}
// NewBridge constructs a new Bridge that starts a server on mux and dispatches // NewBridge constructs a new Bridge that starts a server on mux and dispatches
// HTTP requests to it. The server will run until the bridge is closed. // HTTP requests to it. The server will run until the bridge is closed.
@ -149,13 +179,22 @@ func (b Bridge) Close() error { return b.local.Close() }
// hooks on the bridge client as usual, but the remote client will not see push // hooks on the bridge client as usual, but the remote client will not see push
// messages from the server. // messages from the server.
func NewBridge(mux jrpc2.Assigner, opts *BridgeOptions) Bridge { func NewBridge(mux jrpc2.Assigner, opts *BridgeOptions) Bridge {
return Bridge{ b := Bridge{
local: server.NewLocal(mux, &server.LocalOptions{ local: server.NewLocal(mux, &server.LocalOptions{
Client: opts.clientOptions(), Client: opts.clientOptions(),
Server: opts.serverOptions(), Server: opts.serverOptions(),
}), }),
checkType: opts.checkContentType(), parseReq: opts.parseRequest(),
} }
if pget := opts.parseGETRequest(); pget != nil {
g := NewGetter(mux, &GetterOptions{
Client: opts.clientOptions(),
Server: opts.serverOptions(),
ParseRequest: pget,
})
b.getter = &g
}
return b
} }
// BridgeOptions are optional settings for a Bridge. A nil pointer is ready for // BridgeOptions are optional settings for a Bridge. A nil pointer is ready for
@ -167,11 +206,22 @@ type BridgeOptions struct {
// Options for the bridge server (default nil). // Options for the bridge server (default nil).
Server *jrpc2.ServerOptions Server *jrpc2.ServerOptions
// If non-nil, this function is called to check whether the HTTP request's // If non-nil, this function is called to parse JSON-RPC requests from the
// declared content-type is valid. If this function returns false, the // HTTP request body. If this function reports an error, the request fails.
// request is rejected. If nil, the default check requires a content type of // By default, the bridge uses jrpc2.ParseRequests on the HTTP request body.
// "application/json". //
CheckContentType func(contentType string) bool // Setting this hook disables the default requirement that the request
// method be POST and the content-type be application/json.
ParseRequest func(*http.Request) ([]*jrpc2.Request, error)
// If non-nil, this function is used to parse a JSON-RPC method name and
// parameters from the URL of an HTTP GET request. If this function reports
// an error, the request fails.
//
// If this hook is set, all GET requests are handled by a Getter using this
// parse function, and are not passed to a ParseRequest hook even if one is
// defined.
ParseGETRequest func(*http.Request) (string, interface{}, error)
} }
func (o *BridgeOptions) clientOptions() *jrpc2.ClientOptions { func (o *BridgeOptions) clientOptions() *jrpc2.ClientOptions {
@ -188,11 +238,18 @@ func (o *BridgeOptions) serverOptions() *jrpc2.ServerOptions {
return o.Server return o.Server
} }
func (o *BridgeOptions) checkContentType() func(string) bool { func (o *BridgeOptions) parseRequest() func(*http.Request) ([]*jrpc2.Request, error) {
if o == nil || o.CheckContentType == nil { if o == nil {
return func(ctype string) bool { return ctype == "application/json" } return nil
} }
return o.CheckContentType return o.ParseRequest
}
func (o *BridgeOptions) parseGETRequest() func(*http.Request) (string, interface{}, error) {
if o == nil {
return nil
}
return o.ParseGETRequest
} }
type httpReqKey struct{} type httpReqKey struct{}
@ -206,3 +263,11 @@ func HTTPRequest(ctx context.Context) *http.Request {
} }
return nil return nil
} }
// marshalResponses encodes a batch of JSON-RPC responses into JSON.
func marshalResponses(rsps []*jrpc2.Response) ([]byte, error) {
if len(rsps) == 1 {
return json.Marshal(rsps[0])
}
return json.Marshal(rsps)
}

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jhttp package jhttp
import ( import (

View File

@ -1,46 +1,45 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jhttp_test package jhttp_test
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"log" "log"
"net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/handler"
"github.com/creachadair/jrpc2/jhttp" "github.com/creachadair/jrpc2/jhttp"
) )
func Example() { func Example() {
// Set up a bridge to demonstrate the API. // Set up a bridge exporting a simple service.
b := jhttp.NewBridge(handler.Map{ b := jhttp.NewBridge(handler.Map{
"Test": handler.New(func(ctx context.Context, ss []string) (string, error) { "Test": handler.New(func(ctx context.Context, ss []string) string {
return strings.Join(ss, " "), nil return strings.Join(ss, " ")
}), }),
}, nil) }, nil)
defer b.Close() defer b.Close()
// The bridge can be used as the handler for an HTTP server.
hsrv := httptest.NewServer(b) hsrv := httptest.NewServer(b)
defer hsrv.Close() defer hsrv.Close()
rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{ // Set up a client using an HTTP channel, and use it to call the test
"jsonrpc": "2.0", // service exported by the bridge.
"id": 10235, ch := jhttp.NewChannel(hsrv.URL, nil)
"method": "Test", cli := jrpc2.NewClient(ch, nil)
"params": ["full", "plate", "and", "packing", "steel"]
}`)) var result string
if err != nil { if err := cli.CallResult(context.Background(), "Test", []string{
log.Fatalf("POST request failed: %v", err) "full", "plate", "and", "packing", "steel",
} }, &result); err != nil {
body, err := io.ReadAll(rsp.Body) log.Fatalf("Call failed: %v", err)
rsp.Body.Close()
if err != nil {
log.Fatalf("Reading response body: %v", err)
} }
fmt.Println(string(body)) fmt.Println("Result:", result)
// Output: // Output:
// {"jsonrpc":"2.0","id":10235,"result":"full plate and packing steel"} // Result: full plate and packing steel
} }

288
vendor/github.com/creachadair/jrpc2/jhttp/getter.go generated vendored Normal file
View File

@ -0,0 +1,288 @@
// Copyright (C) 2021 Michael J. Fromberger. All Rights Reserved.
package jhttp
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/code"
"github.com/creachadair/jrpc2/server"
)
// A Getter is a http.Handler that bridges GET requests to a JSON-RPC server.
//
// The JSON-RPC method name and parameters are decoded from the request URL.
// The results from a successful call are encoded as JSON in the response body
// with status 200 (OK). In case of error, the response body is a JSON-RPC
// error object, and the HTTP status is one of the following:
//
// Condition HTTP Status
// ----------------------- -----------------------------------
// Parsing request 400 (Bad request)
// Method not found 404 (Not found)
// (other errors) 500 (Internal server error)
//
// By default, the URL path identifies the JSON-RPC method, and the URL query
// parameters are converted into a JSON object for the parameters. Leading and
// trailing slashes are stripped from the path, and query values are sent as
// JSON strings.
//
// For example, this URL:
//
// http://site.org:2112/some/method?param1=xyzzy&param2=apple
//
// would produce the method name "some/method" and this parameter object:
//
// {"param1":"xyzzy", "param2":"apple"}
//
// To override the default behaviour, set a ParseRequest hook in GetterOptions.
// See also the jhttp.ParseQuery function for a more expressive translation.
type Getter struct {
local server.Local
parseReq func(*http.Request) (string, interface{}, error)
}
// NewGetter constructs a new Getter that starts a server on mux and dispatches
// HTTP requests to it. The server will run until the getter is closed.
//
// Note that a getter is not able to push calls or notifications from the
// server back to the remote client even if enabled.
func NewGetter(mux jrpc2.Assigner, opts *GetterOptions) Getter {
return Getter{
local: server.NewLocal(mux, &server.LocalOptions{
Client: opts.clientOptions(),
Server: opts.serverOptions(),
}),
parseReq: opts.parseRequest(),
}
}
// ServeHTTP implements the required method of http.Handler.
func (g Getter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
method, params, err := g.parseHTTPRequest(req)
if err != nil {
writeJSON(w, http.StatusBadRequest, &jrpc2.Error{
Code: code.ParseError,
Message: err.Error(),
})
return
}
ctx := context.WithValue(req.Context(), httpReqKey{}, req)
var result json.RawMessage
if err := g.local.Client.CallResult(ctx, method, params, &result); err != nil {
var status int
switch code.FromError(err) {
case code.MethodNotFound:
status = http.StatusNotFound
default:
status = http.StatusInternalServerError
}
writeJSON(w, status, err)
return
}
writeJSON(w, http.StatusOK, result)
}
// Close closes the channel to the server, waits for the server to exit, and
// reports its exit status.
func (g Getter) Close() error { return g.local.Close() }
func (g Getter) parseHTTPRequest(req *http.Request) (string, interface{}, error) {
if g.parseReq != nil {
return g.parseReq(req)
}
if err := req.ParseForm(); err != nil {
return "", nil, err
}
method := strings.Trim(req.URL.Path, "/")
if method == "" {
return "", nil, errors.New("empty method name")
}
params := make(map[string]string)
for key := range req.Form {
params[key] = req.Form.Get(key)
}
return method, params, nil
}
// GetterOptions are optional settings for a Getter. A nil pointer is ready for
// use and provides default values as described.
type GetterOptions struct {
// Options for the getter client (default nil).
Client *jrpc2.ClientOptions
// Options for the getter server (default nil).
Server *jrpc2.ServerOptions
// If set, this function is called to parse a method name and request
// parameters from an HTTP request. If this is not set, the default handler
// uses the URL path as the method name and the URL query as the method
// parameters.
ParseRequest func(*http.Request) (string, interface{}, error)
}
func (o *GetterOptions) clientOptions() *jrpc2.ClientOptions {
if o == nil {
return nil
}
return o.Client
}
func (o *GetterOptions) serverOptions() *jrpc2.ServerOptions {
if o == nil {
return nil
}
return o.Server
}
func (o *GetterOptions) parseRequest() func(*http.Request) (string, interface{}, error) {
if o == nil {
return nil
}
return o.ParseRequest
}
func writeJSON(w http.ResponseWriter, code int, obj interface{}) {
bits, err := json.Marshal(obj)
if err != nil {
// Fallback in case of marshaling error. This should not happen, but
// ensures the client gets a loggable reply from a broken server.
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintln(w, err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(bits)))
w.WriteHeader(code)
w.Write(bits)
}
// ParseQuery parses a request URL and constructs a parameter map from the
// query values encoded in the URL and/or request body.
//
// The method name is the URL path, with leading and trailing slashes trimmed.
// Query values are converted into argument values by these rules:
//
// Double-quoted values are interpreted as JSON string values, with the same
// encoding and escaping rules (UTF-8 with backslash escapes). Examples:
//
// ""
// "foo\nbar"
// "a \"string\" of text"
//
// Values that consist of decimal digits and an optional leading sign are
// treated as either int64 (if there is no decimal point) or float64 values.
// Examples:
//
// 25
// -16
// 3.259
//
// The unquoted strings "true" and "false" are converted to the corresponding
// Boolean values. The unquoted string "null" is converted to nil.
//
// To express arbitrary bytes, use a singly-quoted string encoded in base64.
// For example:
//
// 'aGVsbG8sIHdvcmxk' -- represents "hello, world"
//
// All values not matching any of the above are treated as literal strings.
//
// On success, the result has concrete type map[string]interface{} and the
// method name is not empty.
func ParseQuery(req *http.Request) (string, interface{}, error) {
if err := req.ParseForm(); err != nil {
return "", nil, err
}
method := strings.Trim(req.URL.Path, "/")
if method == "" {
return "", nil, errors.New("empty URL path")
}
if len(req.Form) == 0 {
return method, nil, nil
}
params := make(map[string]interface{})
for key := range req.Form {
val := req.Form.Get(key)
if v, ok, err := parseJSONString(val); err != nil {
return "", nil, fmt.Errorf("decoding string %q: %w", key, err)
} else if ok {
params[key] = v
} else if n, ok := parseNumber(val); ok {
params[key] = n
} else if b, ok := parseConstant(val); ok {
params[key] = b
} else if d, ok, err := parseQuoted64(val); err != nil {
return "", nil, fmt.Errorf("decoding bytes %q: %w", key, err)
} else if ok {
params[key] = d
} else {
params[key] = val
}
}
return method, params, nil
}
func parseJSONString(s string) (string, bool, error) {
if len(s) >= 2 {
if s[0] == '"' && s[len(s)-1] == '"' {
var dec string
err := json.Unmarshal([]byte(s), &dec)
if err != nil {
return "", false, err
}
return dec, true, nil
} else if s[0] == '"' || s[len(s)-1] == '"' {
return "", false, errors.New("missing string quote")
}
}
return "", false, nil
}
func parseNumber(s string) (interface{}, bool) {
z, err := strconv.ParseInt(s, 10, 64)
if err == nil {
return z, true
}
v, err := strconv.ParseFloat(s, 64)
if err == nil {
return v, true
}
return nil, false
}
func parseConstant(s string) (interface{}, bool) {
switch s {
case "true":
return true, true
case "false":
return false, true
case "null":
return nil, true
default:
return nil, false
}
}
func parseQuoted64(s string) ([]byte, bool, error) {
if len(s) >= 2 {
if s[0] == '\'' && s[len(s)-1] == '\'' {
trim := strings.TrimRight(s[1:len(s)-1], "=") // discard base64 padding
dec, err := base64.RawStdEncoding.DecodeString(trim)
return dec, err == nil, err
} else if s[0] == '\'' || s[len(s)-1] == '\'' {
return nil, false, errors.New("missing bytes quote")
}
}
return nil, false, nil
}

View File

@ -0,0 +1,236 @@
// Copyright (C) 2021 Michael J. Fromberger. All Rights Reserved.
package jhttp_test
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/handler"
"github.com/creachadair/jrpc2/jhttp"
"github.com/fortytw2/leaktest"
"github.com/google/go-cmp/cmp"
)
func TestGetter(t *testing.T) {
defer leaktest.Check(t)()
mux := handler.Map{
"concat": handler.NewPos(func(ctx context.Context, a, b string) string {
return a + b
}, "first", "second"),
}
g := jhttp.NewGetter(mux, &jhttp.GetterOptions{
Client: &jrpc2.ClientOptions{EncodeContext: checkContext},
})
defer checkClose(t, g)
hsrv := httptest.NewServer(g)
defer hsrv.Close()
url := func(pathQuery string) string {
return hsrv.URL + "/" + pathQuery
}
t.Run("OK", func(t *testing.T) {
got := mustGet(t, url("concat?second=world&first=hello"), http.StatusOK)
const want = `"helloworld"`
if got != want {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
t.Run("NotFound", func(t *testing.T) {
got := mustGet(t, url("nonesuch"), http.StatusNotFound)
const want = `"code":-32601` // MethodNotFound
if !strings.Contains(got, want) {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
t.Run("BadRequest", func(t *testing.T) {
// N.B. invalid query string
got := mustGet(t, url("concat?x%2"), http.StatusBadRequest)
const want = `"code":-32700` // ParseError
if !strings.Contains(got, want) {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
t.Run("InternalError", func(t *testing.T) {
got := mustGet(t, url("concat?third=c"), http.StatusInternalServerError)
const want = `"code":-32602` // InvalidParams
if !strings.Contains(got, want) {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
}
func TestParseQuery(t *testing.T) {
tests := []struct {
url string
body string
method string
want interface{}
errText string
}{
// Error: Missing method name.
{"http://localhost:2112/", "", "", nil, "empty URL path"},
// No parameters.
{"http://localhost/foo", "", "foo", nil, ""},
// Unbalanced double-quoted string.
{"https://fuzz.ball/foo?bad=%22xyz", "", "foo", nil, "missing string quote"},
{"https://fuzz.ball/bar?bad=xyz%22", "", "bar", nil, "missing string quote"},
// Unbalanced single-quoted string.
{`http://stripe/sister?bad='invalid`, "", "sister", nil, "missing bytes quote"},
{`http://stripe/sister?bad=invalid'`, "", "sister", nil, "missing bytes quote"},
// Invalid byte string.
{`http://green.as/balls?bad='NOT%20VALID'`, "", "balls", nil, "decoding bytes"},
// Invalid double-quoted string.
{`http://black.as/sin?bad=%22a%5Cx25%22`, "", "sin", nil, "invalid character"},
// Valid: Single-quoted byte string (base64).
{`http://fast.as.hell/and?twice='YXMgcHJldHR5IGFzIHlvdQ=='`,
"", "and", map[string]interface{}{
"twice": []byte("as pretty as you"),
}, ""},
// Valid: Unquoted strings and null.
{`http://head.like/a-hole?black=as&your=null&soul`,
"", "a-hole", map[string]interface{}{
"black": "as",
"your": nil,
"soul": "",
}, ""},
// Valid: Quoted strings, numbers, Booleans.
{`http://foo.com:1999/go/west/?alpha=%22xyz%22&bravo=3&charlie=true&delta=false&echo=3.2`,
"", "go/west", map[string]interface{}{
"alpha": "xyz",
"bravo": int64(3),
"charlie": true,
"delta": false,
"echo": 3.2,
}, ""},
// Valid: Form-encoded query in the request body.
{`http://buz.org:2013/bodyblow`,
"alpha=%22pdq%22&bravo=-19.4&charlie=false", "bodyblow", map[string]interface{}{
"alpha": "pdq",
"bravo": float64(-19.4),
"charlie": false,
}, ""},
}
for _, test := range tests {
t.Run("ParseQuery", func(t *testing.T) {
req, err := http.NewRequest("PUT", test.url, strings.NewReader(test.body))
if err != nil {
t.Fatalf("New request for %q failed: %v", test.url, err)
}
if test.body != "" {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
method, params, err := jhttp.ParseQuery(req)
if err != nil {
if test.errText == "" {
t.Fatalf("ParseQuery failed: %v", err)
} else if !strings.Contains(err.Error(), test.errText) {
t.Fatalf("ParseQuery: got error %v, want %q", err, test.errText)
}
} else if test.errText != "" {
t.Fatalf("ParseQuery: got method %q, params %+v, wanted error %q",
method, params, test.errText)
} else {
if method != test.method {
t.Errorf("ParseQuery method: got %q, want %q", method, test.method)
}
if diff := cmp.Diff(test.want, params); diff != "" {
t.Errorf("Wrong params: (-want, +got)\n%s", diff)
}
}
})
}
}
func TestGetter_parseRequest(t *testing.T) {
defer leaktest.Check(t)()
mux := handler.Map{
"format": handler.NewPos(func(ctx context.Context, a string, b int) string {
return fmt.Sprintf("%s-%d", a, b)
}, "tag", "value"),
}
g := jhttp.NewGetter(mux, &jhttp.GetterOptions{
ParseRequest: func(req *http.Request) (string, interface{}, error) {
if err := req.ParseForm(); err != nil {
return "", nil, err
}
tag := req.Form.Get("tag")
val, err := strconv.ParseInt(req.Form.Get("value"), 10, 64)
if err != nil && req.Form.Get("value") != "" {
return "", nil, fmt.Errorf("invalid number: %w", err)
}
return strings.TrimPrefix(req.URL.Path, "/x/"), map[string]interface{}{
"tag": tag,
"value": val,
}, nil
},
})
defer checkClose(t, g)
hsrv := httptest.NewServer(g)
defer hsrv.Close()
url := func(pathQuery string) string {
return hsrv.URL + "/" + pathQuery
}
t.Run("OK", func(t *testing.T) {
got := mustGet(t, url("x/format?tag=foo&value=25"), http.StatusOK)
const want = `"foo-25"`
if got != want {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
t.Run("NotFound", func(t *testing.T) {
// N.B. Missing path prefix.
got := mustGet(t, url("format"), http.StatusNotFound)
const want = `"code":-32601` // MethodNotFound
if !strings.Contains(got, want) {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
t.Run("InternalError", func(t *testing.T) {
// N.B. Parameter type does not match on the server side.
got := mustGet(t, url("x/format?tag=foo&value=bar"), http.StatusBadRequest)
const want = `"code":-32700` // ParseError
if !strings.Contains(got, want) {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
}
func mustGet(t *testing.T, url string, code int) string {
t.Helper()
rsp, err := http.Get(url)
if err != nil {
t.Fatalf("GET request failed: %v", err)
} else if got := rsp.StatusCode; got != code {
t.Errorf("GET response code: got %v, want %v", got, code)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Errorf("Reading GET body: %v", err)
}
return string(body)
}

View File

@ -1,10 +1,11 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jhttp_test package jhttp_test
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -14,6 +15,7 @@ import (
"github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/handler"
"github.com/creachadair/jrpc2/jhttp" "github.com/creachadair/jrpc2/jhttp"
"github.com/fortytw2/leaktest"
) )
var testService = handler.Map{ var testService = handler.Map{
@ -33,6 +35,8 @@ func checkContext(ctx context.Context, _ string, p json.RawMessage) (json.RawMes
} }
func TestBridge(t *testing.T) { func TestBridge(t *testing.T) {
defer leaktest.Check(t)()
// Set up a bridge with the test configuration. // Set up a bridge with the test configuration.
b := jhttp.NewBridge(testService, &jhttp.BridgeOptions{ b := jhttp.NewBridge(testService, &jhttp.BridgeOptions{
Client: &jrpc2.ClientOptions{EncodeContext: checkContext}, Client: &jrpc2.ClientOptions{EncodeContext: checkContext},
@ -45,49 +49,29 @@ func TestBridge(t *testing.T) {
// Verify that a valid POST request succeeds. // Verify that a valid POST request succeeds.
t.Run("PostOK", func(t *testing.T) { t.Run("PostOK", func(t *testing.T) {
rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{ got := mustPost(t, hsrv.URL, `{
"jsonrpc": "2.0", "jsonrpc": "2.0",
"id": 1, "id": 1,
"method": "Test1", "method": "Test1",
"params": ["a", "foolish", "consistency", "is", "the", "hobgoblin"] "params": ["a", "foolish", "consistency", "is", "the", "hobgoblin"]
} }`, http.StatusOK)
`))
if err != nil {
t.Fatalf("POST request failed: %v", err)
} else if got, want := rsp.StatusCode, http.StatusOK; got != want {
t.Errorf("POST response code: got %v, want %v", got, want)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Errorf("Reading POST body: %v", err)
}
const want = `{"jsonrpc":"2.0","id":1,"result":6}` const want = `{"jsonrpc":"2.0","id":1,"result":6}`
if got := string(body); got != want { if got != want {
t.Errorf("POST body: got %#q, want %#q", got, want) t.Errorf("POST body: got %#q, want %#q", got, want)
} }
}) })
// Verify that the bridge will accept a batch. // Verify that the bridge will accept a batch.
t.Run("PostBatchOK", func(t *testing.T) { t.Run("PostBatchOK", func(t *testing.T) {
rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`[ got := mustPost(t, hsrv.URL, `[
{"jsonrpc":"2.0", "id": 3, "method": "Test1", "params": ["first"]}, {"jsonrpc":"2.0", "id": 3, "method": "Test1", "params": ["first"]},
{"jsonrpc":"2.0", "id": 7, "method": "Test1", "params": ["among", "equals"]} {"jsonrpc":"2.0", "id": 7, "method": "Test1", "params": ["among", "equals"]}
] ]`, http.StatusOK)
`))
if err != nil {
t.Fatalf("POST request failed: %v", err)
} else if got, want := rsp.StatusCode, http.StatusOK; got != want {
t.Errorf("POST response code: got %v, want %v", got, want)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Errorf("Reading POST body: %v", err)
}
const want = `[{"jsonrpc":"2.0","id":3,"result":1},` + const want = `[{"jsonrpc":"2.0","id":3,"result":1},` +
`{"jsonrpc":"2.0","id":7,"result":2}]` `{"jsonrpc":"2.0","id":7,"result":2}]`
if got := string(body); got != want { if got != want {
t.Errorf("POST body: got %#q, want %#q", got, want) t.Errorf("POST body: got %#q, want %#q", got, want)
} }
}) })
@ -108,62 +92,51 @@ func TestBridge(t *testing.T) {
rsp, err := http.Post(hsrv.URL, "text/plain", strings.NewReader(`{}`)) rsp, err := http.Post(hsrv.URL, "text/plain", strings.NewReader(`{}`))
if err != nil { if err != nil {
t.Fatalf("POST request failed: %v", err) t.Fatalf("POST request failed: %v", err)
} } else if got, want := rsp.StatusCode, http.StatusUnsupportedMediaType; got != want {
if got, want := rsp.StatusCode, http.StatusUnsupportedMediaType; got != want { t.Errorf("POST response code: got %v, want %v", got, want)
t.Errorf("POST status: got %v, want %v", got, want)
} }
}) })
// Verify that a POST that generates a JSON-RPC error succeeds. // Verify that a POST that generates a JSON-RPC error succeeds.
t.Run("PostErrorReply", func(t *testing.T) { t.Run("PostErrorReply", func(t *testing.T) {
rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{ got := mustPost(t, hsrv.URL, `{
"id": 1, "id": 1,
"jsonrpc": "2.0" "jsonrpc": "2.0"
} }`, http.StatusOK)
`))
if err != nil {
t.Fatalf("POST request failed: %v", err)
} else if got, want := rsp.StatusCode, http.StatusOK; got != want {
t.Errorf("POST status: got %v, want %v", got, want)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Errorf("Reading POST body: %v", err)
}
const exp = `{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"empty method name"}}` const exp = `{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"empty method name"}}`
if got := string(body); got != exp { if got != exp {
t.Errorf("POST body: got %#q, want %#q", got, exp) t.Errorf("POST body: got %#q, want %#q", got, exp)
} }
}) })
// Verify that a notification returns an empty success. // Verify that a notification returns an empty success.
t.Run("PostNotification", func(t *testing.T) { t.Run("PostNotification", func(t *testing.T) {
rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{ got := mustPost(t, hsrv.URL, `{
"jsonrpc": "2.0", "jsonrpc": "2.0",
"method": "TakeNotice", "method": "TakeNotice",
"params": [] "params": []
}`)) }`, http.StatusNoContent)
if err != nil { if got != "" {
t.Fatalf("POST request failed: %v", err)
} else if got, want := rsp.StatusCode, http.StatusNoContent; got != want {
t.Errorf("POST status: got %v, want %v", got, want)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Errorf("Reading POST body: %v", err)
}
if got := string(body); got != "" {
t.Errorf("POST body: got %q, want empty", got) t.Errorf("POST body: got %q, want empty", got)
} }
}) })
} }
// Verify that the content-type check hook works. // Verify that the request-parsing hook works.
func TestBridge_contentTypeCheck(t *testing.T) { func TestBridge_parseRequest(t *testing.T) {
defer leaktest.Check(t)()
const reqMessage = `{"jsonrpc":"2.0", "method": "Test2", "id": 100, "params":null}`
const wantReply = `{"jsonrpc":"2.0","id":100,"result":0}`
b := jhttp.NewBridge(testService, &jhttp.BridgeOptions{ b := jhttp.NewBridge(testService, &jhttp.BridgeOptions{
CheckContentType: func(ctype string) bool { ParseRequest: func(req *http.Request) ([]*jrpc2.Request, error) {
return ctype == "application/octet-stream" action := req.Header.Get("x-test-header")
if action == "fail" {
return nil, errors.New("parse hook reporting failure")
}
return jrpc2.ParseRequests([]byte(reqMessage))
}, },
}) })
defer checkClose(t, b) defer checkClose(t, b)
@ -171,29 +144,93 @@ func TestBridge_contentTypeCheck(t *testing.T) {
hsrv := httptest.NewServer(b) hsrv := httptest.NewServer(b)
defer hsrv.Close() defer hsrv.Close()
const reqTemplate = `{"jsonrpc":"2.0","id":%q,"method":"Test1","params":["a","b","c"]}` t.Run("Succeed", func(t *testing.T) {
t.Run("ContentTypeOK", func(t *testing.T) { // Since a parse hook is set, the method and content-type checks should not occur.
rsp, err := http.Post(hsrv.URL, "application/octet-stream", // We send an empty body to be sure the request comes from the hook.
strings.NewReader(fmt.Sprintf(reqTemplate, "ok"))) req, err := http.NewRequest("GET", hsrv.URL, strings.NewReader(""))
if err != nil { if err != nil {
t.Fatalf("POST request failed: %v", err) t.Fatalf("NewRequest: %v", err)
}
rsp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("GET request failed: %v", err)
} else if got, want := rsp.StatusCode, http.StatusOK; got != want { } else if got, want := rsp.StatusCode, http.StatusOK; got != want {
t.Errorf("POST response code: got %v, want %v", got, want) t.Errorf("GET response code: got %v, want %v", got, want)
}
body, _ := io.ReadAll(rsp.Body)
rsp.Body.Close()
if got := string(body); got != wantReply {
t.Errorf("Response: got %#q, want %#q", got, wantReply)
} }
}) })
t.Run("ContentTypeBad", func(t *testing.T) { t.Run("Fail", func(t *testing.T) {
rsp, err := http.Post(hsrv.URL, "text/plain", req, err := http.NewRequest("POST", hsrv.URL, strings.NewReader(""))
strings.NewReader(fmt.Sprintf(reqTemplate, "bad"))) if err != nil {
t.Fatalf("NewRequest: %v", err)
}
req.Header.Set("X-Test-Header", "fail")
rsp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Fatalf("POST request failed: %v", err) t.Fatalf("POST request failed: %v", err)
} else if got, want := rsp.StatusCode, http.StatusUnsupportedMediaType; got != want { } else if got, want := rsp.StatusCode, http.StatusInternalServerError; got != want {
t.Errorf("POST response code: got %v, want %v", got, want) t.Errorf("POST response code: got %v, want %v", got, want)
} }
}) })
} }
func TestBridge_parseGETRequest(t *testing.T) {
defer leaktest.Check(t)()
mux := handler.Map{
"str/eq": handler.NewPos(func(ctx context.Context, a, b string) bool {
return a == b
}, "lhs", "rhs"),
}
b := jhttp.NewBridge(mux, &jhttp.BridgeOptions{
ParseGETRequest: func(req *http.Request) (string, interface{}, error) {
if err := req.ParseForm(); err != nil {
return "", nil, err
}
method := strings.Trim(req.URL.Path, "/")
params := make(map[string]string)
for key := range req.Form {
params[key] = req.Form.Get(key)
}
return method, params, nil
},
})
defer checkClose(t, b)
hsrv := httptest.NewServer(b)
defer hsrv.Close()
url := func(pathQuery string) string {
return hsrv.URL + "/" + pathQuery
}
t.Run("GET", func(t *testing.T) {
got := mustGet(t, url("str/eq?rhs=fizz&lhs=buzz"), http.StatusOK)
const want = `false`
if got != want {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
t.Run("POST", func(t *testing.T) {
const req = `{"jsonrpc":"2.0", "id":1, "method":"str/eq", "params":["foo","foo"]}`
got := mustPost(t, hsrv.URL, req, http.StatusOK)
const want = `{"jsonrpc":"2.0","id":1,"result":true}`
if got != want {
t.Errorf("Response body: got %#q, want %#q", got, want)
}
})
}
func TestChannel(t *testing.T) { func TestChannel(t *testing.T) {
defer leaktest.Check(t)()
b := jhttp.NewBridge(testService, nil) b := jhttp.NewBridge(testService, nil)
defer checkClose(t, b) defer checkClose(t, b)
hsrv := httptest.NewServer(b) hsrv := httptest.NewServer(b)
@ -259,3 +296,18 @@ func checkClose(t *testing.T, c io.Closer) {
t.Errorf("Error in Close: %v", err) t.Errorf("Error in Close: %v", err)
} }
} }
func mustPost(t *testing.T, url, req string, code int) string {
t.Helper()
rsp, err := http.Post(url, "application/json", strings.NewReader(req))
if err != nil {
t.Fatalf("POST request failed: %v", err)
} else if got := rsp.StatusCode; got != code {
t.Errorf("POST response code: got %v, want %v", got, code)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Errorf("Reading POST body: %v", err)
}
return string(body)
}

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2_test package jrpc2_test
import ( import (
@ -6,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"sort" "sort"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -16,6 +19,7 @@ import (
"github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/handler"
"github.com/creachadair/jrpc2/jctx" "github.com/creachadair/jrpc2/jctx"
"github.com/creachadair/jrpc2/server" "github.com/creachadair/jrpc2/server"
"github.com/fortytw2/leaktest"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
) )
@ -24,8 +28,6 @@ var (
_ code.ErrCoder = (*jrpc2.Error)(nil) _ code.ErrCoder = (*jrpc2.Error)(nil)
) )
var notAuthorized = code.Register(-32095, "request not authorized")
var testOK = handler.New(func(ctx context.Context) (string, error) { var testOK = handler.New(func(ctx context.Context) (string, error) {
return "OK", nil return "OK", nil
}) })
@ -114,6 +116,8 @@ var callTests = []struct {
} }
func TestServerInfo_methodNames(t *testing.T) { func TestServerInfo_methodNames(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(handler.ServiceMap{ loc := server.NewLocal(handler.ServiceMap{
"Test": testService, "Test": testService,
}, nil) }, nil)
@ -130,13 +134,12 @@ func TestServerInfo_methodNames(t *testing.T) {
} }
func TestClient_Call(t *testing.T) { func TestClient_Call(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(handler.ServiceMap{ loc := server.NewLocal(handler.ServiceMap{
"Test": testService, "Test": testService,
}, &server.LocalOptions{ }, &server.LocalOptions{
Server: &jrpc2.ServerOptions{ Server: &jrpc2.ServerOptions{Concurrency: 16},
AllowV1: true,
Concurrency: 16,
},
}) })
defer loc.Close() defer loc.Close()
c := loc.Client c := loc.Client
@ -164,6 +167,8 @@ func TestClient_Call(t *testing.T) {
} }
func TestClient_CallResult(t *testing.T) { func TestClient_CallResult(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(handler.ServiceMap{ loc := server.NewLocal(handler.ServiceMap{
"Test": testService, "Test": testService,
}, &server.LocalOptions{ }, &server.LocalOptions{
@ -187,13 +192,12 @@ func TestClient_CallResult(t *testing.T) {
} }
func TestClient_Batch(t *testing.T) { func TestClient_Batch(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(handler.ServiceMap{ loc := server.NewLocal(handler.ServiceMap{
"Test": testService, "Test": testService,
}, &server.LocalOptions{ }, &server.LocalOptions{
Server: &jrpc2.ServerOptions{ Server: &jrpc2.ServerOptions{Concurrency: 16},
AllowV1: true,
Concurrency: 16,
},
}) })
defer loc.Close() defer loc.Close()
c := loc.Client c := loc.Client
@ -238,6 +242,8 @@ func TestClient_Batch(t *testing.T) {
// Verify that notifications respect order of arrival. // Verify that notifications respect order of arrival.
func TestServer_notificationOrder(t *testing.T) { func TestServer_notificationOrder(t *testing.T) {
defer leaktest.Check(t)()
var last int32 var last int32
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
@ -269,6 +275,8 @@ func TestServer_notificationOrder(t *testing.T) {
// Verify that a method that returns only an error (no result payload) is set // Verify that a method that returns only an error (no result payload) is set
// up and handled correctly. // up and handled correctly.
func TestHandler_errorOnly(t *testing.T) { func TestHandler_errorOnly(t *testing.T) {
defer leaktest.Check(t)()
const errMessage = "not enough strings" const errMessage = "not enough strings"
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
"ErrorOnly": handler.New(func(_ context.Context, ss []string) error { "ErrorOnly": handler.New(func(_ context.Context, ss []string) error {
@ -309,9 +317,11 @@ func TestHandler_errorOnly(t *testing.T) {
}) })
} }
// Verify that a timeout set on the context is respected by the server and // Verify that a timeout set on the client context is respected and reports
// propagates back to the client as an error. // back to the caller as an error.
func TestServer_contextTimeout(t *testing.T) { func TestClient_contextTimeout(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
"Stall": handler.New(func(ctx context.Context) (bool, error) { "Stall": handler.New(func(ctx context.Context) (bool, error) {
t.Log("Stalling...") t.Log("Stalling...")
@ -325,13 +335,12 @@ func TestServer_contextTimeout(t *testing.T) {
}), }),
}, nil) }, nil)
defer loc.Close() defer loc.Close()
c := loc.Client
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
start := time.Now() start := time.Now()
got, err := c.Call(ctx, "Stall", nil) got, err := loc.Client.Call(ctx, "Stall", nil)
if err == nil { if err == nil {
t.Errorf("Stall: got %+v, wanted error", got) t.Errorf("Stall: got %+v, wanted error", got)
} else if err != context.DeadlineExceeded { } else if err != context.DeadlineExceeded {
@ -343,6 +352,8 @@ func TestServer_contextTimeout(t *testing.T) {
// Verify that stopping the server terminates in-flight requests. // Verify that stopping the server terminates in-flight requests.
func TestServer_stopCancelsHandlers(t *testing.T) { func TestServer_stopCancelsHandlers(t *testing.T) {
defer leaktest.Check(t)()
started := make(chan struct{}) started := make(chan struct{})
stopped := make(chan error, 1) stopped := make(chan error, 1)
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
@ -379,6 +390,8 @@ func TestServer_stopCancelsHandlers(t *testing.T) {
// Test that a handler can cancel an in-flight request. // Test that a handler can cancel an in-flight request.
func TestServer_CancelRequest(t *testing.T) { func TestServer_CancelRequest(t *testing.T) {
defer leaktest.Check(t)()
ready := make(chan struct{}) ready := make(chan struct{})
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
"Stall": handler.New(func(ctx context.Context) error { "Stall": handler.New(func(ctx context.Context) error {
@ -430,6 +443,8 @@ func TestServer_CancelRequest(t *testing.T) {
// Test that an error with data attached to it is correctly propagated back // Test that an error with data attached to it is correctly propagated back
// from the server to the client, in a value of concrete type *Error. // from the server to the client, in a value of concrete type *Error.
func TestError_withData(t *testing.T) { func TestError_withData(t *testing.T) {
defer leaktest.Check(t)()
const errCode = -32000 const errCode = -32000
const errData = `{"caroline":452}` const errData = `{"caroline":452}`
const errMessage = "error thingy" const errMessage = "error thingy"
@ -482,6 +497,8 @@ func TestError_withData(t *testing.T) {
// Test that a client correctly reports bad parameters. // Test that a client correctly reports bad parameters.
func TestClient_badCallParams(t *testing.T) { func TestClient_badCallParams(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
"Test": handler.New(func(_ context.Context, v interface{}) error { "Test": handler.New(func(_ context.Context, v interface{}) error {
return jrpc2.Errorf(129, "this should not be reached") return jrpc2.Errorf(129, "this should not be reached")
@ -499,6 +516,8 @@ func TestClient_badCallParams(t *testing.T) {
// Verify that metrics are correctly propagated to server info. // Verify that metrics are correctly propagated to server info.
func TestServer_serverInfoMetrics(t *testing.T) { func TestServer_serverInfoMetrics(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
"Metricize": handler.New(func(ctx context.Context) (bool, error) { "Metricize": handler.New(func(ctx context.Context) (bool, error) {
m := jrpc2.ServerFromContext(ctx).Metrics() m := jrpc2.ServerFromContext(ctx).Metrics()
@ -529,6 +548,9 @@ func TestServer_serverInfoMetrics(t *testing.T) {
if _, err := c.Call(ctx, "Metricize", nil); err != nil { if _, err := c.Call(ctx, "Metricize", nil); err != nil {
t.Fatalf("Call(Metricize) failed: %v", err) t.Fatalf("Call(Metricize) failed: %v", err)
} }
if got := s.ServerInfo().Counter["rpc.serversActive"]; got != 1 {
t.Errorf("Metric rpc.serversActive: got %d, want 1", got)
}
loc.Close() loc.Close()
info := s.ServerInfo() info := s.ServerInfo()
@ -542,6 +564,7 @@ func TestServer_serverInfoMetrics(t *testing.T) {
{info.Counter, "zero-sum", 0}, {info.Counter, "zero-sum", 0},
{info.Counter, "rpc.bytesRead", -1}, {info.Counter, "rpc.bytesRead", -1},
{info.Counter, "rpc.bytesWritten", -1}, {info.Counter, "rpc.bytesWritten", -1},
{info.Counter, "rpc.serversActive", 0},
{info.MaxValue, "max-metric-value", 5}, {info.MaxValue, "max-metric-value", 5},
{info.MaxValue, "rpc.bytesRead", -1}, {info.MaxValue, "rpc.bytesRead", -1},
{info.MaxValue, "rpc.bytesWritten", -1}, {info.MaxValue, "rpc.bytesWritten", -1},
@ -562,6 +585,8 @@ func TestServer_serverInfoMetrics(t *testing.T) {
// elicit a correct response from the server. Here we simulate a "different" // elicit a correct response from the server. Here we simulate a "different"
// client by writing requests directly into the channel. // client by writing requests directly into the channel.
func TestServer_nonLibraryClient(t *testing.T) { func TestServer_nonLibraryClient(t *testing.T) {
defer leaktest.Check(t)()
srv, cli := channel.Direct() srv, cli := channel.Direct()
s := jrpc2.NewServer(handler.Map{ s := jrpc2.NewServer(handler.Map{
"X": testOK, "X": testOK,
@ -594,7 +619,7 @@ func TestServer_nonLibraryClient(t *testing.T) {
// The method specified doesn't exist. // The method specified doesn't exist.
{`{"jsonrpc":"2.0", "id": 3, "method": "NoneSuch"}`, {`{"jsonrpc":"2.0", "id": 3, "method": "NoneSuch"}`,
`{"jsonrpc":"2.0","id":3,"error":{"code":-32601,"message":"no such method \"NoneSuch\""}}`}, `{"jsonrpc":"2.0","id":3,"error":{"code":-32601,"message":"no such method","data":"NoneSuch"}}`},
// The parameters are of the wrong form. // The parameters are of the wrong form.
{`{"jsonrpc":"2.0", "id": 4, "method": "X", "params": "bogus"}`, {`{"jsonrpc":"2.0", "id": 4, "method": "X", "params": "bogus"}`,
@ -604,9 +629,13 @@ func TestServer_nonLibraryClient(t *testing.T) {
{`{"jsonrpc": "2.0", "id": 6, "method": "X", "params": null}`, {`{"jsonrpc": "2.0", "id": 6, "method": "X", "params": null}`,
`{"jsonrpc":"2.0","id":6,"result":"OK"}`}, `{"jsonrpc":"2.0","id":6,"result":"OK"}`},
// Correct requests, one with a non-null response, one with a null response. // Correct requests.
{`{"jsonrpc":"2.0","id": 5, "method": "X"}`, `{"jsonrpc":"2.0","id":5,"result":"OK"}`}, {`{"jsonrpc":"2.0","id": 5, "method": "X"}`, `{"jsonrpc":"2.0","id":5,"result":"OK"}`},
{`{"jsonrpc":"2.0","id":21,"method":"Y"}`, `{"jsonrpc":"2.0","id":21,"result":null}`}, {`{"jsonrpc":"2.0","id":21,"method":"Y"}`, `{"jsonrpc":"2.0","id":21,"result":null}`},
{`{"jsonrpc":"2.0","id":0,"method":"X"}`, `{"jsonrpc":"2.0","id":0,"result":"OK"}`},
{`{"jsonrpc":"2.0","id":-0,"method":"X"}`, `{"jsonrpc":"2.0","id":-0,"result":"OK"}`},
{`{"jsonrpc":"2.0","id":-1,"method":"X"}`, `{"jsonrpc":"2.0","id":-1,"result":"OK"}`},
{`{"jsonrpc":"2.0","id":-600,"method":"Y"}`, `{"jsonrpc":"2.0","id":-600,"result":null}`},
// A batch of correct requests. // A batch of correct requests.
{`[{"jsonrpc":"2.0", "id":"a1", "method":"X"}, {"jsonrpc":"2.0", "id":"a2", "method": "X"}]`, {`[{"jsonrpc":"2.0", "id":"a1", "method":"X"}, {"jsonrpc":"2.0", "id":"a2", "method": "X"}]`,
@ -625,7 +654,7 @@ func TestServer_nonLibraryClient(t *testing.T) {
// A batch of invalid requests returns a batch of errors. // A batch of invalid requests returns a batch of errors.
{`[{"jsonrpc": "2.0", "id": 6, "method":"bogus"}]`, {`[{"jsonrpc": "2.0", "id": 6, "method":"bogus"}]`,
`[{"jsonrpc":"2.0","id":6,"error":{"code":-32601,"message":"no such method \"bogus\""}}]`}, `[{"jsonrpc":"2.0","id":6,"error":{"code":-32601,"message":"no such method","data":"bogus"}}]`},
// Batch requests return batch responses, even for a singleton. // Batch requests return batch responses, even for a singleton.
{`[{"jsonrpc": "2.0", "id": 7, "method": "X"}]`, `[{"jsonrpc":"2.0","id":7,"result":"OK"}]`}, {`[{"jsonrpc": "2.0", "id": 7, "method": "X"}]`, `[{"jsonrpc":"2.0","id":7,"result":"OK"}]`},
@ -677,6 +706,8 @@ func TestServer_nonLibraryClient(t *testing.T) {
// Verify that server-side push notifications work. // Verify that server-side push notifications work.
func TestServer_Notify(t *testing.T) { func TestServer_Notify(t *testing.T) {
defer leaktest.Check(t)()
// Set up a server and client with server-side notification support. Here // Set up a server and client with server-side notification support. Here
// we're just capturing the name of the notification method, as a sign we // we're just capturing the name of the notification method, as a sign we
// got the right thing. // got the right thing.
@ -728,6 +759,8 @@ func TestServer_Notify(t *testing.T) {
// Verify that server-side callbacks can time out. // Verify that server-side callbacks can time out.
func TestServer_callbackTimeout(t *testing.T) { func TestServer_callbackTimeout(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context) error { "Test": handler.New(func(ctx context.Context) error {
tctx, cancel := context.WithTimeout(ctx, 5*time.Millisecond) tctx, cancel := context.WithTimeout(ctx, 5*time.Millisecond)
@ -757,6 +790,8 @@ func TestServer_callbackTimeout(t *testing.T) {
// Verify that server-side callbacks work. // Verify that server-side callbacks work.
func TestServer_Callback(t *testing.T) { func TestServer_Callback(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
"CallMeMaybe": handler.New(func(ctx context.Context) error { "CallMeMaybe": handler.New(func(ctx context.Context) error {
if _, err := jrpc2.ServerFromContext(ctx).Callback(ctx, "succeed", nil); err != nil { if _, err := jrpc2.ServerFromContext(ctx).Callback(ctx, "succeed", nil); err != nil {
@ -799,6 +834,8 @@ func TestServer_Callback(t *testing.T) {
// Verify that a server push after the client closes does not trigger a panic. // Verify that a server push after the client closes does not trigger a panic.
func TestServer_pushAfterClose(t *testing.T) { func TestServer_pushAfterClose(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(make(handler.Map), &server.LocalOptions{ loc := server.NewLocal(make(handler.Map), &server.LocalOptions{
Server: &jrpc2.ServerOptions{AllowPush: true}, Server: &jrpc2.ServerOptions{AllowPush: true},
}) })
@ -814,6 +851,8 @@ func TestServer_pushAfterClose(t *testing.T) {
// Verify that an OnCancel hook is called when expected. // Verify that an OnCancel hook is called when expected.
func TestClient_onCancelHook(t *testing.T) { func TestClient_onCancelHook(t *testing.T) {
defer leaktest.Check(t)()
hooked := make(chan struct{}) // closed when hook notification is finished hooked := make(chan struct{}) // closed when hook notification is finished
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
@ -870,8 +909,162 @@ func TestClient_onCancelHook(t *testing.T) {
} }
} }
// Verify that client callback handlers are cancelled when the client stops.
func TestClient_closeEndsCallbacks(t *testing.T) {
defer leaktest.Check(t)()
ready := make(chan struct{})
loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context) error {
// Call back to the client and block indefinitely until it returns.
srv := jrpc2.ServerFromContext(ctx)
_, err := srv.Callback(ctx, "whatever", nil)
return err
}),
}, &server.LocalOptions{
Server: &jrpc2.ServerOptions{AllowPush: true},
Client: &jrpc2.ClientOptions{
OnCallback: handler.New(func(ctx context.Context) error {
// Signal the test that the callback handler is running. When the
// client is closed, it should terminate ctx and allow this to
// return. If that doesn't happen, time out and fail.
close(ready)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(10 * time.Second):
return errors.New("context not cancelled before timeout")
}
}),
},
})
go func() {
rsp, err := loc.Client.Call(context.Background(), "Test", nil)
if err == nil {
t.Errorf("Client call: got %+v, wanted error", rsp)
}
}()
<-ready
loc.Client.Close()
loc.Server.Wait()
}
// Verify that it is possible for multiple callback handlers to execute
// concurrently.
func TestClient_concurrentCallbacks(t *testing.T) {
defer leaktest.Check(t)()
ready1 := make(chan struct{})
ready2 := make(chan struct{})
release := make(chan struct{})
loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context) []string {
srv := jrpc2.ServerFromContext(ctx)
// Call two callbacks concurrently, wait until they are both running,
// then ungate them and wait for them both to reply. Return their
// responses back to the test for validation.
ss := make([]string, 2)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
rsp, err := srv.Callback(ctx, "C1", nil)
if err != nil {
t.Errorf("Callback C1 failed: %v", err)
} else {
rsp.UnmarshalResult(&ss[0])
}
}()
go func() {
defer wg.Done()
rsp, err := srv.Callback(ctx, "C2", nil)
if err != nil {
t.Errorf("Callback C2 failed: %v", err)
} else {
rsp.UnmarshalResult(&ss[1])
}
}()
<-ready1 // C1 is ready
<-ready2 // C2 is ready
close(release) // allow all callbacks to proceed
wg.Wait() // wait for all callbacks to be done
return ss
}),
}, &server.LocalOptions{
Server: &jrpc2.ServerOptions{AllowPush: true},
Client: &jrpc2.ClientOptions{
OnCallback: handler.Func(func(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
// A trivial callback that reports its method name.
// The name is used to select which invocation we are serving.
switch req.Method() {
case "C1":
close(ready1)
case "C2":
close(ready2)
default:
return nil, fmt.Errorf("unexpected method %q", req.Method())
}
<-release
return req.Method(), nil
}),
},
})
defer loc.Close()
var got []string
if err := loc.Client.CallResult(context.Background(), "Test", nil, &got); err != nil {
t.Errorf("Call Test failed: %v", err)
}
want := []string{"C1", "C2"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("Wrong callback results: (-want, +got)\n%s", diff)
}
}
// Verify that a callback can successfully call "up" into the server.
func TestClient_callbackUpCall(t *testing.T) {
defer leaktest.Check(t)()
const pingMessage = "kittens!"
var probe string
loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context) error {
// Call back to the client, and propagate its response.
srv := jrpc2.ServerFromContext(ctx)
_, err := srv.Callback(ctx, "whatever", nil)
return err
}),
"Ping": handler.New(func(context.Context) string {
// This method is called by the client-side callback.
return pingMessage
}),
}, &server.LocalOptions{
Server: &jrpc2.ServerOptions{AllowPush: true},
Client: &jrpc2.ClientOptions{
OnCallback: handler.New(func(ctx context.Context) error {
// Call back up into the server.
cli := jrpc2.ClientFromContext(ctx)
return cli.CallResult(ctx, "Ping", nil, &probe)
}),
},
})
if _, err := loc.Client.Call(context.Background(), "Test", nil); err != nil {
t.Errorf("Call Test failed: %v", err)
}
loc.Close()
if probe != pingMessage {
t.Errorf("Probe response: got %q, want %q", probe, pingMessage)
}
}
// Verify that the context encoding/decoding hooks work. // Verify that the context encoding/decoding hooks work.
func TestContextPlumbing(t *testing.T) { func TestContextPlumbing(t *testing.T) {
defer leaktest.Check(t)()
want := time.Now().Add(10 * time.Second) want := time.Now().Add(10 * time.Second)
ctx, cancel := context.WithDeadline(context.Background(), want) ctx, cancel := context.WithDeadline(context.Background(), want)
defer cancel() defer cancel()
@ -898,88 +1091,11 @@ func TestContextPlumbing(t *testing.T) {
} }
} }
// Verify that the request-checking hook works.
func TestServer_checkRequestHook(t *testing.T) {
const wantResponse = "Hey girl"
const wantToken = "OK"
loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context) (string, error) {
return wantResponse, nil
}),
}, &server.LocalOptions{
// Enable auth checking and context decoding for the server.
Server: &jrpc2.ServerOptions{
DecodeContext: jctx.Decode,
CheckRequest: func(ctx context.Context, req *jrpc2.Request) error {
var token []byte
switch err := jctx.UnmarshalMetadata(ctx, &token); err {
case nil:
t.Logf("Metadata present: value=%q", string(token))
case jctx.ErrNoMetadata:
t.Log("Metadata not set")
default:
return err
}
if s := string(token); s != wantToken {
return jrpc2.Errorf(notAuthorized, "not authorized")
}
return nil
},
},
// Enable context encoding for the client.
Client: &jrpc2.ClientOptions{
EncodeContext: jctx.Encode,
},
})
defer loc.Close()
c := loc.Client
// Call without a token and verify that we get an error.
t.Run("NoToken", func(t *testing.T) {
var rsp string
err := c.CallResult(context.Background(), "Test", nil, &rsp)
if err == nil {
t.Errorf("Call(Test): got %q, wanted error", rsp)
} else if ec := code.FromError(err); ec != notAuthorized {
t.Errorf("Call(Test): got code %v, want %v", ec, notAuthorized)
}
})
// Call with a valid token and verify that we get a response.
t.Run("GoodToken", func(t *testing.T) {
ctx, err := jctx.WithMetadata(context.Background(), []byte(wantToken))
if err != nil {
t.Fatalf("Call(Test): attaching metadata: %v", err)
}
var rsp string
if err := c.CallResult(ctx, "Test", nil, &rsp); err != nil {
t.Errorf("Call(Test): unexpected error: %v", err)
}
if rsp != wantResponse {
t.Errorf("Call(Test): got %q, want %q", rsp, wantResponse)
}
})
// Call with an invalid token and verify that we get an error.
t.Run("BadToken", func(t *testing.T) {
ctx, err := jctx.WithMetadata(context.Background(), []byte("BAD"))
if err != nil {
t.Fatalf("Call(Test): attaching metadata: %v", err)
}
var rsp string
if err := c.CallResult(ctx, "Test", nil, &rsp); err == nil {
t.Errorf("Call(Test): got %q, wanted error", rsp)
} else if ec := code.FromError(err); ec != notAuthorized {
t.Errorf("Call(Test): got code %v, want %v", ec, notAuthorized)
}
})
}
// Verify that calling a wrapped method which takes no parameters, but in which // Verify that calling a wrapped method which takes no parameters, but in which
// the caller provided parameters, will correctly report an error. // the caller provided parameters, will correctly report an error.
func TestHandler_noParams(t *testing.T) { func TestHandler_noParams(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(handler.Map{"Test": testOK}, nil) loc := server.NewLocal(handler.Map{"Test": testOK}, nil)
defer loc.Close() defer loc.Close()
@ -993,6 +1109,8 @@ func TestHandler_noParams(t *testing.T) {
// Verify that the rpc.serverInfo handler and client wrapper work together. // Verify that the rpc.serverInfo handler and client wrapper work together.
func TestRPCServerInfo(t *testing.T) { func TestRPCServerInfo(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(handler.Map{"Test": testOK}, nil) loc := server.NewLocal(handler.Map{"Test": testOK}, nil)
defer loc.Close() defer loc.Close()
@ -1040,6 +1158,8 @@ func TestNetwork(t *testing.T) {
// Verify that the context passed to an assigner has the correct structure. // Verify that the context passed to an assigner has the correct structure.
func TestHandler_assignContext(t *testing.T) { func TestHandler_assignContext(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(assignFunc(func(ctx context.Context, method string) jrpc2.Handler { loc := server.NewLocal(assignFunc(func(ctx context.Context, method string) jrpc2.Handler {
req := jrpc2.InboundRequest(ctx) req := jrpc2.InboundRequest(ctx)
if req == nil { if req == nil {
@ -1065,9 +1185,10 @@ func TestHandler_assignContext(t *testing.T) {
type assignFunc func(context.Context, string) jrpc2.Handler type assignFunc func(context.Context, string) jrpc2.Handler
func (a assignFunc) Assign(ctx context.Context, m string) jrpc2.Handler { return a(ctx, m) } func (a assignFunc) Assign(ctx context.Context, m string) jrpc2.Handler { return a(ctx, m) }
func (assignFunc) Names() []string { return nil }
func TestServer_WaitStatus(t *testing.T) { func TestServer_WaitStatus(t *testing.T) {
defer leaktest.Check(t)()
check := func(t *testing.T, stat jrpc2.ServerStatus, closed, stopped bool, wantErr error) { check := func(t *testing.T, stat jrpc2.ServerStatus, closed, stopped bool, wantErr error) {
t.Helper() t.Helper()
if got, want := stat.Success(), wantErr == nil; got != want { if got, want := stat.Success(), wantErr == nil; got != want {
@ -1113,6 +1234,8 @@ func (b buggyChannel) Recv() ([]byte, error) { return []byte(b.data), b.err }
func (buggyChannel) Close() error { return nil } func (buggyChannel) Close() error { return nil }
func TestRequest_strictFields(t *testing.T) { func TestRequest_strictFields(t *testing.T) {
defer leaktest.Check(t)()
type other struct { type other struct {
C bool `json:"charlie"` C bool `json:"charlie"`
} }
@ -1121,56 +1244,97 @@ func TestRequest_strictFields(t *testing.T) {
B int `json:"bravo"` B int `json:"bravo"`
other other
} }
type result struct {
X string `json:"xray"`
}
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context, req *jrpc2.Request) (interface{}, error) { "Strict": handler.New(func(ctx context.Context, req *jrpc2.Request) (string, error) {
var ps, qs params var ps params
if err := req.UnmarshalParams(jrpc2.StrictFields(&ps)); err != nil {
if err := req.UnmarshalParams(jrpc2.StrictFields(&ps)); err == nil { return "", err
t.Errorf("Unmarshal strict: got %+v, want error", ps)
} }
return ps.A, nil
if err := req.UnmarshalParams(&qs); err != nil { }),
t.Errorf("Unmarshal non-strict (default): unexpected error: %v", err) "Normal": handler.New(func(ctx context.Context, req *jrpc2.Request) (string, error) {
var ps params
if err := req.UnmarshalParams(&ps); err != nil {
return "", err
} }
return ps.A, nil
return map[string]string{
"xray": "ok",
"gamma": "not ok",
}, nil
}), }),
}, nil) }, nil)
defer loc.Close() defer loc.Close()
ctx := context.Background() ctx := context.Background()
rsp, err := loc.Client.Call(ctx, "Test", handler.Obj{
"alpha": "foo", tests := []struct {
"bravo": 25, method string
"charlie": true, // exercise embedding params interface{}
"delta": 31.5, // unknown field code code.Code
want string
}{
{"Strict", handler.Obj{"alpha": "aiuto"}, code.NoError, "aiuto"},
{"Strict", handler.Obj{"alpha": "selva me", "charlie": true}, code.NoError, "selva me"},
{"Strict", handler.Obj{"alpha": "OK", "nonesuch": true}, code.InvalidParams, ""},
{"Normal", handler.Obj{"alpha": "OK", "nonesuch": true}, code.NoError, "OK"},
}
for _, test := range tests {
name := test.method + "/"
if test.code == code.NoError {
name += "OK"
} else {
name += test.code.String()
}
t.Run(name, func(t *testing.T) {
var res string
err := loc.Client.CallResult(ctx, test.method, test.params, &res)
if err == nil && test.code != code.NoError {
t.Errorf("CallResult: got %+v, want error code %v", res, test.code)
} else if err != nil {
if c := code.FromError(err); c != test.code {
t.Errorf("CallResult: got error %v, wanted code %v", err, test.code)
}
} else if res != test.want {
t.Errorf("CallResult: got %#q, want %#q", res, test.want)
}
}) })
}
}
func TestResponse_strictFields(t *testing.T) {
defer leaktest.Check(t)()
type result struct {
A string `json:"alpha"`
}
loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context, req *jrpc2.Request) handler.Obj {
return handler.Obj{"alpha": "OK", "bravo": "not OK"}
}),
}, nil)
defer loc.Close()
ctx := context.Background()
res, err := loc.Client.Call(ctx, "Test", nil)
if err != nil { if err != nil {
t.Fatalf("Call failed: %v", err) t.Fatalf("Call failed: %v", err)
} }
t.Run("NonStrictResult", func(t *testing.T) { t.Run("Normal", func(t *testing.T) {
var res result var got result
if err := rsp.UnmarshalResult(&res); err != nil { if err := res.UnmarshalResult(&got); err != nil {
t.Errorf("UnmarshalResult: %v", err) t.Errorf("UnmarshalResult failed: %v", err)
} else if got.A != "OK" {
t.Errorf("Result: got %#q, want OK", got.A)
} }
}) })
t.Run("Strict", func(t *testing.T) {
t.Run("StrictResult", func(t *testing.T) { var got result
var res result if err := res.UnmarshalResult(jrpc2.StrictFields(&got)); err == nil {
if err := rsp.UnmarshalResult(jrpc2.StrictFields(&res)); err == nil { t.Errorf("UnmarshalResult: got %#v, wanted error", got)
t.Errorf("UnmarshalResult: got %+v, want error", res)
} }
}) })
} }
func TestServerFromContext(t *testing.T) { func TestServerFromContext(t *testing.T) {
defer leaktest.Check(t)()
var got *jrpc2.Server var got *jrpc2.Server
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context) error { "Test": handler.New(func(ctx context.Context) error {
@ -1190,6 +1354,8 @@ func TestServerFromContext(t *testing.T) {
} }
func TestServer_newContext(t *testing.T) { func TestServer_newContext(t *testing.T) {
defer leaktest.Check(t)()
// Prepare a context with a test value attached to it, that the handler can // Prepare a context with a test value attached to it, that the handler can
// extract to verify that the base context was plumbed in correctly. // extract to verify that the base context was plumbed in correctly.
type ctxKey string type ctxKey string

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2 package jrpc2
import ( import (
@ -70,7 +72,7 @@ func (j *jmessages) parseJSON(data []byte) error {
// or array. // or array.
var msgs []json.RawMessage var msgs []json.RawMessage
var batch bool var batch bool
if len(data) == 0 || data[0] != '[' { if firstByte(data) != '[' {
msgs = append(msgs, nil) msgs = append(msgs, nil)
if err := json.Unmarshal(data, &msgs[0]); err != nil { if err := json.Unmarshal(data, &msgs[0]); err != nil {
return errInvalidRequest return errInvalidRequest
@ -128,7 +130,7 @@ func isValidID(v json.RawMessage) bool {
} }
func (j *jmessage) fail(code code.Code, msg string) { func (j *jmessage) fail(code code.Code, msg string) {
j.err = Errorf(code, msg) j.err = &Error{Code: code, Message: msg}
} }
func (j *jmessage) toJSON() ([]byte, error) { func (j *jmessage) toJSON() ([]byte, error) {
@ -205,7 +207,7 @@ func (j *jmessage) parseJSON(data []byte) error {
if !isNull(val) { if !isNull(val) {
j.P = val j.P = val
} }
if len(j.P) != 0 && j.P[0] != '[' && j.P[0] != '{' { if fb := firstByte(j.P); fb != 0 && fb != '[' && fb != '{' {
j.fail(code.InvalidRequest, "parameters must be array or object") j.fail(code.InvalidRequest, "parameters must be array or object")
} }
case "error": case "error":
@ -267,6 +269,15 @@ func isNull(msg json.RawMessage) bool {
return len(msg) == 4 && msg[0] == 'n' && msg[1] == 'u' && msg[2] == 'l' && msg[3] == 'l' return len(msg) == 4 && msg[0] == 'n' && msg[1] == 'u' && msg[2] == 'l' && msg[3] == 'l'
} }
// firstByte returns the first non-whitespace byte of data, or 0 if there is none.
func firstByte(data []byte) byte {
clean := bytes.TrimSpace(data)
if len(clean) == 0 {
return 0
}
return clean[0]
}
// strictFielder is an optional interface that can be implemented by a type to // strictFielder is an optional interface that can be implemented by a type to
// reject unknown fields when unmarshaling from JSON. If a type does not // reject unknown fields when unmarshaling from JSON. If a type does not
// implement this interface, unknown fields are ignored. // implement this interface, unknown fields are ignored.
@ -274,8 +285,8 @@ type strictFielder interface {
DisallowUnknownFields() DisallowUnknownFields()
} }
// StrictFields wraps a value v to implement the DisallowUnknownFields method, // StrictFields wraps a value v to require unknown fields to be rejected when
// requiring unknown fields to be rejected when unmarshaling from JSON. // unmarshaling from JSON.
// //
// For example: // For example:
// //
@ -286,4 +297,8 @@ func StrictFields(v interface{}) interface{} { return &strict{v: v} }
type strict struct{ v interface{} } type strict struct{ v interface{} }
func (strict) DisallowUnknownFields() {} func (s *strict) UnmarshalJSON(data []byte) error {
dec := json.NewDecoder(bytes.NewReader(data))
dec.DisallowUnknownFields()
return dec.Decode(s.v)
}

View File

@ -1,5 +1,4 @@
//go:build oldbench // Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// +build oldbench
package jrpc2 package jrpc2

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Package metrics defines a concurrently-accessible metrics collector. // Package metrics defines a concurrently-accessible metrics collector.
// //
// A *metrics.M value exports methods to track integer counters and maximum // A *metrics.M value exports methods to track integer counters and maximum

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package metrics_test package metrics_test
import ( import (

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2 package jrpc2
import ( import (
@ -23,10 +25,6 @@ type ServerOptions struct {
// received and each response or error returned. // received and each response or error returned.
RPCLog RPCLogger RPCLog RPCLogger
// Instructs the server to tolerate requests that do not include the
// required "jsonrpc" version marker.
AllowV1 bool
// Instructs the server to allow server callbacks and notifications, a // Instructs the server to allow server callbacks and notifications, a
// non-standard extension to the JSON-RPC protocol. If AllowPush is false, // non-standard extension to the JSON-RPC protocol. If AllowPush is false,
// the Notify and Callback methods of the server report errors if called. // the Notify and Callback methods of the server report errors if called.
@ -55,12 +53,6 @@ type ServerOptions struct {
// If unset, context and parameters are used as given. // If unset, context and parameters are used as given.
DecodeContext func(context.Context, string, json.RawMessage) (context.Context, json.RawMessage, error) DecodeContext func(context.Context, string, json.RawMessage) (context.Context, json.RawMessage, error)
// If set, this function is called with the context and client request
// (after decoding, if DecodeContext is set) that are to be delivered to the
// handler. If CheckRequest reports a non-nil error, the request fails with
// that error without invoking the handler.
CheckRequest func(ctx context.Context, req *Request) error
// If set, use this value to record server metrics. All servers created // If set, use this value to record server metrics. All servers created
// from the same options will share the same metrics collector. If none is // from the same options will share the same metrics collector. If none is
// set, an empty collector will be created for each new server. // set, an empty collector will be created for each new server.
@ -79,7 +71,6 @@ func (s *ServerOptions) logFunc() func(string, ...interface{}) {
return s.Logger.Printf return s.Logger.Printf
} }
func (s *ServerOptions) allowV1() bool { return s != nil && s.AllowV1 }
func (s *ServerOptions) allowPush() bool { return s != nil && s.AllowPush } func (s *ServerOptions) allowPush() bool { return s != nil && s.AllowPush }
func (s *ServerOptions) allowBuiltin() bool { return s == nil || !s.DisableBuiltin } func (s *ServerOptions) allowBuiltin() bool { return s == nil || !s.DisableBuiltin }
@ -106,22 +97,13 @@ func (o *ServerOptions) newContext() func() context.Context {
type decoder = func(context.Context, string, json.RawMessage) (context.Context, json.RawMessage, error) type decoder = func(context.Context, string, json.RawMessage) (context.Context, json.RawMessage, error)
func (s *ServerOptions) decodeContext() (decoder, bool) { func (s *ServerOptions) decodeContext() decoder {
if s == nil || s.DecodeContext == nil { if s == nil || s.DecodeContext == nil {
return func(ctx context.Context, method string, params json.RawMessage) (context.Context, json.RawMessage, error) { return func(ctx context.Context, method string, params json.RawMessage) (context.Context, json.RawMessage, error) {
return ctx, params, nil return ctx, params, nil
}, false
} }
return s.DecodeContext, true
}
type verifier = func(context.Context, *Request) error
func (s *ServerOptions) checkRequest() verifier {
if s == nil || s.CheckRequest == nil {
return func(context.Context, *Request) error { return nil }
} }
return s.CheckRequest return s.DecodeContext
} }
func (s *ServerOptions) metrics() *metrics.M { func (s *ServerOptions) metrics() *metrics.M {
@ -144,10 +126,6 @@ type ClientOptions struct {
// If not nil, send debug text logs here. // If not nil, send debug text logs here.
Logger Logger Logger Logger
// Instructs the client to tolerate responses that do not include the
// required "jsonrpc" version marker.
AllowV1 bool
// If set, this function is called with the context, method name, and // If set, this function is called with the context, method name, and
// encoded request parameters before the request is sent to the server. // encoded request parameters before the request is sent to the server.
// Its return value replaces the request parameters. This allows the client // Its return value replaces the request parameters. This allows the client
@ -162,12 +140,17 @@ type ClientOptions struct {
OnNotify func(*Request) OnNotify func(*Request)
// If set, this function is called if a request is received from the server. // If set, this function is called if a request is received from the server.
// If unset, server requests are logged and discarded. At most one // If unset, server requests are logged and discarded. Multiple invocations
// invocation of this callback will be active at a time. // of the callback handler may be active concurrently.
// Server callbacks are a non-standard extension of JSON-RPC. //
// The callback handler can retrieve the client from its context using the
// jrpc2.ClientFromContext function. The context terminates when the client
// is closed.
// //
// If a callback handler panics, the client will recover the panic and // If a callback handler panics, the client will recover the panic and
// report a system error back to the server describing the error. // report a system error back to the server describing the error.
//
// Server callbacks are a non-standard extension of JSON-RPC.
OnCallback func(context.Context, *Request) (interface{}, error) OnCallback func(context.Context, *Request) (interface{}, error)
// If set, this function is called when the context for a request terminates. // If set, this function is called when the context for a request terminates.
@ -186,8 +169,6 @@ func (c *ClientOptions) logFunc() func(string, ...interface{}) {
return c.Logger.Printf return c.Logger.Printf
} }
func (c *ClientOptions) allowV1() bool { return c != nil && c.AllowV1 }
type encoder = func(context.Context, string, json.RawMessage) (json.RawMessage, error) type encoder = func(context.Context, string, json.RawMessage) (json.RawMessage, error)
func (c *ClientOptions) encodeContext() encoder { func (c *ClientOptions) encodeContext() encoder {
@ -214,15 +195,12 @@ func (c *ClientOptions) handleCancel() func(*Client, *Response) {
return c.OnCancel return c.OnCancel
} }
func (c *ClientOptions) handleCallback() func(*jmessage) []byte { func (c *ClientOptions) handleCallback() func(context.Context, *jmessage) []byte {
if c == nil || c.OnCallback == nil { if c == nil || c.OnCallback == nil {
return nil return nil
} }
cb := c.OnCallback cb := c.OnCallback
return func(req *jmessage) []byte { return func(ctx context.Context, req *jmessage) []byte {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Recover panics from the callback handler to ensure the server gets a // Recover panics from the callback handler to ensure the server gets a
// response even if the callback fails without a result. // response even if the callback fails without a result.
// //

64
vendor/github.com/creachadair/jrpc2/queue.go generated vendored Normal file
View File

@ -0,0 +1,64 @@
// Copyright (C) 2021 Michael J. Fromberger. All Rights Reserved.
package jrpc2
type queue struct {
front, back *entry
free *entry
nelts int
}
func newQueue() *queue {
sentinel := new(entry)
return &queue{front: sentinel, back: sentinel}
}
func (q *queue) isEmpty() bool { return q.front.link == nil }
func (q *queue) size() int { return q.nelts }
func (q *queue) reset() { q.front.link = nil; q.back = q.front; q.nelts = 0 }
func (q *queue) alloc(data jmessages) *entry {
if q.free == nil {
return &entry{data: data}
}
out := q.free
q.free = out.link
out.data = data
out.link = nil
return out
}
func (q *queue) release(e *entry) {
e.link, q.free = q.free, e
e.data = nil
}
func (q *queue) each(f func(jmessages)) {
for cur := q.front.link; cur != nil; cur = cur.link {
f(cur.data)
}
}
func (q *queue) push(m jmessages) {
e := q.alloc(m)
q.back.link = e
q.back = e
q.nelts++
}
func (q *queue) pop() jmessages {
out := q.front.link
q.front.link = out.link
if out == q.back {
q.back = q.front
}
q.nelts--
data := out.data
q.release(out)
return data
}
type entry struct {
data jmessages
link *entry
}

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2_test package jrpc2_test
import ( import (
@ -10,11 +12,15 @@ import (
"github.com/creachadair/jrpc2/channel" "github.com/creachadair/jrpc2/channel"
"github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/handler"
"github.com/creachadair/jrpc2/server" "github.com/creachadair/jrpc2/server"
"github.com/fortytw2/leaktest"
"github.com/google/go-cmp/cmp"
) )
// Verify that a notification handler will not deadlock with the dispatcher on // Verify that a notification handler will not deadlock with the dispatcher on
// holding the server lock. See: https://github.com/creachadair/jrpc2/issues/27 // holding the server lock. See: https://github.com/creachadair/jrpc2/issues/27
func TestLockRaceRegression(t *testing.T) { func TestLockRaceRegression(t *testing.T) {
defer leaktest.Check(t)()
hdone := make(chan struct{}) hdone := make(chan struct{})
local := server.NewLocal(handler.Map{ local := server.NewLocal(handler.Map{
// Do some busy-work and then try to get the server lock, in this case // Do some busy-work and then try to get the server lock, in this case
@ -54,6 +60,8 @@ func TestLockRaceRegression(t *testing.T) {
// Verify that if a callback handler panics, the client will report an error // Verify that if a callback handler panics, the client will report an error
// back to the server. See https://github.com/creachadair/jrpc2/issues/41. // back to the server. See https://github.com/creachadair/jrpc2/issues/41.
func TestOnCallbackPanicRegression(t *testing.T) { func TestOnCallbackPanicRegression(t *testing.T) {
defer leaktest.Check(t)()
const panicString = "the devil you say" const panicString = "the devil you say"
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
@ -89,6 +97,8 @@ func TestOnCallbackPanicRegression(t *testing.T) {
// Verify that a duplicate request ID that arrives while a task is in flight // Verify that a duplicate request ID that arrives while a task is in flight
// does not cause the existing task to be cancelled. // does not cause the existing task to be cancelled.
func TestDuplicateIDCancellation(t *testing.T) { func TestDuplicateIDCancellation(t *testing.T) {
defer leaktest.Check(t)()
tctx, cancel := context.WithCancel(context.Background()) tctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -132,7 +142,7 @@ func TestDuplicateIDCancellation(t *testing.T) {
// Send the duplicate, which should report an error. // Send the duplicate, which should report an error.
send(duplicateReq) send(duplicateReq)
expect(`{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"duplicate request id \"1\""}}`) expect(`{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"duplicate request ID","data":"1"}}`)
// Unblock the handler, which should now complete. If the duplicate request // Unblock the handler, which should now complete. If the duplicate request
// caused the handler to cancel, it will have logged an error to fail the test. // caused the handler to cancel, it will have logged an error to fail the test.
@ -142,3 +152,43 @@ func TestDuplicateIDCancellation(t *testing.T) {
cch.Close() cch.Close()
srv.Wait() srv.Wait()
} }
func TestCheckBatchDuplicateID(t *testing.T) {
defer leaktest.Check(t)()
srv, cli := channel.Direct()
s := jrpc2.NewServer(handler.Map{
"Test": testOK,
}, nil).Start(srv)
defer func() {
cli.Close()
if err := s.Wait(); err != nil {
t.Errorf("Server wait: unexpected error: %v", err)
}
}()
// A batch of requests containing two calls with the same ID.
const input = `[
{"jsonrpc": "2.0", "id": 1, "method": "Test"},
{"jsonrpc": "2.0", "id": 1, "method": "Test"},
{"jsonrpc": "2.0", "id": 2, "method": "Test"}
]
`
const errorReply = `{` +
`"jsonrpc":"2.0",` +
`"id":1,` +
`"error":{"code":-32600,"message":"duplicate request ID","data":"1"}` +
`}`
const want = `[` + errorReply + `,` + errorReply + `,` + `{"jsonrpc":"2.0","id":2,"result":"OK"}]`
if err := cli.Send([]byte(input)); err != nil {
t.Fatalf("Send %d bytes failed: %v", len(input), err)
}
rsp, err := cli.Recv()
if err != nil {
t.Fatalf("Recv failed: %v", err)
}
if diff := cmp.Diff(want, string(rsp)); diff != "" {
t.Errorf("Server response: (-want, +got)\n%s", diff)
}
}

View File

@ -1,7 +1,8 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2 package jrpc2
import ( import (
"container/list"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
@ -26,14 +27,11 @@ type Server struct {
sem *semaphore.Weighted // bounds concurrent execution (default 1) sem *semaphore.Weighted // bounds concurrent execution (default 1)
// Configurable settings // Configurable settings
allow1 bool // allow v1 requests with no version marker
allowP bool // allow server notifications to the client allowP bool // allow server notifications to the client
log func(string, ...interface{}) // write debug logs here log func(string, ...interface{}) // write debug logs here
rpcLog RPCLogger // log RPC requests and responses here rpcLog RPCLogger // log RPC requests and responses here
newctx func() context.Context // create a new base request context newctx func() context.Context // create a new base request context
dectx decoder // decode context from request dectx decoder // decode context from request
ckreq verifier // request checking hook
expctx bool // whether to expect request context
metrics *metrics.M // metrics collected during execution metrics *metrics.M // metrics collected during execution
start time.Time // when Start was called start time.Time // when Start was called
builtin bool // whether built-in rpc.* methods are enabled builtin bool // whether built-in rpc.* methods are enabled
@ -42,8 +40,8 @@ type Server struct {
nbar sync.WaitGroup // notification barrier (see the dispatch method) nbar sync.WaitGroup // notification barrier (see the dispatch method)
err error // error from a previous operation err error // error from a previous operation
work *sync.Cond // for signaling message availability work chan struct{} // for signaling message availability
inq *list.List // inbound requests awaiting processing inq *queue // inbound requests awaiting processing
ch channel.Channel // the channel to the client ch channel.Channel // the channel to the client
// For each request ID currently in-flight, this map carries a cancel // For each request ID currently in-flight, this map carries a cancel
@ -67,28 +65,23 @@ func NewServer(mux Assigner, opts *ServerOptions) *Server {
if mux == nil { if mux == nil {
panic("nil assigner") panic("nil assigner")
} }
dc, exp := opts.decodeContext()
s := &Server{ s := &Server{
mux: mux, mux: mux,
sem: semaphore.NewWeighted(opts.concurrency()), sem: semaphore.NewWeighted(opts.concurrency()),
allow1: opts.allowV1(),
allowP: opts.allowPush(), allowP: opts.allowPush(),
log: opts.logFunc(), log: opts.logFunc(),
rpcLog: opts.rpcLog(), rpcLog: opts.rpcLog(),
newctx: opts.newContext(), newctx: opts.newContext(),
dectx: dc, dectx: opts.decodeContext(),
ckreq: opts.checkRequest(),
expctx: exp,
mu: new(sync.Mutex), mu: new(sync.Mutex),
metrics: opts.metrics(), metrics: opts.metrics(),
start: opts.startTime(), start: opts.startTime(),
builtin: opts.allowBuiltin(), builtin: opts.allowBuiltin(),
inq: list.New(), inq: newQueue(),
used: make(map[string]context.CancelFunc), used: make(map[string]context.CancelFunc),
call: make(map[string]*Response), call: make(map[string]*Response),
callID: 1, callID: 1,
} }
s.work = sync.NewCond(s.mu)
return s return s
} }
@ -107,10 +100,14 @@ func (s *Server) Start(c channel.Channel) *Server {
if s.start.IsZero() { if s.start.IsZero() {
s.start = time.Now().In(time.UTC) s.start = time.Now().In(time.UTC)
} }
s.metrics.Count("rpc.serversActive", 1)
// Reset all the I/O structures and start up the workers. // Reset all the I/O structures and start up the workers.
s.err = nil s.err = nil
// Reset the signal channel.
s.work = make(chan struct{}, 1)
// s.wg waits for the maintenance goroutines for receiving input and // s.wg waits for the maintenance goroutines for receiving input and
// processing the request queue. In addition, each request in flight adds a // processing the request queue. In addition, each request in flight adds a
// goroutine to s.wg. At server shutdown, s.wg completes when the // goroutine to s.wg. At server shutdown, s.wg completes when the
@ -157,6 +154,13 @@ func (s *Server) serve() {
} }
} }
func (s *Server) signal() {
select {
case s.work <- struct{}{}:
default:
}
}
// nextRequest blocks until a request batch is available and returns a function // nextRequest blocks until a request batch is available and returns a function
// that dispatches it to the appropriate handlers. The result is only an error // that dispatches it to the appropriate handlers. The result is only an error
// if the connection failed; errors reported by the handler are reported to the // if the connection failed; errors reported by the handler are reported to the
@ -166,16 +170,18 @@ func (s *Server) serve() {
func (s *Server) nextRequest() (func() error, error) { func (s *Server) nextRequest() (func() error, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
for s.ch != nil && s.inq.Len() == 0 { for s.ch != nil && s.inq.isEmpty() {
s.work.Wait() s.mu.Unlock()
<-s.work
s.mu.Lock()
} }
if s.ch == nil && s.inq.Len() == 0 { if s.ch == nil && s.inq.isEmpty() {
return nil, s.err return nil, s.err
} }
ch := s.ch // capture ch := s.ch // capture
next := s.inq.Remove(s.inq.Front()).(jmessages) next := s.inq.pop()
s.log("Dequeued request batch of length %d (qlen=%d)", len(next), s.inq.Len()) s.log("Dequeued request batch of length %d (qlen=%d)", len(next), s.inq.size())
// Construct a dispatcher to run the handlers outside the lock. // Construct a dispatcher to run the handlers outside the lock.
return s.dispatch(next, ch), nil return s.dispatch(next, ch), nil
@ -207,32 +213,35 @@ func (s *Server) dispatch(next jmessages, ch sender) func() error {
// Resolve all the task handlers or record errors. // Resolve all the task handlers or record errors.
start := time.Now() start := time.Now()
tasks := s.checkAndAssign(next) tasks := s.checkAndAssign(next)
last := len(tasks) - 1
// Ensure all notifications already issued have completed; see #24. // Ensure all notifications already issued have completed; see #24.
s.waitForBarrier(tasks.numValidNotifications()) todo, notes := tasks.numToDo()
s.waitForBarrier(notes)
return func() error { return func() error {
var wg sync.WaitGroup var wg sync.WaitGroup
for i, t := range tasks { for _, t := range tasks {
if t.err != nil { if t.err != nil {
continue // nothing to do here; this task has already failed continue // nothing to do here; this task has already failed
} }
t := t
wg.Add(1) todo--
run := func() { if todo == 0 {
defer wg.Done()
if t.hreq.IsNotification() {
defer s.nbar.Done()
}
t.val, t.err = s.invoke(t.ctx, t.m, t.hreq) t.val, t.err = s.invoke(t.ctx, t.m, t.hreq)
if t.hreq.IsNotification() {
s.nbar.Done()
} }
if i < last { break
go run()
} else {
run()
} }
t := t
wg.Add(1)
go func() {
defer wg.Done()
t.val, t.err = s.invoke(t.ctx, t.m, t.hreq)
if t.hreq.IsNotification() {
s.nbar.Done()
}
}()
} }
// Wait for all the handlers to return, then deliver any responses. // Wait for all the handlers to return, then deliver any responses.
@ -269,16 +278,15 @@ func (s *Server) deliver(rsps jmessages, ch sender, elapsed time.Duration) error
// records errors for them as appropriate. The caller must hold s.mu. // records errors for them as appropriate. The caller must hold s.mu.
func (s *Server) checkAndAssign(next jmessages) tasks { func (s *Server) checkAndAssign(next jmessages) tasks {
var ts tasks var ts tasks
var ids []string
dup := make(map[string]*task) // :: id ⇒ first task in batch with id
// Phase 1: Filter out responses from push calls and check for duplicate
// request ID.s
for _, req := range next { for _, req := range next {
fid := fixID(req.ID) fid := fixID(req.ID)
t := &task{
hreq: &Request{id: fid, method: req.M, params: req.P},
batch: req.batch,
}
id := string(fid) id := string(fid)
if req.err != nil { if !req.isRequestOrNotification() && s.call[id] != nil {
t.err = req.err // deferred validation error
} else if !req.isRequestOrNotification() && s.call[id] != nil {
// This is a result or error for a pending push-call. // This is a result or error for a pending push-call.
// //
// N.B. It is important to check for this before checking for // N.B. It is important to check for this before checking for
@ -287,24 +295,51 @@ func (s *Server) checkAndAssign(next jmessages) tasks {
delete(s.call, id) delete(s.call, id)
rsp.ch <- req rsp.ch <- req
continue // don't send a reply for this continue // don't send a reply for this
} else if id != "" && s.used[id] != nil { } else if req.err != nil {
t.err = Errorf(code.InvalidRequest, "duplicate request id %q", id) // keep the existing error
} else if !s.versionOK(req.V) { } else if !s.versionOK(req.V) {
t.err = ErrInvalidVersion req.err = ErrInvalidVersion
} else if req.M == "" { }
t := &task{
hreq: &Request{id: fid, method: req.M, params: req.P},
batch: req.batch,
err: req.err,
}
if old := dup[id]; old != nil {
// A previous task already used this ID, fail both.
old.err = errDuplicateID.WithData(id)
t.err = old.err
} else if id != "" && s.used[id] != nil {
// A task from a previous batch already used this ID, fail this one.
t.err = errDuplicateID.WithData(id)
} else if id != "" {
// This is the first task with this ID in the batch.
dup[id] = t
}
ts = append(ts, t)
ids = append(ids, id)
}
// Phase 2: Assign method handlers and set up contexts.
for i, t := range ts {
id := ids[i]
if t.err != nil {
// deferred validation error
} else if t.hreq.method == "" {
t.err = errEmptyMethod t.err = errEmptyMethod
} else if s.setContext(t, id) { } else if s.setContext(t, id) {
t.m = s.assign(t.ctx, req.M) t.m = s.assign(t.ctx, t.hreq.method)
if t.m == nil { if t.m == nil {
t.err = Errorf(code.MethodNotFound, "no such method %q", req.M) t.err = errNoSuchMethod.WithData(t.hreq.method)
} }
} }
if t.err != nil { if t.err != nil {
s.log("Request check error for %q (params %q): %v", req.M, string(req.P), t.err) s.log("Request check error for %q (params %q): %v",
t.hreq.method, string(t.hreq.params), t.err)
s.metrics.Count("rpc.errors", 1) s.metrics.Count("rpc.errors", 1)
} }
ts = append(ts, t)
} }
return ts return ts
} }
@ -319,12 +354,6 @@ func (s *Server) setContext(t *task, id string) bool {
return false return false
} }
// Check request.
if err := s.ckreq(base, t.hreq); err != nil {
t.err = err
return false
}
t.ctx = context.WithValue(base, inboundRequestKey{}, t.hreq) t.ctx = context.WithValue(base, inboundRequestKey{}, t.hreq)
// Store the cancellation for a request that needs a reply, so that we can // Store the cancellation for a request that needs a reply, so that we can
@ -361,13 +390,15 @@ func (s *Server) invoke(base context.Context, h Handler, req *Request) (json.Raw
// ServerInfo returns an atomic snapshot of the current server info for s. // ServerInfo returns an atomic snapshot of the current server info for s.
func (s *Server) ServerInfo() *ServerInfo { func (s *Server) ServerInfo() *ServerInfo {
info := &ServerInfo{ info := &ServerInfo{
Methods: s.mux.Names(), Methods: []string{"*"},
UsesContext: s.expctx,
StartTime: s.start, StartTime: s.start,
Counter: make(map[string]int64), Counter: make(map[string]int64),
MaxValue: make(map[string]int64), MaxValue: make(map[string]int64),
Label: make(map[string]interface{}), Label: make(map[string]interface{}),
} }
if n, ok := s.mux.(Namer); ok {
info.Methods = n.Names()
}
s.metrics.Snapshot(metrics.Snapshot{ s.metrics.Snapshot(metrics.Snapshot{
Counter: info.Counter, Counter: info.Counter,
MaxValue: info.MaxValue, MaxValue: info.MaxValue,
@ -524,7 +555,7 @@ func (s ServerStatus) Success() bool { return s.Err == nil }
func (s *Server) WaitStatus() ServerStatus { func (s *Server) WaitStatus() ServerStatus {
s.wg.Wait() s.wg.Wait()
// Postcondition check. // Postcondition check.
if s.inq.Len() != 0 { if !s.inq.isEmpty() {
panic("s.inq is not empty at shutdown") panic("s.inq is not empty at shutdown")
} }
stat := ServerStatus{Err: s.err} stat := ServerStatus{Err: s.err}
@ -557,8 +588,8 @@ func (s *Server) stop(err error) {
// //
// TODO(@creachadair): We need better tests for this behaviour. // TODO(@creachadair): We need better tests for this behaviour.
var keep jmessages var keep jmessages
for cur := s.inq.Front(); cur != nil; cur = s.inq.Front() { s.inq.each(func(cur jmessages) {
for _, req := range cur.Value.(jmessages) { for _, req := range cur {
if req.isNotification() { if req.isNotification() {
keep = append(keep, req) keep = append(keep, req)
s.log("Retaining notification %p", req) s.log("Retaining notification %p", req)
@ -566,18 +597,17 @@ func (s *Server) stop(err error) {
s.cancel(string(req.ID)) s.cancel(string(req.ID))
} }
} }
s.inq.Remove(cur) })
} s.inq.reset()
for _, elt := range keep { for _, elt := range keep {
s.inq.PushBack(jmessages{elt}) s.inq.push(jmessages{elt})
} }
s.work.Broadcast() close(s.work)
// Cancel any in-flight requests that made it out of the queue, and // Cancel any in-flight requests that made it out of the queue, and
// terminate any pending callback invocations. // terminate any pending callback invocations.
for id, rsp := range s.call { for _, rsp := range s.call {
delete(s.call, id) rsp.cancel() // the waiter will clean up the map
rsp.cancel()
} }
for id, cancel := range s.used { for id, cancel := range s.used {
cancel() cancel()
@ -591,6 +621,7 @@ func (s *Server) stop(err error) {
s.err = err s.err = err
s.ch = nil s.ch = nil
s.metrics.Count("rpc.serversActive", -1)
} }
// read is the main receiver loop, decoding requests from the client and adding // read is the main receiver loop, decoding requests from the client and adding
@ -620,9 +651,11 @@ func (s *Server) read(ch receiver) {
} else if len(in) == 0 { } else if len(in) == 0 {
s.pushError(errEmptyBatch) s.pushError(errEmptyBatch)
} else { } else {
s.log("Received request batch of size %d (qlen=%d)", len(in), s.inq.Len()) s.log("Received request batch of size %d (qlen=%d)", len(in), s.inq.size())
s.inq.PushBack(in) s.inq.push(in)
s.work.Broadcast() if s.inq.size() == 1 { // the queue was empty
s.signal()
}
} }
s.mu.Unlock() s.mu.Unlock()
} }
@ -633,9 +666,6 @@ type ServerInfo struct {
// The list of method names exported by this server. // The list of method names exported by this server.
Methods []string `json:"methods,omitempty"` Methods []string `json:"methods,omitempty"`
// Whether this server understands context wrappers.
UsesContext bool `json:"usesContext"`
// Metric values defined by the evaluation of methods. // Metric values defined by the evaluation of methods.
Counter map[string]int64 `json:"counters,omitempty"` Counter map[string]int64 `json:"counters,omitempty"`
MaxValue map[string]int64 `json:"maxValue,omitempty"` MaxValue map[string]int64 `json:"maxValue,omitempty"`
@ -694,12 +724,7 @@ func (s *Server) cancel(id string) bool {
return ok return ok
} }
func (s *Server) versionOK(v string) bool { func (s *Server) versionOK(v string) bool { return v == Version }
if v == "" {
return s.allow1 // an empty version is OK if the server allows it
}
return v == Version // ... otherwise it must match the spec
}
// A task represents a pending method invocation received by the server. // A task represents a pending method invocation received by the server.
type task struct { type task struct {
@ -758,12 +783,15 @@ func (ts tasks) responses(rpcLog RPCLogger) jmessages {
return rsps return rsps
} }
// numValidNotifications reports the number of elements in ts that are // numToDo reports the number of tasks in ts that need to be executed, and the
// syntactically valid notifications. // number of those that are notifications.
func (ts tasks) numValidNotifications() (n int) { func (ts tasks) numToDo() (todo, notes int) {
for _, t := range ts { for _, t := range ts {
if t.err == nil && t.hreq.IsNotification() { if t.err == nil {
n++ todo++
if t.hreq.IsNotification() {
notes++
}
} }
} }
return return

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package server_test package server_test
import ( import (

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package server package server
import ( import (

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package server_test package server_test
import ( import (
@ -9,6 +11,7 @@ import (
"github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/handler"
"github.com/creachadair/jrpc2/server" "github.com/creachadair/jrpc2/server"
"github.com/fortytw2/leaktest"
) )
var doDebug = flag.Bool("debug", false, "Enable server and client debugging logs") var doDebug = flag.Bool("debug", false, "Enable server and client debugging logs")
@ -24,6 +27,8 @@ func testOpts(t *testing.T) *server.LocalOptions {
} }
func TestLocal(t *testing.T) { func TestLocal(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(make(handler.Map), testOpts(t)) loc := server.NewLocal(make(handler.Map), testOpts(t))
ctx := context.Background() ctx := context.Background()
si, err := jrpc2.RPCServerInfo(ctx, loc.Client) si, err := jrpc2.RPCServerInfo(ctx, loc.Client)
@ -47,6 +52,8 @@ func TestLocal(t *testing.T) {
// Test that concurrent callers to a local service do not deadlock. // Test that concurrent callers to a local service do not deadlock.
func TestLocalConcurrent(t *testing.T) { func TestLocalConcurrent(t *testing.T) {
defer leaktest.Check(t)()
loc := server.NewLocal(handler.Map{ loc := server.NewLocal(handler.Map{
"Test": handler.New(func(context.Context) error { return nil }), "Test": handler.New(func(context.Context) error { return nil }),
}, testOpts(t)) }, testOpts(t))

View File

@ -1,7 +1,10 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Package server provides support routines for running jrpc2 servers. // Package server provides support routines for running jrpc2 servers.
package server package server
import ( import (
"context"
"net" "net"
"sync" "sync"
@ -37,8 +40,9 @@ func (static) Finish(jrpc2.Assigner, jrpc2.ServerStatus) {}
// An Accepter obtains client connections from an external source and // An Accepter obtains client connections from an external source and
// constructs channels from them. // constructs channels from them.
type Accepter interface { type Accepter interface {
// Accept accepts a connection and returns a new channel for it. // Accept blocks until a connection is available, or until ctx ends.
Accept() (channel.Channel, error) // If a connection is found, Accept returns a new channel for it.
Accept(ctx context.Context) (channel.Channel, error)
} }
// NetAccepter adapts a net.Listener to the Accepter interface, using f as the // NetAccepter adapts a net.Listener to the Accepter interface, using f as the
@ -52,7 +56,20 @@ type netAccepter struct {
newChannel channel.Framing newChannel channel.Framing
} }
func (n netAccepter) Accept() (channel.Channel, error) { func (n netAccepter) Accept(ctx context.Context) (channel.Channel, error) {
// A net.Listener does not obey a context, so simulate it by closing the
// listener if ctx ends.
ok := make(chan struct{})
defer close(ok)
go func() {
select {
case <-ctx.Done():
n.Listener.Close()
case <-ok:
return
}
}()
conn, err := n.Listener.Accept() conn, err := n.Listener.Accept()
if err != nil { if err != nil {
return nil, err return nil, err
@ -64,10 +81,10 @@ func (n netAccepter) Accept() (channel.Channel, error) {
// service instance returned by newService and the given options. Each server // service instance returned by newService and the given options. Each server
// runs in a new goroutine. // runs in a new goroutine.
// //
// If the listener reports an error, the loop will terminate and that error // If lst is closed or otherwise reports an error, the loop will terminate.
// will be reported to the caller of Loop once any active servers have // The error will be reported to the caller of Loop once any active servers
// returned. // have returned. In addition, if ctx ends, any active servers will be stopped.
func Loop(lst Accepter, newService func() Service, opts *LoopOptions) error { func Loop(ctx context.Context, lst Accepter, newService func() Service, opts *LoopOptions) error {
serverOpts := opts.serverOpts() serverOpts := opts.serverOpts()
log := func(string, ...interface{}) {} log := func(string, ...interface{}) {}
if serverOpts != nil && serverOpts.Logger != nil { if serverOpts != nil && serverOpts.Logger != nil {
@ -76,7 +93,7 @@ func Loop(lst Accepter, newService func() Service, opts *LoopOptions) error {
var wg sync.WaitGroup var wg sync.WaitGroup
for { for {
ch, err := lst.Accept() ch, err := lst.Accept(ctx)
if err != nil { if err != nil {
if channel.IsErrClosing(err) { if channel.IsErrClosing(err) {
err = nil err = nil
@ -89,13 +106,20 @@ func Loop(lst Accepter, newService func() Service, opts *LoopOptions) error {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
svc := newService() svc := newService()
assigner, err := svc.Assigner() assigner, err := svc.Assigner()
if err != nil { if err != nil {
log("Service initialization failed: %v", err) log("Service initialization failed: %v", err)
return return
} }
sctx, cancel := context.WithCancel(ctx)
defer cancel()
srv := jrpc2.NewServer(assigner, serverOpts).Start(ch) srv := jrpc2.NewServer(assigner, serverOpts).Start(ch)
go func() { <-sctx.Done(); srv.Stop() }()
stat := srv.WaitStatus() stat := srv.WaitStatus()
svc.Finish(assigner, stat) svc.Finish(assigner, stat)
if stat.Err != nil { if stat.Err != nil {

View File

@ -1,4 +1,6 @@
package server // Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package server_test
import ( import (
"context" "context"
@ -11,12 +13,14 @@ import (
"github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/channel" "github.com/creachadair/jrpc2/channel"
"github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/handler"
"github.com/creachadair/jrpc2/server"
"github.com/fortytw2/leaktest"
) )
var newChan = channel.Line var newChan = channel.Line
// A static test service that returns the same thing each time. // A static test service that returns the same thing each time.
var testService = Static(handler.Map{ var testStatic = server.Static(handler.Map{
"Test": handler.New(func(context.Context) (string, error) { "Test": handler.New(func(context.Context) (string, error) {
return "OK", nil return "OK", nil
}), }),
@ -29,8 +33,8 @@ type testSession struct {
nCall int nCall int
} }
func newTestSession(t *testing.T) func() Service { func newTestSession(t *testing.T) func() server.Service {
return func() Service { t.Helper(); return &testSession{t: t} } return func() server.Service { t.Helper(); return &testSession{t: t} }
} }
func (t *testSession) Assigner() (jrpc2.Assigner, error) { func (t *testSession) Assigner() (jrpc2.Assigner, error) {
@ -81,27 +85,27 @@ func mustDial(t *testing.T, addr string) *jrpc2.Client {
return jrpc2.NewClient(newChan(conn, conn), nil) return jrpc2.NewClient(newChan(conn, conn), nil)
} }
func mustServe(t *testing.T, lst net.Listener, newService func() Service) <-chan struct{} { func mustServe(t *testing.T, ctx context.Context, lst net.Listener, newService func() server.Service) <-chan error {
t.Helper() t.Helper()
sc := make(chan struct{}) acc := server.NetAccepter(lst, newChan)
errc := make(chan error, 1)
go func() { go func() {
defer close(sc) defer close(errc)
// Start a server loop to accept connections from the clients. This should // Start a server loop to accept connections from the clients. This should
// exit cleanly once all the clients have finished and the listener closes. // exit cleanly once all the clients have finished and the listener closes.
lst := NetAccepter(lst, newChan) errc <- server.Loop(ctx, acc, newService, nil)
if err := Loop(lst, newService, nil); err != nil {
t.Errorf("Loop: unexpected failure: %v", err)
}
}() }()
return sc return errc
} }
// Test that sequential clients against the same server work sanely. // Test that sequential clients against the same server work sanely.
func TestSeq(t *testing.T) { func TestSeq(t *testing.T) {
defer leaktest.Check(t)()
lst := mustListen(t) lst := mustListen(t)
addr := lst.Addr().String() addr := lst.Addr().String()
sc := mustServe(t, lst, testService) errc := mustServe(t, context.Background(), lst, testStatic)
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
cli := mustDial(t, addr) cli := mustDial(t, addr)
@ -114,16 +118,82 @@ func TestSeq(t *testing.T) {
cli.Close() cli.Close()
} }
lst.Close() lst.Close()
<-sc if err := <-errc; err != nil {
t.Errorf("Server exit failed: %v", err)
}
}
// Test that context plumbing works properly.
func TestLoop_cancelContext(t *testing.T) {
defer leaktest.Check(t)()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
lst := mustListen(t)
defer lst.Close()
errc := mustServe(t, ctx, lst, testStatic)
time.AfterFunc(50*time.Millisecond, cancel)
select {
case err := <-errc:
if err != nil {
t.Errorf("Loop exit reported error: %v", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Loop did not exit in a timely manner after cancellation")
}
}
// Test that cancelling a loop stops its servers.
func TestLoop_cancelServers(t *testing.T) {
defer leaktest.Check(t)()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ready := make(chan struct{})
lst := mustListen(t)
errc := mustServe(t, ctx, lst, server.Static(handler.Map{
"Test": handler.New(func(ctx context.Context) error {
// Signal readiness then block until cancelled.
// The server will cancel this method when stopped.
close(ready)
<-ctx.Done()
return ctx.Err()
}),
}))
cli := mustDial(t, lst.Addr().String())
defer cli.Close()
// Issue a call to the server that will block until the server cancels the
// handler at shutdown. If the server blocks after cancellation, it means it
// is not correctly stopping its active servers.
go cli.Call(context.Background(), "Test", nil)
<-ready
cancel() // this should stop the loop and the server
select {
case err := <-errc:
if err != nil {
t.Errorf("Loop result: %v", err)
}
case <-time.After(1 * time.Second):
t.Error("Loop did not exit in a timely manner after cancellation")
}
} }
// Test that concurrent clients against the same server work sanely. // Test that concurrent clients against the same server work sanely.
func TestLoop(t *testing.T) { func TestLoop(t *testing.T) {
defer leaktest.Check(t)()
tests := []struct { tests := []struct {
desc string desc string
cons func() Service cons func() server.Service
}{ }{
{"StaticService", testService}, {"StaticService", testStatic},
{"SessionStateService", newTestSession(t)}, {"SessionStateService", newTestSession(t)},
} }
const numClients = 5 const numClients = 5
@ -133,7 +203,7 @@ func TestLoop(t *testing.T) {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
lst := mustListen(t) lst := mustListen(t)
addr := lst.Addr().String() addr := lst.Addr().String()
sc := mustServe(t, lst, test.cons) errc := mustServe(t, context.Background(), lst, test.cons)
// Start a bunch of clients, each of which will dial the server and make // Start a bunch of clients, each of which will dial the server and make
// some calls at random intervals to tickle the race detector. // some calls at random intervals to tickle the race detector.
@ -162,7 +232,9 @@ func TestLoop(t *testing.T) {
// the service loop will stop. // the service loop will stop.
wg.Wait() wg.Wait()
lst.Close() lst.Close()
<-sc if err := <-errc; err != nil {
t.Errorf("Server exit failed: %v", err)
}
}) })
} }
} }

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package server package server
import ( import (

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package server_test package server_test
import ( import (

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2 package jrpc2
import ( import (

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Program adder demonstrates a trivial JSON-RPC server that communicates over // Program adder demonstrates a trivial JSON-RPC server that communicates over
// the process's stdin and stdout. // the process's stdin and stdout.
// //
@ -22,7 +24,7 @@ import (
) )
// Add will be exported as a method named "Add". // Add will be exported as a method named "Add".
func Add(ctx context.Context, vs ...int) int { func Add(ctx context.Context, vs []int) int {
sum := 0 sum := 0
for _, v := range vs { for _, v := range vs {
sum += v sum += v

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Program client demonstrates how to set up a JSON-RPC 2.0 client using the // Program client demonstrates how to set up a JSON-RPC 2.0 client using the
// github.com/creachadair/jrpc2 package. // github.com/creachadair/jrpc2 package.
// //

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Program http demonstrates how to set up a JSON-RPC 2.0 server using the // Program http demonstrates how to set up a JSON-RPC 2.0 server using the
// github.com/creachadair/jrpc2 package with an HTTP transport. // github.com/creachadair/jrpc2 package with an HTTP transport.
// //
@ -45,6 +47,6 @@ func main() {
log.Fatal(http.ListenAndServe(*listenAddr, nil)) log.Fatal(http.ListenAndServe(*listenAddr, nil))
} }
func ping(ctx context.Context, msg ...string) string { func ping(ctx context.Context, msg []string) string {
return "OK: " + strings.Join(msg, "|") return "OK: " + strings.Join(msg, "|")
} }

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Program server demonstrates how to set up a JSON-RPC 2.0 server using the // Program server demonstrates how to set up a JSON-RPC 2.0 server using the
// github.com/creachadair/jrpc2 package. // github.com/creachadair/jrpc2 package.
// //
@ -97,7 +99,8 @@ func main() {
} }
log.Printf("Listening at %v...", lst.Addr()) log.Printf("Listening at %v...", lst.Addr())
acc := server.NetAccepter(lst, channel.Line) acc := server.NetAccepter(lst, channel.Line)
server.Loop(acc, server.Static(mux), &server.LoopOptions{ ctx := context.Background()
server.Loop(ctx, acc, server.Static(mux), &server.LoopOptions{
ServerOptions: &jrpc2.ServerOptions{ ServerOptions: &jrpc2.ServerOptions{
Logger: jrpc2.StdLogger(nil), Logger: jrpc2.StdLogger(nil),
Concurrency: *maxTasks, Concurrency: *maxTasks,

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Program wshttp demonstrates how to set up a JSON-RPC 20 server using // Program wshttp demonstrates how to set up a JSON-RPC 20 server using
// the github.com/creachadair/jrpc2 package with a Websocket transport. // the github.com/creachadair/jrpc2 package with a Websocket transport.
// //
@ -34,19 +36,18 @@ func main() {
http.Handle("/rpc", lst) http.Handle("/rpc", lst)
go hs.ListenAndServe() go hs.ListenAndServe()
acc := accepter{ svc := server.Static(handler.Map{
Listener: lst, "Reverse": handler.New(reverse),
ctx: context.Background(), })
}
svc := handler.Map{"Reverse": handler.New(reverse)}
log.Printf("Listing at ws://%s/rpc", *listenAddr) log.Printf("Listening at ws://%s/rpc", *listenAddr)
err := server.Loop(acc, server.Static(svc), &server.LoopOptions{ ctx := context.Background()
err := server.Loop(ctx, accepter{lst}, svc, &server.LoopOptions{
ServerOptions: &jrpc2.ServerOptions{ ServerOptions: &jrpc2.ServerOptions{
Logger: jrpc2.StdLogger(nil), Logger: jrpc2.StdLogger(nil),
}, },
}) })
hs.Shutdown(acc.ctx) hs.Shutdown(ctx)
if err != nil { if err != nil {
log.Fatalf("Loop exited: %v", err) log.Fatalf("Loop exited: %v", err)
} }
@ -60,11 +61,8 @@ func reverse(_ context.Context, ss []string) []string {
return ss return ss
} }
type accepter struct { type accepter struct{ *wschannel.Listener }
*wschannel.Listener
ctx context.Context
}
func (a accepter) Accept() (channel.Channel, error) { func (a accepter) Accept(ctx context.Context) (channel.Channel, error) {
return a.Listener.Accept(a.ctx) return a.Listener.Accept(ctx)
} }

View File

@ -3,8 +3,8 @@ module github.com/creachadair/jrpc2/tools
go 1.17 go 1.17
require ( require (
github.com/creachadair/jrpc2 v0.30.3 github.com/creachadair/jrpc2 v0.35.4
github.com/creachadair/wschannel v0.0.0-20211118152247-10d58f4f0def github.com/creachadair/wschannel v0.0.0-20220126134344-769774727b29
) )
require ( require (

View File

@ -1,9 +1,11 @@
github.com/creachadair/jrpc2 v0.30.3 h1:fz8xYfTmIgxJXvr9HAoz0XBOpNklyixE7Hnh6iQP/4o= github.com/creachadair/jrpc2 v0.35.4 h1:5ELLV7CMKLfALzkKNsQ//ngZLWDbEmAXgTgkL3JXAcU=
github.com/creachadair/jrpc2 v0.30.3/go.mod h1:w+GXZGc+NwsH0xsUOgeLBIIRM0jBOSTXhv28KaWGRZU= github.com/creachadair/jrpc2 v0.35.4/go.mod h1:a53Cer/NMD1y8P9UB2XbuOLRELKRLDf8u7bRi4v1qsE=
github.com/creachadair/wschannel v0.0.0-20211118152247-10d58f4f0def h1:FV0vHCqItsi0b3LwaEKyxj0su3VKdvbenCOkXnCAXnI= github.com/creachadair/wschannel v0.0.0-20220126134344-769774727b29 h1:EtcZoRTuhqCedRtvfUrzuyrsT53RWNN7xZOE9lljDw0=
github.com/creachadair/wschannel v0.0.0-20211118152247-10d58f4f0def/go.mod h1:/9Csuxj8r9h0YXexL0WmkahIhd85BleYWz7nt42ZgDc= github.com/creachadair/wschannel v0.0.0-20220126134344-769774727b29/go.mod h1:xFi56wWYs7X0OlNzbtz/yzLCuN3a8Hf36QALYnAsO0o=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=

View File

@ -1,3 +1,5 @@
// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
// Program jcall issues RPC calls to a JSON-RPC server. // Program jcall issues RPC calls to a JSON-RPC server.
// //
// Usage: // Usage:

View File

@ -42,3 +42,44 @@ func Test_AEAD_Cipher(t *testing.T) {
t.Fatalf("AEAD cipher encryption/decryption failed") t.Fatalf("AEAD cipher encryption/decryption failed")
} }
} }
// functional test whether the wrappers are okay
func Test_Signed_Message_Cipher(t *testing.T) {
wsrc, err := Create_Encrypted_Wallet_From_Recovery_Words_Memory("", "sequence atlas unveil summon pebbles tuesday beer rudely snake rockets different fuselage woven tagged bested dented vegan hover rapid fawns obvious muppet randomly seasons randomly")
if err != nil {
t.Fatalf("Cannot create encrypted wallet, err %s", err)
}
result := wsrc.SignData([]byte("HELLO"))
signer, message, err := wsrc.CheckSignature(result)
if err != nil {
t.Fatalf("Cannot check signature, err %s", err)
}
if string(message) != "HELLO" {
t.Fatalf("Message corruption")
}
if signer.String() != wsrc.GetAddress().String() {
t.Fatalf("Address corruption")
}
// make sure other wallets can also verify the signatures
w2, err := Create_Encrypted_Wallet_From_Recovery_Words_Memory("", "perfil lujo faja puma favor pedir detalle doble carbón neón paella cuarto ánimo cuento conga correr dental moneda león donar entero logro realidad acceso doble")
if err != nil {
t.Fatalf("Cannot create encrypted wallet, err %s", err)
}
signer, message, err = w2.CheckSignature(result)
if err != nil {
t.Fatalf("Cannot check signature, err %s", err)
}
if string(message) != "HELLO" {
t.Fatalf("Message corruption")
}
if signer.String() != wsrc.GetAddress().String() {
t.Fatalf("Address corruption")
}
}

View File

@ -165,6 +165,7 @@ func test_connectivity() (err error) {
if info.Testnet != !globals.IsMainnet() { if info.Testnet != !globals.IsMainnet() {
err = fmt.Errorf("Mainnet/TestNet is different between wallet/daemon.Please run daemon/wallet without --testnet") err = fmt.Errorf("Mainnet/TestNet is different between wallet/daemon.Please run daemon/wallet without --testnet")
logger.Error(err, "Mainnet/Testnet mismatch") logger.Error(err, "Mainnet/Testnet mismatch")
fmt.Printf("Mainnet/Testnet mismatch\n")
return return
} }
@ -187,8 +188,23 @@ func (w *Wallet_Memory) sync_loop() {
} }
err := w.Sync_Wallet_Memory_With_Daemon() // sync with the daemon if IsDaemonOnline() && test_connectivity() != nil {
time.Sleep(timeout) // wait 5 seconds
continue
}
var zerohash crypto.Hash
if len(w.account.EntriesNative) == 0 {
err := w.Sync_Wallet_Memory_With_Daemon()
logger.V(1).Error(err, "wallet syncing err", err) logger.V(1).Error(err, "wallet syncing err", err)
} else {
for k := range w.account.EntriesNative {
err := w.Sync_Wallet_Memory_With_Daemon_internal(k)
if k == zerohash && err != nil {
logger.V(1).Error(err, "wallet syncing err", err)
}
}
}
time.Sleep(timeout) // wait 5 seconds time.Sleep(timeout) // wait 5 seconds
} }
@ -456,6 +472,12 @@ func (w *Wallet_Memory) GetDecryptedBalanceAtTopoHeight(scid crypto.Hash, topohe
return 0, 0, err return 0, 0, err
} }
if w.account.EntriesNative != nil {
if _, ok := w.account.EntriesNative[scid]; !ok { //if we could obtain something, try tracking
w.account.EntriesNative[scid] = []rpc.Entry{}
}
}
return w.DecodeEncryptedBalance_Memory(encrypted_balance, 0), noncetopo, nil return w.DecodeEncryptedBalance_Memory(encrypted_balance, 0), noncetopo, nil
} }

View File

@ -78,7 +78,7 @@ func Words_To_Key(words_line string) (language_name string, keybig *big.Int, err
//rlog.Tracef(1, "len of words %d", words) //rlog.Tracef(1, "len of words %d", words)
// if seed size is not 24 or 25, return err // if seed size is not 24 or 25, return err
if len(words) != SEED_LENGTH && len(words) != (SEED_LENGTH+1) { if /*len(words) != SEED_LENGTH &&*/ len(words) != (SEED_LENGTH + 1) {
err = fmt.Errorf("Invalid Seed") err = fmt.Errorf("Invalid Seed")
return return
} }

View File

@ -56,6 +56,8 @@ type RPCServer struct {
srv *http.Server srv *http.Server
mux *http.ServeMux mux *http.ServeMux
logger logr.Logger logger logr.Logger
user string
password string
Exit_Event chan bool // blockchain is shutting down and we must quit ASAP Exit_Event chan bool // blockchain is shutting down and we must quit ASAP
sync.RWMutex sync.RWMutex
} }
@ -69,6 +71,13 @@ func RPCServer_Start(wallet *walletapi.Wallet_Disk, title string) (*RPCServer, e
r.Exit_Event = make(chan bool) r.Exit_Event = make(chan bool)
if globals.Arguments["--rpc-login"] != nil { // this was verified at startup
userpass := globals.Arguments["--rpc-login"].(string)
parts := strings.SplitN(userpass, ":", 2)
r.user = parts[0]
r.password = parts[1]
}
go r.Run(wallet) go r.Run(wallet)
atomic.AddUint32(&globals.Subsystem_Active, 1) // increment subsystem atomic.AddUint32(&globals.Subsystem_Active, 1) // increment subsystem
@ -91,6 +100,34 @@ func (r *RPCServer) RPCServer_Stop() {
atomic.AddUint32(&globals.Subsystem_Active, ^uint32(0)) // this decrement 1 fom subsystem atomic.AddUint32(&globals.Subsystem_Active, ^uint32(0)) // this decrement 1 fom subsystem
} }
// check basic authrizaion
func hasbasicauthfailed(rpcserver *RPCServer, w http.ResponseWriter, r *http.Request) bool {
if rpcserver.user == "" {
return false
}
u, p, ok := r.BasicAuth()
if !ok {
w.WriteHeader(401)
io.WriteString(w, "Authorization Required")
return true
}
if u != rpcserver.user || p != rpcserver.password {
w.WriteHeader(401)
io.WriteString(w, "Authorization Required")
return true
}
if globals.Arguments["--allow-rpc-password-change"] != nil && globals.Arguments["--allow-rpc-password-change"].(bool) == true {
if r.Header.Get("Pass") != "" {
rpcserver.password = r.Header.Get("Pass")
}
}
return false
}
// setup handlers // setup handlers
func (rpcserver *RPCServer) Run(wallet *walletapi.Wallet_Disk) { func (rpcserver *RPCServer) Run(wallet *walletapi.Wallet_Disk) {
@ -131,6 +168,10 @@ func (rpcserver *RPCServer) Run(wallet *walletapi.Wallet_Disk) {
var bridge = jhttp.NewBridge(wallet_handler, &jhttp.BridgeOptions{Server: options}) var bridge = jhttp.NewBridge(wallet_handler, &jhttp.BridgeOptions{Server: options})
translate_http_to_jsonrpc_and_vice_versa := func(w http.ResponseWriter, r *http.Request) { translate_http_to_jsonrpc_and_vice_versa := func(w http.ResponseWriter, r *http.Request) {
if hasbasicauthfailed(rpcserver, w, r) {
return
}
bridge.ServeHTTP(w, r) bridge.ServeHTTP(w, r)
} }
@ -144,6 +185,9 @@ func (rpcserver *RPCServer) Run(wallet *walletapi.Wallet_Disk) {
client_connections.Delete(ws_server) client_connections.Delete(ws_server)
} }
}() }()
if hasbasicauthfailed(rpcserver, w, r) {
return
}
c, err := upgrader.Upgrade(w, r, nil) c, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
@ -167,6 +211,10 @@ func (rpcserver *RPCServer) Run(wallet *walletapi.Wallet_Disk) {
rpcserver.mux.HandleFunc("/install_sc", func(w http.ResponseWriter, req *http.Request) { // translate call internally, how to do it using a single json request rpcserver.mux.HandleFunc("/install_sc", func(w http.ResponseWriter, req *http.Request) { // translate call internally, how to do it using a single json request
var p rpc.Transfer_Params var p rpc.Transfer_Params
if hasbasicauthfailed(rpcserver, w, req) {
return
}
b, err := ioutil.ReadAll(req.Body) b, err := ioutil.ReadAll(req.Body)
defer req.Body.Close() defer req.Body.Close()
if err != nil { if err != nil {

View File

@ -136,7 +136,7 @@ rebuild_tx:
value := transfers[t].Amount value := transfers[t].Amount
burn_value := transfers[t].Burn burn_value := transfers[t].Burn
if fees == 0 && asset.SCID.IsZero() && !fees_done { if fees == 0 && asset.SCID.IsZero() && !fees_done {
fees = fees + uint64(len(transfers)+2)*uint64((float64(config.FEE_PER_KB)*float64(w.GetFeeMultiplier()))) fees = fees + uint64(len(transfers)+2)*uint64((float64(config.FEE_PER_KB)*float64(float32(len(publickeylist)/16)+w.GetFeeMultiplier())))
if data, err := scdata.MarshalBinary(); err != nil { if data, err := scdata.MarshalBinary(); err != nil {
panic(err) panic(err)
} else { } else {

View File

@ -80,7 +80,7 @@ func simulator_chain_start() (*blockchain.Blockchain, *derodrpc.RPCServer, map[s
OptionsFirst: true, OptionsFirst: true,
} }
globals.Arguments, err = parser.ParseArgs(command_line, []string{"--data-dir", tmpdirectory, "--rpc-bind", rpcport}, config.Version.String()) globals.Arguments, err = parser.ParseArgs(command_line, []string{"--data-dir", tmpdirectory, "--rpc-bind", rpcport, "--testnet"}, config.Version.String())
if err != nil { if err != nil {
//log.Fatalf("Error while parsing options err: %s\n", err) //log.Fatalf("Error while parsing options err: %s\n", err)
return nil, nil, nil return nil, nil, nil

View File

@ -23,6 +23,7 @@ import "strings"
import "math/big" import "math/big"
import "crypto/rand" import "crypto/rand"
import "encoding/pem"
import "encoding/binary" import "encoding/binary"
import "github.com/go-logr/logr" import "github.com/go-logr/logr"
@ -122,7 +123,7 @@ func (w *Wallet_Memory) InsertReplace(scid crypto.Hash, e rpc.Entry) {
// generate keys from using random numbers // generate keys from using random numbers
func Generate_Keys_From_Random() (user *Account, err error) { func Generate_Keys_From_Random() (user *Account, err error) {
user = &Account{Ringsize: 4, FeesMultiplier: 1.5} user = &Account{Ringsize: 16, FeesMultiplier: 2.0}
seed := crypto.RandomScalarBNRed() seed := crypto.RandomScalarBNRed()
user.Keys = Generate_Keys_From_Seed(seed) user.Keys = Generate_Keys_From_Seed(seed)
@ -142,7 +143,7 @@ func Generate_Keys_From_Seed(Seed *crypto.BNRed) (keys _Keys) {
// generate user account using recovery seeds // generate user account using recovery seeds
func Generate_Account_From_Recovery_Words(words string) (user *Account, err error) { func Generate_Account_From_Recovery_Words(words string) (user *Account, err error) {
user = &Account{Ringsize: 4, FeesMultiplier: 1.5} user = &Account{Ringsize: 16, FeesMultiplier: 2.0}
language, seed, err := mnemonics.Words_To_Key(words) language, seed, err := mnemonics.Words_To_Key(words)
if err != nil { if err != nil {
return return
@ -155,7 +156,7 @@ func Generate_Account_From_Recovery_Words(words string) (user *Account, err erro
} }
func Generate_Account_From_Seed(Seed *crypto.BNRed) (user *Account, err error) { func Generate_Account_From_Seed(Seed *crypto.BNRed) (user *Account, err error) {
user = &Account{Ringsize: 4, FeesMultiplier: 1.5} user = &Account{Ringsize: 16, FeesMultiplier: 2.0}
// TODO check whether the seed is invalid // TODO check whether the seed is invalid
user.Keys = Generate_Keys_From_Seed(Seed) user.Keys = Generate_Keys_From_Seed(Seed)
@ -303,6 +304,7 @@ func (w *Wallet_Memory) Get_Payments_TXID(txid string) (entry rpc.Entry) {
} }
// delete most of the data and prepare for rescan // delete most of the data and prepare for rescan
// TODO we must save tokens list and reuse, them, but will be created on-demand when using shows transfers/or rpc apis
func (w *Wallet_Memory) Clean() { func (w *Wallet_Memory) Clean() {
//w.account.Entries = w.account.Entries[:0] //w.account.Entries = w.account.Entries[:0]
@ -526,3 +528,70 @@ func FormatMoneyPrecision(amount uint64, precision int) string {
result.Quo(float_amount, hard_coded_decimals) result.Quo(float_amount, hard_coded_decimals)
return result.Text('f', precision) // 5 is display precision after floating point return result.Text('f', precision) // 5 is display precision after floating point
} }
// this basically does a Schnorr Signature on random information for registration
func (w *Wallet_Memory) SignData(input []byte) []byte {
var tmppoint bn256.G1
tmpsecret := crypto.RandomScalar()
tmppoint.ScalarMult(crypto.G, tmpsecret)
serialize := []byte(fmt.Sprintf("%s%s%x", w.account.Keys.Public.G1().String(), tmppoint.String(), input))
c := crypto.ReducedHash(serialize)
s := new(big.Int).Mul(c, w.account.Keys.Secret.BigInt()) // basicaly scalar mul add
s = s.Mod(s, bn256.Order)
s = s.Add(s, tmpsecret)
s = s.Mod(s, bn256.Order)
p := &pem.Block{Type: "DERO SIGNED MESSAGE"}
p.Headers = map[string]string{}
p.Headers["Address"] = w.GetAddress().String()
p.Headers["C"] = fmt.Sprintf("%x", c)
p.Headers["S"] = fmt.Sprintf("%x", s)
p.Bytes = input
return pem.EncodeToMemory(p)
}
func (w *Wallet_Memory) CheckSignature(input []byte) (signer *rpc.Address, message []byte, err error) {
p, _ := pem.Decode(input)
if p == nil {
err = fmt.Errorf("Unknown format")
return
}
astr := p.Headers["Address"]
cstr := p.Headers["C"]
sstr := p.Headers["S"]
addr, err := rpc.NewAddress(astr)
if err != nil {
return
}
c, ok := new(big.Int).SetString(cstr, 16)
if !ok {
err = fmt.Errorf("Unknown C format")
return
}
s, ok := new(big.Int).SetString(sstr, 16)
if !ok {
err = fmt.Errorf("Unknown S format")
return
}
tmppoint := new(bn256.G1).Add(new(bn256.G1).ScalarMult(crypto.G, s), new(bn256.G1).ScalarMult(addr.PublicKey.G1(), new(big.Int).Neg(c)))
serialize := []byte(fmt.Sprintf("%s%s%x", addr.PublicKey.G1().String(), tmppoint.String(), p.Bytes))
c_calculated := crypto.ReducedHash(serialize)
if c.String() != c_calculated.String() {
err = fmt.Errorf("signature mismatch")
return
}
signer = addr
message = p.Bytes
return
}

View File

@ -328,6 +328,13 @@ func Generate_Key(k KDF, password string) (key []byte) {
} }
} }
func (w *Wallet_Memory) GetAccount() *Account {
if w == nil {
return nil
}
return w.account
}
func (w *Wallet_Memory) save_if_disk() { func (w *Wallet_Memory) save_if_disk() {
if w == nil || w.wallet_disk == nil { if w == nil || w.wallet_disk == nil {
return return