diff --git a/core/state/database.go b/core/state/database.go index 0d8acec35aaa..3f92aebdf484 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -193,7 +193,7 @@ func (db *CachingDB) Reader(stateRoot common.Hash) (Reader, error) { } // Set up the trie reader, which is expected to always be available // as the gatekeeper unless the state is corrupted. - tr, err := newTrieReader(stateRoot, db.triedb, db.pointCache) + tr, err := newTrieReader(stateRoot, db.triedb, db, db.pointCache) if err != nil { return nil, err } diff --git a/core/state/reader.go b/core/state/reader.go index 85842adde85f..5905df5206cf 100644 --- a/core/state/reader.go +++ b/core/state/reader.go @@ -48,6 +48,24 @@ type Reader interface { // - The returned storage slot is safe to modify after the call Storage(addr common.Address, slot common.Hash) (common.Hash, error) + // ContractCode returns the code associated with the given code hash. + // + // - It returns an error to indicate code doesn't exist or is empty + // - The returned code is safe to modify after the call + ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error) + + // ContractCodeSize returns the size of the code associated with the given code hash. + // + // - It returns an error to indicate code doesn't exist + ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) + + // SupportsCodeQuery returns true if the reader supports querying contract code. Right now + // the trie reader supports querying contract code but the state reader doesn't. This is + // technically an arbitrary distinction as code is not stored either in the snapshot or trie + // per se. Important point is that one of them implements the code query to avoid double-lookups + // for non-existent code. + SupportsCodeQuery() bool + // Copy returns a deep-copied state reader. Copy() Reader } @@ -123,6 +141,22 @@ func (r *stateReader) Storage(addr common.Address, key common.Hash) (common.Hash return value, nil } +// ContractCode implements Reader, retrieving the code associated with a particular account. +func (r *stateReader) ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error) { + return nil, errors.New("not supported") +} + +// ContractCodeSize implements Reader, returning the size of the code associated with a particular account. +func (r *stateReader) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) { + return 0, errors.New("not supported") +} + +// SupportsCodeQuery implements Reader, returning false as the state reader +// doesn't support querying contract code. +func (r *stateReader) SupportsCodeQuery() bool { + return false +} + // Copy implements Reader, returning a deep-copied snap reader. func (r *stateReader) Copy() Reader { return &stateReader{ @@ -134,17 +168,18 @@ func (r *stateReader) Copy() Reader { // trieReader implements the Reader interface, providing functions to access // state from the referenced trie. type trieReader struct { - root common.Hash // State root which uniquely represent a state - db *triedb.Database // Database for loading trie - buff crypto.KeccakState // Buffer for keccak256 hashing - mainTrie Trie // Main trie, resolved in constructor - subRoots map[common.Address]common.Hash // Set of storage roots, cached when the account is resolved - subTries map[common.Address]Trie // Group of storage tries, cached when it's resolved + root common.Hash // State root which uniquely represent a state + db *triedb.Database // Database for loading trie + contractDB Database // Database for loading code + buff crypto.KeccakState // Buffer for keccak256 hashing + mainTrie Trie // Main trie, resolved in constructor + subRoots map[common.Address]common.Hash // Set of storage roots, cached when the account is resolved + subTries map[common.Address]Trie // Group of storage tries, cached when it's resolved } // trieReader constructs a trie reader of the specific state. An error will be // returned if the associated trie specified by root is not existent. -func newTrieReader(root common.Hash, db *triedb.Database, cache *utils.PointCache) (*trieReader, error) { +func newTrieReader(root common.Hash, db *triedb.Database, contractDB Database, cache *utils.PointCache) (*trieReader, error) { var ( tr Trie err error @@ -158,12 +193,13 @@ func newTrieReader(root common.Hash, db *triedb.Database, cache *utils.PointCach return nil, err } return &trieReader{ - root: root, - db: db, - buff: crypto.NewKeccakState(), - mainTrie: tr, - subRoots: make(map[common.Address]common.Hash), - subTries: make(map[common.Address]Trie), + root: root, + db: db, + contractDB: contractDB, + buff: crypto.NewKeccakState(), + mainTrie: tr, + subRoots: make(map[common.Address]common.Hash), + subTries: make(map[common.Address]Trie), }, nil } @@ -227,6 +263,22 @@ func (r *trieReader) Storage(addr common.Address, key common.Hash) (common.Hash, return value, nil } +// ContractCode implements Reader, retrieving the code associated with a particular account. +func (r *trieReader) ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error) { + return r.contractDB.ContractCode(addr, codeHash) +} + +// ContractCodeSize implements Reader, returning the size of the code associated with a particular account. +func (r *trieReader) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) { + return r.contractDB.ContractCodeSize(addr, codeHash) +} + +// SupportsCodeQuery implements Reader, returning true as the trie reader +// supports querying contract code. +func (r *trieReader) SupportsCodeQuery() bool { + return true +} + // Copy implements Reader, returning a deep-copied trie reader. func (r *trieReader) Copy() Reader { tries := make(map[common.Address]Trie) @@ -234,12 +286,13 @@ func (r *trieReader) Copy() Reader { tries[addr] = mustCopyTrie(tr) } return &trieReader{ - root: r.root, - db: r.db, - buff: crypto.NewKeccakState(), - mainTrie: mustCopyTrie(r.mainTrie), - subRoots: maps.Clone(r.subRoots), - subTries: tries, + root: r.root, + db: r.db, + contractDB: r.contractDB, + buff: crypto.NewKeccakState(), + mainTrie: mustCopyTrie(r.mainTrie), + subRoots: maps.Clone(r.subRoots), + subTries: tries, } } @@ -298,6 +351,53 @@ func (r *multiReader) Storage(addr common.Address, slot common.Hash) (common.Has return common.Hash{}, errors.Join(errs...) } +// ContractCode implements Reader, retrieving the code associated with a particular account. +func (r *multiReader) ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error) { + var errs []error + for _, reader := range r.readers { + if reader.SupportsCodeQuery() { + code, err := reader.ContractCode(addr, codeHash) + if err == nil { + return code, nil + } + errs = append(errs, err) + } + } + if len(errs) == 0 { + return nil, errors.New("not found") + } + return nil, errors.Join(errs...) +} + +// ContractCodeSize implements Reader, returning the size of the code associated with a particular account. +func (r *multiReader) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) { + var errs []error + for _, reader := range r.readers { + if reader.SupportsCodeQuery() { + size, err := reader.ContractCodeSize(addr, codeHash) + if err == nil { + return size, nil + } + errs = append(errs, err) + } + } + if len(errs) == 0 { + return 0, errors.New("not found") + } + return 0, errors.Join(errs...) +} + +// SupportsCodeQuery implements Reader, returning true if the reader supports querying contract code. +func (r *multiReader) SupportsCodeQuery() bool { + // Return true if one of readers supports querying contract code. + for _, reader := range r.readers { + if reader.SupportsCodeQuery() { + return true + } + } + return false +} + // Copy implementing Reader interface, returning a deep-copied state reader. func (r *multiReader) Copy() Reader { var readers []Reader diff --git a/core/state/state_object.go b/core/state/state_object.go index b659bf7ff208..172d6c5f5ea3 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -510,7 +510,7 @@ func (s *stateObject) Code() []byte { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) { return nil } - code, err := s.db.db.ContractCode(s.address, common.BytesToHash(s.CodeHash())) + code, err := s.db.reader.ContractCode(s.address, common.BytesToHash(s.CodeHash())) if err != nil { s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) } @@ -528,7 +528,7 @@ func (s *stateObject) CodeSize() int { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) { return 0 } - size, err := s.db.db.ContractCodeSize(s.address, common.BytesToHash(s.CodeHash())) + size, err := s.db.reader.ContractCodeSize(s.address, common.BytesToHash(s.CodeHash())) if err != nil { s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err)) }