Skip to content

Commit 2d69424

Browse files
authored
added more flexibility to public interface around heartbeating (#78)
* added more flexibility to public inteface around heartbeating * do the heartbeat handling here * improve testing ergonomics
1 parent 557713c commit 2d69424

File tree

4 files changed

+172
-33
lines changed

4 files changed

+172
-33
lines changed

src/context.rs

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ use uuid::Uuid;
77

88
use crate::Durable;
99
use crate::error::{ControlFlow, TaskError, TaskResult};
10+
use std::sync::Arc;
11+
12+
use crate::heartbeat::{HeartbeatHandle, Heartbeater, StepState};
1013
use crate::task::Task;
1114
use crate::types::DurableEventPayload;
1215
use crate::types::{
@@ -63,6 +66,9 @@ where
6366

6467
/// Notifies the worker when the lease is extended via step() or heartbeat().
6568
lease_extender: LeaseExtender,
69+
70+
/// Cloneable heartbeat handle for use in step closures.
71+
heartbeat_handle: HeartbeatHandle,
6672
}
6773

6874
/// Validate that a user-provided step name doesn't use reserved prefix.
@@ -103,6 +109,14 @@ where
103109
cache.insert(row.checkpoint_name, row.state);
104110
}
105111

112+
let heartbeat_handle = HeartbeatHandle::new(
113+
durable.pool().clone(),
114+
durable.queue_name().to_string(),
115+
task.run_id,
116+
claim_timeout,
117+
lease_extender.clone(),
118+
);
119+
106120
Ok(Self {
107121
task_id: task.task_id,
108122
run_id: task.run_id,
@@ -113,6 +127,7 @@ where
113127
checkpoint_cache: cache,
114128
step_counters: HashMap::new(),
115129
lease_extender,
130+
heartbeat_handle,
116131
})
117132
}
118133

@@ -152,9 +167,9 @@ where
152167
/// # Example
153168
///
154169
/// ```ignore
155-
/// let payment_id = ctx.step("charge-payment", ctx.task_id, |task_id, _state| async {
170+
/// let payment_id = ctx.step("charge-payment", ctx.task_id, |task_id, step_state| async {
156171
/// let idempotency_key = format!("{}:charge", task_id);
157-
/// stripe::charge(amount, &idempotency_key).await
172+
/// stripe::charge(amount, &idempotency_key, &step_state.state).await
158173
/// }).await?;
159174
/// ```
160175
#[cfg_attr(
@@ -169,7 +184,7 @@ where
169184
&mut self,
170185
base_name: &str,
171186
params: P,
172-
f: fn(P, State) -> Fut,
187+
f: fn(P, StepState<State>) -> Fut,
173188
) -> TaskResult<T>
174189
where
175190
P: Serialize,
@@ -193,13 +208,14 @@ where
193208
span.record("cached", false);
194209

195210
// Execute the step
196-
let result =
197-
f(params, self.durable.state().clone())
198-
.await
199-
.map_err(|e| TaskError::Step {
200-
base_name: base_name.to_string(),
201-
error: e,
202-
})?;
211+
let step_state = StepState {
212+
state: self.durable.state().clone(),
213+
heartbeater: Arc::new(self.heartbeat_handle.clone()),
214+
};
215+
let result = f(params, step_state).await.map_err(|e| TaskError::Step {
216+
base_name: base_name.to_string(),
217+
error: e,
218+
})?;
203219

204220
// Persist checkpoint (also extends claim lease)
205221
#[cfg(feature = "telemetry")]
@@ -461,6 +477,14 @@ where
461477
})
462478
}
463479

