gitdataai/libs/agent/tool/context.rs

223 lines
6.4 KiB
Rust

//! Execution context passed to each tool handler.
//!
//! Carries runtime information a tool handler needs: database, cache,
//! request metadata, and the tool registry. Cheap to clone via `Arc`.
use std::sync::Arc;
use db::cache::AppCache;
use db::database::AppDatabase;
use config::AppConfig;
use queue::MessageProducer;
use uuid::Uuid;
use super::registry::ToolRegistry;
/// Context available during tool execution. Cheap to clone via `Arc`.
#[derive(Clone)]
pub struct ToolContext {
inner: Arc<Inner>,
}
#[derive(Clone)]
struct Inner {
pub db: AppDatabase,
pub cache: AppCache,
pub config: AppConfig,
pub room_id: Uuid,
pub sender_id: Option<Uuid>,
pub project_id: Uuid,
pub registry: ToolRegistry,
pub embed_service: Option<crate::embed::EmbedService>,
pub message_producer: Option<MessageProducer>,
/// When in room context, identifies the AI model that is responding.
/// Used by send_message/retract_message to set the correct sender.
pub ai_model_id: Option<Uuid>,
pub ai_model_name: Option<String>,
/// Message IDs sent by the AI in the current ReAct turn.
/// Shared across tool calls so send_message can register IDs
/// and retract_message can validate turn-scoped retraction.
pub sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<Uuid>>>,
depth: u32,
max_depth: u32,
tool_call_count: usize,
max_tool_calls: usize,
}
impl ToolContext {
pub fn new(
db: AppDatabase,
cache: AppCache,
config: AppConfig,
room_id: Uuid,
sender_id: Option<Uuid>,
) -> Self {
Self {
inner: Arc::new(Inner {
db,
cache,
config,
room_id,
sender_id,
project_id: Uuid::nil(),
registry: ToolRegistry::new(),
embed_service: None,
message_producer: None,
ai_model_id: None,
ai_model_name: None,
sent_in_turn: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
depth: 0,
max_depth: 5,
tool_call_count: 0,
max_tool_calls: 128,
}),
}
}
pub fn with_project(mut self, project_id: Uuid) -> Self {
Arc::make_mut(&mut self.inner).project_id = project_id;
self
}
pub fn with_registry(mut self, registry: ToolRegistry) -> Self {
Arc::make_mut(&mut self.inner).registry = registry;
self
}
pub fn with_max_depth(mut self, max_depth: u32) -> Self {
Arc::make_mut(&mut self.inner).max_depth = max_depth;
self
}
pub fn with_max_tool_calls(mut self, max: usize) -> Self {
Arc::make_mut(&mut self.inner).max_tool_calls = max;
self
}
pub fn with_embed_service(mut self, embed_service: crate::embed::EmbedService) -> Self {
Arc::make_mut(&mut self.inner).embed_service = Some(embed_service);
self
}
pub fn with_message_producer(mut self, producer: MessageProducer) -> Self {
Arc::make_mut(&mut self.inner).message_producer = Some(producer);
self
}
pub fn with_ai_model(mut self, model_id: Uuid, model_name: String) -> Self {
Arc::make_mut(&mut self.inner).ai_model_id = Some(model_id);
Arc::make_mut(&mut self.inner).ai_model_name = Some(model_name);
self
}
pub fn with_sent_in_turn(mut self, sent: std::sync::Arc<std::sync::Mutex<Vec<Uuid>>>) -> Self {
Arc::make_mut(&mut self.inner).sent_in_turn = sent;
self
}
/// Register a message ID as sent in the current turn (called by send_message).
pub fn register_sent_message(&self, id: Uuid) {
if let Ok(mut list) = self.inner.sent_in_turn.lock() {
list.push(id);
}
}
/// Check if a message ID was sent in the current turn (called by retract_message).
pub fn is_sent_in_turn(&self, id: Uuid) -> bool {
self.inner.sent_in_turn.lock()
.map(|list| list.contains(&id))
.unwrap_or(false)
}
pub fn embed_service(&self) -> Option<&crate::embed::EmbedService> {
self.inner.embed_service.as_ref()
}
/// Message queue producer for publishing room events (messages, retractions, etc.).
pub fn message_producer(&self) -> Option<&MessageProducer> {
self.inner.message_producer.as_ref()
}
pub fn recursion_exceeded(&self) -> bool {
self.inner.depth >= self.inner.max_depth
}
pub fn tool_calls_exceeded(&self) -> bool {
self.inner.tool_call_count >= self.inner.max_tool_calls
}
/// Current recursion depth.
pub fn depth(&self) -> u32 {
self.inner.depth
}
/// Current tool call count.
pub fn tool_call_count(&self) -> usize {
self.inner.tool_call_count
}
/// Increments the tool call count.
pub(crate) fn increment_tool_calls(&mut self) {
Arc::make_mut(&mut self.inner).tool_call_count += 1;
}
/// Returns a child context for a recursive tool call (depth + 1).
pub(crate) fn child_context(&self) -> Self {
let mut inner = (*self.inner).clone();
inner.depth += 1;
Self {
inner: Arc::new(inner),
}
}
/// Database connection.
pub fn db(&self) -> &AppDatabase {
&self.inner.db
}
/// Redis cache.
pub fn cache(&self) -> &AppCache {
&self.inner.cache
}
/// Application config.
pub fn config(&self) -> &AppConfig {
&self.inner.config
}
/// Room where the original message was sent.
pub fn room_id(&self) -> Uuid {
self.inner.room_id
}
/// User who sent the original message.
pub fn sender_id(&self) -> Option<Uuid> {
self.inner.sender_id
}
/// AI model ID when in room context (the AI that is responding).
pub fn ai_model_id(&self) -> Option<Uuid> {
self.inner.ai_model_id
}
/// AI model display name when in room context.
pub fn ai_model_name(&self) -> Option<String> {
self.inner.ai_model_name.clone()
}
/// Project context for the room.
pub fn project_id(&self) -> Uuid {
self.inner.project_id
}
/// Tool registry for this request.
pub fn registry(&self) -> &ToolRegistry {
&self.inner.registry
}
/// Mutable access to registry for adding tools.
pub fn registry_mut(&mut self) -> &mut ToolRegistry {
&mut Arc::make_mut(&mut self.inner).registry
}
}