New git subcommands: lfs (summary/scan_tree), merge_analysis, ref_list/ref_info, and git_status. New project tool: bing_search. Update repo_analysis with expanded field coverage and curl tool.
356 lines
12 KiB
Rust
356 lines
12 KiB
Rust
//! Tool: project_curl — perform HTTP requests (GET/POST/PUT/DELETE)
|
|
//!
|
|
//! Security measures:
|
|
//! - SSRF protection: blocks private IPs and blocks redirects to private IPs
|
|
//! - Sensitive header injection: blocks Host, Authorization, Cookie, Proxy-*
|
|
//! - Connection pooling via a shared reqwest::Client
|
|
|
|
use agent::{ToolContext, ToolDefinition, ToolError, ToolParam, ToolSchema};
|
|
use std::collections::HashMap;
|
|
use std::sync::OnceLock;
|
|
|
|
/// Maximum response body size: 1 MB.
|
|
const MAX_BODY_BYTES: usize = 1 << 20;
|
|
|
|
/// Headers that are blocked from user-supplied values to prevent injection attacks.
|
|
const BLOCKED_HEADERS: &[&str] = &[
|
|
"host",
|
|
"authorization",
|
|
"cookie",
|
|
"proxy-authorization",
|
|
"proxy-connection",
|
|
"proxy-authenticate",
|
|
];
|
|
|
|
/// Shared reqwest::Client for connection pooling.
|
|
static SHARED_CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
|
|
|
|
fn shared_client() -> &'static reqwest::Client {
|
|
SHARED_CLIENT.get_or_init(|| {
|
|
reqwest::Client::builder()
|
|
.connect_timeout(std::time::Duration::from_secs(10))
|
|
.timeout(std::time::Duration::from_secs(120))
|
|
// Block automatic redirect following so we can validate each hop
|
|
.redirect(reqwest::redirect::Policy::limited(0))
|
|
.build()
|
|
.expect("reqwest client build should not fail")
|
|
})
|
|
}
|
|
|
|
/// Check if a host string resolves to or is a private/internal IP.
|
|
fn is_private_host(host: &str) -> bool {
|
|
host.eq_ignore_ascii_case("localhost")
|
|
|| host.eq_ignore_ascii_case("127.0.0.1")
|
|
|| host.eq_ignore_ascii_case("::1")
|
|
|| host.eq_ignore_ascii_case("0.0.0.0")
|
|
|| host.eq_ignore_ascii_case("metadata.google.internal")
|
|
|| host.eq_ignore_ascii_case("169.254.169.254")
|
|
|| host.starts_with("10.")
|
|
|| host.starts_with("172.16.")
|
|
|| host.starts_with("172.17.")
|
|
|| host.starts_with("172.18.")
|
|
|| host.starts_with("172.19.")
|
|
|| host.starts_with("172.20.")
|
|
|| host.starts_with("172.21.")
|
|
|| host.starts_with("172.22.")
|
|
|| host.starts_with("172.23.")
|
|
|| host.starts_with("172.24.")
|
|
|| host.starts_with("172.25.")
|
|
|| host.starts_with("172.26.")
|
|
|| host.starts_with("172.27.")
|
|
|| host.starts_with("172.28.")
|
|
|| host.starts_with("172.29.")
|
|
|| host.starts_with("172.30.")
|
|
|| host.starts_with("172.31.")
|
|
|| host.starts_with("192.168.")
|
|
}
|
|
|
|
/// Validate URL and any redirect hops against SSRF rules.
|
|
fn validate_url_against_ssrf(url_str: &str) -> Result<reqwest::Url, ToolError> {
|
|
let parsed = reqwest::Url::parse(url_str)
|
|
.map_err(|e| ToolError::ExecutionError(format!("Invalid URL: {}", e)))?;
|
|
if let Some(host) = parsed.host_str() {
|
|
if is_private_host(host) {
|
|
return Err(ToolError::ExecutionError(
|
|
"Requests to internal/private IPs are not allowed for security reasons".into(),
|
|
));
|
|
}
|
|
}
|
|
Ok(parsed)
|
|
}
|
|
|
|
/// Perform an HTTP request and return the response body and metadata.
|
|
/// Supports GET, POST, PUT, DELETE methods. Useful for fetching web pages,
|
|
/// calling external APIs, or downloading resources.
|
|
pub async fn curl_exec(
|
|
_ctx: ToolContext,
|
|
args: serde_json::Value,
|
|
) -> Result<serde_json::Value, ToolError> {
|
|
let url_str = args
|
|
.get("url")
|
|
.and_then(|v| v.as_str())
|
|
.ok_or_else(|| ToolError::ExecutionError("url is required".into()))?;
|
|
|
|
// SSRF protection: validate initial URL
|
|
validate_url_against_ssrf(url_str)?;
|
|
|
|
let method = args
|
|
.get("method")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("GET")
|
|
.to_uppercase();
|
|
|
|
let body = args.get("body").and_then(|v| v.as_str()).map(String::from);
|
|
|
|
let headers: Vec<(String, String)> = args
|
|
.get("headers")
|
|
.and_then(|v| v.as_object())
|
|
.map(|obj| {
|
|
obj.iter()
|
|
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
|
|
.collect()
|
|
})
|
|
.unwrap_or_default();
|
|
|
|
// Block sensitive headers that could be used for injection attacks
|
|
for (key, _) in &headers {
|
|
if BLOCKED_HEADERS.contains(&key.to_lowercase().as_str()) {
|
|
return Err(ToolError::ExecutionError(format!(
|
|
"Header '{}' is not allowed for security reasons",
|
|
key
|
|
)));
|
|
}
|
|
}
|
|
|
|
let timeout_secs = args
|
|
.get("timeout")
|
|
.and_then(|v| v.as_u64())
|
|
.unwrap_or(30)
|
|
.min(120);
|
|
|
|
let client = shared_client();
|
|
// Build a per-request client with the specific timeout by using the shared
|
|
// client's connection pool but overriding timeout per request via request builder.
|
|
// Since reqwest::Client::builder().redirect(Policy::limited(0)) disables auto-redirects,
|
|
// we manually follow up to 5 redirects with SSRF validation on each hop.
|
|
|
|
let mut current_url = url_str.to_string();
|
|
let mut redirect_count = 0u32;
|
|
const MAX_REDIRECTS: u32 = 5;
|
|
|
|
loop {
|
|
let mut request = match method.as_str() {
|
|
"GET" => client.get(¤t_url),
|
|
"POST" => client.post(¤t_url),
|
|
"PUT" => client.put(¤t_url),
|
|
"DELETE" => client.delete(¤t_url),
|
|
"PATCH" => client.patch(¤t_url),
|
|
"HEAD" => client.head(¤t_url),
|
|
_ => {
|
|
return Err(ToolError::ExecutionError(format!(
|
|
"Unsupported HTTP method: {}. Use GET, POST, PUT, DELETE, PATCH, or HEAD.",
|
|
method
|
|
)));
|
|
}
|
|
};
|
|
|
|
request = request.timeout(std::time::Duration::from_secs(timeout_secs));
|
|
|
|
for (key, value) in &headers {
|
|
request = request.header(key, value);
|
|
}
|
|
|
|
// Set default Content-Type for POST/PUT/PATCH if not provided and body exists
|
|
if body.is_some()
|
|
&& !headers
|
|
.iter()
|
|
.any(|(k, _)| k.to_lowercase() == "content-type")
|
|
{
|
|
request = request.header("Content-Type", "application/json");
|
|
}
|
|
|
|
if let Some(ref b) = body {
|
|
request = request.body(b.clone());
|
|
}
|
|
|
|
let response = request
|
|
.send()
|
|
.await
|
|
.map_err(|e| ToolError::ExecutionError(format!("HTTP request failed: {}", e)))?;
|
|
|
|
let status = response.status().as_u16();
|
|
|
|
// Handle redirects manually with SSRF validation
|
|
if status >= 300 && status < 400 {
|
|
redirect_count += 1;
|
|
if redirect_count > MAX_REDIRECTS {
|
|
return Err(ToolError::ExecutionError(format!(
|
|
"Too many redirects (max {})",
|
|
MAX_REDIRECTS
|
|
)));
|
|
}
|
|
let location = response
|
|
.headers()
|
|
.get("location")
|
|
.and_then(|v| v.to_str().ok())
|
|
.map(|s| s.to_string());
|
|
let location = match location {
|
|
Some(l) => l,
|
|
None => {
|
|
return Err(ToolError::ExecutionError(
|
|
"Redirect with no Location header".into(),
|
|
));
|
|
}
|
|
};
|
|
// Resolve relative redirect against current URL
|
|
let base = reqwest::Url::parse(¤t_url)
|
|
.map_err(|e| ToolError::ExecutionError(format!("Invalid current URL: {}", e)))?;
|
|
let next_url = base
|
|
.join(&location)
|
|
.map_err(|e| ToolError::ExecutionError(format!("Invalid redirect URL: {}", e)))?;
|
|
// Validate redirect target against SSRF rules
|
|
if let Some(host) = next_url.host_str() {
|
|
if is_private_host(host) {
|
|
return Err(ToolError::ExecutionError(
|
|
"Redirect to internal/private IP is not allowed".into(),
|
|
));
|
|
}
|
|
}
|
|
current_url = next_url.to_string();
|
|
continue;
|
|
}
|
|
|
|
let status_text = response.status().canonical_reason().unwrap_or("");
|
|
|
|
let response_headers: std::collections::HashMap<String, String> = response
|
|
.headers()
|
|
.iter()
|
|
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("<binary>").to_string()))
|
|
.collect();
|
|
|
|
let content_type = response
|
|
.headers()
|
|
.get("content-type")
|
|
.and_then(|v| v.to_str().ok())
|
|
.unwrap_or("")
|
|
.to_string();
|
|
|
|
let is_text = content_type.starts_with("text/")
|
|
|| content_type.contains("json")
|
|
|| content_type.contains("xml")
|
|
|| content_type.contains("javascript");
|
|
|
|
let body_bytes = response.bytes().await.map_err(|e| {
|
|
ToolError::ExecutionError(format!("Failed to read response body: {}", e))
|
|
})?;
|
|
|
|
let body_len = body_bytes.len();
|
|
let truncated = body_len > MAX_BODY_BYTES;
|
|
let body_text = if truncated {
|
|
String::from("[Response truncated — exceeds 1 MB limit]")
|
|
} else if is_text {
|
|
String::from_utf8_lossy(&body_bytes).to_string()
|
|
} else {
|
|
format!(
|
|
"[Binary body, {} bytes, Content-Type: {}]",
|
|
body_len, content_type
|
|
)
|
|
};
|
|
|
|
return Ok(serde_json::json!({
|
|
"url": current_url,
|
|
"method": method,
|
|
"status": status,
|
|
"status_text": status_text,
|
|
"headers": response_headers,
|
|
"body": body_text,
|
|
"truncated": truncated,
|
|
"size_bytes": body_len,
|
|
}));
|
|
}
|
|
}
|
|
|
|
// ─── tool definition ─────────────────────────────────────────────────────────
|
|
|
|
fn tool_definition_with_name(name: &str) -> ToolDefinition {
|
|
let mut p = HashMap::new();
|
|
p.insert(
|
|
"url".into(),
|
|
ToolParam {
|
|
name: "url".into(),
|
|
param_type: "string".into(),
|
|
description: Some("Full URL to request (required).".into()),
|
|
required: true,
|
|
properties: None,
|
|
items: None,
|
|
},
|
|
);
|
|
p.insert(
|
|
"method".into(),
|
|
ToolParam {
|
|
name: "method".into(),
|
|
param_type: "string".into(),
|
|
description: Some("HTTP method: GET (default), POST, PUT, DELETE, PATCH, HEAD.".into()),
|
|
required: false,
|
|
properties: None,
|
|
items: None,
|
|
},
|
|
);
|
|
p.insert(
|
|
"body".into(),
|
|
ToolParam {
|
|
name: "body".into(),
|
|
param_type: "string".into(),
|
|
description: Some(
|
|
"Request body. Defaults to 'application/json' Content-Type if provided. Optional."
|
|
.into(),
|
|
),
|
|
required: false,
|
|
properties: None,
|
|
items: None,
|
|
},
|
|
);
|
|
p.insert(
|
|
"headers".into(),
|
|
ToolParam {
|
|
name: "headers".into(),
|
|
param_type: "object".into(),
|
|
description: Some("HTTP headers as key-value pairs. Optional.".into()),
|
|
required: false,
|
|
properties: None,
|
|
items: None,
|
|
},
|
|
);
|
|
p.insert(
|
|
"timeout".into(),
|
|
ToolParam {
|
|
name: "timeout".into(),
|
|
param_type: "integer".into(),
|
|
description: Some("Request timeout in seconds (default 30, max 120). Optional.".into()),
|
|
required: false,
|
|
properties: None,
|
|
items: None,
|
|
},
|
|
);
|
|
ToolDefinition::new(name)
|
|
.description(
|
|
"Perform an HTTP request to any URL. Supports GET, POST, PUT, DELETE, PATCH, HEAD. \
|
|
Returns status code, headers, and response body. \
|
|
Response body is truncated at 1 MB. Binary responses are described as text metadata. \
|
|
Useful for fetching web pages, calling APIs, or downloading resources.",
|
|
)
|
|
.parameters(ToolSchema {
|
|
schema_type: "object".into(),
|
|
properties: Some(p),
|
|
required: Some(vec!["url".into()]),
|
|
})
|
|
}
|
|
|
|
pub fn tool_definition() -> ToolDefinition {
|
|
tool_definition_with_name("project_curl")
|
|
}
|
|
|
|
pub fn alias_tool_definition() -> ToolDefinition {
|
|
tool_definition_with_name("curl_exec")
|
|
}
|