restflow_core/storage/
agent.rs

1use crate::node::agent::AgentNode;
2use anyhow::Result;
3use redb::{Database, ReadableDatabase, ReadableTable, TableDefinition};
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6use ts_rs::TS;
7use uuid::Uuid;
8
9#[derive(Serialize, Deserialize, Debug, Clone, TS)]
10#[ts(export)]
11pub struct StoredAgent {
12    pub id: String,
13    pub name: String,
14    pub agent: AgentNode,
15    #[ts(optional)]
16    pub created_at: Option<i64>,
17    #[ts(optional)]
18    pub updated_at: Option<i64>,
19}
20const AGENT_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("agents");
21
22pub struct AgentStorage {
23    db: Arc<Database>,
24}
25
26impl AgentStorage {
27    pub fn new(db: Arc<Database>) -> Result<Self> {
28        // Ensure agents table exists, create if not
29        let write_txn = db.begin_write()?;
30        write_txn.open_table(AGENT_TABLE)?;
31        write_txn.commit()?;
32
33        Ok(Self { db })
34    }
35    pub fn create_agent(&self, name: String, agent: AgentNode) -> Result<StoredAgent> {
36        let now = std::time::SystemTime::now()
37            .duration_since(std::time::UNIX_EPOCH)?
38            .as_millis() as i64;
39
40        let stored_agent = StoredAgent {
41            id: Uuid::new_v4().to_string(),
42            name,
43            agent,
44            created_at: Some(now),
45            updated_at: Some(now),
46        };
47        let write_txn = self.db.begin_write()?;
48        {
49            let mut table = write_txn.open_table(AGENT_TABLE)?;
50            let json_bytes = serde_json::to_vec(&stored_agent)?;
51            table.insert(stored_agent.id.as_str(), json_bytes.as_slice())?;
52        }
53        write_txn.commit()?;
54
55        Ok(stored_agent)
56    }
57
58    pub fn get_agent(&self, id: String) -> Result<Option<StoredAgent>> {
59        let read_txn = self.db.begin_read()?;
60        let table = read_txn.open_table(AGENT_TABLE)?;
61        if let Some(value) = table.get(id.as_str())? {
62            let agent: StoredAgent = serde_json::from_slice(value.value())?;
63            Ok(Some(agent))
64        } else {
65            Ok(None)
66        }
67    }
68
69    pub fn list_agents(&self) -> Result<Vec<StoredAgent>> {
70        let read_txn = self.db.begin_read()?;
71        let table = read_txn.open_table(AGENT_TABLE)?;
72        let mut agents = Vec::new();
73        for item in table.iter()? {
74            let (_, value) = item?;
75            let agent: StoredAgent = serde_json::from_slice(value.value())?;
76            agents.push(agent);
77        }
78        Ok(agents)
79    }
80
81    pub fn update_agent(
82        &self,
83        id: String,
84        name: Option<String>,
85        agent: Option<AgentNode>,
86    ) -> Result<StoredAgent> {
87        let mut existing_agent = self
88            .get_agent(id.clone())?
89            .ok_or_else(|| anyhow::anyhow!("Agent {} not found", id))?;
90        if let Some(new_name) = name {
91            existing_agent.name = new_name;
92        };
93
94        if let Some(new_agent) = agent {
95            existing_agent.agent = new_agent;
96        };
97
98        let now = std::time::SystemTime::now()
99            .duration_since(std::time::UNIX_EPOCH)?
100            .as_millis() as i64;
101        existing_agent.updated_at = Some(now);
102
103        let write_txn = self.db.begin_write()?;
104        {
105            let mut table = write_txn.open_table(AGENT_TABLE)?;
106            let json_bytes = serde_json::to_vec(&existing_agent)?;
107            table.insert(existing_agent.id.as_str(), json_bytes.as_slice())?;
108        }
109        write_txn.commit()?;
110
111        Ok(existing_agent)
112    }
113
114    pub fn delete_agent(&self, id: String) -> Result<()> {
115        let write_txn = self.db.begin_write()?;
116        {
117            let mut table = write_txn.open_table(AGENT_TABLE)?;
118            table
119                .remove(id.as_str())?
120                .ok_or_else(|| anyhow::anyhow!("Agent {} not found", id))?;
121        }
122        write_txn.commit()?;
123        Ok(())
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use tempfile::tempdir;
131
132    fn create_test_agent_node() -> AgentNode {
133        use crate::node::agent::ApiKeyConfig;
134
135        AgentNode {
136            model: "gpt-4.1".to_string(),
137            prompt: Some("You are a helpful assistant".to_string()),
138            temperature: Some(0.7),
139            api_key_config: Some(ApiKeyConfig::Direct("test_key".to_string())),
140            tools: Some(vec!["add".to_string()]),
141        }
142    }
143
144    #[test]
145    fn test_insert_and_get_agent() {
146        let temp_dir = tempdir().unwrap();
147        let db_path = temp_dir.path().join("test.db");
148        let db = Arc::new(Database::create(db_path).unwrap());
149        let storage = AgentStorage::new(db).unwrap();
150
151        let agent_node = create_test_agent_node();
152        let stored = storage
153            .create_agent("Test Agent".to_string(), agent_node)
154            .unwrap();
155
156        assert!(!stored.id.is_empty());
157        assert_eq!(stored.name, "Test Agent");
158
159        let retrieved = storage.get_agent(stored.id.clone()).unwrap();
160        assert!(retrieved.is_some());
161
162        let agent = retrieved.unwrap();
163        assert_eq!(agent.name, "Test Agent");
164        assert_eq!(agent.agent.model, "gpt-4.1");
165    }
166
167    #[test]
168    fn test_list_agents() {
169        let temp_dir = tempdir().unwrap();
170        let db_path = temp_dir.path().join("test.db");
171        let db = Arc::new(Database::create(db_path).unwrap());
172        let storage = AgentStorage::new(db).unwrap();
173
174        storage
175            .create_agent("Agent 1".to_string(), create_test_agent_node())
176            .unwrap();
177        storage
178            .create_agent("Agent 2".to_string(), create_test_agent_node())
179            .unwrap();
180        storage
181            .create_agent("Agent 3".to_string(), create_test_agent_node())
182            .unwrap();
183
184        let agents = storage.list_agents().unwrap();
185        assert_eq!(agents.len(), 3);
186
187        let names: Vec<String> = agents.iter().map(|a| a.name.clone()).collect();
188        assert!(names.contains(&"Agent 1".to_string()));
189        assert!(names.contains(&"Agent 2".to_string()));
190        assert!(names.contains(&"Agent 3".to_string()));
191    }
192
193    #[test]
194    fn test_update_agent() {
195        let temp_dir = tempdir().unwrap();
196        let db_path = temp_dir.path().join("test.db");
197        let db = Arc::new(Database::create(db_path).unwrap());
198        let storage = AgentStorage::new(db).unwrap();
199
200        let stored = storage
201            .create_agent("Original Name".to_string(), create_test_agent_node())
202            .unwrap();
203        let updated = storage
204            .update_agent(stored.id.clone(), Some("Updated Name".to_string()), None)
205            .unwrap();
206
207        assert_eq!(updated.name, "Updated Name");
208        assert_eq!(updated.agent.model, "gpt-4.1");
209
210        let mut new_agent_node = create_test_agent_node();
211        new_agent_node.temperature = Some(0.9);
212
213        let updated2 = storage
214            .update_agent(stored.id.clone(), None, Some(new_agent_node))
215            .unwrap();
216
217        assert_eq!(updated2.name, "Updated Name");
218        assert_eq!(updated2.agent.temperature, Some(0.9));
219    }
220
221    #[test]
222    fn test_delete_agent() {
223        let temp_dir = tempdir().unwrap();
224        let db_path = temp_dir.path().join("test.db");
225        let db = Arc::new(Database::create(db_path).unwrap());
226        let storage = AgentStorage::new(db).unwrap();
227
228        let stored = storage
229            .create_agent("To Delete".to_string(), create_test_agent_node())
230            .unwrap();
231        storage.delete_agent(stored.id.clone()).unwrap();
232
233        let retrieved = storage.get_agent(stored.id.clone()).unwrap();
234        assert!(retrieved.is_none());
235
236        let deleted_again = storage.delete_agent(stored.id);
237        assert!(deleted_again.is_err());
238        assert!(deleted_again.unwrap_err().to_string().contains("not found"));
239    }
240
241    #[test]
242    fn test_get_nonexistent_agent() {
243        let temp_dir = tempdir().unwrap();
244        let db_path = temp_dir.path().join("test.db");
245        let db = Arc::new(Database::create(db_path).unwrap());
246        let storage = AgentStorage::new(db).unwrap();
247
248        let result = storage.get_agent("nonexistent".to_string()).unwrap();
249        assert!(result.is_none());
250    }
251
252    #[test]
253    fn test_update_nonexistent_agent() {
254        let temp_dir = tempdir().unwrap();
255        let db_path = temp_dir.path().join("test.db");
256        let db = Arc::new(Database::create(db_path).unwrap());
257        let storage = AgentStorage::new(db).unwrap();
258
259        let result = storage.update_agent(
260            "nonexistent".to_string(),
261            Some("New Name".to_string()),
262            None,
263        );
264
265        assert!(result.is_err());
266        assert!(result.unwrap_err().to_string().contains("not found"));
267    }
268}