From 8cee76b4927e20e0df798e944c972cd73ca0f4c1 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Mon, 10 Jun 2024 00:05:44 -0400 Subject: [PATCH 1/4] nonce: support reading a new Claimed state. Add preliminary (backwards compatible) support for a Claimed nonce state. This state will be used in the future to indicate that a specific server has 'claimed' a nonce for future use. In this future system, servers will 'claim' a batch of Available nonces for use in future transactions. This enables servers to cache nonces to help reduce the overhead on acquiring a nonce per-intent. The defense mechanism to abandoned claim is that claims will expire. Healthy servers will refresh their claimed nonces periodically in the background. --- pkg/code/data/nonce/memory/store.go | 32 +++-- pkg/code/data/nonce/nonce.go | 70 +++++++--- pkg/code/data/nonce/postgres/model.go | 150 +++++++++++++-------- pkg/code/data/nonce/postgres/store_test.go | 7 +- pkg/code/data/nonce/tests/tests.go | 119 +++++++++++++++- pkg/code/transaction/nonce.go | 27 ++-- pkg/code/transaction/nonce_test.go | 65 ++++++++- pkg/database/postgres/errors.go | 4 +- 8 files changed, 362 insertions(+), 112 deletions(-) diff --git a/pkg/code/data/nonce/memory/store.go b/pkg/code/data/nonce/memory/store.go index 5193689b..d2c2d1ce 100644 --- a/pkg/code/data/nonce/memory/store.go +++ b/pkg/code/data/nonce/memory/store.go @@ -6,8 +6,8 @@ import ( "sort" "sync" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/nonce" + "github.com/code-payments/code-server/pkg/database/query" ) type store struct { @@ -58,27 +58,23 @@ func (s *store) findAddress(address string) *nonce.Record { } func (s *store) findByState(state nonce.State) []*nonce.Record { - res := make([]*nonce.Record, 0) - for _, item := range s.records { - if item.State == state { - res = append(res, item) - } - } - return res + return s.findFn(func(nonce *nonce.Record) bool { + return nonce.State == state + }) } func (s *store) findByStateAndPurpose(state nonce.State, purpose nonce.Purpose) []*nonce.Record { + return s.findFn(func(record *nonce.Record) bool { + return record.State == state && record.Purpose == purpose + }) +} + +func (s *store) findFn(f func(nonce *nonce.Record) bool) []*nonce.Record { res := make([]*nonce.Record, 0) for _, item := range s.records { - if item.State != state { - continue - } - - if item.Purpose != purpose { - continue + if f(item) { + res = append(res, item) } - - res = append(res, item) } return res } @@ -194,7 +190,9 @@ func (s *store) GetRandomAvailableByPurpose(ctx context.Context, purpose nonce.P s.mu.Lock() defer s.mu.Unlock() - items := s.findByStateAndPurpose(nonce.StateAvailable, purpose) + items := s.findFn(func(n *nonce.Record) bool { + return n.Purpose == purpose && n.IsAvailable() + }) if len(items) == 0 { return nil, nonce.ErrNonceNotFound } diff --git a/pkg/code/data/nonce/nonce.go b/pkg/code/data/nonce/nonce.go index 69cec6d6..97f39378 100644 --- a/pkg/code/data/nonce/nonce.go +++ b/pkg/code/data/nonce/nonce.go @@ -3,6 +3,7 @@ package nonce import ( "crypto/ed25519" "errors" + "time" "github.com/mr-tron/base58" ) @@ -20,14 +21,11 @@ const ( StateAvailable // The nonce is available to be used by a payment intent, subscription, or other nonce-related transaction. StateReserved // The nonce is reserved by a payment intent, subscription, or other nonce-related transaction. StateInvalid // The nonce account is invalid (e.g. insufficient funds, etc). + StateClaimed // The nonce is claimed for future use by a process (identified by Node ID). ) -// Split nonce pool across different use cases. This has an added benefit of: -// - Solving for race conditions without distributed locks. -// - Avoiding different use cases from starving each other and ending up in a -// deadlocked state. Concretely, it would be really bad if clients could starve -// internal processes from creating transactions that would allow us to progress -// and submit existing transactions. +// Purpose indicates the intended use purpose of the nonce. By partitioning nonce's by +// purpose, we help prevent various use cases from starving each other. type Purpose uint8 const ( @@ -46,6 +44,17 @@ type Record struct { Purpose Purpose State State + // Contains the NodeId that transitioned the state into StateClaimed. + // + // Should be ignored if State != StateClaimed. + ClaimNodeId string + + // The time at which StateClaimed is no longer valid, and the state should + // be considered StateAvailable. + // + // Should be ignored if State != StateClaimed. + ClaimExpiresAt time.Time + Signature string } @@ -53,15 +62,28 @@ func (r *Record) GetPublicKey() (ed25519.PublicKey, error) { return base58.Decode(r.Address) } +func (r *Record) IsAvailable() bool { + if r.State == StateAvailable { + return true + } + if r.State != StateClaimed { + return false + } + + return time.Now().After(r.ClaimExpiresAt) +} + func (r *Record) Clone() Record { return Record{ - Id: r.Id, - Address: r.Address, - Authority: r.Authority, - Blockhash: r.Blockhash, - Purpose: r.Purpose, - State: r.State, - Signature: r.Signature, + Id: r.Id, + Address: r.Address, + Authority: r.Authority, + Blockhash: r.Blockhash, + Purpose: r.Purpose, + State: r.State, + ClaimNodeId: r.ClaimNodeId, + ClaimExpiresAt: r.ClaimExpiresAt, + Signature: r.Signature, } } @@ -72,21 +94,33 @@ func (r *Record) CopyTo(dst *Record) { dst.Blockhash = r.Blockhash dst.Purpose = r.Purpose dst.State = r.State + dst.ClaimNodeId = r.ClaimNodeId + dst.ClaimExpiresAt = r.ClaimExpiresAt dst.Signature = r.Signature } -func (v *Record) Validate() error { - if len(v.Address) == 0 { +func (r *Record) Validate() error { + if len(r.Address) == 0 { return errors.New("nonce account address is required") } - if len(v.Authority) == 0 { + if len(r.Authority) == 0 { return errors.New("authority address is required") } - if v.Purpose == PurposeUnknown { + if r.Purpose == PurposeUnknown { return errors.New("nonce purpose must be set") } + + if r.State == StateClaimed { + if r.ClaimNodeId == "" { + return errors.New("missing claim node id") + } + if r.ClaimExpiresAt == (time.Time{}) || r.ClaimExpiresAt.IsZero() { + return errors.New("missing claim expiry date") + } + } + return nil } @@ -102,6 +136,8 @@ func (s State) String() string { return "reserved" case StateInvalid: return "invalid" + case StateClaimed: + return "claimed" } return "unknown" diff --git a/pkg/code/data/nonce/postgres/model.go b/pkg/code/data/nonce/postgres/model.go index b7a990b7..87044248 100644 --- a/pkg/code/data/nonce/postgres/model.go +++ b/pkg/code/data/nonce/postgres/model.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "time" "github.com/jmoiron/sqlx" @@ -17,13 +18,15 @@ const ( ) type nonceModel struct { - Id sql.NullInt64 `db:"id"` - Address string `db:"address"` - Authority string `db:"authority"` - Blockhash string `db:"blockhash"` - Purpose uint `db:"purpose"` - State uint `db:"state"` - Signature string `db:"signature"` + Id sql.NullInt64 `db:"id"` + Address string `db:"address"` + Authority string `db:"authority"` + Blockhash string `db:"blockhash"` + Purpose uint `db:"purpose"` + State uint `db:"state"` + Signature string `db:"signature"` + ClaimNodeId string `db:"claim_node_id"` + ClaimExpiresAtMs int64 `db:"claim_expires_at"` } func toNonceModel(obj *nonce.Record) (*nonceModel, error) { @@ -32,33 +35,37 @@ func toNonceModel(obj *nonce.Record) (*nonceModel, error) { } return &nonceModel{ - Id: sql.NullInt64{Int64: int64(obj.Id), Valid: true}, - Address: obj.Address, - Authority: obj.Authority, - Blockhash: obj.Blockhash, - Purpose: uint(obj.Purpose), - State: uint(obj.State), - Signature: obj.Signature, + Id: sql.NullInt64{Int64: int64(obj.Id), Valid: true}, + Address: obj.Address, + Authority: obj.Authority, + Blockhash: obj.Blockhash, + Purpose: uint(obj.Purpose), + State: uint(obj.State), + Signature: obj.Signature, + ClaimNodeId: obj.ClaimNodeId, + ClaimExpiresAtMs: obj.ClaimExpiresAt.UnixMilli(), }, nil } func fromNonceModel(obj *nonceModel) *nonce.Record { return &nonce.Record{ - Id: uint64(obj.Id.Int64), - Address: obj.Address, - Authority: obj.Authority, - Blockhash: obj.Blockhash, - Purpose: nonce.Purpose(obj.Purpose), - State: nonce.State(obj.State), - Signature: obj.Signature, + Id: uint64(obj.Id.Int64), + Address: obj.Address, + Authority: obj.Authority, + Blockhash: obj.Blockhash, + Purpose: nonce.Purpose(obj.Purpose), + State: nonce.State(obj.State), + Signature: obj.Signature, + ClaimNodeId: obj.ClaimNodeId, + ClaimExpiresAt: time.UnixMilli(obj.ClaimExpiresAtMs), } } func (m *nonceModel) dbSave(ctx context.Context, db *sqlx.DB) error { return pgutil.ExecuteInTx(ctx, db, sql.LevelDefault, func(tx *sqlx.Tx) error { query := `INSERT INTO ` + nonceTableName + ` - (address, authority, blockhash, purpose, state, signature) - VALUES ($1,$2,$3,$4,$5,$6) + (address, authority, blockhash, purpose, state, signature, claim_node_id, claim_expires_at) + VALUES ($1,$2,$3,$4,$5,$6,$7,$8) ON CONFLICT (address) DO UPDATE SET blockhash = $3, state = $5, signature = $6 @@ -75,6 +82,8 @@ func (m *nonceModel) dbSave(ctx context.Context, db *sqlx.DB) error { m.Purpose, m.State, m.Signature, + m.ClaimNodeId, + m.ClaimExpiresAtMs, ).StructScan(m) return pgutil.CheckNoRows(err, nonce.ErrInvalidNonce) @@ -162,48 +171,77 @@ func dbGetAllByState(ctx context.Context, db *sqlx.DB, state nonce.State, cursor return res, nil } -// todo: Implementation still isn't perfect, but better than no randomness. It's -// sufficiently efficient, as long as our nonce pool is larger than the max offset. -// todo: We may need to tune the offset based on pool size and environment, but it -// should be sufficiently good enough for now. +// We query a random nonce by first selecting any available candidate from the +// total set, applying an upper limit of 100, and _then_ randomly shuffling the +// results and selecting the first. By bounding the size before ORDER BY random(), +// we avoid having to shuffle large sets of results. +// +// Previously, we would use OFFSET FLOOR(RANDOM() * 100). However, if the pool +// (post filter) size was less than 100, any selection > pool size would result +// in the OFFSET being set to zero. This meant random() disappeared for a subset +// of values. In practice, this would result in a bias, and increased contention. +// +// For example, 50 Available nonce's, 25 Claimed (expired), 25 Reserved. With Offset: +// +// 1. 50% of the time would be a random Available. +// 2. 25% of the time would be a random expired Claimed. +// 3. 25% of the time would be _the first_ Available. +// +// This meant that 25% of the time would not be random. As we pull from the pool, +// this % only increases, further causing contention. +// +// Performance wise, this approach is slightly worse, but the vast majority of the +// time is spent on the scan and filter. Below are two example query plans (from a +// small dataset in an online editor). +// +// QUERY PLAN (OFFSET): +// +// Limit (cost=17.80..35.60 rows=1 width=140) (actual time=0.019..0.019 rows=0 loops=1) +// -> Seq Scan on codewallet__core_nonce (cost=0.00..17.80 rows=1 width=140) (actual time=0.016..0.017 rows=0 loops=1) +// Filter: ((signature IS NOT NULL) AND (purpose = 1) AND ((state = 0) OR ((state = 2) AND (claim_expires_at < 200)))) +// Rows Removed by Filter: 100 +// +// Planning Time: 0.046 ms +// Execution Time: 0.031 ms +// +// QUERY PLAN (ORDER BY): +// +// Limit (cost=17.82..17.83 rows=1 width=148) (actual time=0.018..0.019 rows=0 loops=1) +// -> Sort (cost=17.82..17.83 rows=1 width=148) (actual time=0.018..0.018 rows=0 loops=1) +// Sort Key: (random()) +// Sort Method: quicksort Memory: 25kB +// -> Subquery Scan on sub (cost=0.00..17.81 rows=1 width=148) (actual time=0.015..0.016 rows=0 loops=1) +// -> Limit (cost=0.00..17.80 rows=1 width=140) (actual time=0.015..0.015 rows=0 loops=1) +// -> Seq Scan on codewallet__core_nonce (cost=0.00..17.80 rows=1 width=140) (actual time=0.015..0.015 rows=0 loops=1) +// Filter: ((signature IS NOT NULL) AND (purpose = 1) AND ((state = 0) OR ((state = 2) AND (claim_expires_at < 200)))) +// Rows Removed by Filter: 100 +// +// Planning Time: 0.068 ms +// Execution Time: 0.037 ms +// +// Overall, the Seq Scan and Filter is the bulk of the work, with the ORDER BY RANDOM() +// adding a small (fixed) amount of overhead. The trade-off is negligible time complexity +// for more reliable semantics. func dbGetRandomAvailableByPurpose(ctx context.Context, db *sqlx.DB, purpose nonce.Purpose) (*nonceModel, error) { res := &nonceModel{} + nowMs := time.Now().UnixMilli() // Signature null check is required because some legacy records didn't have this // set and causes this call to fail. This is a result of the field not being // defined at the time of record creation. // // todo: Fix said nonce records - query := `SELECT - id, address, authority, blockhash, purpose, state, signature - FROM ` + nonceTableName + ` - WHERE state = $1 AND purpose = $2 AND signature IS NOT NULL - OFFSET FLOOR(RANDOM() * 100) - LIMIT 1 - ` - fallbackQuery := `SELECT - id, address, authority, blockhash, purpose, state, signature - FROM ` + nonceTableName + ` - WHERE state = $1 AND purpose = $2 AND signature IS NOT NULL + query := ` + SELECT id, address, authority, blockhash, purpose, state, signature FROM ( + SELECT id, address, authority, blockhash, purpose, state, signature + FROM ` + nonceTableName + ` + WHERE ((state = $1) OR (state = $2 AND claim_expires_at < $3)) AND purpose = $4 AND signature IS NOT NULL + LIMIT 100 + ) sub + ORDER BY random() LIMIT 1 ` - err := db.GetContext(ctx, res, query, nonce.StateAvailable, purpose) - if err != nil { - err = pgutil.CheckNoRows(err, nonce.ErrNonceNotFound) - - // No nonces found. Because our query isn't perfect, fall back to a - // strategy that will guarantee to select something if an available - // nonce exists. - if err == nonce.ErrNonceNotFound { - err := db.GetContext(ctx, res, fallbackQuery, nonce.StateAvailable, purpose) - if err != nil { - return nil, pgutil.CheckNoRows(err, nonce.ErrNonceNotFound) - } - return res, nil - } - - return nil, err - } - return res, nil + err := db.GetContext(ctx, res, query, nonce.StateAvailable, nonce.StateClaimed, nowMs, purpose) + return res, pgutil.CheckNoRows(err, nonce.ErrNonceNotFound) } diff --git a/pkg/code/data/nonce/postgres/store_test.go b/pkg/code/data/nonce/postgres/store_test.go index bfc3b46f..0851e347 100644 --- a/pkg/code/data/nonce/postgres/store_test.go +++ b/pkg/code/data/nonce/postgres/store_test.go @@ -22,13 +22,16 @@ const ( CREATE TABLE codewallet__core_nonce( id SERIAL NOT NULL PRIMARY KEY, - address text NOT NULL UNIQUE, + address text NOT NULL UNIQUE, authority text NOT NULL, blockhash text NULL, purpose integer NOT NULL, state integer NOT NULL, - signature text NULL + signature text NULL, + + claim_node_id text, + claim_expires_at bigint ); ` diff --git a/pkg/code/data/nonce/tests/tests.go b/pkg/code/data/nonce/tests/tests.go index 8eee5cff..3f6bbbb1 100644 --- a/pkg/code/data/nonce/tests/tests.go +++ b/pkg/code/data/nonce/tests/tests.go @@ -3,10 +3,13 @@ package tests import ( "context" "fmt" + "strconv" + "strings" "testing" + "time" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/nonce" + "github.com/code-payments/code-server/pkg/database/query" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -15,6 +18,7 @@ func RunTests(t *testing.T, s nonce.Store, teardown func()) { for _, tf := range []func(t *testing.T, s nonce.Store){ testRoundTrip, testUpdate, + testUpdateInvalid, testGetAllByState, testGetCount, testGetRandomAvailableByPurpose, @@ -80,6 +84,52 @@ func testUpdate(t *testing.T, s nonce.Store) { assert.EqualValues(t, 1, actual.Id) } +func testUpdateInvalid(t *testing.T, s nonce.Store) { + ctx := context.Background() + + for _, invalid := range []*nonce.Record{ + {}, + { + Address: "test_address", + }, + { + Address: "test_address", + Authority: "test_authority", + }, + { + Address: "test_address", + Authority: "test_authority", + Blockhash: "block_hash", + }, + { + Address: "test_address", + Authority: "test_authority", + Blockhash: "test_blockhash", + Purpose: nonce.PurposeClientTransaction, + State: nonce.StateClaimed, + }, + { + Address: "test_address", + Authority: "test_authority", + Blockhash: "test_blockhash", + Purpose: nonce.PurposeClientTransaction, + State: nonce.StateClaimed, + ClaimNodeId: "my-node", + }, + { + Address: "test_address", + Authority: "test_authority", + Blockhash: "test_blockhash", + Purpose: nonce.PurposeClientTransaction, + State: nonce.StateClaimed, + ClaimExpiresAt: time.Now().Add(time.Hour), + }, + } { + require.Error(t, invalid.Validate()) + assert.Error(t, s.Save(ctx, invalid)) + } +} + func testGetAllByState(t *testing.T, s nonce.Store) { ctx := context.Background() @@ -260,8 +310,9 @@ func testGetRandomAvailableByPurpose(t *testing.T, s nonce.Store) { nonce.StateUnknown, nonce.StateAvailable, nonce.StateReserved, + nonce.StateClaimed, } { - for i := 0; i < 500; i++ { + for i := 0; i < 50; i++ { record := &nonce.Record{ Address: fmt.Sprintf("nonce_%s_%s_%d", purpose, state, i), Authority: "authority", @@ -270,27 +321,83 @@ func testGetRandomAvailableByPurpose(t *testing.T, s nonce.Store) { State: state, Signature: "", } + if state == nonce.StateClaimed { + record.ClaimNodeId = "my-node-id" + + if i < 25 { + record.ClaimExpiresAt = time.Now().Add(-time.Hour) + } else { + record.ClaimExpiresAt = time.Now().Add(time.Hour) + } + } + require.NoError(t, s.Save(ctx, record)) } } } + var sequentialLoads int + var availableState, claimedState int + var lastNonce *nonce.Record selectedByAddress := make(map[string]struct{}) - for i := 0; i < 100; i++ { + for i := 0; i < 1000; i++ { actual, err := s.GetRandomAvailableByPurpose(ctx, nonce.PurposeClientTransaction) require.NoError(t, err) assert.Equal(t, nonce.PurposeClientTransaction, actual.Purpose) - assert.Equal(t, nonce.StateAvailable, actual.State) + assert.True(t, actual.IsAvailable()) + + switch actual.State { + case nonce.StateAvailable: + availableState++ + case nonce.StateClaimed: + claimedState++ + assert.True(t, time.Now().After(actual.ClaimExpiresAt)) + default: + } + + // We test for randomness by ensuring we're not loading nonce's sequentially. + if lastNonce != nil && lastNonce.Purpose == actual.Purpose { + lastID, err := strconv.ParseInt(strings.Split(lastNonce.Address, "_")[4], 10, 64) + require.NoError(t, err) + currentID, _ := strconv.ParseInt(strings.Split(actual.Address, "_")[4], 10, 64) + require.NoError(t, err) + + if currentID == lastID+1 { + sequentialLoads++ + } + } + selectedByAddress[actual.Address] = struct{}{} + lastNonce = actual } - assert.True(t, len(selectedByAddress) > 10) + assert.Greater(t, len(selectedByAddress), 10) + assert.NotZero(t, availableState) + assert.NotZero(t, claimedState) + + // We allocated 50 available nonce's, and 25 expired claim nonces. Given that + // we randomly select out of the first available 100 nonces, we expect a ratio + // of 2:1 Available vs Expired Claimed nonces. + assert.InDelta(t, 2.0, float64(availableState)/float64(claimedState), 0.5) + + assert.Less(t, sequentialLoads, 100) + availableState, claimedState = 0, 0 selectedByAddress = make(map[string]struct{}) for i := 0; i < 100; i++ { actual, err := s.GetRandomAvailableByPurpose(ctx, nonce.PurposeInternalServerProcess) require.NoError(t, err) assert.Equal(t, nonce.PurposeInternalServerProcess, actual.Purpose) - assert.Equal(t, nonce.StateAvailable, actual.State) + assert.True(t, actual.IsAvailable()) + + switch actual.State { + case nonce.StateAvailable: + availableState++ + case nonce.StateClaimed: + claimedState++ + assert.True(t, time.Now().After(actual.ClaimExpiresAt)) + default: + } + selectedByAddress[actual.Address] = struct{}{} } assert.True(t, len(selectedByAddress) > 10) diff --git a/pkg/code/transaction/nonce.go b/pkg/code/transaction/nonce.go index 1f6fc7f3..6bfca0a1 100644 --- a/pkg/code/transaction/nonce.go +++ b/pkg/code/transaction/nonce.go @@ -8,12 +8,12 @@ import ( "github.com/mr-tron/base58" - "github.com/code-payments/code-server/pkg/retry" - "github.com/code-payments/code-server/pkg/solana" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/nonce" + "github.com/code-payments/code-server/pkg/retry" + "github.com/code-payments/code-server/pkg/solana" ) var ( @@ -66,7 +66,7 @@ func SelectAvailableNonce(ctx context.Context, data code_data.Provider, useCase defer globalNonceLock.Unlock() randomRecord, err := data.GetRandomAvailableNonceByPurpose(ctx, useCase) - if err == nonce.ErrNonceNotFound { + if errors.Is(err, nonce.ErrNonceNotFound) { return ErrNoAvailableNonces } else if err != nil { return err @@ -77,14 +77,14 @@ func SelectAvailableNonce(ctx context.Context, data code_data.Provider, useCase lock = getNonceLock(record.Address) lock.Lock() - // Refetch because the state could have changed by the time we got the lock + // Re-fetch because the state could have changed by the time we got the lock record, err = data.GetNonce(ctx, record.Address) if err != nil { lock.Unlock() return err } - if record.State != nonce.StateAvailable { + if !record.IsAvailable() { // Unlock and try again lock.Unlock() return errors.New("selected nonce that became unavailable") @@ -105,6 +105,8 @@ func SelectAvailableNonce(ctx context.Context, data code_data.Provider, useCase // Reserve the nonce for use with a fulfillment record.State = nonce.StateReserved + record.ClaimNodeId = "" + record.ClaimExpiresAt = time.UnixMilli(0) err = data.SaveNonce(ctx, record) if err != nil { lock.Unlock() @@ -212,12 +214,15 @@ func (n *SelectedNonce) MarkReservedWithSignature(ctx context.Context, sig strin return n.data.SaveNonce(ctx, n.record) } - if n.record.State != nonce.StateAvailable { + if !n.record.IsAvailable() { return errors.New("nonce must be available to reserve") } n.record.State = nonce.StateReserved n.record.Signature = sig + n.record.ClaimNodeId = "" + n.record.ClaimExpiresAt = time.UnixMilli(0) + return n.data.SaveNonce(ctx, n.record) } @@ -250,8 +255,8 @@ func (n *SelectedNonce) UpdateSignature(ctx context.Context, sig string) error { // ReleaseIfNotReserved makes a nonce available if it hasn't been reserved with // a signature. It's recommended to call this in tandem with Unlock when the -// caller knows it's safe to go from the reserved to available state (ie. don't -// use this in uprade flows!). +// caller knows it's safe to go from the reserved to available state (i.e. don't +// use this in upgrade flows!). func (n *SelectedNonce) ReleaseIfNotReserved() error { n.localLock.Lock() defer n.localLock.Unlock() @@ -264,6 +269,12 @@ func (n *SelectedNonce) ReleaseIfNotReserved() error { return nil } + if n.record.State == nonce.StateClaimed { + n.record.State = nonce.StateAvailable + n.record.ClaimNodeId = "" + n.record.ClaimExpiresAt = time.UnixMilli(0) + } + // A nonce is not fully reserved if it's state is reserved, but there is no // assigned signature. if n.record.State == nonce.StateReserved && len(n.record.Signature) == 0 { diff --git a/pkg/code/transaction/nonce_test.go b/pkg/code/transaction/nonce_test.go index ca8d7f28..86c19e70 100644 --- a/pkg/code/transaction/nonce_test.go +++ b/pkg/code/transaction/nonce_test.go @@ -10,14 +10,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/solana" - "github.com/code-payments/code-server/pkg/testutil" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/nonce" "github.com/code-payments/code-server/pkg/code/data/vault" + "github.com/code-payments/code-server/pkg/pointer" + "github.com/code-payments/code-server/pkg/solana" + "github.com/code-payments/code-server/pkg/testutil" ) func TestNonce_SelectAvailableNonce(t *testing.T) { @@ -62,6 +62,32 @@ func TestNonce_SelectAvailableNonce(t *testing.T) { assert.Equal(t, ErrNoAvailableNonces, err) } +func TestNonce_SelectAvailableNonceClaimed(t *testing.T) { + env := setupNonceTestEnv(t) + + // Should ignore (non-expired) claimed nonces. + expiredNonces := map[string]*nonce.Record{} + for i := 0; i < 10; i++ { + n := generateClaimedNonce(t, env, true) + expiredNonces[n.Address] = n + + generateClaimedNonce(t, env, false) + } + + for i := 0; i < 10; i++ { + nonce, err := SelectAvailableNonce(env.ctx, env.data, nonce.PurposeClientTransaction) + require.NoError(t, err) + + _, ok := expiredNonces[nonce.Account.PublicKey().ToBase58()] + require.True(t, ok) + require.True(t, nonce.record.ClaimExpiresAt.Before(time.Now())) + delete(expiredNonces, nonce.Account.PublicKey().ToBase58()) + } + + _, err := SelectAvailableNonce(env.ctx, env.data, nonce.PurposeInternalServerProcess) + require.ErrorIs(t, ErrNoAvailableNonces, err) +} + func TestNonce_SelectNonceFromFulfillmentToUpgrade_HappyPath(t *testing.T) { env := setupNonceTestEnv(t) @@ -238,7 +264,7 @@ func generateAvailableNonce(t *testing.T, env nonceTestEnv, useCase nonce.Purpos Address: nonceAccount.PublicKey().ToBase58(), Authority: common.GetSubsidizer().PublicKey().ToBase58(), Blockhash: base58.Encode(bh[:]), - Purpose: nonce.PurposeClientTransaction, + Purpose: useCase, State: nonce.StateAvailable, } require.NoError(t, env.data.SaveKey(env.ctx, nonceKey)) @@ -246,6 +272,37 @@ func generateAvailableNonce(t *testing.T, env nonceTestEnv, useCase nonce.Purpos return nonceRecord } +func generateClaimedNonce(t *testing.T, env nonceTestEnv, expired bool) *nonce.Record { + nonceAccount := testutil.NewRandomAccount(t) + + var bh solana.Blockhash + rand.Read(bh[:]) + + nonceKey := &vault.Record{ + PublicKey: nonceAccount.PublicKey().ToBase58(), + PrivateKey: nonceAccount.PrivateKey().ToBase58(), + State: vault.StateAvailable, + CreatedAt: time.Now(), + } + nonceRecord := &nonce.Record{ + Address: nonceAccount.PublicKey().ToBase58(), + Authority: common.GetSubsidizer().PublicKey().ToBase58(), + Blockhash: base58.Encode(bh[:]), + Purpose: nonce.PurposeClientTransaction, + State: nonce.StateClaimed, + ClaimNodeId: "my-node-id", + } + if expired { + nonceRecord.ClaimExpiresAt = time.Now().Add(-time.Hour) + } else { + nonceRecord.ClaimExpiresAt = time.Now().Add(time.Hour) + } + + require.NoError(t, env.data.SaveKey(env.ctx, nonceKey)) + require.NoError(t, env.data.SaveNonce(env.ctx, nonceRecord)) + return nonceRecord +} + func generateAvailableNonces(t *testing.T, env nonceTestEnv, useCase nonce.Purpose, count int) []*nonce.Record { var nonces []*nonce.Record for i := 0; i < count; i++ { diff --git a/pkg/database/postgres/errors.go b/pkg/database/postgres/errors.go index 38f8155a..a61e286c 100644 --- a/pkg/database/postgres/errors.go +++ b/pkg/database/postgres/errors.go @@ -2,14 +2,14 @@ package pg import ( "database/sql" + "errors" "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" - "github.com/pkg/errors" ) func CheckNoRows(inErr, outErr error) error { - if inErr == sql.ErrNoRows { + if errors.Is(inErr, sql.ErrNoRows) { return outErr } return inErr From c918fa2a2c12022f10ce48c35531a056b764472a Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Mon, 10 Jun 2024 00:26:41 -0400 Subject: [PATCH 2/4] nonce: explicitly state NULL in new fields. --- pkg/code/data/nonce/postgres/store_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/code/data/nonce/postgres/store_test.go b/pkg/code/data/nonce/postgres/store_test.go index 0851e347..70ef5ee6 100644 --- a/pkg/code/data/nonce/postgres/store_test.go +++ b/pkg/code/data/nonce/postgres/store_test.go @@ -30,8 +30,8 @@ const ( state integer NOT NULL, signature text NULL, - claim_node_id text, - claim_expires_at bigint + claim_node_id text NULL, + claim_expires_at bigint NULL ); ` From 06cb79dc01281196fcc9ae1dbaefb69ddce2dcbe Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Wed, 19 Jun 2024 19:43:49 -0400 Subject: [PATCH 3/4] nonce: add BatchClaimAvailableByPurpose method. Adds the ability to batch claim available nonce (by purpose) for efficient use of claiming a set of nonces. This lays down the ground work for nonce pools. --- pkg/code/data/internal.go | 4 + pkg/code/data/nonce/memory/store.go | 47 ++++++- pkg/code/data/nonce/postgres/model.go | 62 ++++++++- pkg/code/data/nonce/postgres/store.go | 27 +++- pkg/code/data/nonce/store.go | 12 ++ pkg/code/data/nonce/tests/tests.go | 183 ++++++++++++++++++++++++++ 6 files changed, 327 insertions(+), 8 deletions(-) diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index 49479c84..4129cb70 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -180,6 +180,7 @@ type DatabaseData interface { GetNonceCountByStateAndPurpose(ctx context.Context, state nonce.State, purpose nonce.Purpose) (uint64, error) GetAllNonceByState(ctx context.Context, state nonce.State, opts ...query.Option) ([]*nonce.Record, error) GetRandomAvailableNonceByPurpose(ctx context.Context, purpose nonce.Purpose) (*nonce.Record, error) + BatchClaimAvailableByPurpose(ctx context.Context, purpose nonce.Purpose, limit int, nodeID string, minExpireAt, maxExpireAt time.Time) ([]*nonce.Record, error) SaveNonce(ctx context.Context, record *nonce.Record) error // Fulfillment @@ -743,6 +744,9 @@ func (dp *DatabaseProvider) GetAllNonceByState(ctx context.Context, state nonce. func (dp *DatabaseProvider) GetRandomAvailableNonceByPurpose(ctx context.Context, purpose nonce.Purpose) (*nonce.Record, error) { return dp.nonces.GetRandomAvailableByPurpose(ctx, purpose) } +func (dp *DatabaseProvider) BatchClaimAvailableByPurpose(ctx context.Context, purpose nonce.Purpose, limit int, nodeID string, minExpireAt, maxExpireAt time.Time) ([]*nonce.Record, error) { + return dp.nonces.BatchClaimAvailableByPurpose(ctx, purpose, limit, nodeID, minExpireAt, maxExpireAt) +} func (dp *DatabaseProvider) SaveNonce(ctx context.Context, record *nonce.Record) error { return dp.nonces.Save(ctx, record) } diff --git a/pkg/code/data/nonce/memory/store.go b/pkg/code/data/nonce/memory/store.go index d2c2d1ce..e5336dbf 100644 --- a/pkg/code/data/nonce/memory/store.go +++ b/pkg/code/data/nonce/memory/store.go @@ -5,6 +5,7 @@ import ( "math/rand" "sort" "sync" + "time" "github.com/code-payments/code-server/pkg/code/data/nonce" "github.com/code-payments/code-server/pkg/database/query" @@ -60,21 +61,25 @@ func (s *store) findAddress(address string) *nonce.Record { func (s *store) findByState(state nonce.State) []*nonce.Record { return s.findFn(func(nonce *nonce.Record) bool { return nonce.State == state - }) + }, -1) } func (s *store) findByStateAndPurpose(state nonce.State, purpose nonce.Purpose) []*nonce.Record { return s.findFn(func(record *nonce.Record) bool { return record.State == state && record.Purpose == purpose - }) + }, -1) } -func (s *store) findFn(f func(nonce *nonce.Record) bool) []*nonce.Record { +func (s *store) findFn(f func(nonce *nonce.Record) bool, limit int) []*nonce.Record { res := make([]*nonce.Record, 0) for _, item := range s.records { if f(item) { res = append(res, item) } + + if limit >= 0 && len(res) == limit { + break + } } return res } @@ -192,7 +197,7 @@ func (s *store) GetRandomAvailableByPurpose(ctx context.Context, purpose nonce.P items := s.findFn(func(n *nonce.Record) bool { return n.Purpose == purpose && n.IsAvailable() - }) + }, -1) if len(items) == 0 { return nil, nonce.ErrNonceNotFound } @@ -200,3 +205,37 @@ func (s *store) GetRandomAvailableByPurpose(ctx context.Context, purpose nonce.P index := rand.Intn(len(items)) return items[index], nil } + +func (s *store) BatchClaimAvailableByPurpose( + ctx context.Context, + purpose nonce.Purpose, + limit int, + nodeId string, + minExpireAt time.Time, + maxExpireAt time.Time, +) ([]*nonce.Record, error) { + s.mu.Lock() + defer s.mu.Unlock() + + items := s.findFn(func(n *nonce.Record) bool { + return n.Purpose == purpose && n.IsAvailable() + }, limit) + if len(items) == 0 { + return nil, nil + } + + for i, l := 0, len(items); i < l; i++ { + j := rand.Intn(l) + items[i], items[j] = items[j], items[i] + } + for i := 0; i < len(items); i++ { + window := maxExpireAt.Sub(minExpireAt) + expiry := minExpireAt.Add(time.Duration(rand.Intn(int(window)))) + + items[i].State = nonce.StateClaimed + items[i].ClaimNodeId = nodeId + items[i].ClaimExpiresAt = expiry + } + + return items, nil +} diff --git a/pkg/code/data/nonce/postgres/model.go b/pkg/code/data/nonce/postgres/model.go index 87044248..122ef5ab 100644 --- a/pkg/code/data/nonce/postgres/model.go +++ b/pkg/code/data/nonce/postgres/model.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "fmt" "time" "github.com/jmoiron/sqlx" @@ -69,7 +70,7 @@ func (m *nonceModel) dbSave(ctx context.Context, db *sqlx.DB) error { ON CONFLICT (address) DO UPDATE SET blockhash = $3, state = $5, signature = $6 - WHERE ` + nonceTableName + `.address = $1 + WHERE ` + nonceTableName + `.address = $1 RETURNING id, address, authority, blockhash, purpose, state, signature` @@ -130,7 +131,7 @@ func dbGetNonce(ctx context.Context, db *sqlx.DB, address string) (*nonceModel, res := &nonceModel{} query := `SELECT - id, address, authority, blockhash, purpose, state, signature + id, address, authority, blockhash, purpose, state, signature, claim_node_id, claim_expires_at FROM ` + nonceTableName + ` WHERE address = $1 ` @@ -245,3 +246,60 @@ func dbGetRandomAvailableByPurpose(ctx context.Context, db *sqlx.DB, purpose non err := db.GetContext(ctx, res, query, nonce.StateAvailable, nonce.StateClaimed, nowMs, purpose) return res, pgutil.CheckNoRows(err, nonce.ErrNonceNotFound) } + +func dbBatchClaimAvailableByPurpose( + ctx context.Context, + db *sqlx.DB, + purpose nonce.Purpose, + limit int, + nodeId string, + minExpireAt time.Time, + maxExpireAt time.Time, +) ([]*nonceModel, error) { + tx, err := db.Beginx() + if err != nil { + return nil, err + } + + query := ` + WITH selected_nonces AS ( + SELECT id, address, authority, blockhash, purpose, state, signature + FROM ` + nonceTableName + ` + WHERE ((state = $1) OR (state = $2 AND claim_expires_at < $3)) AND purpose = $4 AND signature IS NOT NULL + LIMIT $5 + FOR UPDATE + ) + UPDATE ` + nonceTableName + ` + SET state = $6, + claim_node_id = $7, + claim_expires_at = $8 + FLOOR(RANDOM() * $9) + FROM selected_nonces + WHERE ` + nonceTableName + `.id = selected_nonces.id + RETURNING ` + nonceTableName + `.* + ` + + nonceModels := []*nonceModel{} + err = tx.SelectContext( + ctx, + &nonceModels, + query, + nonce.StateAvailable, + nonce.StateClaimed, + time.Now().UnixMilli(), + purpose, + limit, + nonce.StateClaimed, + nodeId, + minExpireAt.UnixMilli(), + maxExpireAt.Sub(minExpireAt).Milliseconds(), + ) + if err != nil { + if rollBackErr := tx.Rollback(); rollBackErr != nil { + return nil, fmt.Errorf("failed to rollback (cause: %w): %w", err, rollBackErr) + } + + return nil, err + } + + return nonceModels, tx.Commit() +} diff --git a/pkg/code/data/nonce/postgres/store.go b/pkg/code/data/nonce/postgres/store.go index fa44d5fb..bdc48c00 100644 --- a/pkg/code/data/nonce/postgres/store.go +++ b/pkg/code/data/nonce/postgres/store.go @@ -3,10 +3,12 @@ package postgres import ( "context" "database/sql" + "time" - "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/code/data/nonce" "github.com/jmoiron/sqlx" + + "github.com/code-payments/code-server/pkg/code/data/nonce" + "github.com/code-payments/code-server/pkg/database/query" ) type store struct { @@ -93,3 +95,24 @@ func (s *store) GetRandomAvailableByPurpose(ctx context.Context, purpose nonce.P } return fromNonceModel(model), nil } + +func (s *store) BatchClaimAvailableByPurpose( + ctx context.Context, + purpose nonce.Purpose, + limit int, + nodeId string, + minExpireAt time.Time, + maxExpireAt time.Time, +) ([]*nonce.Record, error) { + models, err := dbBatchClaimAvailableByPurpose(ctx, s.db, purpose, limit, nodeId, minExpireAt, maxExpireAt) + if err != nil { + return nil, err + } + + nonces := make([]*nonce.Record, len(models)) + for i, model := range models { + nonces[i] = fromNonceModel(model) + } + + return nonces, nil +} diff --git a/pkg/code/data/nonce/store.go b/pkg/code/data/nonce/store.go index 973feb37..b7411793 100644 --- a/pkg/code/data/nonce/store.go +++ b/pkg/code/data/nonce/store.go @@ -2,6 +2,7 @@ package nonce import ( "context" + "time" "github.com/code-payments/code-server/pkg/database/query" ) @@ -35,4 +36,15 @@ type Store interface { // // Returns ErrNotFound if no records are found. GetRandomAvailableByPurpose(ctx context.Context, purpose Purpose) (*Record, error) + + // BatchClaimAvailableByPurpose batch claims up to the specified limit. + // + // The returned nonces will be marked as claimed by the current node, with + // the specified expiry date. + // + // Note: Implementations need not randomize the results/selection. + // The transactional nature of the call means that any contention exists + // on the tx level (which always occurs), and not around fighting over + // individual nonces (which was negated in the GetRandomAvailableByPurpose). + BatchClaimAvailableByPurpose(ctx context.Context, purpose Purpose, limit int, nodeID string, minExpireAt, maxExpireAt time.Time) ([]*Record, error) } diff --git a/pkg/code/data/nonce/tests/tests.go b/pkg/code/data/nonce/tests/tests.go index 3f6bbbb1..9f42e2d7 100644 --- a/pkg/code/data/nonce/tests/tests.go +++ b/pkg/code/data/nonce/tests/tests.go @@ -3,6 +3,8 @@ package tests import ( "context" "fmt" + "math" + "slices" "strconv" "strings" "testing" @@ -22,6 +24,8 @@ func RunTests(t *testing.T, s nonce.Store, teardown func()) { testGetAllByState, testGetCount, testGetRandomAvailableByPurpose, + testBatch, + testBatchClaimExpirationRandomness, } { tf(t, s) teardown() @@ -403,3 +407,182 @@ func testGetRandomAvailableByPurpose(t *testing.T, s nonce.Store) { assert.True(t, len(selectedByAddress) > 10) }) } + +func testBatch(t *testing.T, s nonce.Store) { + t.Run("testBatch", func(t *testing.T) { + ctx := context.Background() + + minExpiry := time.Now().Add(time.Hour).Truncate(time.Millisecond) + maxExpiry := time.Now().Add(2 * time.Hour).Truncate(time.Millisecond) + + nonces, err := s.BatchClaimAvailableByPurpose( + ctx, + nonce.PurposeClientTransaction, + 100, + "my-id", + minExpiry, + maxExpiry, + ) + require.Empty(t, nonces) + require.Nil(t, err) + + // Initialize nonce pool. + for _, purpose := range []nonce.Purpose{ + nonce.PurposeClientTransaction, + nonce.PurposeInternalServerProcess, + } { + for _, state := range []nonce.State{ + nonce.StateUnknown, + nonce.StateAvailable, + nonce.StateReserved, + nonce.StateClaimed, + } { + for i := 0; i < 50; i++ { + record := &nonce.Record{ + Address: fmt.Sprintf("nonce_%s_%s_%d", purpose, state, i), + Authority: "authority", + Blockhash: "bh", + Purpose: purpose, + State: state, + Signature: "", + } + if state == nonce.StateClaimed { + record.ClaimNodeId = "my-node-id" + + if i < 25 { + record.ClaimExpiresAt = time.Now().Add(-time.Hour) + } else { + record.ClaimExpiresAt = time.Now().Add(time.Hour) + } + } + + require.NoError(t, s.Save(ctx, record)) + } + } + } + + // Iteratively grab a subset until there are none left. + // + // Note: The odd amount ensures we try grabbing more than exists. + var claimed []*nonce.Record + for remaining := 75; remaining > 0; { + nonces, err = s.BatchClaimAvailableByPurpose( + ctx, + nonce.PurposeClientTransaction, + 10, + "my-id", + minExpiry, + maxExpiry, + ) + require.NoError(t, err) + require.Len(t, nonces, min(remaining, 10)) + + remaining -= len(nonces) + + for _, n := range nonces { + actual, err := s.Get(ctx, n.Address) + require.NoError(t, err) + require.Equal(t, nonce.StateClaimed, actual.State) + require.Equal(t, "my-id", actual.ClaimNodeId) + require.GreaterOrEqual(t, actual.ClaimExpiresAt, minExpiry) + require.LessOrEqual(t, actual.ClaimExpiresAt, maxExpiry) + require.Equal(t, nonce.PurposeClientTransaction, actual.Purpose) + + claimed = append(claimed, actual) + } + } + + // Ensure no more nonces. + nonces, err = s.BatchClaimAvailableByPurpose(ctx, nonce.PurposeClientTransaction, 10, "my-id", minExpiry, maxExpiry) + require.NoError(t, err) + require.Empty(t, nonces) + + // Release and reclaim + for i := range claimed[:20] { + claimed[i].State = nonce.StateAvailable + s.Save(ctx, claimed[i]) + } + + nonces, err = s.BatchClaimAvailableByPurpose(ctx, nonce.PurposeClientTransaction, 30, "my-id2", minExpiry, maxExpiry) + require.NoError(t, err) + require.Len(t, nonces, 20) + + // We sort the sets so we can trivially compare and ensure + // that it's the same set. + slices.SortFunc(claimed[:20], func(a, b *nonce.Record) int { + return strings.Compare(a.Address, b.Address) + }) + slices.SortFunc(nonces, func(a, b *nonce.Record) int { + return strings.Compare(a.Address, b.Address) + }) + + for i, actual := range nonces { + require.Equal(t, nonce.StateClaimed, actual.State) + require.Equal(t, "my-id2", actual.ClaimNodeId) + require.GreaterOrEqual(t, actual.ClaimExpiresAt, minExpiry) + require.LessOrEqual(t, actual.ClaimExpiresAt, maxExpiry) + require.Equal(t, nonce.PurposeClientTransaction, actual.Purpose) + require.Equal(t, claimed[i].Address, actual.Address) + } + }) +} + +func testBatchClaimExpirationRandomness(t *testing.T, s nonce.Store) { + t.Run("testBatch", func(t *testing.T) { + ctx := context.Background() + + min := time.Now().Add(time.Hour).Truncate(time.Millisecond) + max := time.Now().Add(2 * time.Hour).Truncate(time.Millisecond) + + for i := 0; i < 1000; i++ { + record := &nonce.Record{ + Address: fmt.Sprintf("nonce_%s_%s_%d", nonce.PurposeClientTransaction, nonce.StateAvailable, i), + Authority: "authority", + Blockhash: "bh", + Purpose: nonce.PurposeClientTransaction, + State: nonce.StateAvailable, + Signature: "", + } + + require.NoError(t, s.Save(ctx, record)) + } + + nonces, err := s.BatchClaimAvailableByPurpose( + ctx, + nonce.PurposeClientTransaction, + 1000, + "my-id", + min, + max, + ) + require.NoError(t, err) + require.Len(t, nonces, 1000) + + // To verify that we have a rough random distribution of expirations, + // we bucket the expiration space, and compute the standard deviation. + // + // We then compare against the expected value with a tolerance. + // Specifically, we know there should be 50 nonces per bucket in + // an ideal world, and we allow for a 15% deviation on this. + bins := make([]int64, 20) + expected := float64(len(nonces)) / float64(len(bins)) + for _, n := range nonces { + // Formula: bin = k(val - min) / (max-min+1) + // + // We use '+1' in the divisor to ensure we don't divide by zero. + // In practive, this should produce pretty much no bias since our + // testing ranges are large. + bin := int(n.ClaimExpiresAt.Sub(min).Milliseconds()) * len(bins) / int(max.Sub(min).Milliseconds()+1) + bins[bin]++ + } + + sum := 0.0 + for _, count := range bins { + diff := float64(count) - expected + sum += diff * diff + } + + stdDev := math.Sqrt(sum / float64(len(bins))) + assert.LessOrEqual(t, stdDev, 0.15*expected, "expected: %v, bins %v:", expected, bins) + }) +} From 9007fbe8f30392e8f90dee8758a0e508e3eb7e96 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Thu, 20 Jun 2024 11:05:05 -0400 Subject: [PATCH 4/4] transaction: add NoncePool implementation (without hooks). --- go.mod | 1 + go.sum | 2 + pkg/code/transaction/nonce_pool.go | 415 ++++++++++++++++++++++++ pkg/code/transaction/nonce_pool_test.go | 193 +++++++++++ 4 files changed, 611 insertions(+) create mode 100644 pkg/code/transaction/nonce_pool.go create mode 100644 pkg/code/transaction/nonce_pool_test.go diff --git a/go.mod b/go.mod index 8701d46c..df3c4eb1 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/jackc/pgx/v4 v4.13.0 github.com/jdgcs/ed25519 v0.0.0-20200408034030-96c10d46cdc3 github.com/jmoiron/sqlx v1.3.4 + github.com/jonboulle/clockwork v0.4.0 github.com/mr-tron/base58 v1.2.0 github.com/newrelic/go-agent/v3 v3.20.1 github.com/newrelic/go-agent/v3/integrations/nrpgx v1.0.0 diff --git a/go.sum b/go.sum index ca244823..28a61ce3 100644 --- a/go.sum +++ b/go.sum @@ -386,6 +386,8 @@ github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht github.com/jmoiron/sqlx v1.3.4 h1:wv+0IJZfL5z0uZoUjlpKgHkgaFSYD+r9CfrXjEXsO7w= github.com/jmoiron/sqlx v1.3.4/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4= +github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= diff --git a/pkg/code/transaction/nonce_pool.go b/pkg/code/transaction/nonce_pool.go new file mode 100644 index 00000000..6215cea6 --- /dev/null +++ b/pkg/code/transaction/nonce_pool.go @@ -0,0 +1,415 @@ +package transaction + +import ( + "context" + "errors" + "fmt" + "slices" + "sync" + "time" + + "github.com/google/uuid" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + + code_data "github.com/code-payments/code-server/pkg/code/data" + "github.com/code-payments/code-server/pkg/code/data/nonce" +) + +// NoncePoolOption configures a nonce pool. +type NoncePoolOption func(*noncePoolOpts) + +// WithNoncePoolClock configures the clock for the nonce pool. +func WithNoncePoolClock(clock clockwork.Clock) NoncePoolOption { + return func(npo *noncePoolOpts) { + npo.clock = clock + } +} + +// WithNoncePoolSize configures the desired size of the pool. +// +// The pool will use this size to determine how much to load +// when the pool starts up, or when the pool gets low. It can +// be viewed as 'target memory' in a GC, with the actual pool +// size behaving like a saw-tooth graph. +// +// The pool does not have any mechanism to shrink the pool to +// this size beyond the natural consumption of nonces. +func WithNoncePoolSize(size int) NoncePoolOption { + return func(npo *noncePoolOpts) { + npo.desiredPoolSize = size + } +} + +// WithNoncePoolNodeId configures the node id to use when claiming nonces. +func WithNoncePoolNodeId(id string) NoncePoolOption { + return func(npo *noncePoolOpts) { + npo.nodeId = id + } +} + +// WithNoncePoolMinExpiration configures the lower bound for the +// expiration window of claimed nonces. +func WithNoncePoolMinExpiration(d time.Duration) NoncePoolOption { + return func(npo *noncePoolOpts) { + npo.minExpiration = d + } +} + +// WithNoncePoolMaxExpiration configures the upper bound for the +// expiration window of claimed nonces. +func WithNoncePoolMaxExpiration(d time.Duration) NoncePoolOption { + return func(npo *noncePoolOpts) { + npo.maxExpiration = d + } +} + +// WithNoncePoolRefreshInterval specifies how often the pool should be +// scanning it's free list for refresh candidates. Candidates are claimed +// nonces whose expiration is <= 2/3 of the min expiration. +func WithNoncePoolRefreshInterval(interval time.Duration) NoncePoolOption { + return func(npo *noncePoolOpts) { + npo.refreshInterval = interval + } +} + +// WithNoncePoolRefreshPoolInterval configures the pool to refresh the set of +// nonces at this duration. If the pool size is 1/2 the desired size, nonces +// will be fetched from DB. This condition is also checked (asynchronously) every +// time a nonce is pulled from the pool. +func WithNoncePoolRefreshPoolInterval(interval time.Duration) NoncePoolOption { + return func(npo *noncePoolOpts) { + npo.refreshPoolInterval = interval + } +} + +type noncePoolOpts struct { + clock clockwork.Clock + desiredPoolSize int + + nodeId string + minExpiration time.Duration + maxExpiration time.Duration + + refreshInterval time.Duration + refreshPoolInterval time.Duration +} + +func (opts *noncePoolOpts) validate() error { + if opts.clock == nil { + return errors.New("missing clock") + } + if opts.desiredPoolSize < 10 { + return errors.New("pool size must greater than 10") + } + if opts.nodeId == "" { + return errors.New("missing node id") + } + + if opts.minExpiration < 10*time.Second { + return errors.New("min expiry must be >= 10s") + } + if opts.maxExpiration < 10*time.Second { + return errors.New("max expiry must be >= 10s") + } + if opts.minExpiration > opts.maxExpiration { + return errors.New("min expiry must <= max expiry") + } + + if opts.refreshInterval < time.Second || opts.refreshInterval > opts.minExpiration/2 { + return fmt.Errorf("invalid refresh interval %v, must be between (1, minExpiration/2)", opts.refreshInterval) + } + + if opts.refreshPoolInterval < time.Second { + return fmt.Errorf("invalid refresh pool interval %v, must be greater than 1s", opts.refreshPoolInterval) + } + + return nil +} + +// Nonce represents a handle to a nonce that is owned by a pool. +type Nonce struct { + pool *NoncePool + record *nonce.Record +} + +// MarkReservedWithSignature marks the nonce as reserved with a signature +func (n *Nonce) MarkReservedWithSignature(ctx context.Context, sig string) error { + if len(sig) == 0 { + return errors.New("signature is empty") + } + + if n.record.Signature == sig { + return nil + } + + if len(n.record.Signature) != 0 { + return errors.New("nonce already has a different signature") + } + + // Nonce is reserved without a signature, so update its signature + if n.record.State == nonce.StateReserved { + n.record.Signature = sig + return n.pool.data.SaveNonce(ctx, n.record) + } + + if !n.record.IsAvailable() { + return errors.New("nonce must be available to reserve") + } + + n.record.State = nonce.StateReserved + n.record.Signature = sig + n.record.ClaimNodeId = "" + n.record.ClaimExpiresAt = time.UnixMilli(0) + + return n.pool.data.SaveNonce(ctx, n.record) +} + +// UpdateSignature updates the signature for a reserved nonce. The use case here +// being transactions that share a nonce, and the new transaction being designated +// as the one to submit to the blockchain. +func (n *Nonce) UpdateSignature(ctx context.Context, sig string) error { + if len(sig) == 0 { + return errors.New("signature is empty") + } + + if n.record.Signature == sig { + return nil + } + if n.record.State != nonce.StateReserved { + return errors.New("nonce must be in a reserved state") + } + + n.record.Signature = sig + return n.pool.data.SaveNonce(ctx, n.record) +} + +// ReleaseIfNotReserved releases the nonce back to the pool if +// the nonce has not yet been reserved (or more specifically, is +// still owned by the pool). +func (n *Nonce) ReleaseIfNotReserved() { + if n.record.State != nonce.StateClaimed { + return + } + if n.record.ClaimNodeId != n.pool.opts.nodeId { + return + } + if n.record.ClaimExpiresAt.Before(n.pool.opts.clock.Now()) { + return + } + + n.pool.freeListMu.Lock() + n.pool.freeList = append(n.pool.freeList, n) + n.pool.freeListMu.Unlock() +} + +// NoncePool is a pool of nonces that are cached in memory for +// quick access. The NoncePool will continually monitor the pool +// to ensure sufficient size, as well refresh nonce expiration +// times. +// +// If the pool empties before it can be refilled, ErrNoAvailableNonces +// will be returned. Therefore, the pool should be sufficiently large +// such that the consumption of poolSize/2 nonces is _slower_ than the +// operation to top up the pool. +type NoncePool struct { + log *logrus.Entry + data code_data.Provider + poolType nonce.Purpose + opts noncePoolOpts + + ctx context.Context + cancel context.CancelFunc + + freeListMu sync.RWMutex + freeList []*Nonce + + refreshPoolCh chan struct{} +} + +func NewNoncePool( + data code_data.Provider, + poolType nonce.Purpose, + opts ...NoncePoolOption, +) (*NoncePool, error) { + np := &NoncePool{ + log: logrus.StandardLogger().WithFields(logrus.Fields{ + "type": "transaction/NoncePool", + "pool_type": poolType.String(), + }), + data: data, + poolType: poolType, + opts: noncePoolOpts{ + clock: clockwork.NewRealClock(), + desiredPoolSize: 100, + nodeId: uuid.New().String(), + minExpiration: time.Minute, + maxExpiration: 2 * time.Minute, + refreshInterval: 5 * time.Second, + refreshPoolInterval: 10 * time.Second, + }, + refreshPoolCh: make(chan struct{}, 1), + } + + for _, o := range opts { + o(&np.opts) + } + + if err := np.opts.validate(); err != nil { + return nil, err + } + + np.ctx, np.cancel = context.WithCancel(context.Background()) + + go np.refreshPool() + go np.refreshNonces() + + return np, nil +} + +func (np *NoncePool) GetNonce(ctx context.Context) (*Nonce, error) { + var n *Nonce + + np.freeListMu.Lock() + size := len(np.freeList) + if size > 0 { + n = np.freeList[0] + np.freeList = np.freeList[1:] + } + np.freeListMu.Unlock() + + if n == nil { + return nil, ErrNoAvailableNonces + } + + if (size - 1) < np.opts.desiredPoolSize/2 { + select { + case np.refreshPoolCh <- struct{}{}: + default: + } + } + + return n, nil +} + +func (np *NoncePool) Load(ctx context.Context, limit int) error { + now := np.opts.clock.Now() + records, err := np.data.BatchClaimAvailableByPurpose( + ctx, + np.poolType, + limit, + np.opts.nodeId, + now.Add(np.opts.minExpiration), + now.Add(np.opts.maxExpiration), + ) + if err != nil { + return err + } + if len(records) == 0 { + return ErrNoAvailableNonces + } + + np.freeListMu.Lock() + for i := range records { + np.freeList = append(np.freeList, &Nonce{pool: np, record: records[i]}) + } + np.freeListMu.Unlock() + + return nil +} + +func (np *NoncePool) Close() error { + np.freeListMu.Lock() + defer np.freeListMu.Unlock() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + remaining := len(np.freeList) + for _, n := range np.freeList { + n.record.State = nonce.StateAvailable + n.record.ClaimNodeId = "" + n.record.ClaimExpiresAt = time.UnixMilli(0) + + if err := np.data.SaveNonce(ctx, n.record); err != nil { + np.log.WithError(err).WithField("nonce", n.record.Address).Warn("Failed to release nonce on shutdown") + } else { + remaining-- + } + } + + if remaining != 0 { + return fmt.Errorf("failed to free all nonces (%d left unfreed)", remaining) + } + + np.cancel() + + return nil +} + +func (np *NoncePool) refreshPool() { + for { + select { + case <-np.ctx.Done(): + return + case <-np.refreshPoolCh: + case <-np.opts.clock.After(np.opts.refreshPoolInterval): + } + + np.freeListMu.Lock() + size := len(np.freeList) + np.freeListMu.Unlock() + + if size >= np.opts.desiredPoolSize { + continue + } + + err := np.Load(np.ctx, np.opts.desiredPoolSize-size) + if err != nil { + np.log.WithError(err).Warn("Failed to refresh nonce pool") + } + } +} + +func (np *NoncePool) refreshNonces() { + for { + select { + case <-np.ctx.Done(): + return + case <-np.opts.clock.After(np.opts.refreshInterval): + } + + now := np.opts.clock.Now() + refreshList := make([]*Nonce, 0) + + np.freeListMu.Lock() + for i := 0; i < len(np.freeList); { + n := np.freeList[i] + if now.Sub(n.record.ClaimExpiresAt) > 2*np.opts.minExpiration/3 { + i++ + continue + } + + refreshList = append(refreshList, n) + np.freeList = slices.Delete(np.freeList, i, i+1) + } + np.freeListMu.Unlock() + + if len(refreshList) == 0 { + continue + } + + for _, n := range refreshList { + n.record.ClaimExpiresAt = n.record.ClaimExpiresAt.Add(np.opts.minExpiration) + err := np.data.SaveNonce(np.ctx, n.record) + if err != nil { + np.log.WithError(err).WithField("nonce", n.record.Address). + Warn("Failed to refresh nonce, abandoning") + } else { + np.freeListMu.Lock() + np.freeList = append(np.freeList, n) + np.freeListMu.Unlock() + } + } + } +} diff --git a/pkg/code/transaction/nonce_pool_test.go b/pkg/code/transaction/nonce_pool_test.go new file mode 100644 index 00000000..8fffdd5e --- /dev/null +++ b/pkg/code/transaction/nonce_pool_test.go @@ -0,0 +1,193 @@ +package transaction + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + code_data "github.com/code-payments/code-server/pkg/code/data" + "github.com/code-payments/code-server/pkg/code/data/nonce" +) + +func TestNoncePool(t *testing.T) { + for _, tc := range []struct { + name string + tf func(*noncePoolTest) + }{ + {"Happy", testNoncePoolHappy}, + {"RefreshSignal", testNonceRefreshSignal}, + {"Refresh", testNoncePoolRefresh}, + } { + t.Run( + tc.name, + func(t *testing.T) { + nt := newNoncePoolTest(t) + defer nt.pool.Close() + tc.tf(nt) + }, + ) + } +} + +func testNoncePoolHappy(nt *noncePoolTest) { + nt.initializeNonces(100, nonce.PurposeClientTransaction) + nt.initializeNonces(100, nonce.PurposeOnDemandTransaction) + + ctx := context.Background() + + // Since our time hasn't advanced, none of the refresh periods + // should have kicked in, and therefore our pool is empty. + n, err := nt.pool.GetNonce(ctx) + require.ErrorIs(nt.t, err, ErrNoAvailableNonces) + require.Nil(nt.t, n) + + nt.clock.Advance(2 * time.Second) + time.Sleep(100 * time.Millisecond) + + observed := map[string]*Nonce{} + for i := 0; i < 100; i++ { + n, err = nt.pool.GetNonce(ctx) + require.NoError(nt.t, err) + require.NotNil(nt.t, n) + require.NotContains(nt.t, observed, n.record.Address) + observed[n.record.Address] = n + + actual, err := nt.data.GetNonce(ctx, n.record.Address) + require.NoError(nt.t, err) + require.Equal(nt.t, actual.State, nonce.StateClaimed) + require.Equal(nt.t, actual.ClaimNodeId, nt.pool.opts.nodeId) + } + + // Underlying DB pool is empty. + n, err = nt.pool.GetNonce(ctx) + require.ErrorIs(nt.t, err, ErrNoAvailableNonces) + require.Nil(nt.t, n) + + // Releasing back to the pool should allow us to + // re-use the nonces. + for _, v := range observed { + v.ReleaseIfNotReserved() + } + clear(observed) + + for i := 0; i < 100; i++ { + n, err = nt.pool.GetNonce(ctx) + require.NoError(nt.t, err) + require.NotNil(nt.t, n) + require.NotContains(nt.t, observed, n.record.Address) + observed[n.record.Address] = n + } +} + +func testNonceRefreshSignal(nt *noncePoolTest) { + nt.initializeNonces(100, nonce.PurposeClientTransaction) + nt.initializeNonces(100, nonce.PurposeOnDemandTransaction) + + ctx := context.Background() + + require.NoError(nt.t, nt.pool.Load(ctx, 10)) + + for i := 0; i < 10; i++ { + _, err := nt.pool.GetNonce(ctx) + require.NoError(nt.t, err) + } + + // Note: We're sleeping in real time to let the background + // process run, but we haven't advanced the clock, so the + // only way a refresh could occur is via the signal. + time.Sleep(100 * time.Millisecond) + + for i := 0; i < 10; i++ { + _, err := nt.pool.GetNonce(ctx) + require.NoError(nt.t, err) + } +} + +func testNoncePoolRefresh(nt *noncePoolTest) { + nt.initializeNonces(100, nonce.PurposeClientTransaction) + nt.initializeNonces(100, nonce.PurposeOnDemandTransaction) + + ctx := context.Background() + require.NoError(nt.t, nt.pool.Load(ctx, 10)) + + nonceExpirations := map[string]time.Time{} + + nt.pool.freeListMu.Lock() + for _, n := range nt.pool.freeList { + nonceExpirations[n.record.Address] = n.record.ClaimExpiresAt + } + nt.pool.freeListMu.Unlock() + + nt.clock.Advance(3 * time.Minute) + time.Sleep(100 * time.Millisecond) + + refreshedNonceExpirations := map[string]time.Time{} + nt.pool.freeListMu.Lock() + for _, n := range nt.pool.freeList { + refreshedNonceExpirations[n.record.Address] = n.record.ClaimExpiresAt + } + nt.pool.freeListMu.Unlock() + + refreshed := 0 + for nonce, expiration := range refreshedNonceExpirations { + if nonceExpirations[nonce].Before(expiration) { + refreshed++ + } + } + + require.Greater(nt.t, refreshed, 75) + + for n, expiration := range refreshedNonceExpirations { + actual, err := nt.data.GetNonce(ctx, n) + require.NoError(nt.t, err) + + require.Equal(nt.t, actual.State, nonce.StateClaimed) + require.Equal(nt.t, actual.ClaimNodeId, nt.pool.opts.nodeId) + require.True(nt.t, actual.ClaimExpiresAt.Equal(expiration)) + } +} + +type noncePoolTest struct { + t *testing.T + clock clockwork.FakeClock + pool *NoncePool + data code_data.DatabaseData +} + +func newNoncePoolTest(t *testing.T) *noncePoolTest { + clock := clockwork.NewFakeClock() + data := code_data.NewTestDataProvider() + + pool, err := NewNoncePool( + data, + nonce.PurposeClientTransaction, + WithNoncePoolClock(clock), + WithNoncePoolRefreshInterval(time.Second), + WithNoncePoolRefreshPoolInterval(2*time.Second), + ) + require.NoError(t, err) + + return &noncePoolTest{ + t: t, + clock: clock, + data: data, + pool: pool, + } +} + +func (np *noncePoolTest) initializeNonces(amount int, purpose nonce.Purpose) { + for i := 0; i < amount; i++ { + err := np.data.SaveNonce(context.Background(), &nonce.Record{ + Id: uint64(i) + 1, + Address: fmt.Sprintf("addr-%s-%d", purpose.String(), i), + Authority: "my authority!", + Purpose: purpose, + State: nonce.StateAvailable, + }) + require.NoError(np.t, err) + } +}