1use anyhow::Result;
2use redb::{Database, ReadableDatabase, ReadableTable, TableDefinition};
3use std::sync::Arc;
4use tokio::sync::Notify;
5
6const PENDING: TableDefinition<&str, &[u8]> = TableDefinition::new("pending");
9const PROCESSING: TableDefinition<&str, &[u8]> = TableDefinition::new("processing");
10const COMPLETED: TableDefinition<&str, &[u8]> = TableDefinition::new("completed");
11
12#[derive(Clone)]
14pub struct TaskQueue {
15 db: Arc<Database>,
16 notify: Arc<Notify>,
17}
18
19impl TaskQueue {
20 pub fn new(db: Arc<Database>) -> Result<Self> {
22 let write_txn = db.begin_write()?;
24 write_txn.open_table(PENDING)?;
25 write_txn.open_table(PROCESSING)?;
26 write_txn.open_table(COMPLETED)?;
27 write_txn.commit()?;
28
29 Ok(Self {
30 db,
31 notify: Arc::new(Notify::new()),
32 })
33 }
34
35 pub fn insert_pending(&self, priority: u64, task_id: &str, data: &[u8]) -> Result<()> {
37 let write_txn = self.db.begin_write()?;
38 {
39 let mut table = write_txn.open_table(PENDING)?;
40 let key = format!("{:020}:{}", priority, task_id);
42 table.insert(key.as_str(), data)?;
43 }
44 write_txn.commit()?;
45 self.notify.notify_one();
46 Ok(())
47 }
48
49 pub fn atomic_pop_pending<F>(&self, on_task: F) -> Result<Option<crate::models::Task>>
53 where
54 F: FnOnce(&mut crate::models::Task),
55 {
56 let write_txn = self.db.begin_write()?;
57
58 let task = {
60 let mut pending = write_txn.open_table(PENDING)?;
61
62 let first_entry = if let Some(first) = pending.first()? {
64 let key_str = first.0.value().to_string();
65 let data = first.1.value().to_vec();
66 Some((key_str, data))
67 } else {
68 None
69 };
70
71 if let Some((key, data)) = first_entry {
73 pending.remove(key.as_str())?;
75
76 let mut task: crate::models::Task = serde_json::from_slice(&data)?;
78 on_task(&mut task);
79
80 let serialized = serde_json::to_vec(&task)?;
82 let mut processing = write_txn.open_table(PROCESSING)?;
83 processing.insert(task.id.as_str(), serialized.as_slice())?;
84
85 Some(task)
86 } else {
87 None
88 }
89 }; if task.is_some() {
92 write_txn.commit()?;
93 } else {
94 write_txn.abort()?;
95 }
96
97 Ok(task)
98 }
99
100 pub fn get_first_pending(&self) -> Result<Option<(u64, Vec<u8>)>> {
103 let read_txn = self.db.begin_read()?;
104 let pending = read_txn.open_table(PENDING)?;
105
106 if let Some((key, value)) = pending.first()? {
107 let key_str = key.value();
109 let priority = key_str
110 .split(':')
111 .next()
112 .and_then(|s| s.parse::<u64>().ok())
113 .ok_or_else(|| anyhow::anyhow!("Invalid pending key format: {}", key_str))?;
114
115 Ok(Some((priority, value.value().to_vec())))
116 } else {
117 Ok(None)
118 }
119 }
120
121 pub fn move_to_processing(&self, priority: u64, task_id: &str, data: &[u8]) -> Result<()> {
124 let write_txn = self.db.begin_write()?;
125
126 {
128 let mut pending = write_txn.open_table(PENDING)?;
129 let key = format!("{:020}:{}", priority, task_id);
130 pending.remove(key.as_str())?;
131 }
132
133 {
135 let mut processing = write_txn.open_table(PROCESSING)?;
136 processing.insert(task_id, data)?;
137 }
138
139 write_txn.commit()?;
140 Ok(())
141 }
142
143 pub fn move_to_completed(&self, task_id: &str, data: &[u8]) -> Result<()> {
145 let write_txn = self.db.begin_write()?;
146
147 {
149 let mut processing = write_txn.open_table(PROCESSING)?;
150 processing.remove(task_id)?;
151 }
152
153 {
155 let mut completed = write_txn.open_table(COMPLETED)?;
156 completed.insert(task_id, data)?;
157 }
158
159 write_txn.commit()?;
160 Ok(())
161 }
162
163 pub fn get_from_processing(&self, task_id: &str) -> Result<Option<Vec<u8>>> {
165 let read_txn = self.db.begin_read()?;
166 let processing = read_txn.open_table(PROCESSING)?;
167
168 if let Some(data) = processing.get(task_id)? {
169 Ok(Some(data.value().to_vec()))
170 } else {
171 Ok(None)
172 }
173 }
174
175 pub fn remove_from_processing(&self, task_id: &str) -> Result<()> {
177 let write_txn = self.db.begin_write()?;
178 {
179 let mut processing = write_txn.open_table(PROCESSING)?;
180 processing.remove(task_id)?;
181 }
182 write_txn.commit()?;
183 Ok(())
184 }
185
186 pub fn get_from_any_table(&self, task_id: &str) -> Result<Option<Vec<u8>>> {
188 let read_txn = self.db.begin_read()?;
189
190 let processing = read_txn.open_table(PROCESSING)?;
192 if let Some(data) = processing.get(task_id)? {
193 return Ok(Some(data.value().to_vec()));
194 }
195
196 let completed = read_txn.open_table(COMPLETED)?;
198 if let Some(data) = completed.get(task_id)? {
199 return Ok(Some(data.value().to_vec()));
200 }
201
202 let pending = read_txn.open_table(PENDING)?;
205 for entry in pending.iter()? {
206 let (_, value) = entry?;
207 let data = value.value();
208
209 if let Ok(task) = serde_json::from_slice::<crate::models::Task>(data)
211 && task.id == task_id
212 {
213 return Ok(Some(data.to_vec()));
214 }
215 }
216
217 Ok(None)
218 }
219
220 pub fn get_all_pending(&self) -> Result<Vec<Vec<u8>>> {
222 let read_txn = self.db.begin_read()?;
223 let pending = read_txn.open_table(PENDING)?;
224 let mut tasks = Vec::new();
225
226 for entry in pending.iter()? {
227 let (_, value) = entry?;
228 tasks.push(value.value().to_vec());
229 }
230
231 Ok(tasks)
232 }
233
234 pub fn get_all_processing(&self) -> Result<Vec<Vec<u8>>> {
236 let read_txn = self.db.begin_read()?;
237 let processing = read_txn.open_table(PROCESSING)?;
238 let mut tasks = Vec::new();
239
240 for entry in processing.iter()? {
241 let (_, value) = entry?;
242 tasks.push(value.value().to_vec());
243 }
244
245 Ok(tasks)
246 }
247
248 pub fn get_all_completed(&self) -> Result<Vec<Vec<u8>>> {
250 let read_txn = self.db.begin_read()?;
251 let completed = read_txn.open_table(COMPLETED)?;
252 let mut tasks = Vec::new();
253
254 for entry in completed.iter()? {
255 let (_, value) = entry?;
256 tasks.push(value.value().to_vec());
257 }
258
259 Ok(tasks)
260 }
261
262 pub async fn wait_for_task(&self) {
264 self.notify.notified().await;
265 }
266
267 pub fn notify_task_available(&self) {
269 self.notify.notify_one();
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use tempfile::tempdir;
277
278 fn setup_test_queue() -> (TaskQueue, tempfile::TempDir) {
279 let temp_dir = tempdir().unwrap();
280 let db_path = temp_dir.path().join("test.db");
281 let db = Arc::new(Database::create(db_path).unwrap());
282 let queue = TaskQueue::new(db).unwrap();
283 (queue, temp_dir)
284 }
285
286 fn create_test_input() -> crate::models::NodeInput {
287 use crate::models::{NodeInput, ManualTriggerInput};
288
289 NodeInput::ManualTrigger(ManualTriggerInput {
290 payload: Some(serde_json::json!({})),
291 })
292 }
293
294 #[test]
295 fn test_insert_and_get_pending() {
296 let (queue, _temp_dir) = setup_test_queue();
297
298 let task_data = b"test task data";
299 queue.insert_pending(100, "task-001", task_data).unwrap();
300
301 let pending = queue.get_first_pending().unwrap();
302 assert!(pending.is_some());
303
304 let (priority, data) = pending.unwrap();
305 assert_eq!(priority, 100);
306 assert_eq!(data, task_data);
307 }
308
309 #[test]
310 fn test_priority_order() {
311 let (queue, _temp_dir) = setup_test_queue();
312
313 queue
315 .insert_pending(300, "task-low", b"low priority")
316 .unwrap();
317 queue
318 .insert_pending(100, "task-high", b"high priority")
319 .unwrap();
320 queue
321 .insert_pending(200, "task-med", b"medium priority")
322 .unwrap();
323
324 let first = queue.get_first_pending().unwrap().unwrap();
326 assert_eq!(first.0, 100);
327 assert_eq!(first.1, b"high priority");
328 }
329
330 #[test]
331 fn test_move_to_processing() {
332 let (queue, _temp_dir) = setup_test_queue();
333
334 let task_data = b"task to process";
335 queue.insert_pending(100, "task-001", task_data).unwrap();
336
337 queue
339 .move_to_processing(100, "task-001", task_data)
340 .unwrap();
341
342 let pending = queue.get_first_pending().unwrap();
344 assert!(pending.is_none());
345
346 let processing = queue.get_from_processing("task-001").unwrap();
348 assert!(processing.is_some());
349 assert_eq!(processing.unwrap(), task_data);
350 }
351
352 #[test]
353 fn test_move_to_completed() {
354 let (queue, _temp_dir) = setup_test_queue();
355
356 let task_data = b"task to complete";
357
358 queue.insert_pending(100, "task-001", task_data).unwrap();
360 queue
361 .move_to_processing(100, "task-001", task_data)
362 .unwrap();
363
364 queue.move_to_completed("task-001", task_data).unwrap();
366
367 let processing = queue.get_from_processing("task-001").unwrap();
369 assert!(processing.is_none());
370
371 let completed = queue.get_all_completed().unwrap();
373 assert_eq!(completed.len(), 1);
374 assert_eq!(completed[0], task_data);
375 }
376
377 #[test]
378 fn test_remove_from_processing() {
379 let (queue, _temp_dir) = setup_test_queue();
380
381 let task_data = b"task to remove";
382 queue.insert_pending(100, "task-001", task_data).unwrap();
383 queue
384 .move_to_processing(100, "task-001", task_data)
385 .unwrap();
386
387 queue.remove_from_processing("task-001").unwrap();
389
390 let processing = queue.get_from_processing("task-001").unwrap();
392 assert!(processing.is_none());
393 }
394
395 #[test]
396 fn test_get_all_pending() {
397 let (queue, _temp_dir) = setup_test_queue();
398
399 queue.insert_pending(100, "task-001", b"task1").unwrap();
400 queue.insert_pending(200, "task-002", b"task2").unwrap();
401 queue.insert_pending(300, "task-003", b"task3").unwrap();
402
403 let pending = queue.get_all_pending().unwrap();
404 assert_eq!(pending.len(), 3);
405 }
406
407 #[test]
408 fn test_get_all_processing() {
409 let (queue, _temp_dir) = setup_test_queue();
410
411 queue.insert_pending(100, "task-001", b"task1").unwrap();
412 queue.move_to_processing(100, "task-001", b"task1").unwrap();
413
414 queue.insert_pending(200, "task-002", b"task2").unwrap();
415 queue.move_to_processing(200, "task-002", b"task2").unwrap();
416
417 let processing = queue.get_all_processing().unwrap();
418 assert_eq!(processing.len(), 2);
419 }
420
421 #[test]
422 fn test_get_all_completed() {
423 let (queue, _temp_dir) = setup_test_queue();
424
425 for i in 1..=3 {
427 let task_id = format!("task-{:03}", i);
428 let task_data = format!("task{}", i).into_bytes();
429
430 queue
431 .insert_pending(i as u64 * 100, &task_id, &task_data)
432 .unwrap();
433 queue
434 .move_to_processing(i as u64 * 100, &task_id, &task_data)
435 .unwrap();
436 queue.move_to_completed(&task_id, &task_data).unwrap();
437 }
438
439 let completed = queue.get_all_completed().unwrap();
440 assert_eq!(completed.len(), 3);
441 }
442
443 #[tokio::test]
444 async fn test_wait_for_task() {
445 let (queue, _temp_dir) = setup_test_queue();
446
447 let queue_clone = queue.clone();
449 let wait_handle = tokio::spawn(async move {
450 tokio::select! {
451 _ = queue_clone.wait_for_task() => true,
452 _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => false,
453 }
454 });
455
456 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
458
459 queue.insert_pending(100, "task-001", b"new task").unwrap();
461
462 let was_notified = wait_handle.await.unwrap();
464 assert!(was_notified);
465 }
466
467 #[test]
468 fn test_get_from_any_table() {
469 let (queue, _temp_dir) = setup_test_queue();
470
471 queue
473 .insert_pending(100, "task-001", b"processing task")
474 .unwrap();
475 queue
476 .move_to_processing(100, "task-001", b"processing task")
477 .unwrap();
478
479 let result = queue.get_from_any_table("task-001").unwrap();
480 assert!(result.is_some());
481 assert_eq!(result.unwrap(), b"processing task");
482
483 queue
485 .move_to_completed("task-001", b"completed task")
486 .unwrap();
487
488 let result = queue.get_from_any_table("task-001").unwrap();
489 assert!(result.is_some());
490 assert_eq!(result.unwrap(), b"completed task");
491
492 let result = queue.get_from_any_table("nonexistent").unwrap();
494 assert!(result.is_none());
495 }
496
497 #[tokio::test]
498 async fn test_concurrent_same_priority_nanosecond() {
499 use crate::engine::context::ExecutionContext;
500 use crate::models::Task;
501
502 let (queue, _temp_dir) = setup_test_queue();
503
504 let mut handles = vec![];
506 for i in 0..10 {
507 let queue_clone = queue.clone();
508 let handle = tokio::spawn(async move {
509 let task = Task::new(
510 format!("exec-{}", i),
511 "wf-1".to_string(),
512 format!("node-{}", i),
513 create_test_input(),
514 ExecutionContext::new(format!("exec-{}", i)),
515 );
516 let priority = task.priority();
517 let task_id = task.id.clone();
518 let serialized = serde_json::to_vec(&task).unwrap();
519 queue_clone
520 .insert_pending(priority, &task_id, &serialized)
521 .unwrap();
522 });
523 handles.push(handle);
524 }
525
526 for handle in handles {
528 handle.await.unwrap();
529 }
530
531 let pending = queue.get_all_pending().unwrap();
533 assert_eq!(pending.len(), 10, "All 10 tasks should be in pending queue");
534 }
535
536 #[test]
537 fn test_get_from_any_table_pending() {
538 use crate::engine::context::ExecutionContext;
539 use crate::models::Task;
540
541 let (queue, _temp_dir) = setup_test_queue();
542
543 let task = Task::new(
545 "exec-1".to_string(),
546 "wf-1".to_string(),
547 "node-1".to_string(),
548 create_test_input(),
549 ExecutionContext::new("exec-1".to_string()),
550 );
551 let task_id = task.id.clone();
552 let priority = task.priority();
553 let serialized = serde_json::to_vec(&task).unwrap();
554 queue
555 .insert_pending(priority, &task_id, &serialized)
556 .unwrap();
557
558 let result = queue.get_from_any_table(&task_id).unwrap();
560 assert!(result.is_some(), "Should find task in pending table");
561
562 let found_task: Task = serde_json::from_slice(&result.unwrap()).unwrap();
564 assert_eq!(found_task.id, task_id);
565 }
566
567 #[tokio::test]
568 async fn test_concurrent_pop_no_duplicate() {
569 use crate::engine::context::ExecutionContext;
570 use crate::models::Task;
571 use std::collections::HashSet;
572
573 let (queue, _temp_dir) = setup_test_queue();
574
575 for i in 0..3 {
577 let task = Task::new(
578 format!("exec-{}", i),
579 "wf-1".to_string(),
580 format!("node-{}", i),
581 create_test_input(),
582 ExecutionContext::new(format!("exec-{}", i)),
583 );
584 let priority = task.priority();
585 let task_id = task.id.clone();
586 let serialized = serde_json::to_vec(&task).unwrap();
587 queue
588 .insert_pending(priority, &task_id, &serialized)
589 .unwrap();
590 }
591
592 let mut handles = vec![];
594 for _ in 0..10 {
595 let q = queue.clone();
596 handles.push(tokio::spawn(async move {
597 q.atomic_pop_pending(|_| {}).ok().flatten()
598 }));
599 }
600
601 let mut results = vec![];
603 for h in handles {
604 if let Some(task) = h.await.unwrap() {
605 results.push(task.id); }
607 }
608
609 assert_eq!(results.len(), 3, "Should pop exactly 3 tasks");
611 let unique: HashSet<_> = results.into_iter().collect();
612 assert_eq!(
613 unique.len(),
614 3,
615 "All task IDs should be unique (no duplicate execution)"
616 );
617 }
618
619 #[test]
620 fn test_composite_key_uniqueness() {
621 use crate::engine::context::ExecutionContext;
622 use crate::models::Task;
623
624 let (queue, _temp_dir) = setup_test_queue();
625
626 let mut tasks = vec![];
628 for i in 0..5 {
629 let task = Task::new(
630 "exec-1".to_string(),
631 "wf-1".to_string(),
632 format!("node-{}", i),
633 create_test_input(),
634 ExecutionContext::new("exec-1".to_string()),
635 );
636 tasks.push(task);
637 }
638
639 let priority = tasks[0].priority();
641 for task in &tasks {
642 let serialized = serde_json::to_vec(task).unwrap();
643 queue
644 .insert_pending(priority, &task.id, &serialized)
645 .unwrap();
646 }
647
648 let pending = queue.get_all_pending().unwrap();
650 assert_eq!(
651 pending.len(),
652 5,
653 "All tasks should be preserved despite same priority"
654 );
655
656 for task in &tasks {
658 let result = queue.get_from_any_table(&task.id).unwrap();
659 assert!(result.is_some(), "Each task should be retrievable by ID");
660 }
661 }
662
663 #[test]
664 fn test_atomic_pop_state_transition() {
665 use crate::engine::context::ExecutionContext;
666 use crate::models::{Task, TaskStatus};
667
668 let (queue, _temp_dir) = setup_test_queue();
669
670 let task = Task::new(
671 "exec-1".to_string(),
672 "wf-1".to_string(),
673 "node-1".to_string(),
674 create_test_input(),
675 ExecutionContext::new("exec-1".to_string()),
676 );
677 let task_id = task.id.clone();
678 let priority = task.priority();
679 let serialized = serde_json::to_vec(&task).unwrap();
680 queue
681 .insert_pending(priority, &task_id, &serialized)
682 .unwrap();
683
684 let popped_task = queue
686 .atomic_pop_pending(|task| task.start())
687 .unwrap()
688 .unwrap();
689 assert_eq!(popped_task.id, task_id);
690
691 assert_eq!(
693 popped_task.status,
694 TaskStatus::Running,
695 "Task should be Running after pop"
696 );
697 assert!(
698 popped_task.started_at.is_some(),
699 "Task should have started_at set"
700 );
701
702 assert_eq!(
704 queue.get_all_pending().unwrap().len(),
705 0,
706 "Pending should be empty after pop"
707 );
708
709 let processing_data = queue.get_from_processing(&task_id).unwrap().unwrap();
711 let processing_task: Task = serde_json::from_slice(&processing_data).unwrap();
712 assert_eq!(
713 processing_task.status,
714 TaskStatus::Running,
715 "Processing task should be Running"
716 );
717 assert!(
718 processing_task.started_at.is_some(),
719 "Processing task should have started_at"
720 );
721
722 let second_pop = queue.atomic_pop_pending(|task| task.start()).unwrap();
724 assert!(second_pop.is_none(), "Second pop should return None");
725 }
726
727 #[test]
728 fn test_atomic_pop_no_dirty_data_on_crash() {
729 use crate::engine::context::ExecutionContext;
730 use crate::models::Task;
731
732 let (queue, _temp_dir) = setup_test_queue();
733
734 let task = Task::new(
736 "exec-1".to_string(),
737 "wf-1".to_string(),
738 "node-1".to_string(),
739 create_test_input(),
740 ExecutionContext::new("exec-1".to_string()),
741 );
742 let task_id = task.id.clone();
743 let priority = task.priority();
744 let serialized = serde_json::to_vec(&task).unwrap();
745 queue
746 .insert_pending(priority, &task_id, &serialized)
747 .unwrap();
748
749 let panic_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
751 queue.atomic_pop_pending(|_task| {
752 panic!("Simulated worker crash in callback!");
753 })
754 }));
755
756 assert!(panic_result.is_err(), "Callback should panic");
757
758 let pending = queue.get_all_pending().unwrap();
760 assert_eq!(
761 pending.len(),
762 1,
763 "Task should still be in pending after panic"
764 );
765
766 let processing = queue.get_all_processing().unwrap();
768 assert_eq!(
769 processing.len(),
770 0,
771 "Processing should be empty (no dirty data)"
772 );
773
774 let retry_task = queue.atomic_pop_pending(|task| task.start()).unwrap();
776 assert!(
777 retry_task.is_some(),
778 "Task should be retrievable after panic"
779 );
780 assert_eq!(retry_task.unwrap().id, task_id);
781 }
782}