480+
/// Get a cloneable heartbeat handle for use in step closures or `SimpleTool`s.
481+
///
482+
/// The returned [`HeartbeatHandle`] can be passed into contexts that need to
483+
/// extend the task lease without access to the full `TaskContext`.
484+
pub fn heartbeat_handle(&self) -> HeartbeatHandle {
485+
self.heartbeat_handle.clone()
486+
}
487+
464488
/// Extend the task's lease to prevent timeout.
465489
///
466490
/// Use this for long-running operations that don't naturally checkpoint.
@@ -482,27 +506,7 @@ where
482506
)
483507
)]
484508
pub async fn heartbeat(&self, duration: Option<std::time::Duration>) -> TaskResult<()> {
485-
let extend_by = duration.unwrap_or(self.claim_timeout);
486-
487-
if extend_by < std::time::Duration::from_secs(1) {
488-
return Err(TaskError::Validation {
489-
message: "heartbeat duration must be at least 1 second".to_string(),
490-
});
491-
}
492-
493-
let query = "SELECT durable.extend_claim($1, $2, $3)";
494-
sqlx::query(query)
495-
.bind(self.durable.queue_name())
496-
.bind(self.run_id)
497-
.bind(extend_by.as_secs() as i32)
498-
.execute(self.durable.pool())
499-
.await
500-
.map_err(TaskError::from_sqlx_error)?;
501-
502-
// Notify worker that lease was extended so it can reset timers
503-
self.lease_extender.notify(extend_by);
504-
505-
Ok(())
509+
self.heartbeat_handle.heartbeat(duration).await
506510
}
507511

508512
/// Generate a durable random value in [0, 1).

