restflow_core/node/
agent.rs

1use crate::tools::{AddTool, GetTimeTool};
2use anyhow::Result;
3use rig::{
4    client::CompletionClient,
5    completion::Prompt,
6    providers::{anthropic, deepseek, openai},
7};
8use serde::{Deserialize, Serialize};
9use tracing::{debug, warn};
10use ts_rs::TS;
11
12/// Macro to build agent with tools
13/// In rig-core 0.22.0, calling .tool() changes builder type from AgentBuilder to AgentBuilderSimple
14macro_rules! build_with_tools {
15    ($self:expr, $builder:expr, $input:expr) => {{
16        let agent = match &$self.tools {
17            Some(tool_names) if !tool_names.is_empty() => {
18                debug!(tools = ?tool_names, "Configuring agent tools");
19
20                let has_add = tool_names.iter().any(|t| t == "add");
21                let has_time = tool_names.iter().any(|t| t == "get_current_time");
22
23                // Log unknown tools
24                for name in tool_names {
25                    if name != "add" && name != "get_current_time" {
26                        warn!(tool = %name, "Unknown tool specified");
27                    }
28                }
29
30                match (has_add, has_time) {
31                    (true, true) => {
32                        debug!("Adding tools: add, get_current_time");
33                        $builder.tool(AddTool).tool(GetTimeTool).build()
34                    }
35                    (true, false) => {
36                        debug!("Adding tool: add");
37                        $builder.tool(AddTool).build()
38                    }
39                    (false, true) => {
40                        debug!("Adding tool: get_current_time");
41                        $builder.tool(GetTimeTool).build()
42                    }
43                    (false, false) => $builder.build(),
44                }
45            }
46            _ => $builder.build(),
47        };
48
49        agent.prompt($input).await?
50    }};
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, TS)]
54#[ts(export)]
55#[serde(rename_all = "snake_case", tag = "type", content = "value")]
56pub enum ApiKeyConfig {
57    Direct(String),
58    Secret(String), // Reference to secret name in secret manager
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, TS)]
62#[ts(export)]
63pub struct AgentNode {
64    pub model: String,
65    pub prompt: Option<String>,
66    pub temperature: Option<f64>,
67    pub api_key_config: Option<ApiKeyConfig>,
68    pub tools: Option<Vec<String>>, // Tool names to enable
69}
70
71
72impl AgentNode {
73    pub fn new(
74        model: String,
75        prompt: String,
76        temperature: Option<f64>,
77        api_key_config: Option<ApiKeyConfig>,
78    ) -> Self {
79        Self {
80            model,
81            prompt: Some(prompt),
82            temperature,
83            api_key_config,
84            tools: None,
85        }
86    }
87
88    pub fn from_config(config: &serde_json::Value) -> Result<Self> {
89        let model = config["model"]
90            .as_str()
91            .ok_or_else(|| anyhow::anyhow!("Model missing in config"))?
92            .to_string();
93
94        let prompt = config
95            .get("prompt")
96            .and_then(|v| v.as_str())
97            .map(|s| s.to_string());
98
99        let temperature = config.get("temperature").and_then(|v| v.as_f64());
100
101        let api_key_config = config
102            .get("api_key_config")
103            .map(|v| serde_json::from_value(v.clone()))
104            .transpose()?;
105
106        let tools = config["tools"].as_array().map(|arr| {
107            arr.iter()
108                .filter_map(|v| v.as_str().map(String::from))
109                .collect()
110        });
111
112        Ok(Self {
113            model,
114            prompt,
115            temperature,
116            api_key_config,
117            tools,
118        })
119    }
120
121    pub async fn execute(
122        &self,
123        input: &str,
124        secret_storage: Option<&crate::storage::SecretStorage>,
125    ) -> Result<String> {
126        // Get API key from direct input or secret manager
127        let api_key = match &self.api_key_config {
128            Some(ApiKeyConfig::Direct(key)) => key.clone(),
129            Some(ApiKeyConfig::Secret(secret_name)) => {
130                if let Some(storage) = secret_storage {
131                    storage.get_secret(secret_name)?.ok_or_else(|| {
132                        anyhow::anyhow!("Secret '{}' not found in secret manager", secret_name)
133                    })?
134                } else {
135                    return Err(anyhow::anyhow!(
136                        "Secret manager not available but secret reference is configured"
137                    ));
138                }
139            }
140            None => {
141                return Err(anyhow::anyhow!(
142                    "No API key configured. Please provide api_key_config"
143                ));
144            }
145        };
146
147        let response = match self.model.as_str() {
148            m @ ("o4-mini" | "o3" | "o3-mini" | "gpt-4.1" | "gpt-4.1-mini" | "gpt-4.1-nano"
149            | "gpt-4" | "gpt-4-turbo" | "gpt-3.5-turbo" | "gpt-4o" | "gpt-4o-mini") => {
150                let client = openai::Client::new(&api_key);
151
152                let builder = match m {
153                    // O-series models don't support temperature
154                    "o4-mini" | "o3" | "o3-mini" => {
155                        let mut b = client.agent(m);
156                        if let Some(ref prompt) = self.prompt {
157                            b = b.preamble(prompt);
158                        }
159                        b
160                    }
161                    _ => {
162                        let mut b = client.agent(m);
163                        if let Some(ref prompt) = self.prompt {
164                            b = b.preamble(prompt);
165                        }
166                        if let Some(temp) = self.temperature {
167                            b.temperature(temp)
168                        } else {
169                            b
170                        }
171                    }
172                };
173
174                build_with_tools!(self, builder, input)
175            }
176
177            m @ ("claude-4-opus" | "claude-4-sonnet" | "claude-3.7-sonnet") => {
178                let client = anthropic::Client::new(&api_key);
179
180                let mut builder = match m {
181                    "claude-4-opus" => client.agent(anthropic::CLAUDE_4_OPUS),
182                    "claude-4-sonnet" => client.agent(anthropic::CLAUDE_4_SONNET),
183                    "claude-3.7-sonnet" => client.agent(anthropic::CLAUDE_3_7_SONNET),
184                    _ => unreachable!(), // We already matched these exact models
185                };
186                if let Some(ref prompt) = self.prompt {
187                    builder = builder.preamble(prompt);
188                }
189                let builder = if let Some(temp) = self.temperature {
190                    builder.temperature(temp)
191                } else {
192                    builder
193                };
194
195                build_with_tools!(self, builder, input)
196            }
197
198            m @ ("deepseek-chat" | "deepseek-reasoner") => {
199                let client = deepseek::Client::new(&api_key);
200
201                let mut builder = match m {
202                    "deepseek-chat" => client.agent(deepseek::DEEPSEEK_CHAT),
203                    "deepseek-reasoner" => client.agent(deepseek::DEEPSEEK_REASONER),
204                    _ => unreachable!(), // We already matched these exact models
205                };
206                if let Some(ref prompt) = self.prompt {
207                    builder = builder.preamble(prompt);
208                }
209                let builder = if let Some(temp) = self.temperature {
210                    builder.temperature(temp)
211                } else {
212                    builder
213                };
214
215                build_with_tools!(self, builder, input)
216            }
217
218            _ => {
219                return Err(anyhow::anyhow!(
220                    "Unsupported model: {}. Supported models: o4-mini, o3, o3-mini, gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, gpt-4, gpt-4-turbo, gpt-3.5-turbo, gpt-4o, gpt-4o-mini, claude-4-opus, claude-4-sonnet, claude-3.7-sonnet, deepseek-chat, deepseek-reasoner",
221                    self.model
222                ));
223            }
224        };
225
226        Ok(response)
227    }
228}