diff --git a/client.go b/client.go index e947dab..dda7a9b 100644 --- a/client.go +++ b/client.go @@ -25,6 +25,7 @@ type Client interface { PendingNonce(context.Context) (uint64, error) NewTransactor() (*bind.TransactOpts, error) Execute(context.Context, func(context.Context, *TransactOpts) (*types.Transaction, error)) (Transaction, error) + ExecuteAndWait(context.Context, func(context.Context) (Transaction, error)) error NativeTokenBalance(context.Context) (decimal.Decimal, error) CurrentBlockNumber(context.Context) (uint64, error) CurrentBlock(context.Context) (*types.Block, error) @@ -130,7 +131,15 @@ func (c *client) Execute(ctx context.Context, action func(context.Context, *Tran return nil, errors.Wrap(err, "executing waitable action") } log.Debug().Msgf("//TX// tx started / hash: %s / gasprice: %s / gaslimit: %d", tx.Hash().Hex(), decimal.NewFromBigInt(tx.GasPrice(), -9), tx.Gas()) - return NewTransaction(localctx, c, c.pendingTransactionCheckPeriod, tx), nil + return NewTransaction(c, c.pendingTransactionCheckPeriod, tx), nil +} + +func (c *client) ExecuteAndWait(ctx context.Context, action func(context.Context) (Transaction, error)) error { + tx, err := action(ctx) + if err != nil { + return errors.Wrap(err, "executing waitable action") + } + return tx.Wait(ctx) } func (c *client) GetTransactionsForAddressInBlock(ctx context.Context, a string, b int64) { diff --git a/transaction.go b/transaction.go index 8d2abe6..97bc8b6 100644 --- a/transaction.go +++ b/transaction.go @@ -11,36 +11,34 @@ import ( ) type Transaction interface { - Wait() error + Wait(ctx context.Context) error } type transaction struct { *types.Transaction - ctx context.Context client Client pendingCheckPeriod time.Duration } -func NewTransaction(ctx context.Context, client Client, pendingCheckPeriod time.Duration, tx *types.Transaction) Transaction { +func NewTransaction(client Client, pendingCheckPeriod time.Duration, tx *types.Transaction) Transaction { return &transaction{ Transaction: tx, - ctx: ctx, client: client, pendingCheckPeriod: pendingCheckPeriod, } } -func (t *transaction) Wait() error { +func (t *transaction) Wait(ctx context.Context) error { hash := t.Hash() notfoundmax := 10 notfoundcount := 0 ticker := time.NewTicker(t.pendingCheckPeriod) for { select { - case <-t.ctx.Done(): + case <-ctx.Done(): return errors.New("context canceled") case <-ticker.C: - _, pending, err := t.client.TransactionByHash(t.ctx, hash) + _, pending, err := t.client.TransactionByHash(ctx, hash) if err != nil { if err == ethereum.NotFound && notfoundcount < notfoundmax { notfoundcount++