Skip to content

Commit 7fc9b01

Browse files
apollo_gateway: make StatefulTransactionValidatorTrait async
1 parent d50b76a commit 7fc9b01

File tree

4 files changed

+176
-270
lines changed

4 files changed

+176
-270
lines changed

crates/apollo_gateway/src/gateway.rs

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use starknet_api::rpc_transaction::{
3232
RpcDeclareTransaction,
3333
RpcTransaction,
3434
};
35-
use tracing::{debug, warn, Span};
35+
use tracing::{debug, warn};
3636

3737
use crate::errors::{
3838
mempool_client_result_to_deprecated_gw_result,
@@ -157,28 +157,10 @@ impl Gateway {
157157
.await
158158
.inspect_err(|e| metric_counters.record_add_tx_failure(e))?;
159159

160-
let curr_span = Span::current();
161-
let mempool_client = self.mempool_client.clone();
162-
let nonce = tokio::task::spawn_blocking(move || {
163-
curr_span.in_scope(|| {
164-
stateful_transaction_validator.extract_state_nonce_and_run_validations(
165-
&executable_tx.clone(),
166-
mempool_client,
167-
tokio::runtime::Handle::current(),
168-
)
169-
})
170-
})
171-
.await
172-
.map_err(|e| {
173-
let err = StarknetError {
174-
code: StarknetErrorCode::UnknownErrorCode(
175-
"StarknetErrorCode.InternalError".to_string(),
176-
),
177-
message: format!("Validation task failed to complete: {e}"),
178-
};
179-
metric_counters.record_add_tx_failure(&err);
180-
err
181-
})??;
160+
let nonce = stateful_transaction_validator
161+
.extract_state_nonce_and_run_validations(&executable_tx, self.mempool_client.clone())
162+
.await
163+
.inspect_err(|e| metric_counters.record_add_tx_failure(e))?;
182164

183165
let gateway_output = create_gateway_output(&internal_tx);
184166

crates/apollo_gateway/src/gateway_test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ async fn add_tx_returns_error_when_run_transaction_validations_fails(
557557

558558
mock_stateful_transaction_validator
559559
.expect_extract_state_nonce_and_run_validations()
560-
.return_once(|_, _, _| Err(expected_error));
560+
.return_once(|_, _| Err(expected_error));
561561

562562
mock_stateful_transaction_validator_factory
563563
.expect_instantiate_validator()

crates/apollo_gateway/src/stateful_transaction_validator.rs

Lines changed: 116 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@ use apollo_mempool_types::communication::SharedMempoolClient;
1111
use apollo_mempool_types::mempool_types::ValidationArgs;
1212
use apollo_proc_macros::sequencer_latency_histogram;
1313
use async_trait::async_trait;
14-
use blockifier::blockifier::stateful_validator::{
15-
StatefulValidator,
16-
StatefulValidatorTrait as BlockifierStatefulValidatorTrait,
17-
};
14+
use blockifier::blockifier::stateful_validator::{StatefulValidator, StatefulValidatorTrait};
1815
use blockifier::blockifier_versioned_constants::VersionedConstants;
1916
use blockifier::bouncer::BouncerConfig;
2017
use blockifier::context::{BlockContext, ChainInfo};
@@ -87,115 +84,96 @@ impl StatefulTransactionValidatorFactoryTrait for StatefulTransactionValidatorFa
8784
self.contract_class_manager.clone(),
8885
Some(GATEWAY_CLASS_CACHE_METRICS),
8986
);
90-
91-
let state = CachedState::new(state_reader_and_contract_manager);
92-
let mut versioned_constants = VersionedConstants::get_versioned_constants(
93-
self.config.versioned_constants_overrides.clone(),
94-
);
95-
// The validation of a transaction is not affected by the casm hash migration.
96-
versioned_constants.enable_casm_hash_migration = false;
97-
98-
let mut block_info = gateway_fixed_block_state_reader.get_block_info().await?;
99-
block_info.block_number = block_info.block_number.unchecked_next();
100-
let block_context = BlockContext::new(
101-
block_info,
87+
Ok(Box::new(StatefulTransactionValidator::new(
88+
self.config.clone(),
10289
self.chain_info.clone(),
103-
versioned_constants,
104-
BouncerConfig::max(),
105-
);
106-
107-
let blockifier_stateful_tx_validator =
108-
BlockifierStatefulValidator::create(state, block_context);
109-
Ok(Box::new(StatefulTransactionValidator {
110-
config: self.config.clone(),
111-
blockifier_stateful_tx_validator,
90+
state_reader_and_contract_manager,
11291
gateway_fixed_block_state_reader,
113-
}))
92+
)))
11493
}
11594
}
11695

