241 lines
6.5 KiB
Rust
241 lines
6.5 KiB
Rust
use std::time::Duration;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
use tokio_util::sync::CancellationToken;
|
|
|
|
use super::persistence::AgentRunContext;
|
|
use crate::error::{AiError, AiResult};
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct AgentRequest {
|
|
pub input: String,
|
|
pub messages: Vec<AgentMessage>,
|
|
pub context: Vec<AgentContextChunk>,
|
|
pub experts: Vec<AgentExpert>,
|
|
pub run_context: Option<AgentRunContext>,
|
|
#[serde(skip)]
|
|
pub prefill_messages: Vec<rig::completion::Message>,
|
|
#[serde(skip)]
|
|
pub cancellation_token: Option<CancellationToken>,
|
|
#[serde(skip)]
|
|
pub timeout: Option<Duration>,
|
|
}
|
|
|
|
impl AgentRequest {
|
|
pub fn new(input: impl Into<String>) -> Self {
|
|
Self {
|
|
input: input.into(),
|
|
messages: Vec::new(),
|
|
context: Vec::new(),
|
|
experts: Vec::new(),
|
|
run_context: None,
|
|
prefill_messages: Vec::new(),
|
|
cancellation_token: None,
|
|
timeout: None,
|
|
}
|
|
}
|
|
|
|
pub fn validate(&self) -> AiResult<()> {
|
|
if self.input.trim().is_empty() {
|
|
return Err(AiError::Config("agent request input is required".to_string()));
|
|
}
|
|
if self.input.len() > 1_000_000 {
|
|
return Err(AiError::Config(
|
|
"agent request input exceeds maximum length (1MB)".to_string(),
|
|
));
|
|
}
|
|
if self.experts.len() > 32 {
|
|
return Err(AiError::Config(
|
|
"agent request experts count exceeds maximum (32)".to_string(),
|
|
));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub fn with_messages(mut self, messages: Vec<AgentMessage>) -> Self {
|
|
self.messages = messages;
|
|
self
|
|
}
|
|
|
|
pub fn with_context(mut self, context: Vec<AgentContextChunk>) -> Self {
|
|
self.context = context;
|
|
self
|
|
}
|
|
|
|
pub fn add_context(mut self, chunk: AgentContextChunk) -> Self {
|
|
self.context.push(chunk);
|
|
self
|
|
}
|
|
|
|
pub fn with_experts(mut self, experts: Vec<AgentExpert>) -> Self {
|
|
self.experts = experts;
|
|
self
|
|
}
|
|
|
|
pub fn add_expert(mut self, expert: AgentExpert) -> Self {
|
|
self.experts.push(expert);
|
|
self
|
|
}
|
|
|
|
pub fn with_run_context(mut self, run_context: AgentRunContext) -> Self {
|
|
self.run_context = Some(run_context);
|
|
self
|
|
}
|
|
|
|
pub fn with_prefill_messages(mut self, prefill_messages: Vec<rig::completion::Message>) -> Self {
|
|
self.prefill_messages = prefill_messages;
|
|
self
|
|
}
|
|
|
|
pub fn with_cancellation_token(mut self, cancellation_token: CancellationToken) -> Self {
|
|
self.cancellation_token = Some(cancellation_token);
|
|
self
|
|
}
|
|
|
|
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
|
self.timeout = Some(timeout);
|
|
self
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub enum AgentMessage {
|
|
User(String),
|
|
Assistant(String),
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct AgentExpert {
|
|
pub id: String,
|
|
pub role: String,
|
|
pub task: String,
|
|
pub system_prompt: Option<String>,
|
|
pub context: Vec<AgentContextChunk>,
|
|
/// Override the master agent's temperature for this subagent.
|
|
pub temperature: Option<f64>,
|
|
/// Override the master agent's max_completion_tokens for this subagent.
|
|
pub max_completion_tokens: Option<u64>,
|
|
}
|
|
|
|
impl AgentExpert {
|
|
pub fn new(id: impl Into<String>, role: impl Into<String>, task: impl Into<String>) -> Self {
|
|
Self {
|
|
id: id.into(),
|
|
role: role.into(),
|
|
task: task.into(),
|
|
system_prompt: None,
|
|
context: Vec::new(),
|
|
temperature: None,
|
|
max_completion_tokens: None,
|
|
}
|
|
}
|
|
|
|
pub fn with_system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
|
|
self.system_prompt = Some(system_prompt.into());
|
|
self
|
|
}
|
|
|
|
pub fn with_context(mut self, context: Vec<AgentContextChunk>) -> Self {
|
|
self.context = context;
|
|
self
|
|
}
|
|
|
|
pub fn with_temperature(mut self, temperature: f64) -> Self {
|
|
self.temperature = Some(temperature);
|
|
self
|
|
}
|
|
|
|
pub fn with_max_completion_tokens(mut self, max_tokens: u64) -> Self {
|
|
self.max_completion_tokens = Some(max_tokens);
|
|
self
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct AgentContextChunk {
|
|
pub id: String,
|
|
pub content: String,
|
|
pub source: Option<String>,
|
|
pub score: Option<f32>,
|
|
pub metadata: Value,
|
|
}
|
|
|
|
impl AgentContextChunk {
|
|
pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
|
|
Self {
|
|
id: id.into(),
|
|
content: content.into(),
|
|
source: None,
|
|
score: None,
|
|
metadata: Value::Null,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<crate::rag::RagSearchHit> for AgentContextChunk {
|
|
fn from(hit: crate::rag::RagSearchHit) -> Self {
|
|
Self {
|
|
id: hit.id,
|
|
content: hit.content,
|
|
source: Some(hit.session_id),
|
|
score: Some(hit.score),
|
|
metadata: Value::Object(hit.metadata.into_iter().collect()),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<&AgentExpertOutput> for AgentContextChunk {
|
|
fn from(output: &AgentExpertOutput) -> Self {
|
|
Self {
|
|
id: format!("subagent:{}", output.id),
|
|
content: output.output.clone(),
|
|
source: Some(output.role.clone()),
|
|
score: None,
|
|
metadata: serde_json::json!({
|
|
"kind": "subagent",
|
|
"task": output.task,
|
|
}),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct AgentResult {
|
|
pub output: String,
|
|
pub steps: Vec<AgentStep>,
|
|
pub expert_outputs: Vec<AgentExpertOutput>,
|
|
pub input_tokens: i64,
|
|
pub output_tokens: i64,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct AgentStep {
|
|
pub index: usize,
|
|
pub assistant: Option<String>,
|
|
pub reasoning_content: Option<String>,
|
|
pub tool_calls: Vec<ToolCallRecord>,
|
|
pub reflection: Option<String>,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct ToolCallRecord {
|
|
pub id: String,
|
|
pub name: String,
|
|
pub arguments: Value,
|
|
pub output: Option<Value>,
|
|
pub error: Option<String>,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
pub elapsed_ms: Option<i64>,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct AgentExpertOutput {
|
|
pub id: String,
|
|
pub role: String,
|
|
pub task: String,
|
|
pub output: String,
|
|
pub input_tokens: i64,
|
|
pub output_tokens: i64,
|
|
}
|