374 lines
13 KiB
Rust
374 lines
13 KiB
Rust
//! Procedural macro for generating tool definitions from functions.
|
|
//!
|
|
//! # Example
|
|
//!
|
|
//! ```
|
|
//! use agent_tool_derive::tool;
|
|
//!
|
|
//! #[tool(description = "Search issues by title")]
|
|
//! fn search_issues(
|
|
//! title: String,
|
|
//! status: Option<String>,
|
|
//! ) -> Result<Vec<serde_json::Value>, String> {
|
|
//! Ok(vec![])
|
|
//! }
|
|
//! ```
|
|
//!
|
|
//! Generates:
|
|
//! - A `SearchIssuesParameters` struct (serde Deserialize)
|
|
//! - A `SEARCH_ISSUES_DEFINITION: ToolDefinition` constant
|
|
//! - A `register_search_issues(registry: &mut ToolRegistry)` helper
|
|
|
|
extern crate proc_macro;
|
|
|
|
use convert_case::{Case, Casing};
|
|
use proc_macro::TokenStream;
|
|
use quote::{format_ident, quote};
|
|
use std::collections::HashMap;
|
|
use syn::punctuated::Punctuated;
|
|
use syn::{
|
|
Expr, ExprLit, Ident, Lit, Meta, ReturnType, Token, Type,
|
|
parse::{Parse, ParseStream},
|
|
};
|
|
|
|
/// Parse the attribute arguments: `description = "...", params(...), required(...)`
|
|
struct ToolArgs {
|
|
description: Option<String>,
|
|
param_descriptions: HashMap<String, String>,
|
|
required: Vec<String>,
|
|
}
|
|
|
|
impl Parse for ToolArgs {
|
|
fn parse(input: ParseStream) -> syn::Result<Self> {
|
|
Self::parse_from(input)
|
|
}
|
|
}
|
|
|
|
impl ToolArgs {
|
|
fn new() -> Self {
|
|
Self {
|
|
description: None,
|
|
param_descriptions: HashMap::new(),
|
|
required: Vec::new(),
|
|
}
|
|
}
|
|
|
|
fn parse_from(input: ParseStream) -> syn::Result<Self> {
|
|
let mut this = Self::new();
|
|
if input.is_empty() {
|
|
return Ok(this);
|
|
}
|
|
|
|
let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
|
|
|
|
for meta in meta_list {
|
|
match meta {
|
|
Meta::NameValue(nv) => {
|
|
let ident = nv
|
|
.path
|
|
.get_ident()
|
|
.ok_or_else(|| syn::Error::new_spanned(&nv.path, "expected identifier"))?;
|
|
if ident == "description" {
|
|
if let Expr::Lit(ExprLit {
|
|
lit: Lit::Str(s), ..
|
|
}) = nv.value
|
|
{
|
|
this.description = Some(s.value());
|
|
} else {
|
|
return Err(syn::Error::new_spanned(
|
|
&nv.value,
|
|
"description must be a string literal",
|
|
));
|
|
}
|
|
}
|
|
}
|
|
Meta::List(list) if list.path.is_ident("params") => {
|
|
let inner: Punctuated<Meta, Token![,]> =
|
|
list.parse_args_with(Punctuated::parse_terminated)?;
|
|
for item in inner {
|
|
if let Meta::NameValue(nv) = item {
|
|
let param_name = nv
|
|
.path
|
|
.get_ident()
|
|
.ok_or_else(|| {
|
|
syn::Error::new_spanned(&nv.path, "expected identifier")
|
|
})?
|
|
.to_string();
|
|
if let Expr::Lit(ExprLit {
|
|
lit: Lit::Str(s), ..
|
|
}) = nv.value
|
|
{
|
|
this.param_descriptions.insert(param_name, s.value());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Meta::List(list) if list.path.is_ident("required") => {
|
|
let required_vars: Punctuated<Ident, Token![,]> =
|
|
list.parse_args_with(Punctuated::parse_terminated)?;
|
|
for var in required_vars {
|
|
this.required.push(var.to_string());
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
Ok(this)
|
|
}
|
|
}
|
|
|
|
/// Map a Rust type to its JSON Schema type name.
|
|
fn json_type(ty: &Type) -> proc_macro2::TokenStream {
|
|
use syn::Type as T;
|
|
let segs = match ty {
|
|
T::Path(p) => &p.path.segments,
|
|
_ => return quote! { "type": "object" },
|
|
};
|
|
let last = segs.last().map(|s| &s.ident);
|
|
let args = segs.last().and_then(|s| {
|
|
if let syn::PathArguments::AngleBracketed(a) = &s.arguments {
|
|
Some(&a.args)
|
|
} else {
|
|
None
|
|
}
|
|
});
|
|
|
|
match (last.map(|i| i.to_string()).as_deref(), args) {
|
|
(Some("Vec" | "vec::Vec"), Some(args)) if !args.is_empty() => {
|
|
if let syn::GenericArgument::Type(inner) = &args[0] {
|
|
let inner_type = json_type(inner);
|
|
return quote! {
|
|
{
|
|
"type": "array",
|
|
"items": { #inner_type }
|
|
}
|
|
};
|
|
}
|
|
quote! { "type": "array" }
|
|
}
|
|
(Some("String" | "str" | "char"), _) => quote! { "type": "string" },
|
|
(Some("bool"), _) => quote! { "type": "boolean" },
|
|
(Some("i8" | "i16" | "i32" | "i64" | "isize"), _) => quote! { "type": "integer" },
|
|
(Some("u8" | "u16" | "u32" | "u64" | "usize"), _) => quote! { "type": "integer" },
|
|
(Some("f32" | "f64"), _) => quote! { "type": "number" },
|
|
_ => quote! { "type": "object" },
|
|
}
|
|
}
|
|
|
|
/// Extract return type info from `-> Result<T, E>`.
|
|
fn parse_return_type(
|
|
ret: &ReturnType,
|
|
) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
|
|
match ret {
|
|
ReturnType::Type(_, ty) => {
|
|
let ty = &**ty;
|
|
if let Type::Path(p) = ty {
|
|
let last = p
|
|
.path
|
|
.segments
|
|
.last()
|
|
.ok_or_else(|| syn::Error::new_spanned(&p.path, "invalid return type"))?;
|
|
if last.ident == "Result" {
|
|
if let syn::PathArguments::AngleBracketed(a) = &last.arguments {
|
|
let args = &a.args;
|
|
if args.len() == 2 {
|
|
let ok = &args[0];
|
|
let err = &args[1];
|
|
return Ok((quote!(#ok), quote!(#err)));
|
|
}
|
|
}
|
|
return Err(syn::Error::new_spanned(
|
|
&last,
|
|
"Result must have 2 type parameters",
|
|
));
|
|
}
|
|
}
|
|
Err(syn::Error::new_spanned(
|
|
ty,
|
|
"function must return Result<T, E>",
|
|
))
|
|
}
|
|
_ => Err(syn::Error::new_spanned(
|
|
ret,
|
|
"function must have a return type",
|
|
)),
|
|
}
|
|
}
|
|
|
|
/// The `#[tool]` attribute macro.
|
|
///
|
|
/// Usage:
|
|
/// ```
|
|
/// #[tool(description = "Tool description", params(
|
|
/// arg1 = "Description of arg1",
|
|
/// arg2 = "Description of arg2",
|
|
/// ))]
|
|
/// async fn my_tool(arg1: String, arg2: Option<i32>) -> Result<serde_json::Value, String> {
|
|
/// Ok(serde_json::json!({}))
|
|
/// }
|
|
/// ```
|
|
///
|
|
/// Generates:
|
|
/// - `MyToolParameters` struct with serde Deserialize
|
|
/// - `MY_TOOL_DEFINITION: ToolDefinition` constant
|
|
/// - `register_my_tool(registry: &mut ToolRegistry)` helper function
|
|
#[proc_macro_attribute]
|
|
pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
|
|
let args = syn::parse_macro_input!(args as ToolArgs);
|
|
let input_fn = syn::parse_macro_input!(input as syn::ItemFn);
|
|
|
|
let fn_name = &input_fn.sig.ident;
|
|
let fn_name_str = fn_name.to_string();
|
|
let vis = &input_fn.vis;
|
|
let is_async = input_fn.sig.asyncness.is_some();
|
|
|
|
// Parse return type: Result<T, E>
|
|
let (_output_type, _error_type) = match parse_return_type(&input_fn.sig.output) {
|
|
Ok(t) => t,
|
|
Err(e) => return e.into_compile_error().into(),
|
|
};
|
|
|
|
// PascalCase struct name
|
|
let struct_name = format_ident!("{}", fn_name_str.to_case(Case::Pascal));
|
|
let params_struct_name = format_ident!("{}Parameters", struct_name);
|
|
let definition_const_name = format_ident!("{}_DEFINITION", fn_name_str.to_uppercase());
|
|
let register_fn_name = format_ident!("register_{}", fn_name_str);
|
|
|
|
// Extract parameters from function signature
|
|
let mut param_names: Vec<Ident> = Vec::new();
|
|
let mut param_types: Vec<Type> = Vec::new();
|
|
let mut param_json_types: Vec<proc_macro2::TokenStream> = Vec::new();
|
|
let mut param_descs: Vec<proc_macro2::TokenStream> = Vec::new();
|
|
|
|
let required_args = args.required.clone();
|
|
|
|
for arg in &input_fn.sig.inputs {
|
|
let syn::FnArg::Typed(pat_type) = arg else {
|
|
continue;
|
|
};
|
|
let syn::Pat::Ident(pat_ident) = &*pat_type.pat else {
|
|
continue;
|
|
};
|
|
let name = &pat_ident.ident;
|
|
let ty = &*pat_type.ty;
|
|
|
|
let name_str = name.to_string();
|
|
let desc = args
|
|
.param_descriptions
|
|
.get(&name_str)
|
|
.map(|s| quote! { #s.to_string() })
|
|
.unwrap_or_else(|| quote! { format!("Parameter {}", #name_str) });
|
|
|
|
param_names.push(format_ident!("{}", name.to_string()));
|
|
param_types.push(ty.clone());
|
|
param_json_types.push(json_type(ty));
|
|
param_descs.push(desc);
|
|
}
|
|
|
|
// Which params are required (not Option<T>)
|
|
let required: Vec<proc_macro2::TokenStream> = if required_args.is_empty() {
|
|
param_names
|
|
.iter()
|
|
.filter(|name| {
|
|
let name_str = name.to_string();
|
|
!args
|
|
.param_descriptions
|
|
.contains_key(&format!("{}_opt", name_str))
|
|
})
|
|
.map(|name| quote! { stringify!(#name) })
|
|
.collect()
|
|
} else {
|
|
required_args.iter().map(|s| quote! { #s }).collect()
|
|
};
|
|
|
|
// Tool description
|
|
let tool_description = args
|
|
.description
|
|
.map(|s| quote! { #s.to_string() })
|
|
.unwrap_or_else(|| quote! { format!("Function {}", #fn_name_str) });
|
|
|
|
// Call invocation (async vs sync)
|
|
let call_args = param_names.iter().map(|n| quote! { args.#n });
|
|
let fn_call = if is_async {
|
|
quote! { #fn_name(#(#call_args),*).await }
|
|
} else {
|
|
quote! { #fn_name(#(#call_args),*) }
|
|
};
|
|
|
|
let expanded = quote! {
|
|
// Parameters struct: deserialized from JSON args by serde
|
|
#[derive(serde::Deserialize)]
|
|
#vis struct #params_struct_name {
|
|
#(#vis #param_names: #param_types,)*
|
|
}
|
|
|
|
// Keep the original function unchanged
|
|
#input_fn
|
|
|
|
// Static ToolDefinition constant — register this with ToolRegistry
|
|
#vis const #definition_const_name: agent::ToolDefinition = agent::ToolDefinition {
|
|
name: #fn_name_str.to_string(),
|
|
description: Some(#tool_description),
|
|
parameters: Some(agent::ToolSchema {
|
|
schema_type: "object".to_string(),
|
|
properties: Some({
|
|
let mut map = std::collections::HashMap::new();
|
|
#({
|
|
map.insert(stringify!(#param_names).to_string(), agent::ToolParam {
|
|
name: stringify!(#param_names).to_string(),
|
|
param_type: {
|
|
let jt = #param_json_types;
|
|
jt.get("type")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("object")
|
|
.to_string()
|
|
},
|
|
description: Some(#param_descs),
|
|
required: true,
|
|
properties: None,
|
|
items: None,
|
|
});
|
|
})*
|
|
map
|
|
}),
|
|
required: Some(vec![#(#required.to_string()),*]),
|
|
}),
|
|
strict: false,
|
|
};
|
|
|
|
/// Registers this tool in the given registry.
|
|
///
|
|
/// Generated by `#[tool]` macro for function `#fn_name_str`.
|
|
#vis fn #register_fn_name(registry: &mut agent::ToolRegistry) {
|
|
let def = #definition_const_name.clone();
|
|
let fn_name = #fn_name_str.to_string();
|
|
registry.register_fn(fn_name, move |_ctx, args| {
|
|
let args: #params_struct_name = match serde_json::from_value(args) {
|
|
Ok(a) => a,
|
|
Err(e) => {
|
|
return std::pin::Pin::new(Box::new(async move {
|
|
Err(agent::ToolError::ParseError(e.to_string()))
|
|
}))
|
|
}
|
|
};
|
|
std::pin::Pin::new(Box::new(async move {
|
|
let result = #fn_call;
|
|
match result {
|
|
Ok(v) => Ok(serde_json::to_value(v).unwrap_or(serde_json::Value::Null)),
|
|
Err(e) => Err(agent::ToolError::ExecutionError(e.to_string())),
|
|
}
|
|
}))
|
|
});
|
|
}
|
|
};
|
|
|
|
// We need to use boxed futures for the return type.
|
|
// Since we can't add runtime dependencies to the proc-macro crate,
|
|
// we emit the .boxed() call and the caller must ensure
|
|
// `use futures::FutureExt;` or equivalent is in scope.
|
|
// The generated code requires: `futures::FutureExt` (for .boxed()).
|
|
|
|
// Re-emit with futures dependency note
|
|
TokenStream::from(expanded)
|
|
}
|