1use crate::engine::context::{ExecutionContext, namespace};
2use crate::engine::graph::WorkflowGraph;
3use crate::models::{Node, Task, TaskStatus, Workflow};
4use crate::storage::{Storage, TaskQueue};
5use anyhow::Result;
6use chrono::Utc;
7use serde_json::Value;
8use std::sync::Arc;
9use uuid::Uuid;
10
11const DEFAULT_STALL_TIMEOUT_SECONDS: i64 = 300; pub struct Scheduler {
15 queue: TaskQueue,
16 storage: Arc<Storage>,
17}
18
19impl Scheduler {
20 pub fn new(queue: TaskQueue, storage: Arc<Storage>) -> Self {
21 Self { queue, storage }
22 }
23
24 pub fn push_task(
26 &self,
27 execution_id: String,
28 node: Node,
29 workflow: Arc<Workflow>,
30 context: ExecutionContext,
31 ) -> Result<String> {
32 let node_input: crate::models::NodeInput = serde_json::from_value(node.config.clone())
34 .map_err(|e| anyhow::anyhow!("Failed to parse node config as NodeInput: {}", e))?;
35
36 let task = Task::new(
37 execution_id,
38 workflow.id.clone(),
39 node.id.clone(),
40 node_input,
41 context,
42 );
43 let task_id = task.id.clone();
44
45 let _ = task.set_workflow(workflow);
46
47 self.storage.execution_history.record_task_created(
48 &task.workflow_id,
49 &task.execution_id,
50 task.created_at,
51 )?;
52
53 let priority = task.priority();
54 let serialized = serde_json::to_vec(&task)?;
55 self.queue.insert_pending(priority, &task_id, &serialized)?;
56
57 Ok(task_id)
58 }
59
60 pub fn push_single_node(&self, node: Node, _input: Value) -> Result<String> {
61 let node_input: crate::models::NodeInput = serde_json::from_value(node.config.clone())
64 .map_err(|e| anyhow::anyhow!("Failed to parse node config as NodeInput: {}", e))?;
65
66 let task = Task::for_single_node(node, node_input);
67 let task_id = task.id.clone();
68
69 self.storage.execution_history.record_task_created(
70 &task.workflow_id,
71 &task.execution_id,
72 task.created_at,
73 )?;
74
75 let priority = task.priority();
76 let serialized = serde_json::to_vec(&task)?;
77 self.queue.insert_pending(priority, &task_id, &serialized)?;
78
79 Ok(task_id)
80 }
81
82 pub fn submit_workflow(&self, workflow: Workflow, input: Value) -> Result<String> {
83 let execution_id = Uuid::new_v4().to_string();
84 self.submit_workflow_internal(workflow, input, execution_id)
85 }
86
87 fn submit_workflow_internal(
88 &self,
89 workflow: Workflow,
90 input: Value,
91 execution_id: String,
92 ) -> Result<String> {
93 let workflow = Arc::new(workflow);
94
95 let mut context =
96 ExecutionContext::with_execution_id(workflow.id.clone(), execution_id.clone());
97 context.ensure_secret_storage(&self.storage);
98 context.set(namespace::trigger::PAYLOAD, input);
99
100 let graph = WorkflowGraph::from_workflow(&workflow);
101 let start_nodes = graph.get_nodes_with_no_dependencies();
102
103 if start_nodes.is_empty() {
104 return Err(anyhow::anyhow!(
105 "No start nodes found in workflow {}",
106 workflow.id
107 ));
108 }
109
110 for node_id in start_nodes {
111 if let Some(node) = graph.get_node(&node_id) {
112 self.push_task(
113 execution_id.clone(),
114 node.clone(),
115 workflow.clone(),
116 context.clone(),
117 )?;
118 }
119 }
120
121 Ok(execution_id)
122 }
123
124 pub fn submit_workflow_by_id(&self, workflow_id: &str, input: Value) -> Result<String> {
126 let workflow = self
127 .storage
128 .workflows
129 .get_workflow(workflow_id)
130 .map_err(|e| anyhow::anyhow!("Failed to load workflow {}: {}", workflow_id, e))?;
131
132 self.submit_workflow(workflow, input)
133 }
134
135 pub fn submit_workflow_by_id_with_execution_id(
136 &self,
137 workflow_id: &str,
138 input: Value,
139 execution_id: String,
140 ) -> Result<String> {
141 let workflow = self
142 .storage
143 .workflows
144 .get_workflow(workflow_id)
145 .map_err(|e| anyhow::anyhow!("Failed to load workflow {}: {}", workflow_id, e))?;
146
147 self.submit_workflow_internal(workflow, input, execution_id)
148 }
149
150 pub async fn pop_task(&self) -> Result<Task> {
151 loop {
152 match self.try_pop_task()? {
153 Some(task) => return Ok(task),
154 None => {
155 self.queue.wait_for_task().await;
156 }
157 }
158 }
159 }
160
161 fn try_pop_task(&self) -> Result<Option<Task>> {
163 self.queue.atomic_pop_pending(|task| task.start())
164 }
165
166 pub fn complete_task(&self, task_id: &str, output: crate::models::NodeOutput) -> Result<()> {
167 self.finish_task(task_id, TaskStatus::Completed, Some(output), None)
168 }
169
170 pub fn fail_task(&self, task_id: &str, error: String) -> Result<()> {
171 self.finish_task(task_id, TaskStatus::Failed, None, Some(error))
172 }
173
174 fn finish_task(
175 &self,
176 task_id: &str,
177 status: TaskStatus,
178 output: Option<crate::models::NodeOutput>,
179 error: Option<String>,
180 ) -> Result<()> {
181 if let Some(data) = self.queue.get_from_processing(task_id)? {
182 let mut task: Task = serde_json::from_slice(&data)?;
183
184 match status {
185 TaskStatus::Completed => {
186 if let Some(output) = output {
187 task.complete(output);
188 }
189 }
190 TaskStatus::Failed => {
191 if let Some(error) = error {
192 task.fail(error);
193 }
194 }
195 _ => {}
196 }
197
198 let serialized = serde_json::to_vec(&task)?;
199 self.queue.move_to_completed(task_id, &serialized)?;
200
201 let timestamp_ms = Utc::now().timestamp_millis();
202 match status {
203 TaskStatus::Completed => {
204 self.storage.execution_history.record_task_completed(
205 &task.workflow_id,
206 &task.execution_id,
207 timestamp_ms,
208 )?;
209 }
210 TaskStatus::Failed => {
211 self.storage.execution_history.record_task_failed(
212 &task.workflow_id,
213 &task.execution_id,
214 timestamp_ms,
215 )?;
216 }
217 _ => {}
218 }
219 }
220
221 Ok(())
222 }
223
224 fn query_all_tasks<F>(&self, filter: F) -> Result<Vec<Task>>
225 where
226 F: Fn(&Task) -> bool,
227 {
228 let mut tasks = Vec::new();
229
230 for data in self.queue.get_all_pending()? {
231 let task: Task = serde_json::from_slice(&data)?;
232 if filter(&task) {
233 tasks.push(task);
234 }
235 }
236
237 for data in self.queue.get_all_processing()? {
238 let task: Task = serde_json::from_slice(&data)?;
239 if filter(&task) {
240 tasks.push(task);
241 }
242 }
243
244 for data in self.queue.get_all_completed()? {
245 let task: Task = serde_json::from_slice(&data)?;
246 if filter(&task) {
247 tasks.push(task);
248 }
249 }
250
251 Ok(tasks)
252 }
253
254 pub fn get_tasks_by_execution(&self, execution_id: &str) -> Result<Vec<Task>> {
255 let mut tasks = self.query_all_tasks(|task| task.execution_id == execution_id)?;
256
257 tasks.sort_by(|a, b| a.created_at.cmp(&b.created_at));
258 Ok(tasks)
259 }
260
261 pub fn get_task(&self, task_id: &str) -> Result<Option<Task>> {
262 if let Some(data) = self.queue.get_from_any_table(task_id)? {
263 let task: Task = serde_json::from_slice(&data)?;
264 Ok(Some(task))
265 } else {
266 Ok(None)
267 }
268 }
269
270 pub fn list_tasks(
271 &self,
272 workflow_id: Option<&str>,
273 status: Option<TaskStatus>,
274 ) -> Result<Vec<Task>> {
275 let mut tasks = self.query_all_tasks(|task| {
276 workflow_id.is_none_or(|id| task.workflow_id == id)
277 && status.as_ref().is_none_or(|s| &task.status == s)
278 })?;
279
280 tasks.sort_by(|a, b| b.created_at.cmp(&a.created_at));
281 Ok(tasks)
282 }
283
284 pub fn recover_stalled_tasks(&self) -> Result<u32> {
285 let mut recovered = 0;
286 let now = chrono::Utc::now().timestamp_millis();
287
288 for data in self.queue.get_all_processing()? {
289 let mut task: Task = serde_json::from_slice(&data)?;
290
291 if let Some(started_at) = task.started_at {
292 let stall_threshold_ms = DEFAULT_STALL_TIMEOUT_SECONDS * 1000;
293 if now - started_at > stall_threshold_ms {
294 task.status = TaskStatus::Pending;
295 task.started_at = None;
296
297 let task_id = task.id.clone();
298 let priority = task.priority();
299 let serialized = serde_json::to_vec(&task)?;
300
301 self.queue.remove_from_processing(&task_id)?;
302 self.queue.insert_pending(priority, &task_id, &serialized)?;
303
304 recovered += 1;
305 }
306 }
307 }
308
309 Ok(recovered)
310 }
311
312 pub fn are_dependencies_met(
313 graph: &WorkflowGraph,
314 node_id: &str,
315 context: &ExecutionContext,
316 ) -> bool {
317 graph
318 .get_dependencies(node_id)
319 .iter()
320 .all(|dep| context.get_node(dep).is_some())
321 }
322
323 pub fn push_downstream_tasks(&self, task: &Task, output: crate::models::NodeOutput) -> Result<()> {
325 let workflow = task.get_workflow(&self.storage)?;
326
327 let output_value = serde_json::to_value(&output)?;
329
330 let mut context = task.context.clone();
331 context.set_node(&task.node_id, output_value);
332
333 let graph = WorkflowGraph::from_workflow(&workflow);
334 let downstream_nodes = graph.get_downstream_nodes(&task.node_id);
335
336 for downstream_id in downstream_nodes {
337 if let Some(downstream_node) = graph.get_node(&downstream_id)
338 && Self::are_dependencies_met(&graph, &downstream_id, &context)
339 {
340 self.push_task(
341 task.execution_id.clone(),
342 downstream_node.clone(),
343 workflow.clone(),
344 context.clone(),
345 )?;
346 }
347 }
348
349 Ok(())
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use crate::engine::context::ExecutionContext;
357 use crate::models::{Task, TaskStatus};
358 use crate::storage::Storage;
359 use tempfile::tempdir;
360
361 fn setup_test_scheduler() -> (Scheduler, tempfile::TempDir) {
362 let temp_dir = tempdir().unwrap();
363 let db_path = temp_dir.path().join("test.db");
364 let storage = Arc::new(Storage::new(db_path.to_str().unwrap()).unwrap());
365 let scheduler = Scheduler::new(storage.queue.clone(), storage.clone());
366 (scheduler, temp_dir)
367 }
368
369 fn create_test_input() -> crate::models::NodeInput {
370 use crate::models::{NodeInput, ManualTriggerInput};
371
372 NodeInput::ManualTrigger(ManualTriggerInput {
373 payload: Some(serde_json::json!({})),
374 })
375 }
376
377 #[test]
378 fn test_recover_stalled_tasks() {
379 let (scheduler, _temp_dir) = setup_test_scheduler();
380
381 let ten_minutes_ago = chrono::Utc::now().timestamp_millis() - (10 * 60 * 1000);
383 let mut task = Task::new(
384 "exec-1".to_string(),
385 "wf-1".to_string(),
386 "node-1".to_string(),
387 create_test_input(),
388 ExecutionContext::new("exec-1".to_string()),
389 );
390 task.status = TaskStatus::Running;
391 task.started_at = Some(ten_minutes_ago);
392
393 let serialized = serde_json::to_vec(&task).unwrap();
395 scheduler
396 .queue
397 .move_to_processing(0, &task.id, &serialized)
398 .unwrap();
399
400 let recovered = scheduler.recover_stalled_tasks().unwrap();
402 assert_eq!(recovered, 1, "Should recover 1 stalled task");
403
404 let pending_tasks = scheduler.queue.get_all_pending().unwrap();
406 assert_eq!(pending_tasks.len(), 1, "Should have 1 pending task");
407
408 let processing_tasks = scheduler.queue.get_all_processing().unwrap();
410 assert_eq!(processing_tasks.len(), 0, "Should have 0 processing tasks");
411 }
412
413 #[test]
414 fn test_get_pending_task() {
415 let (scheduler, _temp_dir) = setup_test_scheduler();
416
417 let task = Task::new(
419 "exec-1".to_string(),
420 "wf-1".to_string(),
421 "node-1".to_string(),
422 create_test_input(),
423 ExecutionContext::new("exec-1".to_string()),
424 );
425 let task_id = task.id.clone();
426
427 let priority = task.priority();
429 let serialized = serde_json::to_vec(&task).unwrap();
430 scheduler
431 .queue
432 .insert_pending(priority, &task_id, &serialized)
433 .unwrap();
434
435 let found = scheduler.get_task(&task_id).unwrap();
437 assert!(found.is_some(), "Should find task in pending");
438 assert_eq!(found.unwrap().id, task_id);
439 }
440
441 #[test]
442 fn test_submit_workflow() {
443 use crate::models::{Node, NodeType};
444
445 let (scheduler, _temp_dir) = setup_test_scheduler();
446
447 let node = Node {
449 id: "start_node".to_string(),
450 node_type: NodeType::Agent,
451 config: serde_json::json!({
452 "type": "Agent",
453 "data": {
454 "model": "gpt-4",
455 "prompt": "test prompt",
456 "temperature": null,
457 "api_key_config": null,
458 "tools": null
459 }
460 }),
461 position: None,
462 };
463
464 let workflow = Workflow {
465 id: "test-workflow".to_string(),
466 name: "Test Workflow".to_string(),
467 nodes: vec![node],
468 edges: vec![],
469 };
470
471 let input = serde_json::json!({"test": "data"});
473 let execution_id = scheduler.submit_workflow(workflow, input).unwrap();
474
475 assert!(!execution_id.is_empty(), "Execution ID should not be empty");
477
478 let pending_tasks = scheduler.queue.get_all_pending().unwrap();
480 assert_eq!(pending_tasks.len(), 1, "Should have 1 pending task");
481
482 let task: Task = serde_json::from_slice(&pending_tasks[0]).unwrap();
484 assert_eq!(task.execution_id, execution_id);
485 assert_eq!(task.node_id, "start_node");
486 }
487}