//! Tool: project_curl — perform HTTP requests (GET/POST/PUT/DELETE) //! //! Security measures: //! - SSRF protection: blocks private IPs and blocks redirects to private IPs //! - Sensitive header injection: blocks Host, Authorization, Cookie, Proxy-* //! - Connection pooling via a shared reqwest::Client use agent::{ToolContext, ToolDefinition, ToolError, ToolParam, ToolSchema}; use std::collections::HashMap; use std::sync::OnceLock; /// Maximum response body size: 1 MB. const MAX_BODY_BYTES: usize = 1 << 20; /// Headers that are blocked from user-supplied values to prevent injection attacks. const BLOCKED_HEADERS: &[&str] = &[ "host", "authorization", "cookie", "proxy-authorization", "proxy-connection", "proxy-authenticate", ]; /// Shared reqwest::Client for connection pooling. static SHARED_CLIENT: OnceLock = OnceLock::new(); fn shared_client() -> &'static reqwest::Client { SHARED_CLIENT.get_or_init(|| { reqwest::Client::builder() .connect_timeout(std::time::Duration::from_secs(10)) .timeout(std::time::Duration::from_secs(120)) // Block automatic redirect following so we can validate each hop .redirect(reqwest::redirect::Policy::limited(0)) .build() .expect("reqwest client build should not fail") }) } /// Check if a host string resolves to or is a private/internal IP. fn is_private_host(host: &str) -> bool { host.eq_ignore_ascii_case("localhost") || host.eq_ignore_ascii_case("127.0.0.1") || host.eq_ignore_ascii_case("::1") || host.eq_ignore_ascii_case("0.0.0.0") || host.eq_ignore_ascii_case("metadata.google.internal") || host.eq_ignore_ascii_case("169.254.169.254") || host.starts_with("10.") || host.starts_with("172.16.") || host.starts_with("172.17.") || host.starts_with("172.18.") || host.starts_with("172.19.") || host.starts_with("172.20.") || host.starts_with("172.21.") || host.starts_with("172.22.") || host.starts_with("172.23.") || host.starts_with("172.24.") || host.starts_with("172.25.") || host.starts_with("172.26.") || host.starts_with("172.27.") || host.starts_with("172.28.") || host.starts_with("172.29.") || host.starts_with("172.30.") || host.starts_with("172.31.") || host.starts_with("192.168.") } /// Validate URL and any redirect hops against SSRF rules. fn validate_url_against_ssrf(url_str: &str) -> Result { let parsed = reqwest::Url::parse(url_str) .map_err(|e| ToolError::ExecutionError(format!("Invalid URL: {}", e)))?; if let Some(host) = parsed.host_str() { if is_private_host(host) { return Err(ToolError::ExecutionError( "Requests to internal/private IPs are not allowed for security reasons".into(), )); } } Ok(parsed) } /// Perform an HTTP request and return the response body and metadata. /// Supports GET, POST, PUT, DELETE methods. Useful for fetching web pages, /// calling external APIs, or downloading resources. pub async fn curl_exec( _ctx: ToolContext, args: serde_json::Value, ) -> Result { let url_str = args .get("url") .and_then(|v| v.as_str()) .ok_or_else(|| ToolError::ExecutionError("url is required".into()))?; // SSRF protection: validate initial URL validate_url_against_ssrf(url_str)?; let method = args .get("method") .and_then(|v| v.as_str()) .unwrap_or("GET") .to_uppercase(); let body = args.get("body").and_then(|v| v.as_str()).map(String::from); let headers: Vec<(String, String)> = args .get("headers") .and_then(|v| v.as_object()) .map(|obj| { obj.iter() .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string()))) .collect() }) .unwrap_or_default(); // Block sensitive headers that could be used for injection attacks for (key, _) in &headers { if BLOCKED_HEADERS.contains(&key.to_lowercase().as_str()) { return Err(ToolError::ExecutionError( format!("Header '{}' is not allowed for security reasons", key), )); } } let timeout_secs = args .get("timeout") .and_then(|v| v.as_u64()) .unwrap_or(30) .min(120); let client = shared_client(); // Build a per-request client with the specific timeout by using the shared // client's connection pool but overriding timeout per request via request builder. // Since reqwest::Client::builder().redirect(Policy::limited(0)) disables auto-redirects, // we manually follow up to 5 redirects with SSRF validation on each hop. let mut current_url = url_str.to_string(); let mut redirect_count = 0u32; const MAX_REDIRECTS: u32 = 5; loop { let mut request = match method.as_str() { "GET" => client.get(¤t_url), "POST" => client.post(¤t_url), "PUT" => client.put(¤t_url), "DELETE" => client.delete(¤t_url), "PATCH" => client.patch(¤t_url), "HEAD" => client.head(¤t_url), _ => { return Err(ToolError::ExecutionError(format!( "Unsupported HTTP method: {}. Use GET, POST, PUT, DELETE, PATCH, or HEAD.", method ))) } }; request = request.timeout(std::time::Duration::from_secs(timeout_secs)); for (key, value) in &headers { request = request.header(key, value); } // Set default Content-Type for POST/PUT/PATCH if not provided and body exists if body.is_some() && !headers.iter().any(|(k, _)| k.to_lowercase() == "content-type") { request = request.header("Content-Type", "application/json"); } if let Some(ref b) = body { request = request.body(b.clone()); } let response = request .send() .await .map_err(|e| ToolError::ExecutionError(format!("HTTP request failed: {}", e)))?; let status = response.status().as_u16(); // Handle redirects manually with SSRF validation if status >= 300 && status < 400 { redirect_count += 1; if redirect_count > MAX_REDIRECTS { return Err(ToolError::ExecutionError( format!("Too many redirects (max {})", MAX_REDIRECTS), )); } let location = response.headers() .get("location") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); let location = match location { Some(l) => l, None => return Err(ToolError::ExecutionError("Redirect with no Location header".into())), }; // Resolve relative redirect against current URL let base = reqwest::Url::parse(¤t_url) .map_err(|e| ToolError::ExecutionError(format!("Invalid current URL: {}", e)))?; let next_url = base.join(&location) .map_err(|e| ToolError::ExecutionError(format!("Invalid redirect URL: {}", e)))?; // Validate redirect target against SSRF rules if let Some(host) = next_url.host_str() { if is_private_host(host) { return Err(ToolError::ExecutionError( "Redirect to internal/private IP is not allowed".into(), )); } } current_url = next_url.to_string(); continue; } let status_text = response.status().canonical_reason().unwrap_or(""); let response_headers: std::collections::HashMap = response .headers() .iter() .map(|(k, v)| { ( k.to_string(), v.to_str().unwrap_or("").to_string(), ) }) .collect(); let content_type = response .headers() .get("content-type") .and_then(|v| v.to_str().ok()) .unwrap_or("") .to_string(); let is_text = content_type.starts_with("text/") || content_type.contains("json") || content_type.contains("xml") || content_type.contains("javascript"); let body_bytes = response .bytes() .await .map_err(|e| ToolError::ExecutionError(format!("Failed to read response body: {}", e)))?; let body_len = body_bytes.len(); let truncated = body_len > MAX_BODY_BYTES; let body_text = if truncated { String::from("[Response truncated — exceeds 1 MB limit]") } else if is_text { String::from_utf8_lossy(&body_bytes).to_string() } else { format!( "[Binary body, {} bytes, Content-Type: {}]", body_len, content_type ) }; return Ok(serde_json::json!({ "url": current_url, "method": method, "status": status, "status_text": status_text, "headers": response_headers, "body": body_text, "truncated": truncated, "size_bytes": body_len, })); } } // ─── tool definition ───────────────────────────────────────────────────────── pub fn tool_definition() -> ToolDefinition { let mut p = HashMap::new(); p.insert("url".into(), ToolParam { name: "url".into(), param_type: "string".into(), description: Some("Full URL to request (required).".into()), required: true, properties: None, items: None, }); p.insert("method".into(), ToolParam { name: "method".into(), param_type: "string".into(), description: Some("HTTP method: GET (default), POST, PUT, DELETE, PATCH, HEAD.".into()), required: false, properties: None, items: None, }); p.insert("body".into(), ToolParam { name: "body".into(), param_type: "string".into(), description: Some("Request body. Defaults to 'application/json' Content-Type if provided. Optional.".into()), required: false, properties: None, items: None, }); p.insert("headers".into(), ToolParam { name: "headers".into(), param_type: "object".into(), description: Some("HTTP headers as key-value pairs. Optional.".into()), required: false, properties: None, items: None, }); p.insert("timeout".into(), ToolParam { name: "timeout".into(), param_type: "integer".into(), description: Some("Request timeout in seconds (default 30, max 120). Optional.".into()), required: false, properties: None, items: None, }); ToolDefinition::new("project_curl") .description( "Perform an HTTP request to any URL. Supports GET, POST, PUT, DELETE, PATCH, HEAD. \ Returns status code, headers, and response body. \ Response body is truncated at 1 MB. Binary responses are described as text metadata. \ Useful for fetching web pages, calling APIs, or downloading resources.", ) .parameters(ToolSchema { schema_type: "object".into(), properties: Some(p), required: Some(vec!["url".into()]), }) }