gitdataai/libs/agent-tool-derive/src/lib.rs
2026-04-14 19:02:01 +08:00

374 lines
13 KiB
Rust

//! Procedural macro for generating tool definitions from functions.
//!
//! # Example
//!
//! ```
//! use agent_tool_derive::tool;
//!
//! #[tool(description = "Search issues by title")]
//! fn search_issues(
//! title: String,
//! status: Option<String>,
//! ) -> Result<Vec<serde_json::Value>, String> {
//! Ok(vec![])
//! }
//! ```
//!
//! Generates:
//! - A `SearchIssuesParameters` struct (serde Deserialize)
//! - A `SEARCH_ISSUES_DEFINITION: ToolDefinition` constant
//! - A `register_search_issues(registry: &mut ToolRegistry)` helper
extern crate proc_macro;
use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use std::collections::HashMap;
use syn::punctuated::Punctuated;
use syn::{
Expr, ExprLit, Ident, Lit, Meta, ReturnType, Token, Type,
parse::{Parse, ParseStream},
};
/// Parse the attribute arguments: `description = "...", params(...), required(...)`
struct ToolArgs {
description: Option<String>,
param_descriptions: HashMap<String, String>,
required: Vec<String>,
}
impl Parse for ToolArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
Self::parse_from(input)
}
}
impl ToolArgs {
fn new() -> Self {
Self {
description: None,
param_descriptions: HashMap::new(),
required: Vec::new(),
}
}
fn parse_from(input: ParseStream) -> syn::Result<Self> {
let mut this = Self::new();
if input.is_empty() {
return Ok(this);
}
let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
for meta in meta_list {
match meta {
Meta::NameValue(nv) => {
let ident = nv
.path
.get_ident()
.ok_or_else(|| syn::Error::new_spanned(&nv.path, "expected identifier"))?;
if ident == "description" {
if let Expr::Lit(ExprLit {
lit: Lit::Str(s), ..
}) = nv.value
{
this.description = Some(s.value());
} else {
return Err(syn::Error::new_spanned(
&nv.value,
"description must be a string literal",
));
}
}
}
Meta::List(list) if list.path.is_ident("params") => {
let inner: Punctuated<Meta, Token![,]> =
list.parse_args_with(Punctuated::parse_terminated)?;
for item in inner {
if let Meta::NameValue(nv) = item {
let param_name = nv
.path
.get_ident()
.ok_or_else(|| {
syn::Error::new_spanned(&nv.path, "expected identifier")
})?
.to_string();
if let Expr::Lit(ExprLit {
lit: Lit::Str(s), ..
}) = nv.value
{
this.param_descriptions.insert(param_name, s.value());
}
}
}
}
Meta::List(list) if list.path.is_ident("required") => {
let required_vars: Punctuated<Ident, Token![,]> =
list.parse_args_with(Punctuated::parse_terminated)?;
for var in required_vars {
this.required.push(var.to_string());
}
}
_ => {}
}
}
Ok(this)
}
}
/// Map a Rust type to its JSON Schema type name.
fn json_type(ty: &Type) -> proc_macro2::TokenStream {
use syn::Type as T;
let segs = match ty {
T::Path(p) => &p.path.segments,
_ => return quote! { "type": "object" },
};
let last = segs.last().map(|s| &s.ident);
let args = segs.last().and_then(|s| {
if let syn::PathArguments::AngleBracketed(a) = &s.arguments {
Some(&a.args)
} else {
None
}
});
match (last.map(|i| i.to_string()).as_deref(), args) {
(Some("Vec" | "vec::Vec"), Some(args)) if !args.is_empty() => {
if let syn::GenericArgument::Type(inner) = &args[0] {
let inner_type = json_type(inner);
return quote! {
{
"type": "array",
"items": { #inner_type }
}
};
}
quote! { "type": "array" }
}
(Some("String" | "str" | "char"), _) => quote! { "type": "string" },
(Some("bool"), _) => quote! { "type": "boolean" },
(Some("i8" | "i16" | "i32" | "i64" | "isize"), _) => quote! { "type": "integer" },
(Some("u8" | "u16" | "u32" | "u64" | "usize"), _) => quote! { "type": "integer" },
(Some("f32" | "f64"), _) => quote! { "type": "number" },
_ => quote! { "type": "object" },
}
}
/// Extract return type info from `-> Result<T, E>`.
fn parse_return_type(
ret: &ReturnType,
) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
match ret {
ReturnType::Type(_, ty) => {
let ty = &**ty;
if let Type::Path(p) = ty {
let last = p
.path
.segments
.last()
.ok_or_else(|| syn::Error::new_spanned(&p.path, "invalid return type"))?;
if last.ident == "Result" {
if let syn::PathArguments::AngleBracketed(a) = &last.arguments {
let args = &a.args;
if args.len() == 2 {
let ok = &args[0];
let err = &args[1];
return Ok((quote!(#ok), quote!(#err)));
}
}
return Err(syn::Error::new_spanned(
&last,
"Result must have 2 type parameters",
));
}
}
Err(syn::Error::new_spanned(
ty,
"function must return Result<T, E>",
))
}
_ => Err(syn::Error::new_spanned(
ret,
"function must have a return type",
)),
}
}
/// The `#[tool]` attribute macro.
///
/// Usage:
/// ```
/// #[tool(description = "Tool description", params(
/// arg1 = "Description of arg1",
/// arg2 = "Description of arg2",
/// ))]
/// async fn my_tool(arg1: String, arg2: Option<i32>) -> Result<serde_json::Value, String> {
/// Ok(serde_json::json!({}))
/// }
/// ```
///
/// Generates:
/// - `MyToolParameters` struct with serde Deserialize
/// - `MY_TOOL_DEFINITION: ToolDefinition` constant
/// - `register_my_tool(registry: &mut ToolRegistry)` helper function
#[proc_macro_attribute]
pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
let args = syn::parse_macro_input!(args as ToolArgs);
let input_fn = syn::parse_macro_input!(input as syn::ItemFn);
let fn_name = &input_fn.sig.ident;
let fn_name_str = fn_name.to_string();
let vis = &input_fn.vis;
let is_async = input_fn.sig.asyncness.is_some();
// Parse return type: Result<T, E>
let (_output_type, _error_type) = match parse_return_type(&input_fn.sig.output) {
Ok(t) => t,
Err(e) => return e.into_compile_error().into(),
};
// PascalCase struct name
let struct_name = format_ident!("{}", fn_name_str.to_case(Case::Pascal));
let params_struct_name = format_ident!("{}Parameters", struct_name);
let definition_const_name = format_ident!("{}_DEFINITION", fn_name_str.to_uppercase());
let register_fn_name = format_ident!("register_{}", fn_name_str);
// Extract parameters from function signature
let mut param_names: Vec<Ident> = Vec::new();
let mut param_types: Vec<Type> = Vec::new();
let mut param_json_types: Vec<proc_macro2::TokenStream> = Vec::new();
let mut param_descs: Vec<proc_macro2::TokenStream> = Vec::new();
let required_args = args.required.clone();
for arg in &input_fn.sig.inputs {
let syn::FnArg::Typed(pat_type) = arg else {
continue;
};
let syn::Pat::Ident(pat_ident) = &*pat_type.pat else {
continue;
};
let name = &pat_ident.ident;
let ty = &*pat_type.ty;
let name_str = name.to_string();
let desc = args
.param_descriptions
.get(&name_str)
.map(|s| quote! { #s.to_string() })
.unwrap_or_else(|| quote! { format!("Parameter {}", #name_str) });
param_names.push(format_ident!("{}", name.to_string()));
param_types.push(ty.clone());
param_json_types.push(json_type(ty));
param_descs.push(desc);
}
// Which params are required (not Option<T>)
let required: Vec<proc_macro2::TokenStream> = if required_args.is_empty() {
param_names
.iter()
.filter(|name| {
let name_str = name.to_string();
!args
.param_descriptions
.contains_key(&format!("{}_opt", name_str))
})
.map(|name| quote! { stringify!(#name) })
.collect()
} else {
required_args.iter().map(|s| quote! { #s }).collect()
};
// Tool description
let tool_description = args
.description
.map(|s| quote! { #s.to_string() })
.unwrap_or_else(|| quote! { format!("Function {}", #fn_name_str) });
// Call invocation (async vs sync)
let call_args = param_names.iter().map(|n| quote! { args.#n });
let fn_call = if is_async {
quote! { #fn_name(#(#call_args),*).await }
} else {
quote! { #fn_name(#(#call_args),*) }
};
let expanded = quote! {
// Parameters struct: deserialized from JSON args by serde
#[derive(serde::Deserialize)]
#vis struct #params_struct_name {
#(#vis #param_names: #param_types,)*
}
// Keep the original function unchanged
#input_fn
// Static ToolDefinition constant — register this with ToolRegistry
#vis const #definition_const_name: agent::ToolDefinition = agent::ToolDefinition {
name: #fn_name_str.to_string(),
description: Some(#tool_description),
parameters: Some(agent::ToolSchema {
schema_type: "object".to_string(),
properties: Some({
let mut map = std::collections::HashMap::new();
#({
map.insert(stringify!(#param_names).to_string(), agent::ToolParam {
name: stringify!(#param_names).to_string(),
param_type: {
let jt = #param_json_types;
jt.get("type")
.and_then(|v| v.as_str())
.unwrap_or("object")
.to_string()
},
description: Some(#param_descs),
required: true,
properties: None,
items: None,
});
})*
map
}),
required: Some(vec![#(#required.to_string()),*]),
}),
strict: false,
};
/// Registers this tool in the given registry.
///
/// Generated by `#[tool]` macro for function `#fn_name_str`.
#vis fn #register_fn_name(registry: &mut agent::ToolRegistry) {
let def = #definition_const_name.clone();
let fn_name = #fn_name_str.to_string();
registry.register_fn(fn_name, move |_ctx, args| {
let args: #params_struct_name = match serde_json::from_value(args) {
Ok(a) => a,
Err(e) => {
return std::pin::Pin::new(Box::new(async move {
Err(agent::ToolError::ParseError(e.to_string()))
}))
}
};
std::pin::Pin::new(Box::new(async move {
let result = #fn_call;
match result {
Ok(v) => Ok(serde_json::to_value(v).unwrap_or(serde_json::Value::Null)),
Err(e) => Err(agent::ToolError::ExecutionError(e.to_string())),
}
}))
});
}
};
// We need to use boxed futures for the return type.
// Since we can't add runtime dependencies to the proc-macro crate,
// we emit the .boxed() call and the caller must ensure
// `use futures::FutureExt;` or equivalent is in scope.
// The generated code requires: `futures::FutureExt` (for .boxed()).
// Re-emit with futures dependency note
TokenStream::from(expanded)
}