Skip to content

Commit

Permalink
refactor: use sender in tap context (#507)
Browse files Browse the repository at this point in the history
  • Loading branch information
gusinacio authored Nov 26, 2024
1 parent 40e7612 commit a50e23d
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 78 deletions.
2 changes: 1 addition & 1 deletion crates/service/src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ pub use attestation_signer::{signer_middleware, AttestationState};
pub use deployment::deployment_middleware;
pub use labels::labels_middleware;
pub use prometheus_metrics::PrometheusMetricsMiddlewareLayer;
pub use sender::{sender_middleware, SenderState};
pub use sender::{sender_middleware, Sender, SenderState};
pub use tap_context::context_middleware;
pub use tap_receipt::receipt_middleware;
7 changes: 7 additions & 0 deletions crates/service/src/middleware/tap_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use thegraph_core::DeploymentId;

use crate::{error::IndexerServiceError, tap::AgoraQuery};

use super::sender::Sender;

/// Graphql query body to be decoded and passed to agora context
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct QueryBody {
Expand All @@ -39,6 +41,7 @@ pub async fn context_middleware(
Err(_) => return Err(IndexerServiceError::DeploymentIdNotFound),
},
};
let sender = request.extensions().get::<Sender>().cloned();

let (mut parts, body) = request.into_parts();
let bytes = to_bytes(body, usize::MAX).await?;
Expand All @@ -56,6 +59,10 @@ pub async fn context_middleware(
query: query_body.query.clone(),
variables,
});

