146 lines
3.5 KiB
Rust
146 lines
3.5 KiB
Rust
use async_trait::async_trait;
|
|
use serde_json::Value;
|
|
|
|
use crate::agent::persistence::AgentRunContext;
|
|
use crate::error::AiResult;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum ToolGuardrailDecision {
|
|
Allow,
|
|
Block { reason: String },
|
|
RequireApproval { message: String },
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ToolCallOutcome {
|
|
pub name: String,
|
|
pub arguments: Value,
|
|
pub output: Option<Value>,
|
|
pub error: Option<String>,
|
|
pub elapsed_ms: i64,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct HookMessage {
|
|
pub role: String,
|
|
pub content: Option<String>,
|
|
pub tool_calls: Option<Value>,
|
|
pub tool_call_id: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct HookLlmResponse {
|
|
pub content: Option<String>,
|
|
pub tool_calls: Option<Value>,
|
|
pub input_tokens: u64,
|
|
pub output_tokens: u64,
|
|
pub finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct HookToolDef {
|
|
pub name: String,
|
|
pub description: String,
|
|
}
|
|
|
|
#[async_trait]
|
|
pub trait AgentHook: Send + Sync {
|
|
fn name(&self) -> &'static str;
|
|
|
|
async fn on_session_start(&self, _ctx: &AgentRunContext) -> AiResult<()> {
|
|
Ok(())
|
|
}
|
|
|
|
async fn on_session_end(&self, _ctx: &AgentRunContext, _success: bool) -> AiResult<()> {
|
|
Ok(())
|
|
}
|
|
|
|
async fn pre_llm_call(&self, _messages: &[HookMessage], _tools: &[HookToolDef]) -> AiResult<()> {
|
|
Ok(())
|
|
}
|
|
|
|
async fn post_llm_call(&self, _response: &HookLlmResponse) -> AiResult<()> {
|
|
Ok(())
|
|
}
|
|
|
|
async fn pre_tool_call(
|
|
&self,
|
|
_tool_name: &str,
|
|
_arguments: &Value,
|
|
) -> AiResult<Option<ToolGuardrailDecision>> {
|
|
Ok(None)
|
|
}
|
|
|
|
async fn post_tool_call(&self, _outcome: &ToolCallOutcome) -> AiResult<()> {
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
pub struct HookChain {
|
|
hooks: Vec<Box<dyn AgentHook>>,
|
|
}
|
|
|
|
impl HookChain {
|
|
pub fn new(hooks: Vec<Box<dyn AgentHook>>) -> Self {
|
|
Self { hooks }
|
|
}
|
|
|
|
pub fn empty() -> Self {
|
|
Self { hooks: Vec::new() }
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
self.hooks.is_empty()
|
|
}
|
|
|
|
pub async fn run_session_start(&self, ctx: &AgentRunContext) -> AiResult<()> {
|
|
for hook in &self.hooks {
|
|
hook.on_session_start(ctx).await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn run_session_end(&self, ctx: &AgentRunContext, success: bool) -> AiResult<()> {
|
|
for hook in &self.hooks {
|
|
hook.on_session_end(ctx, success).await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn run_pre_llm_call(&self, messages: &[HookMessage], tools: &[HookToolDef]) -> AiResult<()> {
|
|
for hook in &self.hooks {
|
|
hook.pre_llm_call(messages, tools).await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn run_post_llm_call(&self, response: &HookLlmResponse) -> AiResult<()> {
|
|
for hook in &self.hooks {
|
|
hook.post_llm_call(response).await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn run_pre_tool_call(
|
|
&self,
|
|
tool_name: &str,
|
|
arguments: &Value,
|
|
) -> AiResult<Option<ToolGuardrailDecision>> {
|
|
for hook in &self.hooks {
|
|
if let Some(decision) = hook.pre_tool_call(tool_name, arguments).await? {
|
|
if !matches!(decision, ToolGuardrailDecision::Allow) {
|
|
return Ok(Some(decision));
|
|
}
|
|
}
|
|
}
|
|
Ok(None)
|
|
}
|
|
|
|
pub async fn run_post_tool_call(&self, outcome: &ToolCallOutcome) -> AiResult<()> {
|
|
for hook in &self.hooks {
|
|
hook.post_tool_call(outcome).await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|