-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtask.rs
More file actions
174 lines (161 loc) · 5.65 KB
/
task.rs
File metadata and controls
174 lines (161 loc) · 5.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
use async_trait::async_trait;
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value as JsonValue;
use std::borrow::Cow;
use std::sync::Arc;
use crate::context::TaskContext;
use crate::error::{TaskError, TaskResult};
/// Defines a task with typed parameters and output.
///
/// Implement this trait for your task types. The worker will:
/// 1. Deserialize params from JSON into `Params` type
/// 2. Call `run()` with the typed params, a TaskContext, and your application state
/// 3. Serialize the result back to JSON for storage
///
/// # Type Parameter
///
/// * `State` - Application state type (e.g., HTTP clients, database pools).
/// Use `()` if you don't need any state.
///
/// # Example
/// ```ignore
/// #[derive(Default)]
/// struct SendEmailTask;
///
/// #[async_trait]
/// impl Task<()> for SendEmailTask {
/// fn name(&self) -> Cow<'static, str> { Cow::Borrowed("send-email") }
/// type Params = SendEmailParams;
/// type Output = SendEmailResult;
///
/// async fn run(&self, params: Self::Params, mut ctx: TaskContext, _state: ()) -> TaskResult<Self::Output> {
/// let result = ctx.step("send", params, |params, _| async move {
/// email_service::send(¶ms.to, ¶ms.subject, ¶ms.body).await
/// }).await?;
///
/// Ok(SendEmailResult { message_id: result.id })
/// }
/// }
///
/// // With application state:
/// #[derive(Clone)]
/// struct AppState {
/// http_client: reqwest::Client,
/// }
///
/// #[derive(Default)]
/// struct FetchUrlTask;
///
/// #[async_trait]
/// impl Task<AppState> for FetchUrlTask {
/// fn name(&self) -> Cow<'static, str> { Cow::Borrowed("fetch-url") }
/// type Params = String;
/// type Output = String;
///
/// async fn run(&self, url: Self::Params, mut ctx: TaskContext, state: AppState) -> TaskResult<Self::Output> {
/// let body = ctx.step("fetch", url, |url, _| async move {
/// state.http_client.get(&url).send().await
/// .map_err(|e| anyhow::anyhow!("HTTP error: {}", e))?
/// .text().await
/// .map_err(|e| anyhow::anyhow!("HTTP error: {}", e))
/// }).await?;
/// Ok(body)
/// }
/// }
/// ```
#[async_trait]
pub trait Task<State>: Send + Sync + 'static
where
State: Clone + Send + Sync + 'static,
{
/// Task name as stored in the database.
/// Should be unique across your application.
fn name(&self) -> Cow<'static, str>;
/// Parameter type (must be JSON-serializable)
type Params: Serialize + DeserializeOwned + Send;
/// Output type (must be JSON-serializable)
type Output: Serialize + DeserializeOwned + Send;
/// Validate the parameters for this task.
/// This is called before spawning the task, to allow us to catch errors early.
/// By default, this just tries to deserialize the parameters into the `Self::Params` type.
fn validate_params(&self, params: JsonValue) -> Result<(), TaskError> {
let _typed_params: Self::Params = serde_json::from_value(params)?;
Ok(())
}
/// Execute the task logic.
///
/// Return `Ok(output)` on success, or `Err(TaskError)` on failure.
/// Use `?` freely - errors will propagate and the task will be retried
/// according to its [`RetryStrategy`](crate::RetryStrategy).
///
/// For user errors with structured data, use `TaskError::user(data)` where
/// data is any serializable value. For simple message errors, use
/// `TaskError::user_message("message")`.
///
/// The [`TaskContext`] provides methods for checkpointing, sleeping,
/// and waiting for events. See [`TaskContext`] for details.
///
/// The `state` parameter provides access to application-level resources
/// like HTTP clients, database pools, etc.
async fn run(
&self,
params: Self::Params,
ctx: TaskContext<State>,
state: State,
) -> TaskResult<Self::Output>;
}
/// Internal trait for storing heterogeneous tasks in a HashMap.
/// Converts between typed Task interface and JSON values.
#[async_trait]
#[allow(dead_code)]
pub trait ErasedTask<State>: Send + Sync
where
State: Clone + Send + Sync + 'static,
{
fn name(&self) -> Cow<'static, str>;
/// Called before spawning, to check that the `params` are valid for this task.
fn validate_params(&self, params: JsonValue) -> Result<(), TaskError>;
async fn execute(
&self,
params: JsonValue,
ctx: TaskContext<State>,
state: State,
) -> Result<JsonValue, TaskError>;
}
/// Wrapper that implements [`ErasedTask`] for any [`Task`] type.
///
/// This allows storing heterogeneous tasks in a registry while preserving
/// their ability to execute.
pub struct TaskWrapper<T>(pub Arc<T>);
impl<T> TaskWrapper<T> {
/// Create a new TaskWrapper from a task instance.
pub fn new(task: T) -> Self {
Self(Arc::new(task))
}
}
#[async_trait]
impl<T, State> ErasedTask<State> for TaskWrapper<T>
where
T: Task<State>,
State: Clone + Send + Sync + 'static,
{
fn name(&self) -> Cow<'static, str> {
self.0.name()
}
fn validate_params(&self, params: JsonValue) -> Result<(), TaskError> {
self.0.validate_params(params)
}
async fn execute(
&self,
params: JsonValue,
ctx: TaskContext<State>,
state: State,
) -> Result<JsonValue, TaskError> {
let typed_params: T::Params = serde_json::from_value(params)?;
let result = self.0.run(typed_params, ctx, state).await?;
Ok(serde_json::to_value(&result)?)
}
}
/// Type alias for the task registry
pub type TaskRegistry<State> =
std::collections::HashMap<Cow<'static, str>, Arc<dyn ErasedTask<State>>>;