228 lines
7.4 KiB
Rust
228 lines
7.4 KiB
Rust
//! Tool: project_arxiv_search — search arXiv papers by query
|
|
|
|
use agent::{ToolContext, ToolDefinition, ToolError, ToolParam, ToolSchema};
|
|
use serde::Deserialize;
|
|
use std::collections::HashMap;
|
|
|
|
/// Number of results to return by default.
|
|
const DEFAULT_MAX_RESULTS: usize = 10;
|
|
const MAX_MAX_RESULTS: usize = 50;
|
|
|
|
/// arXiv API base URL (Atom feed).
|
|
const ARXIV_API: &str = "http://export.arxiv.org/api/query";
|
|
|
|
/// arXiv Atom feed entry fields we care about.
|
|
#[derive(Debug, Deserialize)]
|
|
struct ArxivEntry {
|
|
#[serde(rename = "id")]
|
|
entry_id: String,
|
|
#[serde(rename = "title")]
|
|
title: String,
|
|
#[serde(rename = "summary")]
|
|
summary: String,
|
|
#[serde(default, rename = "author")]
|
|
author: Vec<ArxivAuthor>,
|
|
#[serde(rename = "published")]
|
|
published: String,
|
|
#[serde(default, rename = "link")]
|
|
link: Vec<ArxivLink>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ArxivAuthor {
|
|
#[serde(rename = "name")]
|
|
name: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ArxivLink {
|
|
#[serde(rename = "type", default)]
|
|
link_type: String,
|
|
#[serde(rename = "href", default)]
|
|
href: String,
|
|
#[serde(rename = "title", default)]
|
|
title: String,
|
|
#[serde(rename = "rel", default)]
|
|
rel: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ArxivFeed {
|
|
#[serde(default, rename = "entry")]
|
|
entry: Vec<ArxivEntry>,
|
|
}
|
|
|
|
/// Search arXiv papers by query string.
|
|
///
|
|
/// Returns up to `max_results` papers (default 10, max 50) matching the query.
|
|
/// Each result includes arXiv ID, title, authors, abstract, published date, and PDF URL.
|
|
pub async fn arxiv_search_exec(
|
|
_ctx: ToolContext,
|
|
args: serde_json::Value,
|
|
) -> Result<serde_json::Value, ToolError> {
|
|
let query = args
|
|
.get("query")
|
|
.and_then(|v| v.as_str())
|
|
.ok_or_else(|| ToolError::ExecutionError("query is required".into()))?;
|
|
|
|
let max_results = args
|
|
.get("max_results")
|
|
.and_then(|v| v.as_u64())
|
|
.unwrap_or(DEFAULT_MAX_RESULTS as u64)
|
|
.min(MAX_MAX_RESULTS as u64) as usize;
|
|
|
|
let start = args
|
|
.get("start")
|
|
.and_then(|v| v.as_u64())
|
|
.unwrap_or(0) as usize;
|
|
|
|
// Build arXiv API query URL
|
|
// Encode query for URL
|
|
let encoded_query = urlencoding_encode(query);
|
|
let url = format!(
|
|
"{}?search_query=all:{}&start={}&max_results={}&sortBy=relevance&sortOrder=descending",
|
|
ARXIV_API, encoded_query, start, max_results
|
|
);
|
|
|
|
let response = reqwest::get(&url)
|
|
.await
|
|
.map_err(|e| ToolError::ExecutionError(format!("HTTP request failed: {}", e)))?;
|
|
|
|
if !response.status().is_success() {
|
|
return Err(ToolError::ExecutionError(format!(
|
|
"arXiv API returned status {}",
|
|
response.status()
|
|
)));
|
|
}
|
|
|
|
let body = response
|
|
.text()
|
|
.await
|
|
.map_err(|e| ToolError::ExecutionError(format!("Failed to read response: {}", e)))?;
|
|
|
|
let feed: ArxivFeed = quick_xml::de::from_str(&body)
|
|
.map_err(|e| ToolError::ExecutionError(format!("Failed to parse Atom feed: {}", e)))?;
|
|
|
|
let results: Vec<serde_json::Value> = feed
|
|
.entry
|
|
.into_iter()
|
|
.map(|entry| {
|
|
// Extract PDF link
|
|
let pdf_url = entry
|
|
.link
|
|
.iter()
|
|
.find(|l| l.link_type == "application/pdf")
|
|
.map(|l| l.href.clone())
|
|
.or_else(|| {
|
|
entry
|
|
.link
|
|
.iter()
|
|
.find(|l| l.rel == "alternate" && l.link_type.is_empty())
|
|
.map(|l| l.href.replace("/abs/", "/pdf/"))
|
|
})
|
|
.unwrap_or_default();
|
|
|
|
// Extract arXiv ID from entry id URL
|
|
// e.g. http://arxiv.org/abs/2312.12345v1 -> 2312.12345v1
|
|
let arxiv_id = entry
|
|
.entry_id
|
|
.rsplit('/')
|
|
.next()
|
|
.unwrap_or(&entry.entry_id)
|
|
.trim();
|
|
|
|
// Whitespace-normalize title and abstract
|
|
let title = normalize_whitespace(&entry.title);
|
|
let summary = normalize_whitespace(&entry.summary);
|
|
let author_str = if entry.author.is_empty() {
|
|
"Unknown".to_string()
|
|
} else {
|
|
entry.author.iter().map(|a| a.name.as_str()).collect::<Vec<_>>().join(", ")
|
|
};
|
|
|
|
serde_json::json!({
|
|
"arxiv_id": arxiv_id,
|
|
"title": title,
|
|
"authors": author_str,
|
|
"abstract": summary,
|
|
"published": entry.published,
|
|
"pdf_url": pdf_url,
|
|
"abs_url": entry.entry_id,
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
Ok(serde_json::json!({
|
|
"count": results.len(),
|
|
"query": query,
|
|
"results": results,
|
|
}))
|
|
}
|
|
|
|
// ─── helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
fn urlencoding_encode(s: &str) -> String {
|
|
let mut encoded = String::with_capacity(s.len() * 2);
|
|
for b in s.bytes() {
|
|
match b {
|
|
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
|
|
encoded.push(b as char);
|
|
}
|
|
_ => {
|
|
encoded.push_str(&format!("%{:02X}", b));
|
|
}
|
|
}
|
|
}
|
|
encoded
|
|
}
|
|
|
|
fn normalize_whitespace(s: &str) -> String {
|
|
let s = s.trim();
|
|
let mut result = String::with_capacity(s.len());
|
|
let mut last_was_space = false;
|
|
for c in s.chars() {
|
|
if c.is_whitespace() {
|
|
if !last_was_space {
|
|
result.push(' ');
|
|
last_was_space = true;
|
|
}
|
|
} else {
|
|
result.push(c);
|
|
last_was_space = false;
|
|
}
|
|
}
|
|
result
|
|
}
|
|
|
|
// ─── tool definition ─────────────────────────────────────────────────────────
|
|
|
|
pub fn tool_definition() -> ToolDefinition {
|
|
let mut p = HashMap::new();
|
|
p.insert("query".into(), ToolParam {
|
|
name: "query".into(), param_type: "string".into(),
|
|
description: Some("Search query (required). Supports arXiv search syntax, e.g. 'ti:transformer AND au:bengio'.".into()),
|
|
required: true, properties: None, items: None,
|
|
});
|
|
p.insert("max_results".into(), ToolParam {
|
|
name: "max_results".into(), param_type: "integer".into(),
|
|
description: Some("Maximum number of results to return (default 10, max 50). Optional.".into()),
|
|
required: false, properties: None, items: None,
|
|
});
|
|
p.insert("start".into(), ToolParam {
|
|
name: "start".into(), param_type: "integer".into(),
|
|
description: Some("Offset for pagination. Optional.".into()),
|
|
required: false, properties: None, items: None,
|
|
});
|
|
ToolDefinition::new("project_arxiv_search")
|
|
.description(
|
|
"Search arXiv papers by keyword or phrase. \
|
|
Returns paper titles, authors, abstracts, arXiv IDs, and PDF URLs. \
|
|
Useful for finding academic papers relevant to the project or a task.",
|
|
)
|
|
.parameters(ToolSchema {
|
|
schema_type: "object".into(),
|
|
properties: Some(p),
|
|
required: Some(vec!["query".into()]),
|
|
})
|
|
}
|