Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions crates/ragfs/src/plugins/queuefs/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use crate::core::errors::{Error, Result};
use chrono::{DateTime, Utc};
use rusqlite::{params, types::ValueRef, Connection, Row};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::Mutex;
use std::time::{Duration, UNIX_EPOCH};
Expand Down Expand Up @@ -39,7 +39,8 @@ impl Message {
#[derive(Debug, Clone, Serialize, Deserialize)]
struct StoredMessage {
id: String,
data: String,
#[serde(deserialize_with = "deserialize_stored_message_data")]
data: Vec<u8>,
#[serde(default)]
timestamp: Option<serde_json::Value>,
}
Expand All @@ -48,7 +49,7 @@ impl StoredMessage {
fn from_message(msg: &Message) -> Self {
Self {
id: msg.id.clone(),
data: String::from_utf8_lossy(&msg.data).to_string(),
data: msg.data.clone(),
// Prefer unix seconds for compatibility with older queue.db producers.
timestamp: Some(serde_json::Value::Number(unix_secs(msg.timestamp).into())),
}
Expand All @@ -57,12 +58,41 @@ impl StoredMessage {
fn into_message(self) -> Message {
Message {
id: self.id,
data: self.data.into_bytes(),
data: self.data,
timestamp: parse_stored_timestamp(self.timestamp),
}
}
}

fn deserialize_stored_message_data<'de, D>(
deserializer: D,
) -> std::result::Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
match value {
// Legacy queue.db rows stored message bytes as a JSON string.
serde_json::Value::String(data) => Ok(data.into_bytes()),
// New rows keep arbitrary message bytes lossless as a JSON byte array.
serde_json::Value::Array(bytes) => bytes
.into_iter()
.map(|byte| {
byte.as_u64()
.and_then(|value| u8::try_from(value).ok())
.ok_or_else(|| {
serde::de::Error::custom(
"stored queue message data byte must be an integer in 0..=255",
)
})
})
.collect(),
other => Err(serde::de::Error::custom(format!(
"stored queue message data must be a string or byte array, got {other}"
))),
}
}

fn parse_stored_timestamp(raw: Option<serde_json::Value>) -> SystemTime {
match raw {
Some(serde_json::Value::String(ts)) => DateTime::parse_from_rfc3339(&ts)
Expand Down Expand Up @@ -846,6 +876,25 @@ mod tests {
assert_eq!(second.data, b"message 2");
}

#[test]
fn test_sqlite_backend_preserves_non_utf8_payload_bytes() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("queue.db");
let mut backend =
SQLiteQueueBackend::open(db_path.to_str().unwrap(), SQLiteQueueOptions::default())
.unwrap();

backend.create_queue("test").unwrap();
let payload = vec![0xff, 0x00, 0x80, b'a'];
backend.enqueue("test", Message::new(payload.clone())).unwrap();

let peeked = backend.peek("test").unwrap().unwrap();
assert_eq!(peeked.data, payload);

let dequeued = backend.dequeue("test").unwrap().unwrap();
assert_eq!(dequeued.data, payload);
}

#[test]
fn test_sqlite_backend_recover_stale() {
let dir = tempdir().unwrap();
Expand Down
Loading