Skip to content

Commit 45734a1

Browse files
authored
Validate task parameters in 'spawn_by_name' (#57)
* Validate task parameters in 'spawn_by_name' The 'spawn_by_name' method requires the task to exist in the registry, so we can validate the parameters by attempting to deserialize into the parameter type. This lets us catch some errors before we try to insert the task into the database * Fix test * Fix test_spawn_with_empty_params
1 parent 75bde52 commit 45734a1

File tree

4 files changed

+62
-18
lines changed

4 files changed

+62
-18
lines changed

src/client.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -342,17 +342,7 @@ where
342342
params: JsonValue,
343343
options: SpawnOptions,
344344
) -> DurableResult<SpawnResult> {
345-
// Validate that the task is registered
346-
{
347-
let registry = self.registry.read().await;
348-
if !registry.contains_key(task_name) {
349-
return Err(DurableError::TaskNotRegistered {
350-
task_name: task_name.to_string(),
351-
});
352-
}
353-
}
354-
355-
self.spawn_by_name_internal(&self.pool, task_name, params, options)
345+
self.spawn_by_name_with(&self.pool, task_name, params, options)
356346
.await
357347
}
358348

@@ -432,11 +422,16 @@ where
432422
// Validate that the task is registered
433423
{
434424
let registry = self.registry.read().await;
435-
if !registry.contains_key(task_name) {
425+
let Some(task) = registry.get(task_name) else {
436426
return Err(DurableError::TaskNotRegistered {
437427
task_name: task_name.to_string(),
438428
});
439-
}
429+
};
430+
task.validate_params(params.clone())
431+
.map_err(|e| DurableError::InvalidTaskParams {
432+
task_name: task_name.to_string(),
433+
message: e.to_string(),
434+
})?;
440435
}
441436

442437
self.spawn_by_name_internal(executor, task_name, params, options)

src/error.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,18 @@ pub enum DurableError {
311311
task_name: String,
312312
},
313313

314+
//// Task params validation failed.
315+
///
316+
/// Returned when the task definition in the registry fails to validate the params
317+
/// (before we attempt to spawn the task in Postgres).
318+
#[error("invalid task parameters for '{task_name}': {message}")]
319+
InvalidTaskParams {
320+
/// The name of the task being spawned
321+
task_name: String,
322+
/// The error message from the task.
323+
message: String,
324+
},
325+
314326
/// Header key uses a reserved prefix.
315327
///
316328
/// User-provided headers cannot start with "durable::" as this prefix

src/task.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ where
109109
State: Clone + Send + Sync + 'static,
110110
{
111111
fn name(&self) -> Cow<'static, str>;
112+
/// Called before spawning, to check that the `params` are valid for this task.
113+
fn validate_params(&self, params: JsonValue) -> Result<(), TaskError>;
112114
async fn execute(
113115
&self,
114116
params: JsonValue,
@@ -127,6 +129,12 @@ where
127129
T::name()
128130
}
129131

132+
fn validate_params(&self, params: JsonValue) -> Result<(), TaskError> {
133+
// For now, just deserialize
134+
let _typed_params: T::Params = serde_json::from_value(params)?;
135+
Ok(())
136+
}
137+
130138
async fn execute(
131139
&self,
132140
params: JsonValue,

tests/spawn_test.rs

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
mod common;
44

55
use common::tasks::{EchoParams, EchoTask, FailingParams, FailingTask};
6-
use durable::{CancellationPolicy, Durable, MIGRATOR, RetryStrategy, SpawnOptions};
6+
use durable::{CancellationPolicy, Durable, DurableError, MIGRATOR, RetryStrategy, SpawnOptions};
77
use sqlx::PgPool;
88
use std::collections::HashMap;
99
use std::time::Duration;
@@ -270,6 +270,33 @@ async fn test_spawn_by_name(pool: PgPool) -> sqlx::Result<()> {
270270
Ok(())
271271
}
272272

273+
#[sqlx::test(migrator = "MIGRATOR")]
274+
async fn test_spawn_by_name_invalid_params(pool: PgPool) -> sqlx::Result<()> {
275+
let client = create_client(pool.clone(), "spawn_by_name").await;
276+
client.create_queue(None).await.unwrap();
277+
client.register::<EchoTask>().await.unwrap();
278+
279+
let params = serde_json::json!({
280+
"message": 12345
281+
});
282+
283+
let result = client
284+
.spawn_by_name("echo", params, SpawnOptions::default())
285+
.await
286+
.expect_err("Spawning task by name with invalid params should fail");
287+
288+
let DurableError::InvalidTaskParams { task_name, message } = result else {
289+
panic!("Unexpected error: {}", result);
290+
};
291+
assert_eq!(task_name, "echo");
292+
assert_eq!(
293+
message,
294+
"serialization error: invalid type: integer `12345`, expected a string"
295+
);
296+
297+
Ok(())
298+
}
299+
273300
#[sqlx::test(migrator = "MIGRATOR")]
274301
async fn test_spawn_by_name_with_options(pool: PgPool) -> sqlx::Result<()> {
275302
let client = create_client(pool.clone(), "spawn_by_name_opts").await;
@@ -308,9 +335,10 @@ async fn test_spawn_with_empty_params(pool: PgPool) -> sqlx::Result<()> {
308335
client.create_queue(None).await.unwrap();
309336
client.register::<EchoTask>().await.unwrap();
310337

311-
// Empty object is valid JSON params for EchoTask (message will be missing but that's ok for this test)
338+
// Empty object is not valid JSON params for EchoTask,
339+
// but spawn_by_name_unchecked does not validate the JSON
312340
let result = client
313-
.spawn_by_name("echo", serde_json::json!({}), SpawnOptions::default())
341+
.spawn_by_name_unchecked("echo", serde_json::json!({}), SpawnOptions::default())
314342
.await
315343
.expect("Failed to spawn task with empty params");
316344

@@ -326,7 +354,8 @@ async fn test_spawn_with_complex_params(pool: PgPool) -> sqlx::Result<()> {
326354
client.register::<EchoTask>().await.unwrap();
327355

328356
// Complex nested JSON structure - the params don't need to match the task's Params type
329-
// because spawn_by_name accepts arbitrary JSON
357+
// because spawn_by_name_unchecked does not validate the JSON
358+
// (unlike `spawn_by_name`)
330359
let params = serde_json::json!({
331360
"nested": {
332361
"array": [1, 2, 3],
@@ -341,7 +370,7 @@ async fn test_spawn_with_complex_params(pool: PgPool) -> sqlx::Result<()> {
341370
});
342371

343372
let result = client
344-
.spawn_by_name("echo", params, SpawnOptions::default())
373+
.spawn_by_name_unchecked("echo", params, SpawnOptions::default())
345374
.await
346375
.expect("Failed to spawn task with complex params");
347376

0 commit comments

Comments
 (0)