11796
#[cfg_attr(test, mockall::automock)]
97+
#[async_trait]
11898
pub trait StatefulTransactionValidatorTrait: Send {
119-
fn extract_state_nonce_and_run_validations(
99+
async fn extract_state_nonce_and_run_validations(
120100
&mut self,
121101
executable_tx: &ExecutableTransaction,
122102
mempool_client: SharedMempoolClient,
123-
runtime: tokio::runtime::Handle,
124103
) -> StatefulTransactionValidatorResult<Nonce>;
125104
}
126105

127-
pub struct StatefulTransactionValidator<B: BlockifierStatefulValidatorTrait> {
106+
pub struct StatefulTransactionValidator {
128107
config: StatefulTransactionValidatorConfig,
129-
blockifier_stateful_tx_validator: B,
108+
chain_info: ChainInfo,
109+
// Consumed when running the CPU-heavy blockifier validation.
110+
// TODO(Itamar): The whole `StatefulTransactionValidator` is never used after
111+
// `state_reader_and_contract_manager` is taken. Make it non-optional and discard the
112+
// instance after use.
113+
state_reader_and_contract_manager:
114+
Option<StateReaderAndContractManager<Box<dyn GatewayStateReaderWithCompiledClasses>>>,
130115
gateway_fixed_block_state_reader: Box<dyn GatewayFixedBlockStateReader>,
131116
}
132117

133-
impl<B: BlockifierStatefulValidatorTrait + Send> StatefulTransactionValidatorTrait
134-
for StatefulTransactionValidator<B>
135-
{
136-
fn extract_state_nonce_and_run_validations(
118+
#[async_trait]
119+
impl StatefulTransactionValidatorTrait for StatefulTransactionValidator {
120+
async fn extract_state_nonce_and_run_validations(
137121
&mut self,
138122
executable_tx: &ExecutableTransaction,
139123
mempool_client: SharedMempoolClient,
140-
runtime: tokio::runtime::Handle,
141124
) -> StatefulTransactionValidatorResult<Nonce> {
142-
let address = executable_tx.contract_address();
143-
let account_nonce = runtime
144-
.block_on(self.gateway_fixed_block_state_reader.get_nonce(address))
145-
.map_err(|e| {
146-
// TODO(noamsp): Fix this. Need to map the errors better.
147-
StarknetError::internal_with_signature_logging(
148-
format!("Failed to get nonce for sender address {address}"),
149-
&executable_tx.signature(),
150-
e,
151-
)
152-
})?;
153-
self.run_transaction_validations(executable_tx, account_nonce, mempool_client, runtime)?;
125+
let account_nonce = self.extract_nonce(executable_tx).await?;
126+
let skip_validate =
127+
self.run_pre_validation_checks(executable_tx, account_nonce, mempool_client).await?;
128+
self.run_validate_entry_point(executable_tx, skip_validate).await?;
154129
Ok(account_nonce)
155130
}
156131
}
157132

158-
impl<B: BlockifierStatefulValidatorTrait> StatefulTransactionValidator<B> {
159-
fn run_transaction_validations(
133+
impl StatefulTransactionValidator {
134+
fn new(
135+
config: StatefulTransactionValidatorConfig,
136+
chain_info: ChainInfo,
137+
state_reader_and_contract_manager: StateReaderAndContractManager<
138+
Box<dyn GatewayStateReaderWithCompiledClasses>,
139+
>,
140+
gateway_fixed_block_state_reader: Box<dyn GatewayFixedBlockStateReader>,
141+
) -> Self {
142+
Self {
143+
config,
144+
chain_info,
145+
state_reader_and_contract_manager: Some(state_reader_and_contract_manager),
146+
gateway_fixed_block_state_reader,
147+
}
148+
}
149+
150+
fn take_state_reader_and_contract_manager(
160151
&mut self,
161-
executable_tx: &ExecutableTransaction,
162-
account_nonce: Nonce,
163-
mempool_client: SharedMempoolClient,
164-
runtime: tokio::runtime::Handle,
165-
) -> StatefulTransactionValidatorResult<()> {
166-
self.validate_state_preconditions(executable_tx, account_nonce)?;
167-
runtime.block_on(validate_by_mempool(
168-
executable_tx,
169-
account_nonce,
170-
mempool_client.clone(),
171-
))?;
172-
self.run_validate_entry_point(executable_tx, account_nonce, mempool_client, runtime)?;
173-
Ok(())
152+
) -> StateReaderAndContractManager<Box<dyn GatewayStateReaderWithCompiledClasses>> {
153+
self.state_reader_and_contract_manager.take().expect("Validator was already consumed")
174154
}
175155

176-
fn validate_state_preconditions(
156+
async fn validate_state_preconditions(
177157
&self,
178158
executable_tx: &ExecutableTransaction,
179159
account_nonce: Nonce,
180160
) -> StatefulTransactionValidatorResult<()> {
181-
self.validate_resource_bounds(executable_tx)?;
161+
self.validate_resource_bounds(executable_tx).await?;
182162
self.validate_nonce(executable_tx, account_nonce)?;
183163
Ok(())
184164
}
185165

186-
fn validate_resource_bounds(
166+
async fn validate_resource_bounds(
187167
&self,
188168
executable_tx: &ExecutableTransaction,
189169
) -> StatefulTransactionValidatorResult<()> {
190170
// Skip this validation during the systems bootstrap phase.
191171
if self.config.validate_resource_bounds {
192172
// TODO(Arni): getnext_l2_gas_price from the block header.
193-
// TODO(Itamar): Replace usage of `blockifier_stateful_tx_validator.block_info()` with
194-
// the GW fixed-block provider and then remove `block_info()` from
195-
// blockifier::{StatefulValidatorTrait, StatefulValidator}.
196173
let previous_block_l2_gas_price = self
197-
.blockifier_stateful_tx_validator
198-
.block_info()
174+
.gateway_fixed_block_state_reader
175+
.get_block_info()
176+
.await?
199177
.gas_prices
200178
.strk_gas_prices
201179
.l2_gas_price;
@@ -265,23 +243,51 @@ impl<B: BlockifierStatefulValidatorTrait> StatefulTransactionValidator<B> {
265243
}
266244

267245
#[sequencer_latency_histogram(GATEWAY_VALIDATE_TX_LATENCY, true)]
268-
fn run_validate_entry_point(
246+
async fn run_validate_entry_point(
269247
&mut self,
270248
executable_tx: &ExecutableTransaction,
271-
account_nonce: Nonce,
272-
mempool_client: SharedMempoolClient,
273-
runtime: tokio::runtime::Handle,
249+
skip_validate: bool,
274250
) -> StatefulTransactionValidatorResult<()> {
275-
let skip_validate =
276-
skip_stateful_validations(executable_tx, account_nonce, mempool_client, runtime)?;
277251
let only_query = false;
278252
let charge_fee = enforce_fee(executable_tx, only_query);
279253
let strict_nonce_check = false;
280254
let execution_flags =
281255
ExecutionFlags { only_query, charge_fee, validate: !skip_validate, strict_nonce_check };
282256

283257
let account_tx = AccountTransaction { tx: executable_tx.clone(), execution_flags };
284-
self.blockifier_stateful_tx_validator.validate(account_tx).map_err(|e| StarknetError {
258+
259+
// Build block context here.
260+
let mut versioned_constants = VersionedConstants::get_versioned_constants(
261+
self.config.versioned_constants_overrides.clone(),
262+
);
263+
// The validation of a transaction is not affected by the casm hash migration.
264+
versioned_constants.enable_casm_hash_migration = false;
265+
266+
let mut block_info = self.gateway_fixed_block_state_reader.get_block_info().await?;
267+
block_info.block_number = block_info.block_number.unchecked_next();
268+
let block_context = BlockContext::new(
269+
block_info,
270+
self.chain_info.clone(),
271+
versioned_constants,
272+
BouncerConfig::max(),
273+
);
274+
275+
// Move state into the blocking task and run CPU-heavy validation.
276+
let state_reader_and_contract_manager = self.take_state_reader_and_contract_manager();
277+
278+
tokio::task::spawn_blocking(move || {
279+
let state = CachedState::new(state_reader_and_contract_manager);
280+
let mut blockifier = BlockifierStatefulValidator::create(state, block_context);
281+
blockifier.validate(account_tx)
282+
})
283+
.await
284+
.map_err(|e| StarknetError {
285+
code: StarknetErrorCode::UnknownErrorCode(
286+
"StarknetErrorCode.InternalError".to_string(),
287+
),
288+
message: format!("Blocking task join error: {e}"),
289+
})?
290+
.map_err(|e| StarknetError {
285291
code: StarknetErrorCode::KnownErrorCode(KnownStarknetErrorCode::ValidateFailure),
286292
message: e.to_string(),
287293
})?;
@@ -321,6 +327,35 @@ impl<B: BlockifierStatefulValidatorTrait> StatefulTransactionValidator<B> {
321327
}
322328
Ok(())
323329
}
330+
331+
async fn extract_nonce(
332+
&self,
333+
executable_tx: &ExecutableTransaction,
334+
) -> StatefulTransactionValidatorResult<Nonce> {
335+
let address = executable_tx.contract_address();
336+
let account_nonce =
337+
self.gateway_fixed_block_state_reader.get_nonce(address).await.map_err(|e| {
338+
StarknetError::internal_with_signature_logging(
339+
format!("Failed to get nonce for sender address {address}"),
340+
&executable_tx.signature(),
341+
e,
342+
)
343+
})?;
344+
Ok(account_nonce)
345+
}
346+
347+
async fn run_pre_validation_checks(
348+
&self,
349+
executable_tx: &ExecutableTransaction,
350+
account_nonce: Nonce,
351+
mempool_client: SharedMempoolClient,
352+
) -> StatefulTransactionValidatorResult<bool> {
353+
self.validate_state_preconditions(executable_tx, account_nonce).await?;
354+
validate_by_mempool(executable_tx, account_nonce, mempool_client.clone()).await?;
355+
let skip_validate =
356+
skip_stateful_validations(executable_tx, account_nonce, mempool_client.clone()).await?;
357+
Ok(skip_validate)
358+
}
324359
}
325360

326361
/// Perform transaction validation by the mempool.
@@ -339,11 +374,10 @@ async fn validate_by_mempool(
339374
/// Check if validation of an invoke transaction should be skipped due to deploy_account not being
340375
/// processed yet. This feature is used to improve UX for users sending deploy_account + invoke at
341376
/// once.
342-
fn skip_stateful_validations(
377+
async fn skip_stateful_validations(
343378
tx: &ExecutableTransaction,
344379
account_nonce: Nonce,
345380
mempool_client: SharedMempoolClient,
346-
runtime: tokio::runtime::Handle,
347381
) -> StatefulTransactionValidatorResult<bool> {
348382
if let ExecutableTransaction::Invoke(ExecutableInvokeTransaction { tx, .. }) = tx {
349383
// check if the transaction nonce is 1, meaning it is post deploy_account, and the
@@ -355,8 +389,9 @@ fn skip_stateful_validations(
355389
// to check if the account exists in the mempool since it means that either it has a
356390
// deploy_account transaction or transactions with future nonces that passed
357391
// validations.
358-
return runtime
359-
.block_on(mempool_client.account_tx_in_pool_or_recent_block(tx.sender_address()))
392+
return mempool_client
393+
.account_tx_in_pool_or_recent_block(tx.sender_address())
394+
.await
360395
.map_err(|err| mempool_client_err_to_deprecated_gw_err(&tx.signature(), err))
361396
.inspect(|exists| {
362397
if *exists {

0 commit comments

Comments
 (0)