src/heartbeat.rs

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
use std::sync::Arc;
2+
use std::time::Duration;
3+
4+
use async_trait::async_trait;
5+
use uuid::Uuid;
6+
7+
use crate::error::{TaskError, TaskResult};
8+
use crate::worker::LeaseExtender;
9+
10+
/// Trait for extending task leases during long-running operations.
11+
///
12+
/// Implementations allow code running inside a `step()` closure to send heartbeats
13+
/// back to the durable framework, preventing the worker from considering the task
14+
/// dead during long-running operations.
15+
///
16+
/// Two implementations are provided:
17+
/// - [`HeartbeatHandle`] — extends leases via the database (used in durable workers)
18+
/// - [`NoopHeartbeater`] — does nothing (used in tests and non-durable contexts)
19+
#[async_trait]
20+
pub trait Heartbeater: Send + Sync {
21+
/// Extend the task's lease.
22+
///
23+
/// # Arguments
24+
/// * `duration` - Extension duration. If `None`, uses the original claim timeout.
25+
/// Must be at least 1 second when `Some`.
26+
async fn heartbeat(&self, duration: Option<Duration>) -> TaskResult<()>;
27+
}
28+
29+
/// Real heartbeat handle that extends leases via the database.
30+
///
31+
/// Created from a [`TaskContext`](crate::TaskContext) via
32+
/// [`heartbeat_handle()`](crate::TaskContext::heartbeat_handle) and can be
33+
/// passed into step closures or other contexts that need to extend the task lease.
34+
#[derive(Clone)]
35+
pub struct HeartbeatHandle {
36+
pool: sqlx::PgPool,
37+
queue_name: String,
38+
run_id: Uuid,
39+
claim_timeout: Duration,
40+
lease_extender: LeaseExtender,
41+
}
42+
43+
impl HeartbeatHandle {
44+
pub(crate) fn new(
45+
pool: sqlx::PgPool,
46+
queue_name: String,
47+
run_id: Uuid,
48+
claim_timeout: Duration,
49+
lease_extender: LeaseExtender,
50+
) -> Self {
51+
Self {
52+
pool,
53+
queue_name,
54+
run_id,
55+
claim_timeout,
56+
lease_extender,
57+
}
58+
}
59+
}
60+
61+
#[async_trait]
62+
impl Heartbeater for HeartbeatHandle {
63+
async fn heartbeat(&self, duration: Option<Duration>) -> TaskResult<()> {
64+
let extend_by = duration.unwrap_or(self.claim_timeout);
65+
66+
if extend_by < Duration::from_secs(1) {
67+
return Err(TaskError::Validation {
68+
message: "heartbeat duration must be at least 1 second".to_string(),
69+
});
70+
}
71+
72+
let query = "SELECT durable.extend_claim($1, $2, $3)";
73+
sqlx::query(query)
74+
.bind(&self.queue_name)
75+
.bind(self.run_id)
76+
.bind(extend_by.as_secs() as i32)
77+
.execute(&self.pool)
78+
.await
79+
.map_err(TaskError::from_sqlx_error)?;
80+
81+
// Notify worker that lease was extended so it can reset timers
82+
self.lease_extender.notify(extend_by);
83+
84+
Ok(())
85+
}
86+
}
87+
88+
/// No-op heartbeater for testing and non-durable contexts.
89+
///
90+
/// All heartbeat calls succeed immediately without any side effects.
91+
#[derive(Clone, Default)]
92+
pub struct NoopHeartbeater;
93+
94+
#[async_trait]
95+
impl Heartbeater for NoopHeartbeater {
96+
async fn heartbeat(&self, _duration: Option<Duration>) -> TaskResult<()> {
97+
Ok(())
98+
}
99+
}
100+
101+
/// State provided to `step()` closures, wrapping the user's application state
102+
/// alongside a heartbeater for extending the task lease.
103+
///
104+
/// This is passed as the second argument to every `step()` closure, making
105+
/// heartbeating available without the consumer needing to thread it manually.
106+
///
107+
/// # Example
108+
///
109+
/// ```ignore
110+
/// ctx.step("long-operation", params, |params, step_state| async move {
111+
/// for item in &params.items {
112+
/// process(item, &step_state.state).await?;
113+
/// // Extend lease during long-running work
114+
/// let _ = step_state.heartbeater.heartbeat(None).await;
115+
/// }
116+
/// Ok(result)
117+
/// }).await?;
118+
/// ```
119+
///
120+
/// For testing step closures in isolation, construct with [`NoopHeartbeater`]:
121+
///
122+
/// ```ignore
123+
/// let step_state = StepState {
124+
/// state: my_test_state,
125+
/// heartbeater: Arc::new(NoopHeartbeater),
126+
/// };
127+
/// ```
128+
pub struct StepState<State> {
129+
/// The user's application state.
130+
pub state: State,
131+
/// Handle for extending the task lease during long-running operations.
132+
pub heartbeater: Arc<dyn Heartbeater>,
133+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ mod client;
9898
mod context;
9999
mod cron;
100100
mod error;
101+
mod heartbeat;
101102
mod task;
102103
#[cfg(feature = "telemetry")]
103104
pub mod telemetry;
@@ -109,6 +110,7 @@ pub use client::{Durable, DurableBuilder};
109110
pub use context::TaskContext;
110111
pub use cron::{ScheduleFilter, ScheduleInfo, ScheduleOptions, setup_pgcron};
111112
pub use error::{ControlFlow, DurableError, DurableResult, TaskError, TaskResult};
113+
pub use heartbeat::{HeartbeatHandle, Heartbeater, NoopHeartbeater, StepState};
112114
pub use task::{ErasedTask, Task, TaskWrapper};
113115
pub use types::{
114116
CancellationPolicy, ClaimedTask, DurableEventPayload, RetryStrategy, SpawnDefaults,

tests/execution_test.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,13 +745,13 @@ impl durable::Task<AppState> for WriteToDbTask {
745745
) -> durable::TaskResult<Self::Output> {
746746
// Use the app state's db pool to write to a table
747747
let row_id: i64 = ctx
748-
.step("insert", params, |params, state| async move {
748+
.step("insert", params, |params, step_state| async move {
749749
let (id,): (i64,) = sqlx::query_as(
750750
"INSERT INTO test_state_table (key, value) VALUES ($1, $2) RETURNING id",
751751
)
752752
.bind(&params.key)
753753
.bind(&params.value)
754-
.fetch_one(&state.db_pool)
754+
.fetch_one(&step_state.state.db_pool)
755755
.await
756756
.map_err(|e| anyhow::anyhow!("DB error: {}", e))?;
757757
Ok(id)

0 commit comments

Comments
 (0)