if let Some(sender) = sender {
ctx.insert(sender);
}
parts.extensions.insert(Arc::new(ctx));
let request = Request::from_parts(parts, bytes.into());
Ok(next.run(request).await)
Expand Down
1 change: 0 additions & 1 deletion crates/service/src/service/indexer_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ pub async fn run(options: IndexerServiceOptions) -> Result<(), anyhow::Error> {
database,
allocations.clone(),
escrow_accounts.clone(),
domain_separator.clone(),
timestamp_error_tolerance,
receipt_max_value,
)
Expand Down
8 changes: 2 additions & 6 deletions crates/service/src/tap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,14 @@ impl IndexerTapContext {
pgpool: PgPool,
indexer_allocations: Receiver<HashMap<Address, Allocation>>,
escrow_accounts: Receiver<EscrowAccounts>,
domain_separator: Eip712Domain,
timestamp_error_tolerance: Duration,
receipt_max_value: u128,
) -> Vec<ReceiptCheck> {
vec![
Arc::new(AllocationEligible::new(indexer_allocations)),
Arc::new(SenderBalanceCheck::new(
escrow_accounts.clone(),
domain_separator.clone(),
)),
Arc::new(SenderBalanceCheck::new(escrow_accounts)),
Arc::new(TimestampCheck::new(timestamp_error_tolerance)),
Arc::new(DenyListCheck::new(pgpool.clone(), escrow_accounts, domain_separator).await),
Arc::new(DenyListCheck::new(pgpool.clone()).await),
Arc::new(ReceiptMaxValueCheck::new(receipt_max_value)),
Arc::new(MinimumValue::new(pgpool, Duration::from_secs(GRACE_PERIOD)).await),
]
Expand Down
63 changes: 19 additions & 44 deletions crates/service/src/tap/checks/deny_list_check.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
// SPDX-License-Identifier: Apache-2.0

use alloy::dyn_abi::Eip712Domain;
use crate::middleware::Sender;
use alloy::primitives::Address;
use indexer_monitor::EscrowAccounts;
use sqlx::postgres::PgListener;
use sqlx::PgPool;
use std::collections::HashSet;
Expand All @@ -15,23 +14,16 @@ use tap_core::receipt::{
state::Checking,
ReceiptWithState,
};
use tokio::sync::watch::Receiver;
use tracing::error;

pub struct DenyListCheck {
escrow_accounts: Receiver<EscrowAccounts>,
domain_separator: Eip712Domain,
sender_denylist: Arc<RwLock<HashSet<Address>>>,
_sender_denylist_watcher_handle: Arc<tokio::task::JoinHandle<()>>,
sender_denylist_watcher_cancel_token: tokio_util::sync::CancellationToken,
}

impl DenyListCheck {
pub async fn new(
pgpool: PgPool,
escrow_accounts: Receiver<EscrowAccounts>,
domain_separator: Eip712Domain,
) -> Self {
pub async fn new(pgpool: PgPool) -> Self {
// Listen to pg_notify events. We start it before updating the sender_denylist so that we
// don't miss any updates. PG will buffer the notifications until we start consuming them.
let mut pglistener = PgListener::connect_with(&pgpool.clone()).await.unwrap();
Expand All @@ -57,8 +49,6 @@ impl DenyListCheck {
sender_denylist_watcher_cancel_token.clone(),
)));
Self {
domain_separator,
escrow_accounts,
sender_denylist,
_sender_denylist_watcher_handle: sender_denylist_watcher_handle,
sender_denylist_watcher_cancel_token,
Expand Down Expand Up @@ -152,29 +142,19 @@ impl DenyListCheck {
impl Check for DenyListCheck {
async fn check(
&self,
_: &tap_core::receipt::Context,
receipt: &ReceiptWithState<Checking>,
ctx: &tap_core::receipt::Context,
_: &ReceiptWithState<Checking>,
) -> CheckResult {
let receipt_signer = receipt
.signed_receipt()
.recover_signer(&self.domain_separator)
.map_err(|e| {
error!("Failed to recover receipt signer: {}", e);
anyhow::anyhow!(e)
})
.map_err(CheckError::Failed)?;
let escrow_accounts_snapshot = self.escrow_accounts.borrow();

let receipt_sender = escrow_accounts_snapshot
.get_sender_for_signer(&receipt_signer)
.map_err(|e| CheckError::Failed(e.into()))?;
let Sender(receipt_sender) = ctx
.get::<Sender>()
.ok_or(CheckError::Failed(anyhow::anyhow!("Could not find sender")))?;

// Check that the sender is not denylisted
if self
.sender_denylist
.read()
.unwrap()
.contains(&receipt_sender)
.contains(receipt_sender)
{
return Err(CheckError::Failed(anyhow::anyhow!(
"Received a receipt from a denylisted sender: {}",
Expand All @@ -200,26 +180,16 @@ mod tests {

use alloy::hex::ToHexExt;
use tap_core::receipt::{Context, ReceiptWithState};
use tokio::sync::watch;

use test_assets::{
self, create_signed_receipt, ESCROW_ACCOUNTS_BALANCES, ESCROW_ACCOUNTS_SENDERS_TO_SIGNERS,
TAP_EIP712_DOMAIN, TAP_SENDER,
};
use test_assets::{self, create_signed_receipt, TAP_SENDER};

use super::*;

const ALLOCATION_ID: &str = "0xdeadbeefcafebabedeadbeefcafebabedeadbeef";

async fn new_deny_list_check(pgpool: PgPool) -> DenyListCheck {
// Mock escrow accounts
let escrow_accounts_rx = watch::channel(EscrowAccounts::new(
ESCROW_ACCOUNTS_BALANCES.to_owned(),
ESCROW_ACCOUNTS_SENDERS_TO_SIGNERS.to_owned(),
))
.1;

DenyListCheck::new(pgpool, escrow_accounts_rx, TAP_EIP712_DOMAIN.to_owned()).await
DenyListCheck::new(pgpool).await
}

#[sqlx::test(migrations = "../../migrations")]
Expand All @@ -244,9 +214,12 @@ mod tests {

let checking_receipt = ReceiptWithState::new(signed_receipt);

let mut ctx = Context::new();
ctx.insert(Sender(TAP_SENDER.1));

// Check that the receipt is rejected
assert!(deny_list_check
.check(&Context::new(), &checking_receipt)
.check(&ctx, &checking_receipt)
.await
.is_err());
}
Expand All @@ -262,8 +235,10 @@ mod tests {
// Check that the receipt is valid
let checking_receipt = ReceiptWithState::new(signed_receipt);

let mut ctx = Context::new();
ctx.insert(Sender(TAP_SENDER.1));
deny_list_check
.check(&Context::new(), &checking_receipt)
.check(&ctx, &checking_receipt)
.await
.unwrap();

Expand All @@ -282,7 +257,7 @@ mod tests {
// Check that the receipt is rejected
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(deny_list_check
.check(&Context::new(), &checking_receipt)
.check(&ctx, &checking_receipt)
.await
.is_err());

Expand All @@ -301,7 +276,7 @@ mod tests {
// Check that the receipt is valid again
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
deny_list_check
.check(&Context::new(), &checking_receipt)
.check(&ctx, &checking_receipt)
.await
.unwrap();
}
Expand Down
37 changes: 11 additions & 26 deletions crates/service/src/tap/checks/sender_balance_check.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
// SPDX-License-Identifier: Apache-2.0

use alloy::dyn_abi::Eip712Domain;
use alloy::primitives::U256;
use anyhow::anyhow;
use indexer_monitor::EscrowAccounts;
Expand All @@ -11,55 +10,41 @@ use tap_core::receipt::{
ReceiptWithState,
};
use tokio::sync::watch::Receiver;
use tracing::error;

use crate::middleware::Sender;

pub struct SenderBalanceCheck {
escrow_accounts: Receiver<EscrowAccounts>,

domain_separator: Eip712Domain,
}

impl SenderBalanceCheck {
pub fn new(escrow_accounts: Receiver<EscrowAccounts>, domain_separator: Eip712Domain) -> Self {
Self {
escrow_accounts,
domain_separator,
}
pub fn new(escrow_accounts: Receiver<EscrowAccounts>) -> Self {
Self { escrow_accounts }
}
}

#[async_trait::async_trait]
impl Check for SenderBalanceCheck {
async fn check(
&self,
_: &tap_core::receipt::Context,
receipt: &ReceiptWithState<Checking>,
ctx: &tap_core::receipt::Context,
_: &ReceiptWithState<Checking>,
) -> CheckResult {
let escrow_accounts_snapshot = self.escrow_accounts.borrow();

let receipt_signer = receipt
.signed_receipt()
.recover_signer(&self.domain_separator)
.inspect_err(|e| {
error!("Failed to recover receipt signer: {}", e);
})
.map_err(|e| CheckError::Failed(e.into()))?;

// We bail if the receipt signer does not have a corresponding sender in the escrow
// accounts.
let receipt_sender = escrow_accounts_snapshot
.get_sender_for_signer(&receipt_signer)
.map_err(|e| CheckError::Failed(e.into()))?;
let Sender(receipt_sender) = ctx
.get::<Sender>()
.ok_or(CheckError::Failed(anyhow::anyhow!("Could not find sender")))?;

// Check that the sender has a non-zero balance -- more advanced accounting is done in
// `tap-agent`.
if !escrow_accounts_snapshot
.get_balance_for_sender(&receipt_sender)
.get_balance_for_sender(receipt_sender)
.map_or(false, |balance| balance > U256::ZERO)
{
return Err(CheckError::Failed(anyhow!(
"Receipt sender `{}` does not have a sufficient balance",
receipt_signer,
receipt_sender,
)));
}
Ok(())
Expand Down

0 comments on commit a50e23d

Please sign in to comment.