restflow_core/engine/
scheduler.rs

1use crate::engine::context::{ExecutionContext, namespace};
2use crate::engine::graph::WorkflowGraph;
3use crate::models::{Node, Task, TaskStatus, Workflow};
4use crate::storage::{Storage, TaskQueue};
5use anyhow::Result;
6use chrono::Utc;
7use serde_json::Value;
8use std::sync::Arc;
9use uuid::Uuid;
10
11// Tasks processing longer than this threshold are considered stalled and will be reset to pending
12const DEFAULT_STALL_TIMEOUT_SECONDS: i64 = 300; // 5 minutes
13
14pub struct Scheduler {
15    queue: TaskQueue,
16    storage: Arc<Storage>,
17}
18
19impl Scheduler {
20    pub fn new(queue: TaskQueue, storage: Arc<Storage>) -> Self {
21        Self { queue, storage }
22    }
23
24    /// Accepts `Arc<Workflow>` to avoid expensive cloning in downstream task queueing
25    pub fn push_task(
26        &self,
27        execution_id: String,
28        node: Node,
29        workflow: Arc<Workflow>,
30        context: ExecutionContext,
31    ) -> Result<String> {
32        // Parse node config as NodeInput (uses serde's tagged enum for O(1) type dispatch)
33        let node_input: crate::models::NodeInput = serde_json::from_value(node.config.clone())
34            .map_err(|e| anyhow::anyhow!("Failed to parse node config as NodeInput: {}", e))?;
35
36        let task = Task::new(
37            execution_id,
38            workflow.id.clone(),
39            node.id.clone(),
40            node_input,
41            context,
42        );
43        let task_id = task.id.clone();
44
45        let _ = task.set_workflow(workflow);
46
47        self.storage.execution_history.record_task_created(
48            &task.workflow_id,
49            &task.execution_id,
50            task.created_at,
51        )?;
52
53        let priority = task.priority();
54        let serialized = serde_json::to_vec(&task)?;
55        self.queue.insert_pending(priority, &task_id, &serialized)?;
56
57        Ok(task_id)
58    }
59
60    pub fn push_single_node(&self, node: Node, _input: Value) -> Result<String> {
61        // For single-node execution, parse NodeInput from node.config
62        // The node.config already contains the tagged union {"type": "...", "data": {...}}
63        let node_input: crate::models::NodeInput = serde_json::from_value(node.config.clone())
64            .map_err(|e| anyhow::anyhow!("Failed to parse node config as NodeInput: {}", e))?;
65
66        let task = Task::for_single_node(node, node_input);
67        let task_id = task.id.clone();
68
69        self.storage.execution_history.record_task_created(
70            &task.workflow_id,
71            &task.execution_id,
72            task.created_at,
73        )?;
74
75        let priority = task.priority();
76        let serialized = serde_json::to_vec(&task)?;
77        self.queue.insert_pending(priority, &task_id, &serialized)?;
78
79        Ok(task_id)
80    }
81
82    pub fn submit_workflow(&self, workflow: Workflow, input: Value) -> Result<String> {
83        let execution_id = Uuid::new_v4().to_string();
84        self.submit_workflow_internal(workflow, input, execution_id)
85    }
86
87    fn submit_workflow_internal(
88        &self,
89        workflow: Workflow,
90        input: Value,
91        execution_id: String,
92    ) -> Result<String> {
93        let workflow = Arc::new(workflow);
94
95        let mut context =
96            ExecutionContext::with_execution_id(workflow.id.clone(), execution_id.clone());
97        context.ensure_secret_storage(&self.storage);
98        context.set(namespace::trigger::PAYLOAD, input);
99
100        let graph = WorkflowGraph::from_workflow(&workflow);
101        let start_nodes = graph.get_nodes_with_no_dependencies();
102
103        if start_nodes.is_empty() {
104            return Err(anyhow::anyhow!(
105                "No start nodes found in workflow {}",
106                workflow.id
107            ));
108        }
109
110        for node_id in start_nodes {
111            if let Some(node) = graph.get_node(&node_id) {
112                self.push_task(
113                    execution_id.clone(),
114                    node.clone(),
115                    workflow.clone(),
116                    context.clone(),
117                )?;
118            }
119        }
120
121        Ok(execution_id)
122    }
123
124    /// Submit a workflow by ID for execution
125    pub fn submit_workflow_by_id(&self, workflow_id: &str, input: Value) -> Result<String> {
126        let workflow = self
127            .storage
128            .workflows
129            .get_workflow(workflow_id)
130            .map_err(|e| anyhow::anyhow!("Failed to load workflow {}: {}", workflow_id, e))?;
131
132        self.submit_workflow(workflow, input)
133    }
134
135    pub fn submit_workflow_by_id_with_execution_id(
136        &self,
137        workflow_id: &str,
138        input: Value,
139        execution_id: String,
140    ) -> Result<String> {
141        let workflow = self
142            .storage
143            .workflows
144            .get_workflow(workflow_id)
145            .map_err(|e| anyhow::anyhow!("Failed to load workflow {}: {}", workflow_id, e))?;
146
147        self.submit_workflow_internal(workflow, input, execution_id)
148    }
149
150    pub async fn pop_task(&self) -> Result<Task> {
151        loop {
152            match self.try_pop_task()? {
153                Some(task) => return Ok(task),
154                None => {
155                    self.queue.wait_for_task().await;
156                }
157            }
158        }
159    }
160
161    /// Uses atomic_pop_pending with callback to ensure atomicity
162    fn try_pop_task(&self) -> Result<Option<Task>> {
163        self.queue.atomic_pop_pending(|task| task.start())
164    }
165
166    pub fn complete_task(&self, task_id: &str, output: crate::models::NodeOutput) -> Result<()> {
167        self.finish_task(task_id, TaskStatus::Completed, Some(output), None)
168    }
169
170    pub fn fail_task(&self, task_id: &str, error: String) -> Result<()> {
171        self.finish_task(task_id, TaskStatus::Failed, None, Some(error))
172    }
173
174    fn finish_task(
175        &self,
176        task_id: &str,
177        status: TaskStatus,
178        output: Option<crate::models::NodeOutput>,
179        error: Option<String>,
180    ) -> Result<()> {
181        if let Some(data) = self.queue.get_from_processing(task_id)? {
182            let mut task: Task = serde_json::from_slice(&data)?;
183
184            match status {
185                TaskStatus::Completed => {
186                    if let Some(output) = output {
187                        task.complete(output);
188                    }
189                }
190                TaskStatus::Failed => {
191                    if let Some(error) = error {
192                        task.fail(error);
193                    }
194                }
195                _ => {}
196            }
197
198            let serialized = serde_json::to_vec(&task)?;
199            self.queue.move_to_completed(task_id, &serialized)?;
200
201            let timestamp_ms = Utc::now().timestamp_millis();
202            match status {
203                TaskStatus::Completed => {
204                    self.storage.execution_history.record_task_completed(
205                        &task.workflow_id,
206                        &task.execution_id,
207                        timestamp_ms,
208                    )?;
209                }
210                TaskStatus::Failed => {
211                    self.storage.execution_history.record_task_failed(
212                        &task.workflow_id,
213                        &task.execution_id,
214                        timestamp_ms,
215                    )?;
216                }
217                _ => {}
218            }
219        }
220
221        Ok(())
222    }
223
224    fn query_all_tasks<F>(&self, filter: F) -> Result<Vec<Task>>
225    where
226        F: Fn(&Task) -> bool,
227    {
228        let mut tasks = Vec::new();
229
230        for data in self.queue.get_all_pending()? {
231            let task: Task = serde_json::from_slice(&data)?;
232            if filter(&task) {
233                tasks.push(task);
234            }
235        }
236
237        for data in self.queue.get_all_processing()? {
238            let task: Task = serde_json::from_slice(&data)?;
239            if filter(&task) {
240                tasks.push(task);
241            }
242        }
243
244        for data in self.queue.get_all_completed()? {
245            let task: Task = serde_json::from_slice(&data)?;
246            if filter(&task) {
247                tasks.push(task);
248            }
249        }
250
251        Ok(tasks)
252    }
253
254    pub fn get_tasks_by_execution(&self, execution_id: &str) -> Result<Vec<Task>> {
255        let mut tasks = self.query_all_tasks(|task| task.execution_id == execution_id)?;
256
257        tasks.sort_by(|a, b| a.created_at.cmp(&b.created_at));
258        Ok(tasks)
259    }
260
261    pub fn get_task(&self, task_id: &str) -> Result<Option<Task>> {
262        if let Some(data) = self.queue.get_from_any_table(task_id)? {
263            let task: Task = serde_json::from_slice(&data)?;
264            Ok(Some(task))
265        } else {
266            Ok(None)
267        }
268    }
269
270    pub fn list_tasks(
271        &self,
272        workflow_id: Option<&str>,
273        status: Option<TaskStatus>,
274    ) -> Result<Vec<Task>> {
275        let mut tasks = self.query_all_tasks(|task| {
276            workflow_id.is_none_or(|id| task.workflow_id == id)
277                && status.as_ref().is_none_or(|s| &task.status == s)
278        })?;
279
280        tasks.sort_by(|a, b| b.created_at.cmp(&a.created_at));
281        Ok(tasks)
282    }
283
284    pub fn recover_stalled_tasks(&self) -> Result<u32> {
285        let mut recovered = 0;
286        let now = chrono::Utc::now().timestamp_millis();
287
288        for data in self.queue.get_all_processing()? {
289            let mut task: Task = serde_json::from_slice(&data)?;
290
291            if let Some(started_at) = task.started_at {
292                let stall_threshold_ms = DEFAULT_STALL_TIMEOUT_SECONDS * 1000;
293                if now - started_at > stall_threshold_ms {
294                    task.status = TaskStatus::Pending;
295                    task.started_at = None;
296
297                    let task_id = task.id.clone();
298                    let priority = task.priority();
299                    let serialized = serde_json::to_vec(&task)?;
300
301                    self.queue.remove_from_processing(&task_id)?;
302                    self.queue.insert_pending(priority, &task_id, &serialized)?;
303
304                    recovered += 1;
305                }
306            }
307        }
308
309        Ok(recovered)
310    }
311
312    pub fn are_dependencies_met(
313        graph: &WorkflowGraph,
314        node_id: &str,
315        context: &ExecutionContext,
316    ) -> bool {
317        graph
318            .get_dependencies(node_id)
319            .iter()
320            .all(|dep| context.get_node(dep).is_some())
321    }
322
323    /// Uses `Arc<Workflow>` to avoid expensive cloning in large workflows
324    pub fn push_downstream_tasks(&self, task: &Task, output: crate::models::NodeOutput) -> Result<()> {
325        let workflow = task.get_workflow(&self.storage)?;
326
327        // Serialize NodeOutput to Value for context storage
328        let output_value = serde_json::to_value(&output)?;
329
330        let mut context = task.context.clone();
331        context.set_node(&task.node_id, output_value);
332
333        let graph = WorkflowGraph::from_workflow(&workflow);
334        let downstream_nodes = graph.get_downstream_nodes(&task.node_id);
335
336        for downstream_id in downstream_nodes {
337            if let Some(downstream_node) = graph.get_node(&downstream_id)
338                && Self::are_dependencies_met(&graph, &downstream_id, &context)
339            {
340                self.push_task(
341                    task.execution_id.clone(),
342                    downstream_node.clone(),
343                    workflow.clone(),
344                    context.clone(),
345                )?;
346            }
347        }
348
349        Ok(())
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::engine::context::ExecutionContext;
357    use crate::models::{Task, TaskStatus};
358    use crate::storage::Storage;
359    use tempfile::tempdir;
360
361    fn setup_test_scheduler() -> (Scheduler, tempfile::TempDir) {
362        let temp_dir = tempdir().unwrap();
363        let db_path = temp_dir.path().join("test.db");
364        let storage = Arc::new(Storage::new(db_path.to_str().unwrap()).unwrap());
365        let scheduler = Scheduler::new(storage.queue.clone(), storage.clone());
366        (scheduler, temp_dir)
367    }
368
369    fn create_test_input() -> crate::models::NodeInput {
370        use crate::models::{NodeInput, ManualTriggerInput};
371
372        NodeInput::ManualTrigger(ManualTriggerInput {
373            payload: Some(serde_json::json!({})),
374        })
375    }
376
377    #[test]
378    fn test_recover_stalled_tasks() {
379        let (scheduler, _temp_dir) = setup_test_scheduler();
380
381        // Create a task with started_at 10 minutes ago (should be recovered)
382        let ten_minutes_ago = chrono::Utc::now().timestamp_millis() - (10 * 60 * 1000);
383        let mut task = Task::new(
384            "exec-1".to_string(),
385            "wf-1".to_string(),
386            "node-1".to_string(),
387            create_test_input(),
388            ExecutionContext::new("exec-1".to_string()),
389        );
390        task.status = TaskStatus::Running;
391        task.started_at = Some(ten_minutes_ago);
392
393        // Put task in processing
394        let serialized = serde_json::to_vec(&task).unwrap();
395        scheduler
396            .queue
397            .move_to_processing(0, &task.id, &serialized)
398            .unwrap();
399
400        // Recover stalled tasks
401        let recovered = scheduler.recover_stalled_tasks().unwrap();
402        assert_eq!(recovered, 1, "Should recover 1 stalled task");
403
404        // Verify task is back in pending
405        let pending_tasks = scheduler.queue.get_all_pending().unwrap();
406        assert_eq!(pending_tasks.len(), 1, "Should have 1 pending task");
407
408        // Verify task is no longer in processing
409        let processing_tasks = scheduler.queue.get_all_processing().unwrap();
410        assert_eq!(processing_tasks.len(), 0, "Should have 0 processing tasks");
411    }
412
413    #[test]
414    fn test_get_pending_task() {
415        let (scheduler, _temp_dir) = setup_test_scheduler();
416
417        // Create a task
418        let task = Task::new(
419            "exec-1".to_string(),
420            "wf-1".to_string(),
421            "node-1".to_string(),
422            create_test_input(),
423            ExecutionContext::new("exec-1".to_string()),
424        );
425        let task_id = task.id.clone();
426
427        // Push to queue
428        let priority = task.priority();
429        let serialized = serde_json::to_vec(&task).unwrap();
430        scheduler
431            .queue
432            .insert_pending(priority, &task_id, &serialized)
433            .unwrap();
434
435        // Get task should find it in pending
436        let found = scheduler.get_task(&task_id).unwrap();
437        assert!(found.is_some(), "Should find task in pending");
438        assert_eq!(found.unwrap().id, task_id);
439    }
440
441    #[test]
442    fn test_submit_workflow() {
443        use crate::models::{Node, NodeType};
444
445        let (scheduler, _temp_dir) = setup_test_scheduler();
446
447        // Create a simple workflow with one node
448        let node = Node {
449            id: "start_node".to_string(),
450            node_type: NodeType::Agent,
451            config: serde_json::json!({
452                "type": "Agent",
453                "data": {
454                    "model": "gpt-4",
455                    "prompt": "test prompt",
456                    "temperature": null,
457                    "api_key_config": null,
458                    "tools": null
459                }
460            }),
461            position: None,
462        };
463
464        let workflow = Workflow {
465            id: "test-workflow".to_string(),
466            name: "Test Workflow".to_string(),
467            nodes: vec![node],
468            edges: vec![],
469        };
470
471        // Submit workflow
472        let input = serde_json::json!({"test": "data"});
473        let execution_id = scheduler.submit_workflow(workflow, input).unwrap();
474
475        // Verify execution_id is valid UUID format
476        assert!(!execution_id.is_empty(), "Execution ID should not be empty");
477
478        // Verify task was queued
479        let pending_tasks = scheduler.queue.get_all_pending().unwrap();
480        assert_eq!(pending_tasks.len(), 1, "Should have 1 pending task");
481
482        // Verify task has correct execution_id
483        let task: Task = serde_json::from_slice(&pending_tasks[0]).unwrap();
484        assert_eq!(task.execution_id, execution_id);
485        assert_eq!(task.node_id, "start_node");
486    }
487}