//! Procedural macro for generating tool definitions from functions. //! //! # Example //! //! ``` //! use agent_tool_derive::tool; //! //! #[tool(description = "Search issues by title")] //! fn search_issues( //! title: String, //! status: Option, //! ) -> Result, String> { //! Ok(vec![]) //! } //! ``` //! //! Generates: //! - A `SearchIssuesParameters` struct (serde Deserialize) //! - A `SEARCH_ISSUES_DEFINITION: ToolDefinition` constant //! - A `register_search_issues(registry: &mut ToolRegistry)` helper extern crate proc_macro; use convert_case::{Case, Casing}; use proc_macro::TokenStream; use quote::{format_ident, quote}; use std::collections::HashMap; use syn::punctuated::Punctuated; use syn::{ Expr, ExprLit, Ident, Lit, Meta, ReturnType, Token, Type, parse::{Parse, ParseStream}, }; /// Parse the attribute arguments: `description = "...", params(...), required(...)` struct ToolArgs { description: Option, param_descriptions: HashMap, required: Vec, } impl Parse for ToolArgs { fn parse(input: ParseStream) -> syn::Result { Self::parse_from(input) } } impl ToolArgs { fn new() -> Self { Self { description: None, param_descriptions: HashMap::new(), required: Vec::new(), } } fn parse_from(input: ParseStream) -> syn::Result { let mut this = Self::new(); if input.is_empty() { return Ok(this); } let meta_list: Punctuated = Punctuated::parse_terminated(input)?; for meta in meta_list { match meta { Meta::NameValue(nv) => { let ident = nv .path .get_ident() .ok_or_else(|| syn::Error::new_spanned(&nv.path, "expected identifier"))?; if ident == "description" { if let Expr::Lit(ExprLit { lit: Lit::Str(s), .. }) = nv.value { this.description = Some(s.value()); } else { return Err(syn::Error::new_spanned( &nv.value, "description must be a string literal", )); } } } Meta::List(list) if list.path.is_ident("params") => { let inner: Punctuated = list.parse_args_with(Punctuated::parse_terminated)?; for item in inner { if let Meta::NameValue(nv) = item { let param_name = nv .path .get_ident() .ok_or_else(|| { syn::Error::new_spanned(&nv.path, "expected identifier") })? .to_string(); if let Expr::Lit(ExprLit { lit: Lit::Str(s), .. }) = nv.value { this.param_descriptions.insert(param_name, s.value()); } } } } Meta::List(list) if list.path.is_ident("required") => { let required_vars: Punctuated = list.parse_args_with(Punctuated::parse_terminated)?; for var in required_vars { this.required.push(var.to_string()); } } _ => {} } } Ok(this) } } /// Map a Rust type to its JSON Schema type name. fn json_type(ty: &Type) -> proc_macro2::TokenStream { use syn::Type as T; let segs = match ty { T::Path(p) => &p.path.segments, _ => return quote! { "type": "object" }, }; let last = segs.last().map(|s| &s.ident); let args = segs.last().and_then(|s| { if let syn::PathArguments::AngleBracketed(a) = &s.arguments { Some(&a.args) } else { None } }); match (last.map(|i| i.to_string()).as_deref(), args) { (Some("Vec" | "vec::Vec"), Some(args)) if !args.is_empty() => { if let syn::GenericArgument::Type(inner) = &args[0] { let inner_type = json_type(inner); return quote! { { "type": "array", "items": { #inner_type } } }; } quote! { "type": "array" } } (Some("String" | "str" | "char"), _) => quote! { "type": "string" }, (Some("bool"), _) => quote! { "type": "boolean" }, (Some("i8" | "i16" | "i32" | "i64" | "isize"), _) => quote! { "type": "integer" }, (Some("u8" | "u16" | "u32" | "u64" | "usize"), _) => quote! { "type": "integer" }, (Some("f32" | "f64"), _) => quote! { "type": "number" }, _ => quote! { "type": "object" }, } } /// Extract return type info from `-> Result`. fn parse_return_type( ret: &ReturnType, ) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> { match ret { ReturnType::Type(_, ty) => { let ty = &**ty; if let Type::Path(p) = ty { let last = p .path .segments .last() .ok_or_else(|| syn::Error::new_spanned(&p.path, "invalid return type"))?; if last.ident == "Result" { if let syn::PathArguments::AngleBracketed(a) = &last.arguments { let args = &a.args; if args.len() == 2 { let ok = &args[0]; let err = &args[1]; return Ok((quote!(#ok), quote!(#err))); } } return Err(syn::Error::new_spanned( &last, "Result must have 2 type parameters", )); } } Err(syn::Error::new_spanned( ty, "function must return Result", )) } _ => Err(syn::Error::new_spanned( ret, "function must have a return type", )), } } /// The `#[tool]` attribute macro. /// /// Usage: /// ``` /// #[tool(description = "Tool description", params( /// arg1 = "Description of arg1", /// arg2 = "Description of arg2", /// ))] /// async fn my_tool(arg1: String, arg2: Option) -> Result { /// Ok(serde_json::json!({})) /// } /// ``` /// /// Generates: /// - `MyToolParameters` struct with serde Deserialize /// - `MY_TOOL_DEFINITION: ToolDefinition` constant /// - `register_my_tool(registry: &mut ToolRegistry)` helper function #[proc_macro_attribute] pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream { let args = syn::parse_macro_input!(args as ToolArgs); let input_fn = syn::parse_macro_input!(input as syn::ItemFn); let fn_name = &input_fn.sig.ident; let fn_name_str = fn_name.to_string(); let vis = &input_fn.vis; let is_async = input_fn.sig.asyncness.is_some(); // Parse return type: Result let (_output_type, _error_type) = match parse_return_type(&input_fn.sig.output) { Ok(t) => t, Err(e) => return e.into_compile_error().into(), }; // PascalCase struct name let struct_name = format_ident!("{}", fn_name_str.to_case(Case::Pascal)); let params_struct_name = format_ident!("{}Parameters", struct_name); let definition_const_name = format_ident!("{}_DEFINITION", fn_name_str.to_uppercase()); let register_fn_name = format_ident!("register_{}", fn_name_str); // Extract parameters from function signature let mut param_names: Vec = Vec::new(); let mut param_types: Vec = Vec::new(); let mut param_json_types: Vec = Vec::new(); let mut param_descs: Vec = Vec::new(); let required_args = args.required.clone(); for arg in &input_fn.sig.inputs { let syn::FnArg::Typed(pat_type) = arg else { continue; }; let syn::Pat::Ident(pat_ident) = &*pat_type.pat else { continue; }; let name = &pat_ident.ident; let ty = &*pat_type.ty; let name_str = name.to_string(); let desc = args .param_descriptions .get(&name_str) .map(|s| quote! { #s.to_string() }) .unwrap_or_else(|| quote! { format!("Parameter {}", #name_str) }); param_names.push(format_ident!("{}", name.to_string())); param_types.push(ty.clone()); param_json_types.push(json_type(ty)); param_descs.push(desc); } // Which params are required (not Option) let required: Vec = if required_args.is_empty() { param_names .iter() .filter(|name| { let name_str = name.to_string(); !args .param_descriptions .contains_key(&format!("{}_opt", name_str)) }) .map(|name| quote! { stringify!(#name) }) .collect() } else { required_args.iter().map(|s| quote! { #s }).collect() }; // Tool description let tool_description = args .description .map(|s| quote! { #s.to_string() }) .unwrap_or_else(|| quote! { format!("Function {}", #fn_name_str) }); // Call invocation (async vs sync) let call_args = param_names.iter().map(|n| quote! { args.#n }); let fn_call = if is_async { quote! { #fn_name(#(#call_args),*).await } } else { quote! { #fn_name(#(#call_args),*) } }; let expanded = quote! { // Parameters struct: deserialized from JSON args by serde #[derive(serde::Deserialize)] #vis struct #params_struct_name { #(#vis #param_names: #param_types,)* } // Keep the original function unchanged #input_fn // Static ToolDefinition constant — register this with ToolRegistry #vis const #definition_const_name: agent::ToolDefinition = agent::ToolDefinition { name: #fn_name_str.to_string(), description: Some(#tool_description), parameters: Some(agent::ToolSchema { schema_type: "object".to_string(), properties: Some({ let mut map = std::collections::HashMap::new(); #({ map.insert(stringify!(#param_names).to_string(), agent::ToolParam { name: stringify!(#param_names).to_string(), param_type: { let jt = #param_json_types; jt.get("type") .and_then(|v| v.as_str()) .unwrap_or("object") .to_string() }, description: Some(#param_descs), required: true, properties: None, items: None, }); })* map }), required: Some(vec![#(#required.to_string()),*]), }), strict: false, }; /// Registers this tool in the given registry. /// /// Generated by `#[tool]` macro for function `#fn_name_str`. #vis fn #register_fn_name(registry: &mut agent::ToolRegistry) { let def = #definition_const_name.clone(); let fn_name = #fn_name_str.to_string(); registry.register_fn(fn_name, move |_ctx, args| { let args: #params_struct_name = match serde_json::from_value(args) { Ok(a) => a, Err(e) => { return std::pin::Pin::new(Box::new(async move { Err(agent::ToolError::ParseError(e.to_string())) })) } }; std::pin::Pin::new(Box::new(async move { let result = #fn_call; match result { Ok(v) => Ok(serde_json::to_value(v).unwrap_or(serde_json::Value::Null)), Err(e) => Err(agent::ToolError::ExecutionError(e.to_string())), } })) }); } }; // We need to use boxed futures for the return type. // Since we can't add runtime dependencies to the proc-macro crate, // we emit the .boxed() call and the caller must ensure // `use futures::FutureExt;` or equivalent is in scope. // The generated code requires: `futures::FutureExt` (for .boxed()). // Re-emit with futures dependency note TokenStream::from(expanded) }