@@ -7,6 +7,9 @@ use uuid::Uuid;
77
88use crate :: Durable ;
99use crate :: error:: { ControlFlow , TaskError , TaskResult } ;
10+ use std:: sync:: Arc ;
11+
12+ use crate :: heartbeat:: { HeartbeatHandle , Heartbeater , StepState } ;
1013use crate :: task:: Task ;
1114use crate :: types:: DurableEventPayload ;
1215use crate :: types:: {
6366
6467 /// Notifies the worker when the lease is extended via step() or heartbeat().
6568 lease_extender : LeaseExtender ,
69+
70+ /// Cloneable heartbeat handle for use in step closures.
71+ heartbeat_handle : HeartbeatHandle ,
6672}
6773
6874/// Validate that a user-provided step name doesn't use reserved prefix.
@@ -103,6 +109,14 @@ where
103109 cache. insert ( row. checkpoint_name , row. state ) ;
104110 }
105111
112+ let heartbeat_handle = HeartbeatHandle :: new (
113+ durable. pool ( ) . clone ( ) ,
114+ durable. queue_name ( ) . to_string ( ) ,
115+ task. run_id ,
116+ claim_timeout,
117+ lease_extender. clone ( ) ,
118+ ) ;
119+
106120 Ok ( Self {
107121 task_id : task. task_id ,
108122 run_id : task. run_id ,
@@ -113,6 +127,7 @@ where
113127 checkpoint_cache : cache,
114128 step_counters : HashMap :: new ( ) ,
115129 lease_extender,
130+ heartbeat_handle,
116131 } )
117132 }
118133
@@ -152,9 +167,9 @@ where
152167 /// # Example
153168 ///
154169 /// ```ignore
155- /// let payment_id = ctx.step("charge-payment", ctx.task_id, |task_id, _state | async {
170+ /// let payment_id = ctx.step("charge-payment", ctx.task_id, |task_id, step_state | async {
156171 /// let idempotency_key = format!("{}:charge", task_id);
157- /// stripe::charge(amount, &idempotency_key).await
172+ /// stripe::charge(amount, &idempotency_key, &step_state.state ).await
158173 /// }).await?;
159174 /// ```
160175 #[ cfg_attr(
@@ -169,7 +184,7 @@ where
169184 & mut self ,
170185 base_name : & str ,
171186 params : P ,
172- f : fn ( P , State ) -> Fut ,
187+ f : fn ( P , StepState < State > ) -> Fut ,
173188 ) -> TaskResult < T >
174189 where
175190 P : Serialize ,
@@ -193,13 +208,14 @@ where
193208 span. record ( "cached" , false ) ;
194209
195210 // Execute the step
196- let result =
197- f ( params, self . durable . state ( ) . clone ( ) )
198- . await
199- . map_err ( |e| TaskError :: Step {
200- base_name : base_name. to_string ( ) ,
201- error : e,
202- } ) ?;
211+ let step_state = StepState {
212+ state : self . durable . state ( ) . clone ( ) ,
213+ heartbeater : Arc :: new ( self . heartbeat_handle . clone ( ) ) ,
214+ } ;
215+ let result = f ( params, step_state) . await . map_err ( |e| TaskError :: Step {
216+ base_name : base_name. to_string ( ) ,
217+ error : e,
218+ } ) ?;
203219
204220 // Persist checkpoint (also extends claim lease)
205221 #[ cfg( feature = "telemetry" ) ]
@@ -461,6 +477,14 @@ where
461477 } )
462478 }
463479
480+ /// Get a cloneable heartbeat handle for use in step closures or `SimpleTool`s.
481+ ///
482+ /// The returned [`HeartbeatHandle`] can be passed into contexts that need to
483+ /// extend the task lease without access to the full `TaskContext`.
484+ pub fn heartbeat_handle ( & self ) -> HeartbeatHandle {
485+ self . heartbeat_handle . clone ( )
486+ }
487+
464488 /// Extend the task's lease to prevent timeout.
465489 ///
466490 /// Use this for long-running operations that don't naturally checkpoint.
@@ -482,27 +506,7 @@ where
482506 )
483507 ) ]
484508 pub async fn heartbeat ( & self , duration : Option < std:: time:: Duration > ) -> TaskResult < ( ) > {
485- let extend_by = duration. unwrap_or ( self . claim_timeout ) ;
486-
487- if extend_by < std:: time:: Duration :: from_secs ( 1 ) {
488- return Err ( TaskError :: Validation {
489- message : "heartbeat duration must be at least 1 second" . to_string ( ) ,
490- } ) ;
491- }
492-
493- let query = "SELECT durable.extend_claim($1, $2, $3)" ;
494- sqlx:: query ( query)
495- . bind ( self . durable . queue_name ( ) )
496- . bind ( self . run_id )
497- . bind ( extend_by. as_secs ( ) as i32 )
498- . execute ( self . durable . pool ( ) )
499- . await
500- . map_err ( TaskError :: from_sqlx_error) ?;
501-
502- // Notify worker that lease was extended so it can reset timers
503- self . lease_extender . notify ( extend_by) ;
504-
505- Ok ( ( ) )
509+ self . heartbeat_handle . heartbeat ( duration) . await
506510 }
507511
508512 /// Generate a durable random value in [0, 1).
0 commit comments