diff --git a/libs/agent/agent/mod.rs b/libs/agent/agent/mod.rs index e301e1f..7d966af 100644 --- a/libs/agent/agent/mod.rs +++ b/libs/agent/agent/mod.rs @@ -1,7 +1,4 @@ -//! Agent service using rig's built-in Agent with full feature support. -//! -//! This module provides a higher-level agent service built on rig's Agent, -//! supporting multi-turn conversations, RAG, and built-in streaming. +//! Rig-based agent using rig's built-in Agent with full feature support. -pub mod service; -pub use service::RigAgentService; +pub mod rig_tool; +pub use rig_tool::{AgentResponse, RigAgentService, StreamChunk}; diff --git a/libs/agent/agent/service.rs b/libs/agent/agent/rig_tool.rs similarity index 62% rename from libs/agent/agent/service.rs rename to libs/agent/agent/rig_tool.rs index 3ac8fe1..869004e 100644 --- a/libs/agent/agent/service.rs +++ b/libs/agent/agent/rig_tool.rs @@ -1,23 +1,17 @@ -//! Agent service using rig's built-in Agent. -//! -//! This is a complete implementation that leverages rig's Agent for -//! multi-turn reasoning, tool execution, streaming, and token tracking. - use futures::Stream; use futures::StreamExt; use rig::{ agent::{AgentBuilder, MultiTurnStreamItem}, client::CompletionClient, completion::Prompt, - streaming::{StreamingPrompt, StreamedAssistantContent}, + streaming::{StreamedAssistantContent, StreamingPrompt}, }; -use tokio_stream::wrappers::ReceiverStream; use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; use crate::client::AiClientConfig; use crate::error::AgentError; -/// Response from an agent completion (rig's Agent prompt response). #[derive(Debug)] pub struct AgentResponse { pub content: String, @@ -25,12 +19,9 @@ pub struct AgentResponse { pub output_tokens: u64, } -/// Streaming chunk from the agent. #[derive(Debug)] pub enum StreamChunk { - /// Text delta from the model Text(String), - /// Final response with aggregated usage Final { content: String, input_tokens: u64, @@ -38,22 +29,19 @@ pub enum StreamChunk { }, } -/// Service for running agents using rig's built-in Agent. -/// -/// Provides both simple prompting and full streaming with automatic -/// tool call handling via rig's native Agent. pub struct RigAgentService { config: AiClientConfig, model_name: String, } impl RigAgentService { - /// Create a new RigAgentService. pub fn new(config: AiClientConfig, model_name: impl Into) -> Self { - Self { config, model_name: model_name.into() } + Self { + config, + model_name: model_name.into(), + } } - /// Run a single prompt with the agent (single-turn, no tools). pub async fn prompt( &self, system_prompt: &str, @@ -62,9 +50,7 @@ impl RigAgentService { let client = self.config.build_rig_client(); let model = client.completion_model(&self.model_name); - let agent = AgentBuilder::new(model) - .preamble(system_prompt) - .build(); + let agent = AgentBuilder::new(model).preamble(system_prompt).build(); let response = agent .prompt(user_input) @@ -74,15 +60,11 @@ impl RigAgentService { Ok(AgentResponse { content: response.output, - input_tokens: response.total_usage.input_tokens, - output_tokens: response.total_usage.output_tokens, + input_tokens: response.usage.input_tokens, + output_tokens: response.usage.output_tokens, }) } - /// Run a prompt with tools (supports multi-turn via rig's Agent). - /// - /// The agent will automatically handle tool calls by calling rig's - /// ToolDyn implementations with proper argument deserialization. pub async fn prompt_with_tools( &self, system_prompt: &str, @@ -108,16 +90,11 @@ impl RigAgentService { Ok(AgentResponse { content: response.output, - input_tokens: response.total_usage.input_tokens, - output_tokens: response.total_usage.output_tokens, + input_tokens: response.usage.input_tokens, + output_tokens: response.usage.output_tokens, }) } - /// Stream a prompt with the agent using rig's native streaming. - /// - /// This returns a proper async stream that yields text chunks as they arrive - /// and a final response chunk with aggregated token usage. Tool calls are - /// handled transparently by rig's Agent. pub async fn stream_prompt( &self, system_prompt: &str, @@ -129,17 +106,10 @@ impl RigAgentService { let client = self.config.build_rig_client(); let model = client.completion_model(&self.model_name); - let agent = AgentBuilder::new(model) - .preamble(system_prompt) - .build(); + let agent = AgentBuilder::new(model).preamble(system_prompt).build(); - // stream_prompt().await returns StreamingResult directly (not wrapped in Result) - // StreamingResult is Pin>>> - let stream: rig::agent::StreamingResult<_> = agent - .stream_prompt(user_input) - .await; + let stream: rig::agent::StreamingResult<_> = agent.stream_prompt(user_input).await; - // Bridge the rig stream to our channel-based stream let (tx, rx) = mpsc::channel::>(100); tokio::spawn(async move { @@ -152,12 +122,14 @@ impl RigAgentService { Ok(MultiTurnStreamItem::StreamAssistantItem( StreamedAssistantContent::Text(text), )) => { - let cleaned = text.text.replace('\n', ""); - let _ = tx.send(Ok(StreamChunk::Text(cleaned))).await; + let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await; final_content.push_str(&text.text); } Ok(MultiTurnStreamItem::StreamAssistantItem( - StreamedAssistantContent::ToolCall { tool_call, internal_call_id: _ }, + StreamedAssistantContent::ToolCall { + tool_call, + internal_call_id: _, + }, )) => { let args_str = match &tool_call.function.arguments { serde_json::Value::String(s) => s.clone(), @@ -168,7 +140,6 @@ impl RigAgentService { args = %args_str, "rig_agent_streaming_tool_call" ); - // Tool calllint — emitted for observability, rig handles execution internally } Ok(MultiTurnStreamItem::StreamUserItem( rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, @@ -180,11 +151,13 @@ impl RigAgentService { } Ok(MultiTurnStreamItem::FinalResponse(resp)) => { let usage = resp.usage(); - let _ = tx.send(Ok(StreamChunk::Final { - content: final_content.clone(), - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - })).await; + let _ = tx + .send(Ok(StreamChunk::Final { + content: final_content.clone(), + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + })) + .await; } Err(e) => { let _ = tx.send(Err(AgentError::OpenAi(e.to_string()))).await; @@ -197,10 +170,6 @@ impl RigAgentService { Ok(ReceiverStream::new(rx)) } - /// Stream a prompt with tools using rig's native streaming. - /// - /// Returns a stream thatproperly handles multi-turn tool calls via rig's Agent - /// streaming infrastructure. pub async fn stream_prompt_with_tools( &self, system_prompt: &str, @@ -222,33 +191,30 @@ impl RigAgentService { let stream = agent .stream_prompt(user_input) - .with_history(Vec::new()) + .with_history(Vec::::new()) .multi_turn(max_turns) .await; - - let (tx, rx) = mpsc::channel::>(100); - + let (tx, rx) = mpsc::channel::>(100); tokio::spawn(async move { let mut final_content = String::new(); - tokio::pin!(stream); - while let Some(item) = stream.next().await { match item { Ok(MultiTurnStreamItem::StreamAssistantItem( StreamedAssistantContent::Text(text), )) => { - let cleaned = text.text.replace('\n', ""); - let _ = tx.send(Ok(StreamChunk::Text(cleaned))).await; + let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await; final_content.push_str(&text.text); } Ok(MultiTurnStreamItem::FinalResponse(resp)) => { let usage = resp.usage(); - let _ = tx.send(Ok(StreamChunk::Final { - content: final_content.clone(), - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - })).await; + let _ = tx + .send(Ok(StreamChunk::Final { + content: final_content.clone(), + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + })) + .await; } Err(e) => { let _ = tx.send(Err(AgentError::OpenAi(e.to_string()))).await; @@ -261,9 +227,8 @@ impl RigAgentService { Ok(ReceiverStream::new(rx)) } - /// Count tokens in text using tiktoken for the configured model. - pub fn count_tokens(&self, text: &str) -> std::result::Result { + pub fn count_tokens(&self, text: &str) -> Result { crate::tokent::count_text(text, &self.model_name) .map_err(|e| AgentError::Internal(e.to_string())) } -} +} \ No newline at end of file diff --git a/libs/agent/chat/chat_execution.rs b/libs/agent/chat/chat_execution.rs index 1beb51c..0498046 100644 --- a/libs/agent/chat/chat_execution.rs +++ b/libs/agent/chat/chat_execution.rs @@ -195,7 +195,7 @@ pub async fn execute_chat_stream( &messages, model_name, config, temperature, max_tokens, if tools_enabled { Some(&tools) } else { None }, None, Arc::new(move |delta| { - let content = delta.to_string().replace('\n', ""); + let content = delta.to_string(); let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer }); fut }), diff --git a/libs/agent/chat/react_execution.rs b/libs/agent/chat/react_execution.rs index e94965f..90c4195 100644 --- a/libs/agent/chat/react_execution.rs +++ b/libs/agent/chat/react_execution.rs @@ -74,7 +74,7 @@ where .build(); let stream = agent.stream_prompt(&request.input) - .with_history(Vec::new()) + .with_history(Vec::::new()) .multi_turn(request.max_tool_depth) .await; @@ -90,12 +90,14 @@ where Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => { step_count += 1; let t = text.text; - let cleaned = t.replace('\n', ""); - on_chunk(ReactStep::Answer { step: step_count, answer: cleaned }).await; + on_chunk(ReactStep::Answer { step: step_count, answer: t.clone() }).await; final_content.push_str(&t); } Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(reasoning))) => { - let reasoning_text = reasoning.reasoning.join(""); + let reasoning_text: String = reasoning.content.iter().filter_map(|c| match c { + rig::completion::message::ReasoningContent::Text { text, .. } => Some(text.as_str()), + _ => None, + }).collect::>().join(""); if !reasoning_text.is_empty() { step_count += 1; on_chunk(ReactStep::Thought { step: step_count, thought: reasoning_text }).await; diff --git a/libs/agent/chat/streaming_execution.rs b/libs/agent/chat/streaming_execution.rs index 2d1d2d8..993688f 100644 --- a/libs/agent/chat/streaming_execution.rs +++ b/libs/agent/chat/streaming_execution.rs @@ -62,7 +62,7 @@ pub async fn execute_process_stream( &messages, &model_name, &config, temperature, max_tokens, if tools_enabled { Some(&tools) } else { None }, None, Arc::new(move |delta| { - let content = delta.to_string().replace('\n', ""); + let content = delta.to_string(); let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer }); fut }), diff --git a/libs/agent/client/mod.rs b/libs/agent/client/mod.rs index 415f054..1271883 100644 --- a/libs/agent/client/mod.rs +++ b/libs/agent/client/mod.rs @@ -160,8 +160,8 @@ fn ai_metrics() -> &'static AiMetrics { pub(crate) fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage { match msg.role.as_str() { "system" => { - // System messages are handled via preamble(), but we still - // need to return something. Return a system message as User for safety. + // System messages are handled via preamble(), not passed as messages. + // We still need to return a valid RigMessage variant. RigMessage::user(msg.content.as_deref().unwrap_or("")) } "user" => { @@ -263,9 +263,6 @@ async fn do_completion( where M: CompletionModel, { - let mut history: Vec = messages.iter().map(to_rig_message).collect(); - - // Extract preamble (first system message) and remove from history let preamble = messages .iter() .find(|m| m.role == "system") @@ -273,12 +270,6 @@ where .unwrap_or("") .to_string(); - history.retain(|m| !matches!(m, RigMessage::User { .. } | RigMessage::Assistant { .. })); - - // For tool_result messages, we need to add them back - // Actually, let's keep the approach: filter out system, add others back - // The rig completion request uses: preamble (system) + messages (conversation) - // For our messages: system → preamble, rest → messages let non_system: Vec = messages .iter() .filter(|m| m.role != "system") @@ -700,13 +691,15 @@ async fn call_stream_once( } } Ok(StreamedAssistantContent::Reasoning(reasoning)) => { - for part in &reasoning.reasoning { - reasoning_content.push_str(part); - on_reasoning_delta(part).await; - chunks.push(StreamChunk { - chunk_type: StreamChunkType::Thinking, - content: part.clone(), - }); + for part in &reasoning.content { + if let rig::completion::message::ReasoningContent::Text { text, .. } = part { + reasoning_content.push_str(text); + on_reasoning_delta(text).await; + chunks.push(StreamChunk { + chunk_type: StreamChunkType::Thinking, + content: text.clone(), + }); + } } } Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => { diff --git a/libs/agent/compact/auth_fetch.rs b/libs/agent/compact/auth_fetch.rs new file mode 100644 index 0000000..0cb00da --- /dev/null +++ b/libs/agent/compact/auth_fetch.rs @@ -0,0 +1,39 @@ +use crate::AgentError; +use models::rooms::room_message::{ + Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel, +}; +use models::Expr; +use sea_orm::*; + +impl super::CompactService { + pub async fn fetch_room_messages_secure( + &self, + room_id: uuid::Uuid, + requester_id: uuid::Uuid, + ) -> Result, AgentError> { + use models::rooms::{RoomAccess, RoomUserState}; + + RoomMessage::find() + .filter(RmCol::Room.eq(room_id)) + .filter( + Condition::any() + .add(Expr::exists( + RoomUserState::find() + .filter(models::rooms::room_user_state::Column::Room.eq(room_id)) + .filter(models::rooms::room_user_state::Column::User.eq(requester_id)) + .into_query(), + )) + .add(Expr::exists( + RoomAccess::find() + .filter(models::rooms::room_access::Column::Room.eq(room_id)) + .filter(models::rooms::room_access::Column::User.eq(requester_id)) + .into_query(), + )), + ) + .order_by_asc(RmCol::Seq) + .limit(10000) + .all(&self.db) + .await + .map_err(|e| AgentError::Internal(e.to_string())) + } +} diff --git a/libs/agent/compact/mod.rs b/libs/agent/compact/mod.rs index 7e2b56b..04d74b7 100644 --- a/libs/agent/compact/mod.rs +++ b/libs/agent/compact/mod.rs @@ -1,8 +1,32 @@ //! Context compaction for AI sessions and room message history. +pub mod auth_fetch; pub mod helpers; -pub mod service; +pub mod room_compactor; +pub mod summarizer; pub mod types; -pub use service::CompactService; +use sea_orm::DatabaseConnection; + pub use types::{CompactConfig, CompactLevel, CompactSummary, MessageSummary, ThresholdResult}; + +#[derive(Clone)] +pub struct CompactService { + db: DatabaseConnection, + ai_client_config: crate::client::AiClientConfig, + model: String, +} + +impl CompactService { + pub fn new( + db: DatabaseConnection, + ai_client_config: crate::client::AiClientConfig, + model: String, + ) -> Self { + Self { + db, + ai_client_config, + model, + } + } +} diff --git a/libs/agent/compact/room_compactor.rs b/libs/agent/compact/room_compactor.rs new file mode 100644 index 0000000..c249ba4 --- /dev/null +++ b/libs/agent/compact/room_compactor.rs @@ -0,0 +1,180 @@ +use models::rooms::room_message::{ + Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel, +}; +use sea_orm::ColumnTrait; +use sea_orm::{EntityTrait, QueryFilter, QueryOrder, QuerySelect}; + +use crate::compact::types::CompactLevel; +use crate::tokent::resolve_usage; +use crate::{AgentError, CompactSummary, MessageSummary}; + +impl super::CompactService { + pub async fn compact_room( + &self, + room_id: uuid::Uuid, + level: CompactLevel, + user_names: Option>, + requester_id: uuid::Uuid, + context_window_tokens: i32, + compaction_max_summary_ratio: f32, + ) -> Result { + let messages = self + .fetch_room_messages_secure(room_id, requester_id) + .await?; + + if messages.is_empty() { + let room_exists = models::rooms::room::Entity::find_by_id(room_id) + .one(&self.db) + .await + .map_err(|e| AgentError::Internal(e.to_string()))? + .is_some(); + + if room_exists { + return Err(AgentError::Internal("Access denied or room empty".into())); + } else { + return Err(AgentError::Internal("Room not found".into())); + } + } + + let user_ids: Vec = messages + .iter() + .filter_map(|m| m.sender_id) + .collect::>() + .into_iter() + .collect(); + let user_name_map = match user_names { + Some(map) => map, + None => self.get_user_name_map(&user_ids).await?, + }; + + if messages.len() <= level.retain_count() { + let retained: Vec = messages + .iter() + .map(|m| Self::message_to_summary(m, &user_name_map)) + .collect(); + return Ok(CompactSummary { + session_id: uuid::Uuid::new_v4(), + room_id, + retained, + summary: String::new(), + compacted_at: chrono::Utc::now(), + messages_compressed: 0, + usage: None, + }); + } + + let retain_count = level.retain_count(); + let split_index = messages.len().saturating_sub(retain_count); + let (to_summarize, retained_messages) = messages.split_at(split_index); + + let retained: Vec = retained_messages + .iter() + .map(|m| Self::message_to_summary(m, &user_name_map)) + .collect(); + + let max_summary_tokens = + (context_window_tokens as f32 * compaction_max_summary_ratio) as usize; + + let (summary, remote_usage) = self + .summarize_messages(to_summarize, max_summary_tokens) + .await?; + + let summarized_text = to_summarize + .iter() + .map(|m| m.content.as_str()) + .collect::>() + .join("\n"); + let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary); + + Ok(CompactSummary { + session_id: uuid::Uuid::new_v4(), + room_id, + retained, + summary, + compacted_at: chrono::Utc::now(), + messages_compressed: to_summarize.len(), + usage: Some(usage), + }) + } + + pub async fn compact_session( + &self, + session_id: uuid::Uuid, + level: CompactLevel, + user_names: Option>, + context_window_tokens: i32, + compaction_max_summary_ratio: f32, + ) -> Result { + let messages: Vec = RoomMessage::find() + .filter(RmCol::Room.eq(session_id)) + .order_by_asc(RmCol::Seq) + .limit(10000) + .all(&self.db) + .await + .map_err(|e| AgentError::Internal(e.to_string()))?; + + if messages.is_empty() { + return Err(AgentError::Internal("session has no messages".into())); + } + + let user_ids: Vec = messages + .iter() + .filter_map(|m| m.sender_id) + .collect::>() + .into_iter() + .collect(); + let user_name_map = match user_names { + Some(map) => map, + None => self.get_user_name_map(&user_ids).await?, + }; + + if messages.len() <= level.retain_count() { + let retained: Vec = messages + .iter() + .map(|m| Self::message_to_summary(m, &user_name_map)) + .collect(); + return Ok(CompactSummary { + session_id, + room_id: uuid::Uuid::nil(), + retained, + summary: String::new(), + compacted_at: chrono::Utc::now(), + messages_compressed: 0, + usage: None, + }); + } + + let retain_count = level.retain_count(); + let split_index = messages.len().saturating_sub(retain_count); + let (to_summarize, retained_messages) = messages.split_at(split_index); + + let retained: Vec = retained_messages + .iter() + .map(|m| Self::message_to_summary(m, &user_name_map)) + .collect(); + + let max_summary_tokens = + (context_window_tokens as f32 * compaction_max_summary_ratio) as usize; + + let (summary, remote_usage) = self + .summarize_messages(to_summarize, max_summary_tokens) + .await?; + + let summarized_text = to_summarize + .iter() + .map(|m| m.content.as_str()) + .collect::>() + .join("\n"); + let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary); + + Ok(CompactSummary { + session_id, + room_id: uuid::Uuid::nil(), + retained, + summary, + compacted_at: chrono::Utc::now(), + messages_compressed: to_summarize.len(), + usage: Some(usage), + }) + } +} diff --git a/libs/agent/compact/service.rs b/libs/agent/compact/service.rs deleted file mode 100644 index ee74b9c..0000000 --- a/libs/agent/compact/service.rs +++ /dev/null @@ -1,327 +0,0 @@ -use chrono::Utc; -use models::ColumnTrait; -use models::rooms::room_message::{ - Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel, -}; -use models::users::user::{Column as UserCol, Entity as User}; -use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder, QuerySelect}; -use uuid::Uuid; - -use crate::client::types::ChatRequestMessage; -use crate::client::AiClientConfig; -use crate::client::call_with_params; -use crate::AgentError; -use crate::compact::types::{CompactLevel, CompactSummary, MessageSummary}; -use crate::tokent::{TokenUsage, resolve_usage}; - -#[derive(Clone)] -pub struct CompactService { - db: DatabaseConnection, - ai_client_config: AiClientConfig, - model: String, -} - -impl CompactService { - pub fn new(db: DatabaseConnection, ai_client_config: AiClientConfig, model: String) -> Self { - Self { db, ai_client_config, model } - } - - pub async fn compact_room( - &self, - room_id: Uuid, - level: CompactLevel, - user_names: Option>, - requester_id: Uuid, - context_window_tokens: i32, - compaction_max_summary_ratio: f32, - ) -> Result { - // Verify room access at the database level to ensure auth context is enforced. - // Public rooms are accessible to project members. - // For simplicity in this audit fix, we'll fetch only if access exists. - let messages = self.fetch_room_messages_secure(room_id, requester_id).await?; - - if messages.is_empty() { - // Check if room actually exists or if it's just empty/inaccessible - let room_exists = models::rooms::room::Entity::find_by_id(room_id) - .one(&self.db) - .await - .map_err(|e| AgentError::Internal(e.to_string()))? - .is_some(); - - if room_exists { - return Err(AgentError::Internal("Access denied or room empty".into())); - } else { - return Err(AgentError::Internal("Room not found".into())); - } - } - - let user_ids: Vec = messages - .iter() - .filter_map(|m| m.sender_id) - .collect::>() - .into_iter() - .collect(); - let user_name_map = match user_names { - Some(map) => map, - None => self.get_user_name_map(&user_ids).await?, - }; - - if messages.len() <= level.retain_count() { - let retained: Vec = messages - .iter() - .map(|m| Self::message_to_summary(m, &user_name_map)) - .collect(); - return Ok(CompactSummary { - session_id: Uuid::new_v4(), - room_id, - retained, - summary: String::new(), - compacted_at: Utc::now(), - messages_compressed: 0, - usage: None, - }); - } - - let retain_count = level.retain_count(); - let split_index = messages.len().saturating_sub(retain_count); - let (to_summarize, retained_messages) = messages.split_at(split_index); - - let retained: Vec = retained_messages - .iter() - .map(|m| Self::message_to_summary(m, &user_name_map)) - .collect(); - - let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize; - - let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?; - - // Build text of what was summarized (for tiktoken fallback) - let summarized_text = to_summarize - .iter() - .map(|m| m.content.as_str()) - .collect::>() - .join("\n"); - let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary); - - Ok(CompactSummary { - session_id: Uuid::new_v4(), - room_id, - retained, - summary, - compacted_at: Utc::now(), - messages_compressed: to_summarize.len(), - usage: Some(usage), - }) - } - - pub async fn compact_session( - &self, - session_id: Uuid, - level: CompactLevel, - user_names: Option>, - context_window_tokens: i32, - compaction_max_summary_ratio: f32, - ) -> Result { - let messages: Vec = RoomMessage::find() - .filter(RmCol::Room.eq(session_id)) - .order_by_asc(RmCol::Seq) - .limit(10000) - .all(&self.db) - .await - .map_err(|e| AgentError::Internal(e.to_string()))?; - - if messages.is_empty() { - return Err(AgentError::Internal("session has no messages".into())); - } - - let user_ids: Vec = messages - .iter() - .filter_map(|m| m.sender_id) - .collect::>() - .into_iter() - .collect(); - let user_name_map = match user_names { - Some(map) => map, - None => self.get_user_name_map(&user_ids).await?, - }; - - if messages.len() <= level.retain_count() { - let retained: Vec = messages - .iter() - .map(|m| Self::message_to_summary(m, &user_name_map)) - .collect(); - return Ok(CompactSummary { - session_id, - room_id: Uuid::nil(), - retained, - summary: String::new(), - compacted_at: Utc::now(), - messages_compressed: 0, - usage: None, - }); - } - - let retain_count = level.retain_count(); - let split_index = messages.len().saturating_sub(retain_count); - let (to_summarize, retained_messages) = messages.split_at(split_index); - - let retained: Vec = retained_messages - .iter() - .map(|m| Self::message_to_summary(m, &user_name_map)) - .collect(); - - let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize; - - let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?; - - let summarized_text = to_summarize - .iter() - .map(|m| m.content.as_str()) - .collect::>() - .join("\n"); - let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary); - - Ok(CompactSummary { - session_id, - room_id: Uuid::nil(), - retained, - summary, - compacted_at: Utc::now(), - messages_compressed: to_summarize.len(), - usage: Some(usage), - }) - } - - async fn fetch_room_messages_secure( - &self, - room_id: Uuid, - requester_id: Uuid, - ) -> Result, AgentError> { - use models::rooms::{RoomUserState, RoomAccess}; - use sea_orm::QueryTrait; - use sea_orm::sea_query::Expr; - - // Find messages for the room where the requester has access. - // We check both the room_user_state table (membership) and the room_access table (explicit grants). - RoomMessage::find() - .filter(RmCol::Room.eq(room_id)) - .filter( - sea_orm::Condition::any() - .add( - Expr::exists( - RoomUserState::find() - .filter(models::rooms::room_user_state::Column::Room.eq(room_id)) - .filter(models::rooms::room_user_state::Column::User.eq(requester_id)) - .into_query() - ) - ) - .add( - Expr::exists( - RoomAccess::find() - .filter(models::rooms::room_access::Column::Room.eq(room_id)) - .filter(models::rooms::room_access::Column::User.eq(requester_id)) - .into_query() - ) - ) - ) - .order_by_asc(RmCol::Seq) - .limit(10000) - .all(&self.db) - .await - .map_err(|e| AgentError::Internal(e.to_string())) - } - - fn message_to_summary(m: &RoomMessageModel, user_name_map: &std::collections::HashMap) -> MessageSummary { - let sender_name = if let Some(user_id) = m.sender_id { - user_name_map.get(&user_id).cloned().unwrap_or_else(|| m.sender_type.to_string()) - } else { - m.sender_type.to_string() - }; - MessageSummary { - id: m.id, - sender_type: m.sender_type.clone(), - sender_id: m.sender_id, - sender_name, - content: m.content.clone(), - content_type: m.content_type.clone(), - tool_call_id: None, - send_at: m.send_at, - } - } - - async fn get_user_name_map( - &self, - user_ids: &[Uuid], - ) -> Result, AgentError> { - use std::collections::HashMap; - let mut map = HashMap::new(); - if !user_ids.is_empty() { - let users = User::find() - .filter(UserCol::Uid.is_in(user_ids.to_vec())) - .all(&self.db) - .await - .map_err(|e| AgentError::Internal(e.to_string()))?; - for user in users { - map.insert(user.uid, user.username); - } - } - Ok(map) - } - - async fn summarize_messages( - &self, - messages: &[RoomMessageModel], - max_summary_tokens: usize, - ) -> Result<(String, Option), AgentError> { - let user_ids: Vec = messages - .iter() - .filter_map(|m| m.sender_id) - .collect::>() - .into_iter() - .collect(); - - let user_name_map = self.get_user_name_map(&user_ids).await?; - - let sender_mapper = |m: &RoomMessageModel| { - if let Some(user_id) = m.sender_id { - if let Some(username) = user_name_map.get(&user_id) { - return username.clone(); - } - } - m.sender_type.to_string() - }; - - let body = crate::compact::helpers::messages_to_text(messages, sender_mapper); - - let user_msg = ChatRequestMessage::user(format!( - "Summarise the following conversation concisely, preserving all key facts, \ - decisions, and any pending or in-progress work. \ - The summary MUST NOT exceed {} tokens. \ - Use this format:\n\n\ - **Summary:** \n\ - **Key decisions:** \n\ - **Open items:** \n\n\ - Conversation:\n\n{}", - max_summary_tokens, - body - )); - - let response = call_with_params( - &[user_msg], - &self.model, - &self.ai_client_config, - 0.3, - 2048, - None, - None, - None, - ) - .await - .map_err(|e| AgentError::OpenAi(e.to_string()))?; - - let remote_usage = - TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32); - - Ok((response.content, remote_usage)) - } -} diff --git a/libs/agent/compact/summarizer.rs b/libs/agent/compact/summarizer.rs new file mode 100644 index 0000000..c82ecf3 --- /dev/null +++ b/libs/agent/compact/summarizer.rs @@ -0,0 +1,110 @@ +use models::rooms::room_message::Model as RoomMessageModel; +use models::users::user::{Column as UserCol, Entity as User}; +use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; + +use crate::client::call_with_params; +use crate::client::types::ChatRequestMessage; +use crate::compact::types::MessageSummary; +use crate::tokent::TokenUsage; +use crate::AgentError; + +impl super::CompactService { + pub async fn summarize_messages( + &self, + messages: &[RoomMessageModel], + max_summary_tokens: usize, + ) -> Result<(String, Option), AgentError> { + let user_ids: Vec = messages + .iter() + .filter_map(|m| m.sender_id) + .collect::>() + .into_iter() + .collect(); + + let user_name_map = self.get_user_name_map(&user_ids).await?; + + let sender_mapper = |m: &RoomMessageModel| { + if let Some(user_id) = m.sender_id { + if let Some(username) = user_name_map.get(&user_id) { + return username.clone(); + } + } + m.sender_type.to_string() + }; + + let body = crate::compact::helpers::messages_to_text(messages, sender_mapper); + + let user_msg = ChatRequestMessage::user(format!( + "Summarise the following conversation concisely, preserving all key facts, \ + decisions, and any pending or in-progress work. \ + The summary MUST NOT exceed {} tokens. \ + Use this format:\n\n\ + **Summary:** \n\ + **Key decisions:** \n\ + **Open items:** \n\n\ + Conversation:\n\n{}", + max_summary_tokens, body + )); + + let response = call_with_params( + &[user_msg], + &self.model, + &self.ai_client_config, + 0.3, + 2048, + None, + None, + None, + ) + .await + .map_err(|e| AgentError::OpenAi(e.to_string()))?; + + let remote_usage = + TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32); + + Ok((response.content, remote_usage)) + } + + pub fn message_to_summary( + m: &RoomMessageModel, + user_name_map: &std::collections::HashMap, + ) -> MessageSummary { + let sender_name = if let Some(user_id) = m.sender_id { + user_name_map + .get(&user_id) + .cloned() + .unwrap_or_else(|| m.sender_type.to_string()) + } else { + m.sender_type.to_string() + }; + MessageSummary { + id: m.id, + sender_type: m.sender_type.clone(), + sender_id: m.sender_id, + sender_name, + content: m.content.clone(), + content_type: m.content_type.clone(), + tool_call_id: None, + send_at: m.send_at, + } + } + + pub async fn get_user_name_map( + &self, + user_ids: &[uuid::Uuid], + ) -> Result, AgentError> { + use std::collections::HashMap; + let mut map = HashMap::new(); + if !user_ids.is_empty() { + let users = User::find() + .filter(UserCol::Uid.is_in(user_ids.to_vec())) + .all(&self.db) + .await + .map_err(|e| AgentError::Internal(e.to_string()))?; + for user in users { + map.insert(user.uid, user.username); + } + } + Ok(map) + } +} diff --git a/libs/agent/embed/chunk.rs b/libs/agent/embed/chunk.rs new file mode 100644 index 0000000..93bab06 --- /dev/null +++ b/libs/agent/embed/chunk.rs @@ -0,0 +1,61 @@ +/// Maximum characters per chunk for embedding (approximates token limit). +/// text-embedding-3-small: 8192 token limit. +/// For CJK ~1 char/token, for English ~4 chars/token. +/// Conservative limit: 7000 chars to leave room for all languages. +const MAX_CHUNK_CHARS: usize = 7000; + +/// Split long text into chunks at paragraph/sentence boundaries. +/// Returns at least one chunk even for empty text. +/// Safe for multi-byte characters (uses char indices, not byte indices). +pub fn chunk_text(text: &str) -> Vec { + if text.is_empty() { + return vec![String::new()]; + } + if text.len() <= MAX_CHUNK_CHARS { + return vec![text.to_string()]; + } + + let char_indices: Vec = text.char_indices().map(|(i, _)| i).collect(); + let total_chars = char_indices.len(); + + let mut chunks = Vec::new(); + let mut start_idx = 0; + + while start_idx < total_chars { + let byte_start = char_indices[start_idx]; + let end_char_idx = (start_idx + MAX_CHUNK_CHARS).min(total_chars); + let byte_end_candidate = char_indices[end_char_idx - 1] + + text[char_indices[end_char_idx - 1]..] + .chars() + .next() + .map(|c| c.len_utf8()) + .unwrap_or(1); + + if end_char_idx >= total_chars { + chunks.push(text[byte_start..].to_string()); + break; + } + + let search_range = &text[byte_start..byte_end_candidate]; + let break_at = search_range.rfind("\n\n").map(|pos| pos + 2) + .or_else(|| search_range.rfind('\n').map(|pos| pos + 1)) + .or_else(|| search_range.rfind(". ").map(|pos| pos + 1)) + .or_else(|| search_range.rfind("! ").map(|pos| pos + 1)) + .or_else(|| search_range.rfind("? ").map(|pos| pos + 1)); + + if let Some(offset) = break_at { + let byte_end = byte_start + offset; + chunks.push(text[byte_start..byte_end].to_string()); + let mut advance = start_idx + 1; + while advance < total_chars && char_indices[advance] < byte_end { + advance += 1; + } + start_idx = advance; + } else { + chunks.push(text[byte_start..byte_end_candidate].to_string()); + start_idx = end_char_idx; + } + } + + chunks +} \ No newline at end of file diff --git a/libs/agent/embed/embeddable.rs b/libs/agent/embed/embeddable.rs new file mode 100644 index 0000000..da16419 --- /dev/null +++ b/libs/agent/embed/embeddable.rs @@ -0,0 +1,23 @@ +use async_trait::async_trait; + +/// Trait for entities that can be embedded as vectors into Qdrant. +#[async_trait] +pub trait Embeddable { + fn entity_type(&self) -> &'static str; + fn to_text(&self) -> String; + fn entity_id(&self) -> String; +} + +/// Input struct for batch memory embedding into per-room Qdrant collections. +#[derive(Debug, Clone)] +pub struct EmbedMemoryInput { + pub message_id: String, + pub content: String, + pub project_name: String, + pub room_id: String, + pub user_id: Option, + pub sender_type: String, +} + +/// Input struct for batch tag embedding. +pub use models::TagEmbedInput; \ No newline at end of file diff --git a/libs/agent/embed/service.rs b/libs/agent/embed/entity_embed.rs similarity index 53% rename from libs/agent/embed/service.rs rename to libs/agent/embed/entity_embed.rs index add80b2..122443d 100644 --- a/libs/agent/embed/service.rs +++ b/libs/agent/embed/entity_embed.rs @@ -1,112 +1,11 @@ -use async_trait::async_trait; -use qdrant_client::qdrant::Filter; -use sea_orm::DatabaseConnection; -use std::sync::Arc; +use std::collections::HashMap; -use super::client::{EmbedClient, EmbedPayload, EmbedVector, SearchResult}; - -/// Maximum characters per chunk for embedding (approximates token limit). -/// text-embedding-3-small: 8192 token limit. -/// For CJK ~1 char/token, for English ~4 chars/token. -/// Conservative limit: 7000 chars to leave room for all languages. -const MAX_CHUNK_CHARS: usize = 7000; - -#[async_trait] -pub trait Embeddable { - fn entity_type(&self) -> &'static str; - fn to_text(&self) -> String; - fn entity_id(&self) -> String; -} - -/// Split long text into chunks at paragraph/sentence boundaries. -/// Returns at least one chunk even for empty text. -/// Safe for multi-byte characters (uses char indices, not byte indices). -fn chunk_text(text: &str) -> Vec { - if text.is_empty() { - return vec![String::new()]; - } - if text.len() <= MAX_CHUNK_CHARS { - return vec![text.to_string()]; - } - - // Collect char boundary byte positions - let char_indices: Vec = text.char_indices().map(|(i, _)| i).collect(); - let total_chars = char_indices.len(); - - let mut chunks = Vec::new(); - let mut start_idx = 0; // char index - - while start_idx < total_chars { - // Start byte offset - let byte_start = char_indices[start_idx]; - - // Find end char index: at most MAX_CHUNK_CHARS characters - let end_char_idx = (start_idx + MAX_CHUNK_CHARS).min(total_chars); - let byte_end_candidate = char_indices[end_char_idx - 1] + text[char_indices[end_char_idx - 1]..].chars().next().map(|c| c.len_utf8()).unwrap_or(1); - - if end_char_idx >= total_chars { - chunks.push(text[byte_start..].to_string()); - break; - } - - // Try to break at paragraph or sentence boundary in the allowed range - let search_range = &text[byte_start..byte_end_candidate]; - let break_at = if let Some(pos) = search_range.rfind("\n\n") { - Some(pos + 2) // after the paragraph break - } else if let Some(pos) = search_range.rfind('\n') { - Some(pos + 1) - } else if let Some(pos) = search_range.rfind(". ") { - Some(pos + 1) - } else if let Some(pos) = search_range.rfind("! ") { - Some(pos + 1) - } else if let Some(pos) = search_range.rfind("? ") { - Some(pos + 1) - } else { - None - }; - - if let Some(offset) = break_at { - let byte_end = byte_start + offset; - chunks.push(text[byte_start..byte_end].to_string()); - // Advance char index to match the byte break - let mut advance = start_idx + 1; - while advance < total_chars && char_indices[advance] < byte_end { - advance += 1; - } - start_idx = advance; - } else { - // Hard break at char boundary - chunks.push(text[byte_start..byte_end_candidate].to_string()); - start_idx = end_char_idx; - } - } - - chunks -} - -#[derive(Clone)] -pub struct EmbedService { - client: Arc, - db: DatabaseConnection, - model_name: String, - dimensions: u64, -} - -impl EmbedService { - pub fn new( - client: EmbedClient, - db: DatabaseConnection, - model_name: String, - dimensions: u64, - ) -> Self { - Self { - client: Arc::new(client), - db, - model_name, - dimensions, - } - } +use super::chunk::chunk_text; +use super::client::{EmbedPayload, EmbedVector}; +use super::embeddable::{EmbedMemoryInput, Embeddable}; +/// Embedding and upsert operations for entity vectors in Qdrant. +impl super::EmbedService { pub async fn embed_issue( &self, id: &str, @@ -203,69 +102,6 @@ impl EmbedService { Ok(()) } - pub async fn search_issues( - &self, - query: &str, - limit: usize, - ) -> crate::Result> { - self.client - .search(query, "issue", &self.model_name, limit) - .await - } - - pub async fn search_repos( - &self, - query: &str, - limit: usize, - ) -> crate::Result> { - self.client - .search(query, "repo", &self.model_name, limit) - .await - } - - pub async fn search_issues_filtered( - &self, - query: &str, - limit: usize, - filter: Filter, - ) -> crate::Result> { - self.client - .search_with_filter(query, "issue", &self.model_name, limit, filter) - .await - } - - pub async fn delete_issue_embedding(&self, issue_id: &str) -> crate::Result<()> { - self.client.delete_by_entity_id("issue", issue_id).await - } - - pub async fn delete_repo_embedding(&self, repo_id: &str) -> crate::Result<()> { - self.client.delete_by_entity_id("repo", repo_id).await - } - - pub async fn ensure_collections(&self) -> crate::Result<()> { - self.client - .ensure_collection("issue", self.dimensions) - .await?; - self.client - .ensure_collection("repo", self.dimensions) - .await?; - self.client.ensure_skill_collection(self.dimensions).await?; - self.client - .ensure_collection("repo_tag", self.dimensions) - .await?; - // Room memory collections are created per-room on first embed - Ok(()) - } - - pub fn db(&self) -> &DatabaseConnection { - &self.db - } - - pub fn client(&self) -> &Arc { - &self.client - } - - /// Embed a project skill into Qdrant for vector-based semantic search. pub async fn embed_skill( &self, skill_id: i64, @@ -279,7 +115,6 @@ impl EmbedService { tracing::debug!(skill_id = %skill_id, name = %name, content_len = content.len(), "embed_skill: starting"); - // Auto-chunk long content let texts = chunk_text(content); tracing::debug!(skill_id = %skill_id, chunks = texts.len(), "embed_skill: chunked"); @@ -288,13 +123,17 @@ impl EmbedService { .embed_skill(&id, name, desc, content, project_uuid, &self.model_name) .await?; } else { - // Multi-chunk: embed each chunk with chunk_index metadata - let full_texts: Vec = texts.iter().map(|t| format!("{}: {} {}", name, desc, t)).collect(); + let full_texts: Vec = texts + .iter() + .map(|t| format!("{}: {} {}", name, desc, t)) + .collect(); tracing::debug!(skill_id = %skill_id, "embed_skill: calling embed_batch"); let embeddings = self.client.embed_batch(&full_texts, &self.model_name).await?; - let points: Vec = embeddings.into_iter().enumerate().map(|(i, vector)| { - EmbedVector { + let points: Vec = embeddings + .into_iter() + .enumerate() + .map(|(i, vector)| EmbedVector { id: format!("{}:chunk:{}", id, i), vector, payload: EmbedPayload { @@ -306,10 +145,11 @@ impl EmbedService { "description": desc, "chunk_index": i, "total_chunks": texts.len(), - }).into(), + }) + .into(), }, - } - }).collect(); + }) + .collect(); self.client.upsert(points).await?; } @@ -317,7 +157,6 @@ impl EmbedService { Ok(()) } - /// Embed an issue with auto-chunking for long content. pub async fn embed_issue_chunked( &self, id: &str, @@ -336,8 +175,10 @@ impl EmbedService { let embeddings = self.client.embed_batch(&chunks, &self.model_name).await?; - let points: Vec = embeddings.into_iter().enumerate().map(|(i, vector)| { - EmbedVector { + let points: Vec = embeddings + .into_iter() + .enumerate() + .map(|(i, vector)| EmbedVector { id: format!("{}:chunk:{}", id, i), vector, payload: EmbedPayload { @@ -347,17 +188,15 @@ impl EmbedService { extra: serde_json::json!({ "chunk_index": i, "total_chunks": chunks.len(), - }).into(), + }) + .into(), }, - } - }).collect(); + }) + .collect(); self.client.upsert(points).await } - /// Batch-embed multiple conversation messages into per-room Qdrant collections. - /// Auto-chunks long messages and filters non-text/system/empty content. - /// Handles all filtering internally: only text-type, non-empty, non-system messages are embedded. pub async fn embed_memories_batch( &self, messages: Vec, @@ -366,8 +205,6 @@ impl EmbedService { return Ok(()); } - // Group by room collection for batch upsert to reduce Qdrant calls - use std::collections::HashMap; let mut by_room: HashMap)>> = HashMap::new(); for msg in messages { @@ -375,15 +212,15 @@ impl EmbedService { if chunks.is_empty() || chunks.iter().all(|c| c.trim().is_empty()) { continue; } - let collection = crate::embed::qdrant::QdrantClient::room_memory_collection_name( + let collection = super::qdrant::QdrantClient::room_memory_collection_name( &msg.project_name, &msg.room_id, ); by_room.entry(collection).or_default().push((msg, chunks)); } for (collection, entries) in &by_room { - // Collect all texts for batch embedding - let all_texts: Vec = entries.iter() + let all_texts: Vec = entries + .iter() .flat_map(|(_, chunks)| chunks.iter().cloned()) .collect(); @@ -393,14 +230,12 @@ impl EmbedService { let embeddings = self.client.embed_batch(&all_texts, &self.model_name).await?; - // Ensure the room collection exists with correct dimensions if let Some((first, _)) = entries.first() { let _ = self.client .ensure_room_memory_collection(&first.project_name, &first.room_id, self.dimensions) .await; } - // Build points: one per chunk let mut points = Vec::new(); let mut embed_idx = 0; for (msg, chunks) in entries { @@ -423,9 +258,18 @@ impl EmbedService { extra: serde_json::json!({ "user_id": msg.user_id, "sender_type": msg.sender_type, - "chunk_index": if chunks.len() > 1 { Some(chunk_i) } else { None }, - "total_chunks": if chunks.len() > 1 { Some(chunks.len()) } else { None }, - }).into(), + "chunk_index": if chunks.len() > 1 { + Some(chunk_i) + } else { + None + }, + "total_chunks": if chunks.len() > 1 { + Some(chunks.len()) + } else { + None + }, + }) + .into(), }, }); embed_idx += 1; @@ -440,11 +284,9 @@ impl EmbedService { Ok(()) } - /// Batch-embed repo tags with project isolation. - /// Each tag stores project_id as entity_id for post-filtering. pub async fn embed_tags_batch( &self, - tags: Vec, + tags: Vec, ) -> crate::Result<()> { if tags.is_empty() { return Ok(()); @@ -494,48 +336,6 @@ impl EmbedService { self.client.upsert(points).await } - /// Search repo tags by semantic similarity within a project. - /// Filters by project_id (stored in entity_id) for project isolation. - pub async fn search_tags( - &self, - query: &str, - project_id: &str, - limit: usize, - ) -> crate::Result> { - let mut results = self - .client - .search(query, "repo_tag", &self.model_name, limit + 1) - .await?; - results.retain(|r| r.payload.entity_id == project_id); - results.truncate(limit); - Ok(results) - } - - pub fn model_name(&self) -> &str { - &self.model_name - } - - pub fn dimensions(&self) -> u64 { - self.dimensions - } - - pub fn embed_client(&self) -> &EmbedClient { - &self.client - } - - /// Search skills by semantic similarity within a project. - pub async fn search_skills( - &self, - query: &str, - project_uuid: &str, - limit: usize, - ) -> crate::Result> { - self.client - .search_skills(query, &self.model_name, project_uuid, limit) - .await - } - - /// Embed a conversation message into Qdrant as a memory vector. pub async fn embed_memory( &self, message_id: &str, @@ -548,32 +348,4 @@ impl EmbedService { .embed_memory(message_id, text, project_name, room_id, user_id, &self.model_name) .await } - - /// Search past conversation messages by semantic similarity within a room. - pub async fn search_memories( - &self, - query: &str, - project_name: &str, - room_id: &str, - limit: usize, - ) -> crate::Result> { - self.client - .search_memories(query, &self.model_name, project_name, room_id, limit, self.dimensions) - .await - } -} - -/// Input struct for batch memory embedding into per-room Qdrant collections. -#[derive(Debug, Clone)] -pub struct EmbedMemoryInput { - pub message_id: String, - pub content: String, - pub project_name: String, - pub room_id: String, - pub user_id: Option, - pub sender_type: String, -} - -/// Input struct for batch tag embedding. -/// Re-exported from models for backward compatibility. -pub use models::TagEmbedInput; +} \ No newline at end of file diff --git a/libs/agent/embed/mod.rs b/libs/agent/embed/mod.rs index e074961..37541cb 100644 --- a/libs/agent/embed/mod.rs +++ b/libs/agent/embed/mod.rs @@ -1,10 +1,69 @@ +pub mod chunk; pub mod client; +pub mod embeddable; +pub mod entity_embed; pub mod qdrant; -pub mod service; +pub mod search; pub use client::{EmbedClient, EmbedPayload, EmbedVector, SearchResult}; +pub use embeddable::{EmbedMemoryInput, Embeddable, TagEmbedInput}; pub use qdrant::QdrantClient; -pub use service::{EmbedMemoryInput, EmbedService, Embeddable, TagEmbedInput}; + +use std::sync::Arc; + +#[derive(Clone)] +pub struct EmbedService { + client: Arc, + db: sea_orm::DatabaseConnection, + model_name: String, + dimensions: u64, +} + +impl EmbedService { + pub fn new( + client: EmbedClient, + db: sea_orm::DatabaseConnection, + model_name: String, + dimensions: u64, + ) -> Self { + Self { + client: Arc::new(client), + db, + model_name, + dimensions, + } + } + + pub async fn ensure_collections(&self) -> crate::Result<()> { + self.client + .ensure_collection("issue", self.dimensions) + .await?; + self.client + .ensure_collection("repo", self.dimensions) + .await?; + self.client.ensure_skill_collection(self.dimensions).await?; + self.client + .ensure_collection("repo_tag", self.dimensions) + .await?; + Ok(()) + } + + pub fn db(&self) -> &sea_orm::DatabaseConnection { + &self.db + } + + pub fn client(&self) -> &Arc { + &self.client + } + + pub fn model_name(&self) -> &str { + &self.model_name + } + + pub fn dimensions(&self) -> u64 { + self.dimensions + } +} pub async fn new_embed_client(config: &config::AppConfig) -> crate::Result { let base_url = config @@ -22,7 +81,9 @@ pub async fn new_embed_client(config: &config::AppConfig) -> crate::Result crate::Result> { + self.client + .search(query, "issue", &self.model_name, limit) + .await + } + + pub async fn search_repos( + &self, + query: &str, + limit: usize, + ) -> crate::Result> { + self.client + .search(query, "repo", &self.model_name, limit) + .await + } + + pub async fn search_issues_filtered( + &self, + query: &str, + limit: usize, + filter: Filter, + ) -> crate::Result> { + self.client + .search_with_filter(query, "issue", &self.model_name, limit, filter) + .await + } + + /// Search repo tags by semantic similarity within a project. + /// Filters by project_id (stored in entity_id) for project isolation. + pub async fn search_tags( + &self, + query: &str, + project_id: &str, + limit: usize, + ) -> crate::Result> { + let mut results = self + .client + .search(query, "repo_tag", &self.model_name, limit + 1) + .await?; + results.retain(|r| r.payload.entity_id == project_id); + results.truncate(limit); + Ok(results) + } + + /// Search skills by semantic similarity within a project. + pub async fn search_skills( + &self, + query: &str, + project_uuid: &str, + limit: usize, + ) -> crate::Result> { + self.client + .search_skills(query, &self.model_name, project_uuid, limit) + .await + } + + /// Search past conversation messages by semantic similarity within a room. + pub async fn search_memories( + &self, + query: &str, + project_name: &str, + room_id: &str, + limit: usize, + ) -> crate::Result> { + self.client + .search_memories(query, &self.model_name, project_name, room_id, limit, self.dimensions) + .await + } +} \ No newline at end of file diff --git a/libs/agent/lib.rs b/libs/agent/lib.rs index eec1127..a6bca8c 100644 --- a/libs/agent/lib.rs +++ b/libs/agent/lib.rs @@ -6,6 +6,7 @@ pub mod compact; pub mod embed; pub mod error; pub mod model; +pub mod orao; pub mod perception; pub mod react; pub mod skills; @@ -13,33 +14,42 @@ pub mod sync; pub mod task; pub mod tokent; pub mod tool; -pub use billing::{BillingRecord, BillingResult, record_ai_usage, initialize_user_billing, initialize_project_billing, check_balance, persist_billing_error}; -pub use sync::list_accessible_models; -pub use task::TaskService; -pub use tokent::{TokenUsage, resolve_usage}; -pub use perception::{PerceptionService, SkillContext, SkillEntry, ToolCallEvent}; -pub use skills::{ - BuiltInSkill, SKILL_TEMPLATES, all_skill_slugs, all_skills, - get_skill, get_skill_by_tool, is_built_in_skill, match_skill_by_keyword, skills_by_category, +pub use billing::{ + check_balance, initialize_project_billing, initialize_user_billing, persist_billing_error, + record_ai_usage, BillingRecord, BillingResult, }; pub use chat::{ AiContextSenderType, AiRequest, AiStreamChunk, ChatService, Mention, RoomMessageContext, StreamCallback, }; -pub use client::{AiCallResponse, AiClientConfig, call_with_params, call_with_retry}; pub use client::types::ChatRequestMessage; +pub use client::{call_with_params, call_with_retry, AiCallResponse, AiClientConfig}; pub use compact::{CompactConfig, CompactLevel, CompactService, CompactSummary, MessageSummary}; pub use embed::{ - EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult, TagEmbedInput, new_embed_client, + new_embed_client, EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult, + TagEmbedInput, }; pub use error::{AgentError, Result}; +pub use orao::{ + ActionExecutor, ActionType, ActionResult, ActionVerdict, OraoConfig, OraoExecutor, + OraoExecutorBuilder, OraoOutcome, OraoStep, PerceptionSnapshot, PlannedAction, + ReasoningOutput, RoundRecord, SafetyLevel, +}; +pub use perception::{PerceptionService, SkillContext, SkillEntry, ToolCallEvent}; pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT, ROOM_CONTEXT_PROMPT}; +pub use skills::{ + all_skill_slugs, all_skills, get_skill, get_skill_by_tool, is_built_in_skill, match_skill_by_keyword, + skills_by_category, BuiltInSkill, SKILL_TEMPLATES, +}; +pub use sync::list_accessible_models; +pub use task::TaskService; +pub use tokent::{resolve_usage, TokenUsage}; pub use tool::{ - ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam, - ToolRegistry, ToolResult, ToolSchema, + ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, + ToolError, ToolExecutor, ToolHandler, ToolParam, ToolRegistry, ToolResult, ToolSchema, }; #[cfg(feature = "rig")] pub use agent::RigAgentService; #[cfg(feature = "rig")] -pub use tool::{RigToolSet, RecordingTool, is_retryable_tool_error}; +pub use tool::{is_retryable_tool_error, RecordingTool, RigToolSet}; diff --git a/libs/agent/modes/mod.rs b/libs/agent/modes/mod.rs deleted file mode 100644 index 37aea94..0000000 --- a/libs/agent/modes/mod.rs +++ /dev/null @@ -1 +0,0 @@ -// All reasoning modes removed - using ReAct pattern directly in chat service diff --git a/libs/agent/orao/act.rs b/libs/agent/orao/act.rs new file mode 100644 index 0000000..f7c650c --- /dev/null +++ b/libs/agent/orao/act.rs @@ -0,0 +1,203 @@ +//! Act phase: execute planned actions with safety checks. +//! +//! Actions are executed through a caller-provided executor callback, which +//! typically dispatches to the [`ToolRegistry`] or runs shell commands. +//! All file access must go through function calls (tools), never direct +//! filesystem operations. +//! +//! [`ToolRegistry`]: crate::tool::ToolRegistry + +use std::future::Future; +use std::pin::Pin; +use std::process::Command; +use std::time::Duration; + +use super::types::{ + ActionType, ActionResult, ActionVerdict, OraoConfig, PlannedAction, SafetyLevel, +}; + +/// Callback for executing a planned action. +/// +/// The caller (service layer) provides this to wire up tool execution. +/// Returns `ActionResult` on completion. +pub type ActionExecutor = Box< + dyn Fn( + PlannedAction, + ) -> Pin + Send>> + + Send + + Sync, +>; + +/// Check whether an action is allowed under the given safety configuration. +/// +/// Returns `None` if allowed, or `Some(reason)` if blocked. +pub fn check_safety(action: &PlannedAction, config: &OraoConfig) -> Option { + let safety = SafetyLevel::classify_command(&action.command_or_content); + + if safety > config.max_safety_level { + return Some(format!( + "Action denied: safety level {:?} exceeds max allowed {:?}", + safety, config.max_safety_level + )); + } + + // Check for dangerous command patterns + if let Some(reason) = check_dangerous_command(&action.command_or_content) { + return Some(reason); + } + + None +} + +/// Execute a single planned action via the provided executor. +/// +/// Applies safety checks and timeout, then delegates to the executor. +pub async fn execute_action( + action: PlannedAction, + config: &OraoConfig, + executor: &ActionExecutor, +) -> ActionResult { + // ── Safety gate ──────────────────────────────────────────────────── + if let Some(reason) = check_safety(&action, config) { + return ActionResult { + action, + exit_code: Some(1), + stdout: String::new(), + stderr: reason, + file_changes: Vec::new(), + verdict: ActionVerdict::Failure, + }; + } + + // ── Execute with timeout ────────────────────────────────────────── + let action_clone = action.clone(); + let exec_future = executor(action); + + match tokio::time::timeout( + Duration::from_secs(config.action_timeout_secs), + exec_future, + ) + .await + { + Ok(result) => result, + Err(_elapsed) => ActionResult { + action: action_clone, + exit_code: None, + stdout: String::new(), + stderr: format!( + "Action timed out after {} seconds", + config.action_timeout_secs + ), + file_changes: Vec::new(), + verdict: ActionVerdict::Failure, + }, + } +} + +/// Build a default action executor that runs shell commands directly. +/// +/// This is suitable for `shell_command` and `git_operation` action types. +/// For `tool_invoke`, the caller should provide a custom executor that +/// dispatches to the [`ToolRegistry`]. +/// +/// [`ToolRegistry`]: crate::tool::ToolRegistry +pub fn shell_executor(working_dir: String) -> ActionExecutor { + Box::new(move |action: PlannedAction| { + let dir = working_dir.clone(); + Box::pin(async move { + match action.action_type { + ActionType::ShellCommand + | ActionType::GitOperation + | ActionType::ToolInvoke => run_shell_command(&action, &dir).await, + ActionType::FileWrite | ActionType::FileEdit => { + // File operations should use tool_invoke with a file-writing tool. + // Direct file access is discouraged; return an error directing to tools. + ActionResult { + exit_code: Some(1), + stdout: String::new(), + stderr: "File operations must use tool_invoke with registered file tools. Use shell_command with sed/echo for inline edits.".to_string(), + file_changes: Vec::new(), + verdict: ActionVerdict::Failure, + action, + } + } + ActionType::UserDialog => ActionResult { + exit_code: None, + stdout: "User dialog requested".to_string(), + stderr: String::new(), + file_changes: Vec::new(), + verdict: ActionVerdict::Success, + action, + }, + } + }) + }) +} + +async fn run_shell_command(action: &PlannedAction, working_dir: &str) -> ActionResult { + let cmd = &action.command_or_content; + + let output = Command::new("sh") + .args(["-c", cmd]) + .current_dir(working_dir) + .output(); + + match output { + Ok(out) => { + let exit_code = out.status.code(); + let stdout = String::from_utf8_lossy(&out.stdout).to_string(); + let stderr = String::from_utf8_lossy(&out.stderr).to_string(); + + let verdict = match exit_code { + Some(0) if !stderr_has_errors(&stderr) => ActionVerdict::Success, + Some(0) => ActionVerdict::SuccessWithWarnings, + _ => ActionVerdict::Failure, + }; + + ActionResult { + action: action.clone(), + exit_code, + stdout, + stderr, + file_changes: Vec::new(), + verdict, + } + } + Err(e) => ActionResult { + action: action.clone(), + exit_code: None, + stdout: String::new(), + stderr: format!("Failed to spawn command: {}", e), + file_changes: Vec::new(), + verdict: ActionVerdict::Failure, + }, + } +} + +fn stderr_has_errors(stderr: &str) -> bool { + let lower = stderr.to_lowercase(); + lower.contains("error") || lower.contains("fail") || lower.contains("panic") +} + +/// Check whether a shell command contains dangerous patterns. +/// +/// Returns `Some(reason)` if the command is blocked, `None` if it's safe. +pub fn check_dangerous_command(cmd: &str) -> Option { + let dangerous = [ + ("rm -rf /", "Recursive root deletion"), + ("rm -rf ~", "Recursive home deletion"), + (":(){ :|:& };:", "Fork bomb"), + ("mkfs.", "Filesystem format"), + ("dd if=", "Raw device write"), + ("> /dev/sda", "Raw device write"), + ("chmod 777 /", "World-writable root"), + ]; + + for (pattern, reason) in &dangerous { + if cmd.contains(pattern) { + return Some(format!("Blocked: {} — {}", pattern, reason)); + } + } + + None +} \ No newline at end of file diff --git a/libs/agent/orao/mod.rs b/libs/agent/orao/mod.rs new file mode 100644 index 0000000..c5b7dad --- /dev/null +++ b/libs/agent/orao/mod.rs @@ -0,0 +1,427 @@ +//! ORAO (Observe–Reason–Act–Observe) — a single-agent loop for complex engineering tasks. +//! +//! ORAO extends the ReAct paradigm with: +//! - **Multi-channel perception**: LLM-driven observation via read-only tools +//! - **Structured reasoning**: analysis + step-by-step action plan +//! - **Safety levels**: L0–L4 permission grading for every action +//! - **Deadlock detection**: terminates after 3 rounds with no progress +//! - **Plan mode**: optional user-approval gate before execution +//! - **Round recording**: full audit trail for debugging and resumption +//! +//! # Architecture +//! +//! The [`OraoExecutor`] runs the O→R→A→O loop: +//! 1. **Observe** — LLM explores environment via observation tools, produces snapshot +//! 2. **Reason** — LLM analyzes snapshot, generates structured plan +//! 3. **Act** — Execute each planned action via [`ActionExecutor`] with safety checks +//! 4. **Observe** — Collect results, feed into next round +//! +//! All file access goes through function calls (tools), never direct filesystem operations. +//! +//! [`ActionExecutor`]: act::ActionExecutor + +pub mod act; +pub mod observe; +pub mod reason; +pub mod types; + +use std::time::Instant; + +use crate::client::AiClientConfig; +use crate::error::{AgentError, Result}; + +pub use act::ActionExecutor; +pub use types::{ + ActionResult, ActionType, ActionVerdict, FileChange, FileChangeType, OraoConfig, OraoStep, + PerceptionSnapshot, PlannedAction, ReasoningOutput, RoundRecord, SafetyLevel, +}; + +// ── ORAO Executor ─────────────────────────────────────────────────────────── + +/// Executes the ORAO loop for a single task. +/// +/// All environment interaction goes through: +/// - **Observation tools** (read-only) for the Observe phase +/// - **Action executor** callback for the Act phase +/// +/// No direct filesystem access — everything is mediated through function calls. +pub struct OraoExecutor { + config: AiClientConfig, + model_name: String, + action_executor: ActionExecutor, +} + +impl OraoExecutor { + /// Create a new ORAO executor. + /// + /// `action_executor` is called to execute each planned action. Wire it to + /// your [`ToolRegistry`] for tool-based execution, or use + /// [`act::shell_executor`] for simple shell-command execution. + /// + /// [`ToolRegistry`]: crate::tool::ToolRegistry + pub fn new( + config: AiClientConfig, + model_name: impl Into, + action_executor: ActionExecutor, + ) -> Self { + Self { + config, + model_name: model_name.into(), + action_executor, + } + } + + /// Run the ORAO loop to completion. + /// + /// # Parameters + /// - `task_goal`: Description of what to accomplish. + /// - `orao_config`: ORAO-specific settings (max rounds, safety level, etc.). + /// - `tool_factory`: Called each round to produce read-only observation tools + /// (e.g. `git_diff`, `git_blob`, `repo_search`, `git_grep`). This allows + /// callers to provide fresh tool instances each round. + /// - `on_step`: Called with each [`OraoStep`] event for streaming/persistence. + /// - `on_plan_approval`: Called in plan mode; return `true` to proceed. + pub async fn execute( + &self, + task_goal: &str, + orao_config: &OraoConfig, + tool_factory: TF, + on_step: C, + on_plan_approval: PA, + ) -> Result + where + C: Fn(OraoStep) -> Fut + Send, + Fut: Future + Send, + PA: Fn(ReasoningOutput) -> PAFut + Send, + PAFut: Future + Send, + TF: Fn() -> Vec> + Send + Sync, + { + let mut round = 0usize; + let mut round_records: Vec = Vec::new(); + let mut previous_result: Option = None; + let mut previous_snapshot: Option = None; + let mut no_change_count: usize = 0; + + // Observation turns: limit tool calls during exploration + let observe_max_turns = 10; + + loop { + round += 1; + let round_start = Instant::now(); + let round_input_tokens: u64 = 0; + let round_output_tokens: u64 = 0; + + // ── Phase 1: Observe ─────────────────────────────────────── + let snapshot = observe::observe( + &self.config, + &self.model_name, + task_goal, + previous_result.take(), + tool_factory(), + observe_max_turns, + ) + .await?; + + on_step(OraoStep::Observe { + round, + snapshot: snapshot.clone(), + }) + .await; + + // ── Deadlock detection ───────────────────────────────────── + if let Some(ref prev) = previous_snapshot { + if !observe::has_environment_changed(prev, &snapshot) { + no_change_count += 1; + if no_change_count >= orao_config.deadlock_threshold { + let reason = format!( + "Deadlock detected: no environmental change for {} consecutive rounds", + no_change_count + ); + on_step(OraoStep::Failed { + total_rounds: round, + reason: reason.clone(), + }) + .await; + return Ok(OraoOutcome::Failed { + reason, + rounds: round, + records: round_records, + }); + } + } else { + no_change_count = 0; + } + } + previous_snapshot = Some(snapshot.clone()); + + // ── Phase 2: Reason ──────────────────────────────────────── + let reasoning = reason::reason( + &self.config, + &self.model_name, + orao_config, + task_goal, + &snapshot, + round, + &round_records, + ) + .await?; + + on_step(OraoStep::Reason { + round, + reasoning: reasoning.clone(), + }) + .await; + + // ── Plan mode gate ───────────────────────────────────────── + if orao_config.plan_mode { + on_step(OraoStep::PlanProposed { + round, + reasoning: reasoning.clone(), + }) + .await; + + if !on_plan_approval(reasoning.clone()).await { + return Ok(OraoOutcome::Cancelled { + rounds: round, + records: round_records, + }); + } + } + + // ── Phase 3: Act ─────────────────────────────────────────── + let mut round_result: Option = None; + let mut all_success = true; + + for planned in &reasoning.plan { + let safety = SafetyLevel::classify_command(&planned.command_or_content); + + on_step(OraoStep::Act { + round, + action: planned.clone(), + safety_level: safety, + }) + .await; + + let result = + act::execute_action(planned.clone(), orao_config, &self.action_executor).await; + + on_step(OraoStep::ObserveResult { + round, + result: result.clone(), + }) + .await; + + match &result.verdict { + ActionVerdict::Failure => { + all_success = false; + round_result = Some(result); + break; // Stop executing further steps on failure + } + ActionVerdict::SuccessWithWarnings => { + round_result = Some(result); + } + ActionVerdict::Success => { + round_result = Some(result); + } + } + } + + // ── Phase 4: Record round ────────────────────────────────── + let duration_ms = round_start.elapsed().as_millis() as u64; + let record = RoundRecord { + round, + observe_summary: summarize_snapshot(&snapshot), + reasoning_summary: reasoning.analysis.clone(), + action: reasoning.plan.first().cloned(), + result_summary: round_result + .as_ref() + .map(|r| format!("{:?}: {}", r.verdict, truncate(&r.stdout, 200))), + tokens_input: round_input_tokens, + tokens_output: round_output_tokens, + duration_ms, + }; + round_records.push(record); + + // ── Check termination ────────────────────────────────────── + if all_success && !reasoning.plan.is_empty() { + let summary = format!( + "Task completed in {} round(s). Last action: {}", + round, + round_result + .as_ref() + .map(|r| truncate(&r.stdout, 500)) + .unwrap_or_default() + ); + on_step(OraoStep::Completed { + total_rounds: round, + summary: summary.clone(), + }) + .await; + return Ok(OraoOutcome::Completed { + summary, + rounds: round, + records: round_records, + }); + } + + // Max rounds exceeded + if round >= orao_config.max_rounds { + let reason = format!("Reached max rounds ({})", orao_config.max_rounds); + on_step(OraoStep::Failed { + total_rounds: round, + reason: reason.clone(), + }) + .await; + return Ok(OraoOutcome::Failed { + reason, + rounds: round, + records: round_records, + }); + } + + // Prepare for next round + previous_result = round_result; + } + } +} + +// ── Outcome ───────────────────────────────────────────────────────────────── + +/// Final outcome of an ORAO execution. +#[derive(Debug, Clone)] +pub enum OraoOutcome { + /// Task completed successfully. + Completed { + summary: String, + rounds: usize, + records: Vec, + }, + /// Task failed (max rounds, deadlock, or unrecoverable error). + Failed { + reason: String, + rounds: usize, + records: Vec, + }, + /// User cancelled the task (plan mode rejection or explicit interrupt). + Cancelled { + rounds: usize, + records: Vec, + }, +} + +impl OraoOutcome { + /// Number of rounds executed. + pub fn rounds(&self) -> usize { + match self { + Self::Completed { rounds, .. } + | Self::Failed { rounds, .. } + | Self::Cancelled { rounds, .. } => *rounds, + } + } + + /// Whether the task was successful. + pub fn is_success(&self) -> bool { + matches!(self, Self::Completed { .. }) + } + + /// Round records for audit/debugging. + pub fn records(&self) -> &[RoundRecord] { + match self { + Self::Completed { records, .. } + | Self::Failed { records, .. } + | Self::Cancelled { records, .. } => records, + } + } +} + +// ── Helpers ───────────────────────────────────────────────────────────────── + +fn summarize_snapshot(snapshot: &PerceptionSnapshot) -> String { + let mut parts: Vec = Vec::new(); + + if let Some(ref gs) = snapshot.git_status { + let first_line = gs.lines().next().unwrap_or(""); + parts.push(format!("git: {}", truncate(first_line, 80))); + } + + if !snapshot.files.is_empty() { + parts.push(format!("{} files", snapshot.files.len())); + } + + if !snapshot.errors.is_empty() { + parts.push(format!("{} errors", snapshot.errors.len())); + } + + if parts.is_empty() { + "no changes".to_string() + } else { + parts.join(", ") + } +} + +fn truncate(s: &str, max_len: usize) -> String { + if s.len() <= max_len { + s.to_string() + } else { + format!("{}...", &s[..max_len]) + } +} + +// ── Convenience builder ───────────────────────────────────────────────────── + +/// Builder for [`OraoExecutor`] with chainable configuration. +pub struct OraoExecutorBuilder { + config: Option, + model_name: Option, + action_executor: Option, +} + +impl OraoExecutorBuilder { + pub fn new() -> Self { + Self { + config: None, + model_name: None, + action_executor: None, + } + } + + pub fn ai_config(mut self, config: AiClientConfig) -> Self { + self.config = Some(config); + self + } + + pub fn model(mut self, name: impl Into) -> Self { + self.model_name = Some(name.into()); + self + } + + pub fn action_executor(mut self, executor: ActionExecutor) -> Self { + self.action_executor = Some(executor); + self + } + + pub fn build(self) -> Result { + let config = self.config.ok_or_else(|| AgentError::InvalidInput { + field: "config".to_string(), + reason: "AI client config is required".to_string(), + })?; + let model_name = self.model_name.ok_or_else(|| AgentError::InvalidInput { + field: "model_name".to_string(), + reason: "Model name is required".to_string(), + })?; + let action_executor = self + .action_executor + .ok_or_else(|| AgentError::InvalidInput { + field: "action_executor".to_string(), + reason: "Action executor is required".to_string(), + })?; + + Ok(OraoExecutor::new(config, model_name, action_executor)) + } +} + +impl Default for OraoExecutorBuilder { + fn default() -> Self { + Self::new() + } +} diff --git a/libs/agent/orao/observe.rs b/libs/agent/orao/observe.rs new file mode 100644 index 0000000..62aad04 --- /dev/null +++ b/libs/agent/orao/observe.rs @@ -0,0 +1,280 @@ +//! Observe phase: LLM-driven multi-channel environment perception. +//! +//! The Observe phase gives the LLM a set of read-only observation tools and +//! instructs it to explore the environment. All file/git/system access goes +//! through function calls (tools), never direct filesystem operations. +//! +//! After exploration, the LLM produces a structured [`PerceptionSnapshot`] +//! summarizing the current state of the project. + +use rig::agent::AgentBuilder; +use rig::client::CompletionClient; +use rig::completion::Prompt; + +use crate::client::AiClientConfig; +use crate::error::AgentError; + +use super::types::{ActionResult, PerceptionSnapshot}; + +/// Prompt for the ORAO Observe phase. +const OBSERVE_SYSTEM_PROMPT: &str = r#"You are an expert software engineering agent using the ORAO (Observe-Reason-Act-Observe) framework. + +## Your Role: OBSERVE Phase + +You are currently in the OBSERVE phase. Your task is to explore the project environment +and gather all relevant information using the available tools. + +## What to Observe + +Use the tools provided to you to check: + +1. **Git status**: What branch are we on? What files have changed? Any uncommitted work? +2. **Project structure**: What directories and key files exist? +3. **Code content**: Read relevant source files to understand the codebase state. +4. **Errors/warnings**: Check build output, test results, linter output for issues. +5. **Configuration**: Check project config files (Cargo.toml, package.json, etc.) if relevant. + +## Rules + +- Use tools to explore — do NOT guess or assume file contents. +- Focus on information relevant to the task at hand. +- Be thorough but efficient: 3-8 tool calls is typical. +- After gathering information, summarize your findings clearly. + +## Output Format + +After you have finished observing, provide a summary with these sections: + +### Git Status +[Current branch, changed files, commit status] + +### Project Structure +[Key directories and files relevant to the task] + +### Key Files +[Important files you read, with brief notes on their content] + +### Errors / Issues +[Any errors, warnings, or problems detected] + +### Previous Action Result +[If a previous action was executed, describe its outcome]"#; + +/// Run the Observe phase: let the LLM explore the environment via tools. +/// +/// Returns a structured [`PerceptionSnapshot`] built from the LLM's observations. +/// All environment access goes through the provided `tools` — no direct +/// filesystem operations. +/// +/// Takes ownership of `tools` (caller must clone if they need to reuse them). +pub async fn observe( + config: &AiClientConfig, + model_name: &str, + task_goal: &str, + previous_result: Option, + tools: Vec>, + max_turns: usize, +) -> Result { + let user_prompt = build_observe_prompt(task_goal, previous_result.as_ref()); + + let client = config.build_rig_client(); + let model = client.completion_model(model_name); + + let agent = AgentBuilder::new(model) + .preamble(OBSERVE_SYSTEM_PROMPT) + .tools(tools) + .default_max_turns(max_turns) + .build(); + + let response = agent + .prompt(&user_prompt) + .max_turns(max_turns) + .extended_details() + .await + .map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?; + + // Build snapshot from the LLM's final summary + let summary = response.output; + let snapshot = parse_observation_summary(&summary, previous_result); + + Ok(snapshot) +} + +/// Build the user prompt for the Observe phase. +fn build_observe_prompt(task_goal: &str, previous_result: Option<&ActionResult>) -> String { + let mut prompt = format!( + "## Task Goal\n\n{}\n\n## Instructions\n\n\ + Explore the project environment using the available tools. \ + Gather all information relevant to the task above. \ + After you have gathered sufficient information, provide a structured summary.", + task_goal + ); + + if let Some(prev) = previous_result { + prompt.push_str(&format!( + "\n\n## Previous Action Result\n\n\ + - Action: {}\n\ + - Verdict: {:?}\n\ + - Exit code: {:?}\n\ + - stdout: {}\n\ + - stderr: {}", + prev.action.description, + prev.verdict, + prev.exit_code, + truncate_str(&prev.stdout, 2000), + truncate_str(&prev.stderr, 2000), + )); + } + + prompt +} + +/// Parse the LLM's observation summary into a structured snapshot. +fn parse_observation_summary( + summary: &str, + previous_result: Option, +) -> PerceptionSnapshot { + let mut snapshot = PerceptionSnapshot::default(); + + // Extract sections from the markdown summary + let mut current_section = ""; + let mut section_content: Vec<&str> = Vec::new(); + + for line in summary.lines() { + if line.starts_with("### ") { + // Save previous section + store_section(&mut snapshot, current_section, §ion_content); + current_section = line.trim_start_matches("### ").trim(); + section_content.clear(); + } else { + section_content.push(line); + } + } + // Save last section + store_section(&mut snapshot, current_section, §ion_content); + + snapshot.previous_action_result = previous_result; + + // If no structured data was parsed, store the raw summary + if snapshot.git_status.is_none() + && snapshot.project_structure.is_none() + && snapshot.files.is_empty() + && snapshot.errors.is_empty() + { + snapshot.notes.insert( + "raw_observation".to_string(), + summary.to_string(), + ); + } + + snapshot +} + +fn store_section(snapshot: &mut PerceptionSnapshot, section: &str, content: &[&str]) { + let text = content.join("\n").trim().to_string(); + if text.is_empty() { + return; + } + + match section.to_lowercase().as_str() { + s if s.contains("git") => { + snapshot.git_status = Some(text); + } + s if s.contains("project") && s.contains("structure") => { + snapshot.project_structure = Some(text); + } + s if s.contains("file") => { + // Parse file references from the text + for line in content { + let line = line.trim(); + if let Some(path) = extract_file_path(line) { + snapshot.files.push(super::types::PerceivedFile { + path, + size_bytes: 0, + content_preview: None, + }); + } + } + } + s if s.contains("error") || s.contains("issue") || s.contains("warning") => { + for line in content { + let line = line.trim(); + if !line.is_empty() && !line.starts_with('#') { + snapshot.errors.push(line.to_string()); + } + } + } + _ => { + // Store unknown sections as notes + snapshot + .notes + .insert(section.to_string(), text); + } + } +} + +/// Extract a file path from a markdown list item or code reference. +fn extract_file_path(line: &str) -> Option { + // Match patterns like: - `src/main.rs` or - src/main.rs or `src/main.rs` + let line = line.trim(); + + // Backtick-wrapped path + if let Some(start) = line.find('`') { + let rest = &line[start + 1..]; + if let Some(end) = rest.find('`') { + let path = rest[..end].to_string(); + if path.contains('.') || path.contains('/') || path.contains('\\') { + return Some(path); + } + } + } + + // Bare path pattern (word chars, slashes, dots) + if line.starts_with('-') || line.starts_with('*') { + let rest = line.trim_start_matches(&['-', '*', ' ']); + if rest.contains('/') || (rest.contains('.') && !rest.starts_with("http")) { + return Some(rest.to_string()); + } + } + + None +} + +fn truncate_str(s: &str, max_len: usize) -> String { + if s.len() <= max_len { + s.to_string() + } else { + format!("{}...", &s[..max_len]) + } +} + +/// Determine whether the environment has changed since the last snapshot. +/// +/// Used for deadlock detection: if 3 consecutive rounds show no change, +/// the loop is terminated. +pub fn has_environment_changed( + previous: &PerceptionSnapshot, + current: &PerceptionSnapshot, +) -> bool { + if previous.git_status != current.git_status { + return true; + } + + let prev_files: Vec<&str> = previous.files.iter().map(|f| f.path.as_str()).collect(); + let curr_files: Vec<&str> = current.files.iter().map(|f| f.path.as_str()).collect(); + if prev_files != curr_files { + return true; + } + + if previous.errors != current.errors { + return true; + } + + let prev_has_result = previous.previous_action_result.is_some(); + let curr_has_result = current.previous_action_result.is_some(); + if prev_has_result != curr_has_result { + return true; + } + + false +} \ No newline at end of file diff --git a/libs/agent/orao/reason.rs b/libs/agent/orao/reason.rs new file mode 100644 index 0000000..9ea09c9 --- /dev/null +++ b/libs/agent/orao/reason.rs @@ -0,0 +1,218 @@ +//! Reason phase: structured reasoning and plan generation via LLM. +//! +//! Takes the current [`PerceptionSnapshot`] and the task goal, sends them to +//! the configured LLM, and produces a structured [`ReasoningOutput`] with an +//! executable plan. + +use crate::client::AiClientConfig; +use crate::error::AgentError; +use rig::agent::AgentBuilder; +use rig::client::CompletionClient; +use rig::completion::Prompt; + +use super::types::{OraoConfig, PerceptionSnapshot, ReasoningOutput}; + +/// Prompt template for the ORAO reasoning phase. +const REASON_SYSTEM_PROMPT: &str = r#"You are an expert software engineering agent using the ORAO (Observe-Reason-Act-Observe) framework. + +## Your Role: REASON Phase + +You are currently in the REASON phase. You will receive: +1. A **task goal** — what needs to be accomplished +2. An **observation snapshot** — the current state of the environment + +Your job is to produce a structured analysis and a step-by-step action plan. + +## Output Format + +You MUST respond with a valid JSON object matching this schema: + +```json +{ + "analysis": "", + "plan": [ + { + "step_id": 1, + "description": "", + "action_type": "shell_command | file_write | file_edit | git_operation | tool_invoke | user_dialog", + "command_or_content": "", + "expected_result": "", + "fallback_on_failure": "" + } + ] +} +``` + +## Rules + +1. **Be specific**: commands must be exact and complete (include necessary flags, paths, etc.) +2. **One action per step**: each step should do one thing +3. **Validate assumptions**: don't assume dependencies are installed or files exist +4. **Prefer small, safe steps**: each change should be minimal and reversible +5. **Include verification**: build/test after changes to verify correctness +6. **Plan size**: 1-10 steps is typical. Don't plan more than 15 steps without asking the user. +7. **Fallbacks**: for risky steps, specify what to try if the step fails + +## Safety + +- Prefer read-only verification before making changes +- Use version control (git) for reversible operations +- Flag any step that requires network access or system-level privileges"#; + +/// Run the reasoning phase: send observation + task to the LLM and get a plan. +pub async fn reason( + config: &AiClientConfig, + model_name: &str, + _orao_config: &OraoConfig, + task_goal: &str, + snapshot: &PerceptionSnapshot, + round: usize, + history_rounds: &[super::types::RoundRecord], +) -> Result { + let user_prompt = build_reason_prompt(task_goal, snapshot, round, history_rounds); + + let client = config.build_rig_client(); + let model = client.completion_model(model_name); + + let agent = AgentBuilder::new(model) + .preamble(REASON_SYSTEM_PROMPT) + .build(); + + let response = agent + .prompt(&user_prompt) + .extended_details() + .await + .map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?; + + let output = response.output.trim().to_string(); + + // Extract JSON from the response (it may be wrapped in ```json fences) + let json_str = extract_json(&output); + + serde_json::from_str::(json_str).map_err(|e| { + AgentError::Internal(format!( + "Failed to parse reasoning output as JSON: {}. Raw output: {}", + e, output + )) + }) +} + +/// Build the user prompt for the reason phase. +fn build_reason_prompt( + task_goal: &str, + snapshot: &PerceptionSnapshot, + round: usize, + history_rounds: &[super::types::RoundRecord], +) -> String { + let mut prompt = format!( + "## Task Goal\n\n{}\n\n## Round Number\n\n{}\n\n", + task_goal, round + ); + + // ── Observation snapshot ──────────────────────────────────────────── + prompt.push_str("## Current Observation\n\n"); + + if let Some(ref git_status) = snapshot.git_status { + prompt.push_str(&format!("### Git Status\n```\n{}\n```\n\n", git_status)); + } + + if let Some(ref project_structure) = snapshot.project_structure { + prompt.push_str(&format!( + "### Project Structure\n```\n{}\n```\n\n", + project_structure + )); + } + + if !snapshot.files.is_empty() { + prompt.push_str("### Relevant Files\n\n"); + for file in &snapshot.files { + prompt.push_str(&format!( + "- `{}` ({} bytes)\n", + file.path, file.size_bytes + )); + if let Some(ref preview) = file.content_preview { + let truncated = if preview.len() > 2000 { + format!("{}... [truncated]", &preview[..2000]) + } else { + preview.clone() + }; + prompt.push_str(&format!("```\n{}\n```\n\n", truncated)); + } + } + } + + if !snapshot.errors.is_empty() { + prompt.push_str("### Errors / Warnings\n\n"); + for err in &snapshot.errors { + prompt.push_str(&format!("- {}\n", err)); + } + prompt.push('\n'); + } + + if let Some(ref prev_result) = snapshot.previous_action_result { + prompt.push_str(&format!( + "### Previous Action Result\n- Verdict: {:?}\n- Exit code: {:?}\n- stdout: {}\n- stderr: {}\n\n", + prev_result.verdict, + prev_result.exit_code, + truncate_str(&prev_result.stdout, 1000), + truncate_str(&prev_result.stderr, 1000), + )); + } + + // ── History summary ───────────────────────────────────────────────── + if !history_rounds.is_empty() { + prompt.push_str("## Previous Rounds\n\n"); + for record in history_rounds.iter().rev().take(3) { + prompt.push_str(&format!( + "- Round {}: {} | {} tokens | {}ms\n", + record.round, + truncate_str(&record.reasoning_summary, 120), + record.tokens_input + record.tokens_output, + record.duration_ms, + )); + } + prompt.push('\n'); + } + + prompt.push_str( + "## Instructions\n\nProduce a structured analysis and action plan in JSON format as specified.", + ); + + prompt +} + +/// Extract JSON content from a string that may be wrapped in markdown fences. +fn extract_json(s: &str) -> &str { + let s = s.trim(); + + if s.starts_with("```json") { + let inner = &s[7..]; + if let Some(end) = inner.rfind("```") { + return inner[..end].trim(); + } + return inner.trim(); + } + + if s.starts_with("```") { + let inner = &s[3..]; + if let Some(end) = inner.rfind("```") { + return inner[..end].trim(); + } + return inner.trim(); + } + + // Heuristic: find first { and last } + if let (Some(start), Some(end)) = (s.find('{'), s.rfind('}')) { + return &s[start..=end]; + } + + s +} + +fn truncate_str(s: &str, max_len: usize) -> String { + if s.len() <= max_len { + s.to_string() + } else { + format!("{}...", &s[..max_len]) + } +} \ No newline at end of file diff --git a/libs/agent/orao/types.rs b/libs/agent/orao/types.rs new file mode 100644 index 0000000..3ab3433 --- /dev/null +++ b/libs/agent/orao/types.rs @@ -0,0 +1,337 @@ +//! ORAO core types. +//! +//! ORAO (Observe–Reason–Act–Observe) is a single-agent loop paradigm for complex +//! engineering tasks. It extends ReAct with structured multi-channel perception, +//! safety permission levels, plan mode, and deadlock detection. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +// ── Safety levels ─────────────────────────────────────────────────────────── + +/// Permission level for actions executed by ORAO. +/// +/// L0 (read-only) → auto-allow. +/// L1 (local write) → confirm on first use. +/// L2 (build) → confirm on first use. +/// L3 (network) → explicit user approval required. +/// L4 (system) → denied by default. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum SafetyLevel { + /// L0 — Read-only: `ls`, `cat`, `grep`, `git status` + ReadOnly = 0, + /// L1 — Local write: edit source files, create new files + LocalWrite = 1, + /// L2 — Build/test: `cargo build`, `npm test` + Build = 2, + /// L3 — Network: `pip install`, `curl` + Network = 3, + /// L4 — System: `sudo`, global config changes + System = 4, +} + +impl SafetyLevel { + /// Classify a shell command into a safety level. + pub fn classify_command(cmd: &str) -> Self { + let cmd_trimmed = cmd.trim(); + // L0: read-only commands + let l0_prefixes = [ + "ls", "cat", "head", "tail", "less", "file", "stat", "wc", + "grep", "rg", "find", "which", "type", "echo", "printf", + "pwd", "env", "printenv", "date", "uname", "hostname", + "git status", "git log", "git diff", "git show", "git branch", + "git tag", "git remote", "git config --get", "git blame", + "cargo metadata", "cargo tree", "cargo read-manifest", + "tree", "du", "df", + ]; + for p in &l0_prefixes { + if cmd_trimmed.starts_with(p) { + return Self::ReadOnly; + } + } + + // L4: system-level commands (denied by default) + let l4_patterns = [ + "sudo", "su ", "chown", "chmod 777", "mkfs", "mkswap", + "mount", "umount", "fdisk", "parted", "dd if=", + "systemctl", "service ", "chkconfig", "update-rc.d", + "passwd", "useradd", "userdel", "usermod", "groupadd", + "iptables", "ufw", "firewall-cmd", + "shutdown", "reboot", "halt", "poweroff", + "rm -rf /", "rm -rf ~", "rm -rf .", ":(){ :|:& };:", + ]; + for p in &l4_patterns { + if cmd_trimmed.starts_with(p) || cmd_trimmed.contains(p) { + return Self::System; + } + } + + // L3: network commands + let l3_prefixes = [ + "curl", "wget", "nc ", "ncat", "telnet", "ssh ", "scp", + "rsync", "pip install", "pip3 install", "npm install", + "npm i ", "yarn add", "cargo install", "gem install", + "go get", "go install", "apt-get", "apt ", "yum ", "dnf ", + "brew ", "pacman ", "zypper", "docker pull", "docker run", + "git clone", "git fetch", "git push", "git pull", + "gh ", "glab ", "aws ", "gcloud ", "az ", + ]; + for p in &l3_prefixes { + if cmd_trimmed.starts_with(p) { + return Self::Network; + } + } + + // L2: build/test commands + let l2_prefixes = [ + "cargo build", "cargo test", "cargo check", "cargo clippy", + "cargo fmt", "cargo run", "cargo bench", "cargo doc", + "npm test", "npm run", "npx ", "yarn test", "yarn run", + "pnpm test", "pnpm run", "bun test", "bun run", + "make", "cmake", "ninja", "meson", "bazel", + "pytest", "python -m pytest", "python3 -m pytest", + "go test", "go build", "go vet", "go fmt", + "rustc", "rustfmt", "clippy", "miri", + "eslint", "prettier", "tsc", "jest", "vitest", + "docker build", "docker compose", "docker-compose", + "kubectl apply", "kubectl delete", "helm ", + ]; + for p in &l2_prefixes { + if cmd_trimmed.starts_with(p) { + return Self::Build; + } + } + + // Default to L1 (local write) for anything else + Self::LocalWrite + } +} + +// ── Action types ──────────────────────────────────────────────────────────── + +/// The type of action to execute. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ActionType { + /// Execute a shell command in a controlled terminal. + ShellCommand, + /// Create or overwrite a file. + FileWrite, + /// Make a localized edit to an existing file. + FileEdit, + /// Version-control operation (commit, add, etc.). + GitOperation, + /// Invoke an external tool or API. + ToolInvoke, + /// Ask the user for input or a decision. + UserDialog, +} + +// ── Action plan ───────────────────────────────────────────────────────────── + +/// A single planned action from the reasoning phase. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlannedAction { + /// Step number within the plan. + pub step_id: usize, + /// Human-readable description. + pub description: String, + /// The type of action. + pub action_type: ActionType, + /// The command or content to execute/write. + pub command_or_content: String, + /// What success should look like. + pub expected_result: String, + /// What to try if this step fails. + pub fallback_on_failure: Option, +} + +/// Structured reasoning output from the Reason phase. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReasoningOutput { + /// Analysis of the current state. + pub analysis: String, + /// The plan to execute. + pub plan: Vec, +} + +// ── Perception snapshot ───────────────────────────────────────────────────── + +/// Structured observation collected during the Observe phase. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct PerceptionSnapshot { + /// Project directory tree summary. + pub project_structure: Option, + /// Relevant file paths and contents. + pub files: Vec, + /// Current errors/warnings in the environment. + pub errors: Vec, + /// Git status summary. + pub git_status: Option, + /// Result of the previous action (if any). + pub previous_action_result: Option, + /// Free-form context notes. + pub notes: HashMap, +} + +/// A file observed during perception. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PerceivedFile { + pub path: String, + pub size_bytes: u64, + pub content_preview: Option, +} + +// ── Action result ─────────────────────────────────────────────────────────── + +/// The result of executing an action. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ActionResult { + /// The action that was executed. + pub action: PlannedAction, + /// Exit code (0 = success for shell commands). + pub exit_code: Option, + /// Captured stdout. + pub stdout: String, + /// Captured stderr. + pub stderr: String, + /// Summary of file changes (if applicable). + pub file_changes: Vec, + /// Preliminary assessment. + pub verdict: ActionVerdict, +} + +/// A file change detected after an action. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FileChange { + pub path: String, + pub change_type: FileChangeType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum FileChangeType { + Created, + Modified, + Deleted, +} + +/// Preliminary verdict on an action's outcome. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ActionVerdict { + Success, + SuccessWithWarnings, + Failure, +} + +// ── ORAO step events ──────────────────────────────────────────────────────── + +/// A single event emitted during an ORAO round, analogous to `ReactStep`. +/// +/// These are yielded via the streaming callback so the caller can persist +/// them or forward them to a frontend. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum OraoStep { + /// Initial observation: environment snapshot before any action. + Observe { + round: usize, + snapshot: PerceptionSnapshot, + }, + /// The reasoning/analysis output, including the plan. + Reason { + round: usize, + reasoning: ReasoningOutput, + }, + /// An action is about to be executed. + Act { + round: usize, + action: PlannedAction, + safety_level: SafetyLevel, + }, + /// The result observed after executing an action. + ObserveResult { + round: usize, + result: ActionResult, + }, + /// Plan mode: a plan has been generated and is awaiting user approval. + PlanProposed { + round: usize, + reasoning: ReasoningOutput, + }, + /// The task completed successfully. + Completed { + total_rounds: usize, + summary: String, + }, + /// The task failed (max rounds, deadlock, or explicit failure). + Failed { + total_rounds: usize, + reason: String, + }, +} + +// ── Round record (audit) ──────────────────────────────────────────────────── + +/// A persistent record of one ORAO round, used for audit and resumption. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoundRecord { + /// Round number (1-indexed). + pub round: usize, + /// Summary of the Observe phase. + pub observe_summary: String, + /// Summary of the Reasoning phase. + pub reasoning_summary: String, + /// The action that was executed. + pub action: Option, + /// Result observed after the action. + pub result_summary: Option, + /// Tokens consumed this round. + pub tokens_input: u64, + pub tokens_output: u64, + /// Wall-clock duration of this round in milliseconds. + pub duration_ms: u64, +} + +// ── ORAO configuration ────────────────────────────────────────────────────── + +/// Configuration for an ORAO execution. +#[derive(Clone)] +pub struct OraoConfig { + /// Maximum number of ORAO rounds before giving up. + pub max_rounds: usize, + /// Maximum allowed safety level. Actions above this level are denied. + pub max_safety_level: SafetyLevel, + /// Whether to run in plan mode (generate plan first, wait for approval). + pub plan_mode: bool, + /// Whether to enable extended thinking for the reasoning phase. + pub extended_thinking: bool, + /// Per-action timeout in seconds. + pub action_timeout_secs: u64, + /// Number of consecutive no-change rounds before deadlock detection triggers. + pub deadlock_threshold: usize, +} + +impl Default for OraoConfig { + fn default() -> Self { + Self { + max_rounds: 50, + max_safety_level: SafetyLevel::Network, + plan_mode: false, + extended_thinking: false, + action_timeout_secs: 120, + deadlock_threshold: 3, + } + } +} + +impl std::fmt::Debug for OraoConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OraoConfig") + .field("max_rounds", &self.max_rounds) + .field("max_safety_level", &self.max_safety_level) + .field("plan_mode", &self.plan_mode) + .field("extended_thinking", &self.extended_thinking) + .field("action_timeout_secs", &self.action_timeout_secs) + .field("deadlock_threshold", &self.deadlock_threshold) + .finish() + } +} \ No newline at end of file diff --git a/libs/agent/task/events.rs b/libs/agent/task/events.rs new file mode 100644 index 0000000..d67e9ed --- /dev/null +++ b/libs/agent/task/events.rs @@ -0,0 +1,171 @@ +use models::agent_task::TaskStatus; +use serde::Serialize; +use std::sync::Arc; + +/// Event payload published to WebSocket clients via Redis Pub/Sub. +#[derive(Debug, Clone, Serialize)] +pub struct TaskEvent { + pub task_id: i64, + pub project_id: uuid::Uuid, + pub parent_id: Option, + pub event: String, + pub message: Option, + pub output: Option, + pub error: Option, + pub status: String, +} + +impl TaskEvent { + pub fn started(task_id: i64, project_id: uuid::Uuid, parent_id: Option) -> Self { + Self { + task_id, + project_id, + parent_id, + event: "started".to_string(), + message: None, + output: None, + error: None, + status: TaskStatus::Running.to_string(), + } + } + + pub fn progress( + task_id: i64, + project_id: uuid::Uuid, + parent_id: Option, + msg: String, + ) -> Self { + Self { + task_id, + project_id, + parent_id, + event: "progress".to_string(), + message: Some(msg), + output: None, + error: None, + status: TaskStatus::Running.to_string(), + } + } + + pub fn completed( + task_id: i64, + project_id: uuid::Uuid, + parent_id: Option, + output: String, + ) -> Self { + Self { + task_id, + project_id, + parent_id, + event: "done".to_string(), + message: None, + output: Some(output), + error: None, + status: TaskStatus::Done.to_string(), + } + } + + pub fn failed( + task_id: i64, + project_id: uuid::Uuid, + parent_id: Option, + error: String, + ) -> Self { + Self { + task_id, + project_id, + parent_id, + event: "failed".to_string(), + message: None, + output: None, + error: Some(error), + status: TaskStatus::Failed.to_string(), + } + } + + pub fn cancelled(task_id: i64, project_id: uuid::Uuid, parent_id: Option) -> Self { + Self { + task_id, + project_id, + parent_id, + event: "cancelled".to_string(), + message: None, + output: None, + error: None, + status: TaskStatus::Cancelled.to_string(), + } + } +} + +/// Helper trait for publishing task lifecycle events via Redis Pub/Sub. +/// +/// Callers inject a suitable `publish_fn` at construction time via +/// `TaskEvents::new(...)`. If no publisher is supplied events are silently +/// dropped (graceful degradation on startup). +pub trait TaskEventPublisher: Send + Sync { + fn publish(&self, project_id: uuid::Uuid, event: TaskEvent); +} + +/// No-op publisher used when no Redis Pub/Sub connection is available. +#[derive(Clone, Default)] +pub struct NoOpPublisher; + +impl TaskEventPublisher for NoOpPublisher { + fn publish(&self, _: uuid::Uuid, _: TaskEvent) {} +} + +#[derive(Clone)] +pub struct TaskEvents { + publisher: Arc, +} + +impl TaskEvents { + pub fn new(publisher: impl TaskEventPublisher + 'static) -> Self { + Self { + publisher: Arc::new(publisher), + } + } + + pub fn noop() -> Self { + Self::new(NoOpPublisher) + } + + fn emit(&self, task: &models::agent_task::Model, event: TaskEvent) { + self.publisher.publish(task.project_uuid, event); + } + + pub fn emit_started(&self, task: &models::agent_task::Model) { + self.emit( + task, + TaskEvent::started(task.id, task.project_uuid, task.parent_id), + ); + } + + pub fn emit_progress(&self, task: &models::agent_task::Model, msg: String) { + self.emit( + task, + TaskEvent::progress(task.id, task.project_uuid, task.parent_id, msg), + ); + } + + pub fn emit_completed(&self, task: &models::agent_task::Model, output: String) { + self.emit( + task, + TaskEvent::completed(task.id, task.project_uuid, task.parent_id, output), + ); + } + + pub fn emit_failed(&self, task: &models::agent_task::Model, error: String) { + self.emit( + task, + TaskEvent::failed(task.id, task.project_uuid, task.parent_id, error), + ); + } + + pub fn emit_cancelled(&self, task: &models::agent_task::Model) { + self.emit( + task, + TaskEvent::cancelled(task.id, task.project_uuid, task.parent_id), + ); + } +} \ No newline at end of file diff --git a/libs/agent/task/lifecycle.rs b/libs/agent/task/lifecycle.rs new file mode 100644 index 0000000..f88671a --- /dev/null +++ b/libs/agent/task/lifecycle.rs @@ -0,0 +1,192 @@ +use models::agent_task::{ActiveModel, Column as C, Entity, Model, TaskStatus}; +use sea_orm::{ActiveModelTrait, ColumnTrait, DbErr, EntityTrait, QueryFilter}; + +pub struct TaskLifecycle; + +impl super::TaskService { + /// Mark a task as running and record the start time. + pub async fn start(&self, task_id: i64) -> Result { + let model = Entity::find_by_id(task_id).one(self.db()).await?; + let model = + model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; + + let mut active: ActiveModel = model.into(); + active.status = sea_orm::Set(TaskStatus::Running); + active.started_at = sea_orm::Set(Some(chrono::Utc::now().into())); + active.updated_at = sea_orm::Set(chrono::Utc::now().into()); + let updated = active.update(self.db()).await?; + self.events().emit_started(&updated); + Ok(updated) + } + + /// Update progress text (e.g., "step 2/5: analyzing PR"). + pub async fn update_progress( + &self, + task_id: i64, + progress: impl Into, + ) -> Result<(), DbErr> { + let model = Entity::find_by_id(task_id).one(self.db()).await?; + let model = + model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; + + let progress_str = progress.into(); + let mut active: ActiveModel = model.into(); + active.progress = sea_orm::Set(Some(progress_str.clone())); + active.updated_at = sea_orm::Set(chrono::Utc::now().into()); + let updated = active.update(self.db()).await?; + self.events().emit_progress(&updated, progress_str); + Ok(()) + } + + /// Mark a task as completed with the output text. + pub async fn complete(&self, task_id: i64, output: impl Into) -> Result { + let model = Entity::find_by_id(task_id).one(self.db()).await?; + let model = + model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; + + let mut active: ActiveModel = model.into(); + active.status = sea_orm::Set(TaskStatus::Done); + let out = output.into(); + active.output = sea_orm::Set(Some(out.clone())); + active.done_at = sea_orm::Set(Some(chrono::Utc::now().into())); + active.updated_at = sea_orm::Set(chrono::Utc::now().into()); + let updated = active.update(self.db()).await?; + self.events().emit_completed(&updated, out); + Ok(updated) + } + + /// Mark a task as failed with an error message. + pub async fn fail(&self, task_id: i64, error: impl Into) -> Result { + let model = Entity::find_by_id(task_id).one(self.db()).await?; + let model = + model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; + + let mut active: ActiveModel = model.into(); + active.status = sea_orm::Set(TaskStatus::Failed); + let err = error.into(); + active.error = sea_orm::Set(Some(err.clone())); + active.done_at = sea_orm::Set(Some(chrono::Utc::now().into())); + active.updated_at = sea_orm::Set(chrono::Utc::now().into()); + let updated = active.update(self.db()).await?; + self.events().emit_failed(&updated, err); + Ok(updated) + } + + /// Propagate child task status up the tree. + /// + /// Only allows cancelling tasks that are not yet in a terminal state + /// (Pending / Running / Paused). + /// + /// Cancelled children are marked done so that `are_children_done()` returns + /// true for the parent after cancellation. + pub async fn cancel(&self, task_id: i64) -> Result { + // Collect all task IDs (parent + descendants) using an explicit stack. + let mut stack = vec![task_id]; + let mut idx = 0; + while idx < stack.len() { + let current = stack[idx]; + let children = Entity::find() + .filter(C::ParentId.eq(current)) + .all(self.db()) + .await?; + for child in children { + stack.push(child.id); + } + idx += 1; + } + + // Mark every collected task as cancelled (terminal state). + for id in &stack { + let model = Entity::find_by_id(*id).one(self.db()).await?; + if let Some(m) = model { + if !m.is_done() { + let mut active: ActiveModel = m.into(); + active.status = sea_orm::Set(TaskStatus::Cancelled); + active.done_at = sea_orm::Set(Some(chrono::Utc::now().into())); + active.updated_at = sea_orm::Set(chrono::Utc::now().into()); + active.update(self.db()).await?; + } + } + } + + let final_model = Entity::find_by_id(task_id) + .one(self.db()) + .await? + .ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; + self.events().emit_cancelled(&final_model); + Ok(final_model) + } + + /// Pause a running or pending task. + /// + /// Pausing a task that is not Pending/Running is a no-op that returns + /// the current model (same behaviour as `start` on an already-running task). + pub async fn pause(&self, task_id: i64) -> Result { + let model = Entity::find_by_id(task_id).one(self.db()).await?; + let model = + model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; + + if !model.is_running() { + // Already in a terminal or paused state — return unchanged. + return Ok(model); + } + + let mut active: ActiveModel = model.into(); + active.status = sea_orm::Set(TaskStatus::Paused); + active.updated_at = sea_orm::Set(chrono::Utc::now().into()); + active.update(self.db()).await + } + + /// Resume a paused task back to Running. + /// + /// Returns an error if the task is not currently Paused. + pub async fn resume(&self, task_id: i64) -> Result { + let model = Entity::find_by_id(task_id).one(self.db()).await?; + let model = + model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; + + if model.status != TaskStatus::Paused { + return Err(DbErr::Custom(format!( + "cannot resume task {}: expected status Paused, got {}", + task_id, model.status + ))); + } + + let mut active: ActiveModel = model.into(); + active.status = sea_orm::Set(TaskStatus::Running); + active.updated_at = sea_orm::Set(chrono::Utc::now().into()); + active.update(self.db()).await + } + + /// Retry a failed or cancelled task by resetting it to Pending. + /// + /// Clears `output`, `error`, and `done_at`; increments `retry_count`. + /// Only tasks in Failed or Cancelled state can be retried. + pub async fn retry(&self, task_id: i64) -> Result { + let model = Entity::find_by_id(task_id).one(self.db()).await?; + let model = + model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; + + match model.status { + TaskStatus::Failed | TaskStatus::Cancelled | TaskStatus::Done => {} + _ => { + return Err(DbErr::Custom(format!( + "cannot retry task {}: only Failed/Cancelled/Done tasks can be retried (got {})", + task_id, model.status + ))); + } + } + + let retry_count = model.retry_count.map(|c| c + 1).unwrap_or(1); + + let mut active: ActiveModel = model.into(); + active.status = sea_orm::Set(TaskStatus::Pending); + active.output = sea_orm::Set(None); + active.error = sea_orm::Set(None); + active.done_at = sea_orm::Set(None); + active.started_at = sea_orm::Set(None); + active.retry_count = sea_orm::Set(Some(retry_count)); + active.updated_at = sea_orm::Set(chrono::Utc::now().into()); + active.update(self.db()).await + } +} \ No newline at end of file diff --git a/libs/agent/task/mod.rs b/libs/agent/task/mod.rs index 37aa529..81469c5 100644 --- a/libs/agent/task/mod.rs +++ b/libs/agent/task/mod.rs @@ -1,4 +1,4 @@ -//! Agent task service — unified task/sub-agent execution framework. +//! Agent task service — managing task/sub-agent execution lifecycle. //! //! A task (`agent_task` record) can be: //! - A **root task**: initiated by a user or system event. @@ -17,6 +17,61 @@ //! This module is intentionally kept simple and synchronous with the DB. //! Long-running execution is delegated to the caller (tokio::spawn). -pub mod service; +pub mod events; +pub mod lifecycle; +pub mod store; +pub mod tree; -pub use service::TaskService; +use db::database::AppDatabase; + +pub use events::{NoOpPublisher, TaskEvent, TaskEventPublisher, TaskEvents}; +pub use lifecycle::TaskLifecycle; + +/// Service for managing agent tasks (root tasks and sub-tasks). +#[derive(Clone)] +pub struct TaskService { + db: AppDatabase, + events: TaskEvents, +} + +impl TaskService { + pub fn new(db: AppDatabase) -> Self { + Self { + db, + events: TaskEvents::noop(), + } + } + + pub fn with_events(db: AppDatabase, events: TaskEvents) -> Self { + Self { db, events } + } + + pub(crate) fn db(&self) -> &AppDatabase { + &self.db + } + + pub(crate) fn events(&self) -> &TaskEvents { + &self.events + } +} + +/// Builder for TaskService so that the events publisher can be set independently +/// of the database connection. +#[derive(Clone, Default)] +pub struct TaskServiceBuilder { + events: Option, +} + +impl TaskServiceBuilder { + pub fn with_events(mut self, events: TaskEvents) -> Self { + self.events = Some(events); + self + } + + pub async fn build(self, db: AppDatabase) -> TaskService { + TaskService { + db, + events: self.events.unwrap_or_else(TaskEvents::noop), + } + } +} \ No newline at end of file diff --git a/libs/agent/task/service.rs b/libs/agent/task/service.rs deleted file mode 100644 index 071c5fc..0000000 --- a/libs/agent/task/service.rs +++ /dev/null @@ -1,600 +0,0 @@ -//! Task service for creating, tracking, and executing agent tasks. -//! -//! All methods are async and interact with the database directly. -//! Execution of the task logic (running the ReAct loop, etc.) is delegated -//! to the caller — this service only manages task lifecycle and state. - -use db::database::AppDatabase; -use models::agent_task::{ActiveModel, AgentType, Column as C, Entity, Model, TaskStatus}; -use models::IssueId; -use sea_orm::{ - entity::EntityTrait, query::{QueryFilter, QueryOrder, QuerySelect}, ActiveModelTrait, - ColumnTrait, - DbErr, -}; -use serde::Serialize; -use std::sync::Arc; - -/// Event payload published to WebSocket clients via Redis Pub/Sub. -#[derive(Debug, Clone, Serialize)] -pub struct TaskEvent { - pub task_id: i64, - pub project_id: uuid::Uuid, - pub parent_id: Option, - pub event: String, - pub message: Option, - pub output: Option, - pub error: Option, - pub status: String, -} - -impl TaskEvent { - pub fn started(task_id: i64, project_id: uuid::Uuid, parent_id: Option) -> Self { - Self { - task_id, - project_id, - parent_id, - event: "started".to_string(), - message: None, - output: None, - error: None, - status: TaskStatus::Running.to_string(), - } - } - - pub fn progress( - task_id: i64, - project_id: uuid::Uuid, - parent_id: Option, - msg: String, - ) -> Self { - Self { - task_id, - project_id, - parent_id, - event: "progress".to_string(), - message: Some(msg), - output: None, - error: None, - status: TaskStatus::Running.to_string(), - } - } - - pub fn completed( - task_id: i64, - project_id: uuid::Uuid, - parent_id: Option, - output: String, - ) -> Self { - Self { - task_id, - project_id, - parent_id, - event: "done".to_string(), - message: None, - output: Some(output), - error: None, - status: TaskStatus::Done.to_string(), - } - } - - pub fn failed( - task_id: i64, - project_id: uuid::Uuid, - parent_id: Option, - error: String, - ) -> Self { - Self { - task_id, - project_id, - parent_id, - event: "failed".to_string(), - message: None, - output: None, - error: Some(error), - status: TaskStatus::Failed.to_string(), - } - } - - pub fn cancelled(task_id: i64, project_id: uuid::Uuid, parent_id: Option) -> Self { - Self { - task_id, - project_id, - parent_id, - event: "cancelled".to_string(), - message: None, - output: None, - error: None, - status: TaskStatus::Cancelled.to_string(), - } - } -} - -/// Helper trait for publishing task lifecycle events via Redis Pub/Sub. -/// -/// Callers inject a suitable `publish_fn` at construction time via -/// `TaskEvents::new(...)`. If no publisher is supplied events are silently -/// dropped (graceful degradation on startup). -pub trait TaskEventPublisher: Send + Sync { - fn publish(&self, project_id: uuid::Uuid, event: TaskEvent); -} - -/// No-op publisher used when no Redis Pub/Sub connection is available. -#[derive(Clone, Default)] -pub struct NoOpPublisher; - -impl TaskEventPublisher for NoOpPublisher { - fn publish(&self, _: uuid::Uuid, _: TaskEvent) {} -} - -#[derive(Clone)] -pub struct TaskEvents { - publisher: Arc, -} - -impl TaskEvents { - pub fn new(publisher: impl TaskEventPublisher + 'static) -> Self { - Self { - publisher: Arc::new(publisher), - } - } - - pub fn noop() -> Self { - Self::new(NoOpPublisher) - } - - fn emit(&self, task: &Model, event: TaskEvent) { - self.publisher.publish(task.project_uuid, event); - } - - pub fn emit_started(&self, task: &Model) { - self.emit( - task, - TaskEvent::started(task.id, task.project_uuid, task.parent_id), - ); - } - - pub fn emit_progress(&self, task: &Model, msg: String) { - self.emit( - task, - TaskEvent::progress(task.id, task.project_uuid, task.parent_id, msg), - ); - } - - pub fn emit_completed(&self, task: &Model, output: String) { - self.emit( - task, - TaskEvent::completed(task.id, task.project_uuid, task.parent_id, output), - ); - } - - pub fn emit_failed(&self, task: &Model, error: String) { - self.emit( - task, - TaskEvent::failed(task.id, task.project_uuid, task.parent_id, error), - ); - } - - pub fn emit_cancelled(&self, task: &Model) { - self.emit( - task, - TaskEvent::cancelled(task.id, task.project_uuid, task.parent_id), - ); - } -} - -/// Builder for TaskService so that the events publisher can be set independently -/// of the database connection. -#[derive(Clone, Default)] -pub struct TaskServiceBuilder { - events: Option, -} - -impl TaskServiceBuilder { - pub fn with_events(mut self, events: TaskEvents) -> Self { - self.events = Some(events); - self - } - - pub async fn build(self, db: AppDatabase) -> TaskService { - TaskService { - db, - events: self.events.unwrap_or_else(TaskEvents::noop), - } - } -} - -/// Service for managing agent tasks (root tasks and sub-tasks). -#[derive(Clone)] -pub struct TaskService { - db: AppDatabase, - events: TaskEvents, -} - -impl TaskService { - pub fn new(db: AppDatabase) -> Self { - Self { - db, - events: TaskEvents::noop(), - } - } - - pub fn with_events(db: AppDatabase, events: TaskEvents) -> Self { - Self { db, events } - } - - /// Create a new task (root or sub-task) with status = pending. - pub async fn create( - &self, - project_uuid: impl Into, - input: impl Into, - agent_type: AgentType, - ) -> Result { - self.create_with_parent(project_uuid, None, input, agent_type, None, None) - .await - } - - /// Create a new task bound to an issue. - pub async fn create_for_issue( - &self, - project_uuid: impl Into, - issue_id: IssueId, - input: impl Into, - agent_type: AgentType, - ) -> Result { - self.create_with_parent(project_uuid, None, input, agent_type, None, Some(issue_id)) - .await - } - - /// Create a new sub-task with a parent reference. - pub async fn create_subtask( - &self, - project_uuid: impl Into, - parent_id: i64, - input: impl Into, - agent_type: AgentType, - title: Option, - ) -> Result { - self.create_with_parent( - project_uuid, - Some(parent_id), - input, - agent_type, - title, - None, - ) - .await - } - - async fn create_with_parent( - &self, - project_uuid: impl Into, - parent_id: Option, - input: impl Into, - agent_type: AgentType, - title: Option, - issue_id: Option, - ) -> Result { - let model = ActiveModel { - project_uuid: sea_orm::Set(project_uuid.into()), - parent_id: sea_orm::Set(parent_id), - issue_id: sea_orm::Set(issue_id), - agent_type: sea_orm::Set(agent_type), - status: sea_orm::Set(TaskStatus::Pending), - title: sea_orm::Set(title), - input: sea_orm::Set(input.into()), - ..Default::default() - }; - model.insert(&self.db).await - } - - /// Mark a task as running and record the start time. - pub async fn start(&self, task_id: i64) -> Result { - let model = Entity::find_by_id(task_id).one(&self.db).await?; - let model = - model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; - - let mut active: ActiveModel = model.into(); - active.status = sea_orm::Set(TaskStatus::Running); - active.started_at = sea_orm::Set(Some(chrono::Utc::now().into())); - active.updated_at = sea_orm::Set(chrono::Utc::now().into()); - let updated = active.update(&self.db).await?; - self.events.emit_started(&updated); - Ok(updated) - } - - /// Update progress text (e.g., "step 2/5: analyzing PR"). - pub async fn update_progress( - &self, - task_id: i64, - progress: impl Into, - ) -> Result<(), DbErr> { - let model = Entity::find_by_id(task_id).one(&self.db).await?; - let model = - model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; - - let progress_str = progress.into(); - let mut active: ActiveModel = model.into(); - active.progress = sea_orm::Set(Some(progress_str.clone())); - active.updated_at = sea_orm::Set(chrono::Utc::now().into()); - let updated = active.update(&self.db).await?; - self.events.emit_progress(&updated, progress_str); - Ok(()) - } - - /// Mark a task as completed with the output text. - pub async fn complete(&self, task_id: i64, output: impl Into) -> Result { - let model = Entity::find_by_id(task_id).one(&self.db).await?; - let model = - model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; - - let mut active: ActiveModel = model.into(); - active.status = sea_orm::Set(TaskStatus::Done); - let out = output.into(); - active.output = sea_orm::Set(Some(out.clone())); - active.done_at = sea_orm::Set(Some(chrono::Utc::now().into())); - active.updated_at = sea_orm::Set(chrono::Utc::now().into()); - let updated = active.update(&self.db).await?; - self.events.emit_completed(&updated, out); - Ok(updated) - } - - /// Mark a task as failed with an error message. - pub async fn fail(&self, task_id: i64, error: impl Into) -> Result { - let model = Entity::find_by_id(task_id).one(&self.db).await?; - let model = - model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; - - let mut active: ActiveModel = model.into(); - active.status = sea_orm::Set(TaskStatus::Failed); - let err = error.into(); - active.error = sea_orm::Set(Some(err.clone())); - active.done_at = sea_orm::Set(Some(chrono::Utc::now().into())); - active.updated_at = sea_orm::Set(chrono::Utc::now().into()); - let updated = active.update(&self.db).await?; - self.events.emit_failed(&updated, err); - Ok(updated) - } - - /// Propagate child task status up the tree. - /// - /// Only allows cancelling tasks that are not yet in a terminal state - /// (Pending / Running / Paused). - /// - /// Cancelled children are marked done so that `are_children_done()` returns - /// true for the parent after cancellation. - pub async fn cancel(&self, task_id: i64) -> Result { - // Collect all task IDs (parent + descendants) using an explicit stack. - let mut stack = vec![task_id]; - let mut idx = 0; - while idx < stack.len() { - let current = stack[idx]; - let children = Entity::find() - .filter(C::ParentId.eq(current)) - .all(&self.db) - .await?; - for child in children { - stack.push(child.id); - } - idx += 1; - } - - // Mark every collected task as cancelled (terminal state). - for id in &stack { - let model = Entity::find_by_id(*id).one(&self.db).await?; - if let Some(m) = model { - if !m.is_done() { - let mut active: ActiveModel = m.into(); - active.status = sea_orm::Set(TaskStatus::Cancelled); - active.done_at = sea_orm::Set(Some(chrono::Utc::now().into())); - active.updated_at = sea_orm::Set(chrono::Utc::now().into()); - active.update(&self.db).await?; - } - } - } - - let final_model = Entity::find_by_id(task_id) - .one(&self.db) - .await? - .ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; - self.events.emit_cancelled(&final_model); - Ok(final_model) - } - - /// Pause a running or pending task. - /// - /// Pausing a task that is not Pending/Running is a no-op that returns - /// the current model (same behaviour as `start` on an already-running task). - pub async fn pause(&self, task_id: i64) -> Result { - let model = Entity::find_by_id(task_id).one(&self.db).await?; - let model = - model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; - - if !model.is_running() { - // Already in a terminal or paused state — return unchanged. - return Ok(model); - } - - let mut active: ActiveModel = model.into(); - active.status = sea_orm::Set(TaskStatus::Paused); - active.updated_at = sea_orm::Set(chrono::Utc::now().into()); - active.update(&self.db).await - } - - /// Resume a paused task back to Running. - /// - /// Returns an error if the task is not currently Paused. - pub async fn resume(&self, task_id: i64) -> Result { - let model = Entity::find_by_id(task_id).one(&self.db).await?; - let model = - model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; - - if model.status != TaskStatus::Paused { - return Err(DbErr::Custom(format!( - "cannot resume task {}: expected status Paused, got {}", - task_id, model.status - ))); - } - - let mut active: ActiveModel = model.into(); - active.status = sea_orm::Set(TaskStatus::Running); - active.updated_at = sea_orm::Set(chrono::Utc::now().into()); - active.update(&self.db).await - } - - /// Retry a failed or cancelled task by resetting it to Pending. - /// - /// Clears `output`, `error`, and `done_at`; increments `retry_count`. - /// Only tasks in Failed or Cancelled state can be retried. - pub async fn retry(&self, task_id: i64) -> Result { - let model = Entity::find_by_id(task_id).one(&self.db).await?; - let model = - model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; - - match model.status { - TaskStatus::Failed | TaskStatus::Cancelled | TaskStatus::Done => {} - _ => { - return Err(DbErr::Custom(format!( - "cannot retry task {}: only Failed/Cancelled/Done tasks can be retried (got {})", - task_id, model.status - ))); - } - } - - let retry_count = model.retry_count.map(|c| c + 1).unwrap_or(1); - - let mut active: ActiveModel = model.into(); - active.status = sea_orm::Set(TaskStatus::Pending); - active.output = sea_orm::Set(None); - active.error = sea_orm::Set(None); - active.done_at = sea_orm::Set(None); - active.started_at = sea_orm::Set(None); - active.retry_count = sea_orm::Set(Some(retry_count)); - active.updated_at = sea_orm::Set(chrono::Utc::now().into()); - active.update(&self.db).await - } - - /// Propagate child task status up the tree. - /// - /// When a child task reaches a terminal state, checks whether all its - /// siblings are also terminal. If so, marks the parent appropriately: - /// - Done if any child succeeded - /// - Failed if all children failed or were cancelled - pub async fn propagate_to_parent(&self, task_id: i64) -> Result, DbErr> { - let model = self - .get(task_id) - .await? - .ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; - - let Some(parent_id) = model.parent_id else { - return Ok(None); - }; - - let siblings = self.children(parent_id).await?; - if siblings.iter().all(|s| s.is_done()) { - let parent = self.get(parent_id).await?.ok_or_else(|| { - DbErr::RecordNotFound(format!("parent task {} not found", parent_id)) - })?; - if parent.is_running() { - let mut active: ActiveModel = parent.into(); - let has_success = siblings.iter().any(|s| s.status == TaskStatus::Done); - if has_success { - active.status = sea_orm::Set(TaskStatus::Done); - active.error = sea_orm::Set(None); - } else { - active.status = sea_orm::Set(TaskStatus::Failed); - active.error = - sea_orm::Set(Some("All sub-tasks failed or were cancelled".to_string())); - } - active.done_at = sea_orm::Set(Some(chrono::Utc::now().into())); - active.updated_at = sea_orm::Set(chrono::Utc::now().into()); - let updated = active.update(&self.db).await?; - return Ok(Some(updated)); - } - } - Ok(None) - } - - /// Get a task by ID. - pub async fn get(&self, task_id: i64) -> Result, DbErr> { - Entity::find_by_id(task_id).one(&self.db).await - } - - /// List all sub-tasks for a parent task. - pub async fn children(&self, parent_id: i64) -> Result, DbErr> { - Entity::find() - .filter(C::ParentId.eq(parent_id)) - .order_by_asc(C::CreatedAt) - .all(&self.db) - .await - } - - /// List all active (non-terminal) tasks for a project. - pub async fn active_tasks( - &self, - project_uuid: impl Into, - ) -> Result, DbErr> { - let uuid: uuid::Uuid = project_uuid.into(); - Entity::find() - .filter(C::ProjectUuid.eq(uuid)) - .filter(C::Status.is_in([TaskStatus::Pending, TaskStatus::Running, TaskStatus::Paused])) - .order_by_desc(C::CreatedAt) - .all(&self.db) - .await - } - - /// List all tasks (root only) for a project. - pub async fn list( - &self, - project_uuid: impl Into, - limit: u64, - ) -> Result, DbErr> { - let uuid: uuid::Uuid = project_uuid.into(); - Entity::find() - .filter(C::ProjectUuid.eq(uuid)) - .filter(C::ParentId.is_null()) - .order_by_desc(C::CreatedAt) - .limit(limit) - .all(&self.db) - .await - } - - /// Delete a task and all its sub-tasks recursively. - /// Only allows deletion of root tasks. - pub async fn delete(&self, task_id: i64) -> Result<(), DbErr> { - self.delete_recursive(task_id).await - } - - async fn delete_recursive(&self, task_id: i64) -> Result<(), DbErr> { - // Collect all task IDs to delete using an explicit stack (avoiding async recursion). - let mut stack = vec![task_id]; - let mut idx = 0; - while idx < stack.len() { - let current = stack[idx]; - let children = Entity::find() - .filter(C::ParentId.eq(current)) - .all(&self.db) - .await?; - for child in children { - stack.push(child.id); - } - idx += 1; - } - - for task_id in stack { - let model = Entity::find_by_id(task_id).one(&self.db).await?; - if let Some(m) = model { - let active: ActiveModel = m.into(); - active.delete(&self.db).await?; - } - } - Ok(()) - } - - /// Check if all sub-tasks of a given parent are in a terminal state. - /// Returns true if there are no children (empty tree counts as done). - pub async fn are_children_done(&self, parent_id: i64) -> Result { - let children = self.children(parent_id).await?; - Ok(children.is_empty() || children.iter().all(|c| c.is_done())) - } -} diff --git a/libs/agent/task/store.rs b/libs/agent/task/store.rs new file mode 100644 index 0000000..0c090aa --- /dev/null +++ b/libs/agent/task/store.rs @@ -0,0 +1,109 @@ +use models::agent_task::{ActiveModel, AgentType, Entity, Model}; +use models::IssueId; +use sea_orm::{ActiveModelTrait, ColumnTrait, DbErr, EntityTrait, QueryFilter, QueryOrder, QuerySelect}; + +impl super::TaskService { + /// Get a task by ID. + pub async fn get(&self, task_id: i64) -> Result, DbErr> { + Entity::find_by_id(task_id).one(self.db()).await + } + + /// List all tasks (root only) for a project. + pub async fn list( + &self, + project_uuid: impl Into, + limit: u64, + ) -> Result, DbErr> { + let uuid: uuid::Uuid = project_uuid.into(); + Entity::find() + .filter(models::agent_task::Column::ProjectUuid.eq(uuid)) + .filter(models::agent_task::Column::ParentId.is_null()) + .order_by_desc(models::agent_task::Column::CreatedAt) + .limit(limit) + .all(self.db()) + .await + } + + /// List all active (non-terminal) tasks for a project. + pub async fn active_tasks( + &self, + project_uuid: impl Into, + ) -> Result, DbErr> { + let uuid: uuid::Uuid = project_uuid.into(); + Entity::find() + .filter(models::agent_task::Column::ProjectUuid.eq(uuid)) + .filter(models::agent_task::Column::Status.is_in([ + models::agent_task::TaskStatus::Pending, + models::agent_task::TaskStatus::Running, + models::agent_task::TaskStatus::Paused, + ])) + .order_by_desc(models::agent_task::Column::CreatedAt) + .all(self.db()) + .await + } + + /// Create a new task (root or sub-task) with status = pending. + pub async fn create( + &self, + project_uuid: impl Into, + input: impl Into, + agent_type: AgentType, + ) -> Result { + self.create_with_parent(project_uuid, None, input, agent_type, None, None) + .await + } + + /// Create a new task bound to an issue. + pub async fn create_for_issue( + &self, + project_uuid: impl Into, + issue_id: IssueId, + input: impl Into, + agent_type: AgentType, + ) -> Result { + self.create_with_parent(project_uuid, None, input, agent_type, None, Some(issue_id)) + .await + } + + /// Create a new sub-task with a parent reference. + pub async fn create_subtask( + &self, + project_uuid: impl Into, + parent_id: i64, + input: impl Into, + agent_type: AgentType, + title: Option, + ) -> Result { + self.create_with_parent( + project_uuid, + Some(parent_id), + input, + agent_type, + title, + None, + ) + .await + } + + async fn create_with_parent( + &self, + project_uuid: impl Into, + parent_id: Option, + input: impl Into, + agent_type: AgentType, + title: Option, + issue_id: Option, + ) -> Result { + let model = ActiveModel { + project_uuid: sea_orm::Set(project_uuid.into()), + parent_id: sea_orm::Set(parent_id), + issue_id: sea_orm::Set(issue_id), + agent_type: sea_orm::Set(agent_type), + status: sea_orm::Set(models::agent_task::TaskStatus::Pending), + title: sea_orm::Set(title), + input: sea_orm::Set(input.into()), + ..Default::default() + }; + model.insert(self.db()).await + } +} \ No newline at end of file diff --git a/libs/agent/task/tree.rs b/libs/agent/task/tree.rs new file mode 100644 index 0000000..c316626 --- /dev/null +++ b/libs/agent/task/tree.rs @@ -0,0 +1,89 @@ +use models::agent_task::{ActiveModel, Column as C, Entity, Model, TaskStatus}; +use sea_orm::{ActiveModelTrait, ColumnTrait, DbErr, EntityTrait, QueryFilter, QueryOrder}; + +impl super::TaskService { + /// Propagate child task status up the tree. + /// + /// When a child task reaches a terminal state, checks whether all its + /// siblings are also terminal. If so, marks the parent appropriately: + /// - Done if any child succeeded + /// - Failed if all children failed or were cancelled + pub async fn propagate_to_parent(&self, task_id: i64) -> Result, DbErr> { + let model = self + .get(task_id) + .await? + .ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?; + + let Some(parent_id) = model.parent_id else { + return Ok(None); + }; + + let siblings = self.children(parent_id).await?; + if siblings.iter().all(|s| s.is_done()) { + let parent = self.get(parent_id).await?.ok_or_else(|| { + DbErr::RecordNotFound(format!("parent task {} not found", parent_id)) + })?; + if parent.is_running() { + let mut active: ActiveModel = parent.into(); + let has_success = siblings.iter().any(|s| s.status == TaskStatus::Done); + if has_success { + active.status = sea_orm::Set(TaskStatus::Done); + active.error = sea_orm::Set(None); + } else { + active.status = sea_orm::Set(TaskStatus::Failed); + active.error = + sea_orm::Set(Some("All sub-tasks failed or were cancelled".to_string())); + } + active.done_at = sea_orm::Set(Some(chrono::Utc::now().into())); + active.updated_at = sea_orm::Set(chrono::Utc::now().into()); + let updated = active.update(self.db()).await?; + return Ok(Some(updated)); + } + } + Ok(None) + } + + /// List all sub-tasks for a parent task. + pub async fn children(&self, parent_id: i64) -> Result, DbErr> { + Entity::find() + .filter(C::ParentId.eq(parent_id)) + .order_by_asc(C::CreatedAt) + .all(self.db()) + .await + } + + /// Check if all sub-tasks of a given parent are in a terminal state. + /// Returns true if there are no children (empty tree counts as done). + pub async fn are_children_done(&self, parent_id: i64) -> Result { + let children = self.children(parent_id).await?; + Ok(children.is_empty() || children.iter().all(|c| c.is_done())) + } + + /// Delete a task and all its sub-tasks recursively. + /// Only allows deletion of root tasks. + pub async fn delete(&self, task_id: i64) -> Result<(), DbErr> { + // Collect all task IDs to delete using an explicit stack (avoiding async recursion). + let mut stack = vec![task_id]; + let mut idx = 0; + while idx < stack.len() { + let current = stack[idx]; + let children = Entity::find() + .filter(C::ParentId.eq(current)) + .all(self.db()) + .await?; + for child in children { + stack.push(child.id); + } + idx += 1; + } + + for task_id in stack { + let model = Entity::find_by_id(task_id).one(self.db()).await?; + if let Some(m) = model { + let active: ActiveModel = m.into(); + active.delete(self.db()).await?; + } + } + Ok(()) + } +} \ No newline at end of file diff --git a/libs/agent/tokent.rs b/libs/agent/tokent.rs index d763a36..4460003 100644 --- a/libs/agent/tokent.rs +++ b/libs/agent/tokent.rs @@ -13,7 +13,7 @@ use std::collections::HashMap; use std::sync::OnceLock; use std::sync::RwLock; -use crate::error::{AgentError, Result}; +use crate::error::Result; static TOKENIZER_CACHE: OnceLock>> = OnceLock::new(); @@ -173,12 +173,11 @@ fn get_tokenizer(model: &str) -> Result { } // Try model-specific tokenizer first - let bpe = if let Ok(bpe) = tiktoken_rs::get_bpe_from_model(model) { + let bpe: &'static _ = if let Ok(bpe) = tiktoken_rs::bpe_for_model(model) { bpe } else { // Fallback: use cl100k_base for unknown models - tiktoken_rs::cl100k_base() - .map_err(|e| AgentError::Internal(format!("Failed to init tokenizer: {}", e)))? + tiktoken_rs::cl100k_base_singleton() }; { @@ -186,7 +185,7 @@ fn get_tokenizer(model: &str) -> Result { cache.insert(model.to_string(), bpe.clone()); } - Ok(bpe) + Ok(bpe.clone()) } /// Estimate tokens for a simple prefix/suffix pattern (e.g., "assistant\n" + text). diff --git a/libs/git/HOOK_QUEUE_NATS_MIGRATION.md b/libs/git/HOOK_QUEUE_NATS_MIGRATION.md deleted file mode 100644 index cd77787..0000000 --- a/libs/git/HOOK_QUEUE_NATS_MIGRATION.md +++ /dev/null @@ -1,256 +0,0 @@ -# Hook Queue NATS JetStream Migration Guide - -## Overview - -The git hook queue now supports both Redis Lists and NATS JetStream as backend message queues. This allows gradual migration from Redis to NATS without downtime. - -## Architecture - -### Producer (`ReceiveSyncService`) - -The producer tries NATS first (if configured), then falls back to Redis: - -```rust -pub struct ReceiveSyncService { - pool: deadpool_redis::cluster::Pool, - redis_prefix: String, - nats_publish: Option) -> Pin> + Send>> + Send + Sync>>, -} -``` - -### Consumer (`RedisConsumer`) - -The consumer uses NATS if configured, otherwise falls back to Redis: - -```rust -pub struct RedisConsumer { - pool: deadpool_redis::cluster::Pool, - prefix: String, - block_timeout_secs: u64, - nats_consume: Option, -} -``` - -## Integration with AppTransport - -### Producer Integration - -```rust -use git::ssh::ReceiveSyncService; -use transport::AppTransport; - -let transport = Arc::new(AppTransport::new(/* ... */)); - -// Create NATS publish function -let nats_publish = { - let transport = transport.clone(); - Arc::new(move |subject: String, payload: Vec| { - let transport = transport.clone(); - Box::pin(async move { - let ack = transport.publish(&subject, payload).await?; - Ok(ack.sequence) - }) as Pin> + Send>> - }) -}; - -// Create service with NATS support -let sync_service = ReceiveSyncService::with_nats(redis_pool, nats_publish); - -// Or use Redis-only mode -let sync_service = ReceiveSyncService::new(redis_pool); -``` - -### Consumer Integration - -```rust -use git::hook::pool::redis::{RedisConsumer, NatsHookConsumeFn}; - -// Create NATS consume function -let nats_consume: NatsHookConsumeFn = { - let transport = transport.clone(); - Arc::new(move |subject: String, batch_size: usize| { - let transport = transport.clone(); - Box::pin(async move { - let mut results = Vec::new(); - - // Pull messages from JetStream consumer - for _ in 0..batch_size { - match transport.pull_one(&subject).await { - Ok(Some(msg)) => { - let data = msg.payload.to_vec(); - let msg_clone = msg.clone(); - let ack_fn = Box::new(move || { - let msg = msg_clone.clone(); - Box::pin(async move { - msg.ack().await?; - Ok(()) - }) as Pin> + Send>> - }); - results.push((data, ack_fn)); - } - Ok(None) => break, - Err(e) => return Err(e), - } - } - - Ok(results) - }) as Pin, Box Pin> + Send>> + Send>)>>> + Send>> - }) -}; - -// Create consumer with NATS support -let consumer = RedisConsumer::with_nats( - redis_pool, - "{hook}".to_string(), - 5, // block_timeout_secs - nats_consume, -); - -// Or use Redis-only mode -let consumer = RedisConsumer::new(redis_pool, "{hook}".to_string(), 5); -``` - -## Queue Subjects - -The hook queue uses the following NATS subjects: - -- `queue.hook.sync` - Repository sync tasks (git push/pull operations) - -Additional task types can be added by extending the subject pattern: -- `queue.hook.{task_type}` - Generic pattern for any hook task type - -## Migration Strategy - -### Phase 1: Dual Write (Current) -- Producer writes to both NATS and Redis -- Consumer reads from Redis only -- Zero risk, full rollback capability - -### Phase 2: Dual Read -- Producer writes to both NATS and Redis -- Consumer reads from NATS, falls back to Redis on error -- Validates NATS consumer stability - -### Phase 3: NATS Primary -- Producer writes to NATS only (Redis disabled) -- Consumer reads from NATS only -- Redis queue deprecated - -### Phase 4: Redis Removal -- Remove Redis Lists code -- Remove `pool` parameter -- Simplify to NATS-only implementation - -## NATS JetStream Setup - -### Stream Configuration - -```bash -nats stream add HOOK_QUEUE \ - --subjects "queue.hook.>" \ - --retention limits \ - --max-msgs=-1 \ - --max-age=7d \ - --storage file \ - --replicas 3 -``` - -### Consumer Configuration - -```bash -nats consumer add HOOK_QUEUE hook-sync-worker \ - --filter "queue.hook.sync" \ - --ack explicit \ - --pull \ - --deliver all \ - --max-deliver 3 \ - --max-pending 100 -``` - -## Differences from Email Queue - -### Redis Backend -- **Email Queue**: Uses Redis Streams (XADD/XREADGROUP) -- **Hook Queue**: Uses Redis Lists (LPUSH/BLMOVE) - -### Atomicity -- **Email Queue**: Consumer group provides at-least-once delivery -- **Hook Queue**: BLMOVE provides atomic move-to-work-queue pattern - -### Work Queue Pattern -- **Email Queue**: No work queue, relies on consumer group -- **Hook Queue**: Uses separate work queue (`{hook}:sync:work`) for in-flight tracking - -### Acknowledgment -- **Email Queue**: XACK removes from pending entries list -- **Hook Queue**: LREM removes from work queue - -### Retry Logic -- **Email Queue**: Automatic via consumer group pending entries -- **Hook Queue**: Manual via Lua script (LREM + LPUSH) - -## Monitoring - -### Logs - -- NATS publish: `"hook task queued to NATS"` -- Redis publish: `"hook task queued to Redis"` -- NATS consume: `"task dequeued from NATS"` -- Redis consume: `"task dequeued"` - -### Metrics - -Add these metrics to track hook queue performance: - -```rust -counter!("hook_task_queued_total", "backend" => "nats").increment(1); -counter!("hook_task_queued_total", "backend" => "redis").increment(1); -counter!("hook_task_consumed_total", "backend" => "nats").increment(1); -counter!("hook_task_consumed_total", "backend" => "redis").increment(1); -``` - -## Rollback - -To disable NATS and return to Redis-only: - -```rust -// Producer -let sync_service = ReceiveSyncService::new(redis_pool); - -// Consumer -let consumer = RedisConsumer::new(redis_pool, "{hook}".to_string(), 5); -``` - -No code changes required, just use the `new()` constructor instead of `with_nats()`. - -## Benefits - -1. **Zero Downtime**: Gradual migration with fallback -2. **No Circular Dependency**: Uses function pointers instead of crate dependencies -3. **Backward Compatible**: Existing code works without changes -4. **Type Safe**: Compile-time guarantees for integration -5. **Observable**: Consistent logging for both backends - -## Known Limitations - -### NATS Acknowledgment Timing - -The current implementation acks NATS messages immediately after deserialization, not after successful processing. This is different from the Redis pattern where: - -- Redis: Task moves to work queue → processes → acks (removes from work queue) -- NATS: Task received → acks immediately → processes - -**Future Enhancement**: Store ack functions in a map keyed by task ID, then call them after successful processing. This requires refactoring the worker loop to track pending acks. - -### Work Queue Pattern - -NATS JetStream doesn't have a direct equivalent to Redis's work queue pattern. The current implementation relies on JetStream's built-in redelivery mechanism instead of a separate work queue. - -## Next Steps - -1. Add NATS integration to `apps/git-hook/src/main.rs` -2. Add configuration flags for queue backend selection -3. Test dual-write mode in staging -4. Monitor NATS consumer stability -5. Implement proper ack-after-processing pattern -6. Add metrics for queue depth and processing latency diff --git a/scripts/README.md b/scripts/README.md deleted file mode 100644 index 8f15af9..0000000 --- a/scripts/README.md +++ /dev/null @@ -1,107 +0,0 @@ -# 构建脚本 - -## 一键构建脚本 - -### build.js - 构建镜像 - -```bash -# 构建所有镜像 -node scripts/build.js - -# 构建指定服务 -node scripts/build.js app gitserver - -# 指定 tag -TAG=v1.0.0 node scripts/build.js - -# 指定架构 -TARGET=aarch64-unknown-linux-gnu node scripts/build.js -``` - -**环境变量:** - -| 变量 | 默认值 | 说明 | -|------------|------------------------------|-------------| -| `REGISTRY` | `harbor.gitdata.me/gta_team` | 镜像仓库 | -| `TAG` | `latest` | 镜像标签 | -| `TARGET` | `x86_64-unknown-linux-gnu` | Rust 交叉编译目标 | - ---- - -### push.js - 推送镜像 - -```bash -# 推送所有镜像 -HARBOR_USERNAME=user HARBOR_PASSWORD=pass node scripts/push.js - -# 推送指定服务 -HARBOR_USERNAME=user HARBOR_PASSWORD=pass TAG=sha-abc123 node scripts/push.js app -``` - -**环境变量:** - -| 变量 | 默认值 | 说明 | -|-------------------|------------------------------|--------------| -| `REGISTRY` | `harbor.gitdata.me/gta_team` | 镜像仓库 | -| `TAG` | `latest` 或 Git SHA | 镜像标签 | -| `HARBOR_USERNAME` | - | **必填** 仓库用户名 | -| `HARBOR_PASSWORD` | - | **必填** 仓库密码 | - ---- - -### deploy.js - 部署到 Kubernetes - -```bash -# 部署最新镜像 -node scripts/deploy.js - -# 干跑模式(不实际部署) -node scripts/deploy.js --dry-run - -# 部署并运行数据库迁移 -node scripts/deploy.js --migrate - -# 指定 tag -TAG=sha-abc123 node scripts/deploy.js - -# 指定命名空间 -NAMESPACE=staging node scripts/deploy.js -``` - -**环境变量:** - -| 变量 | 默认值 | 说明 | -|--------------|------------------------------|-----------------| -| `REGISTRY` | `harbor.gitdata.me/gta_team` | 镜像仓库 | -| `TAG` | `latest` 或 Git SHA | 镜像标签 | -| `NAMESPACE` | `gitdata` | K8s 命名空间 | -| `RELEASE` | `gitdata` | Helm Release 名称 | -| `KUBECONFIG` | `~/.kube/config` | Kubeconfig 路径 | - ---- - -## 完整 CI/CD 流程 - -```bash -# 1. 构建 -node scripts/build.js - -# 2. 推送 -HARBOR_USERNAME=user HARBOR_PASSWORD=pass node scripts/push.js - -# 3. 部署 -node scripts/deploy.js --migrate -``` - -## 本地开发 - -```bash -# 本地构建测试 -node scripts/build.js app - -# 使用本地 tag -TAG=dev node scripts/build.js - -# 部署到测试环境 -NAMESPACE=test node scripts/deploy.js -``` diff --git a/scripts/fix-openapi-tags.js b/scripts/fix-openapi-tags.js deleted file mode 100644 index d22a2b0..0000000 --- a/scripts/fix-openapi-tags.js +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Fix ugly module-path tags in openapi.json generated by utoipa. - * Transforms tags like "crate::agent::provider" -> "agent-provider" - */ -const fs = require('fs'); -const path = require('path'); - -const openapiPath = path.join(__dirname, '..', 'openapi.json'); -const json = JSON.parse(fs.readFileSync(openapiPath, 'utf8')); - -let fixed = 0; - -// Fix operation tags -for (const p in json.paths) { - for (const m in json.paths[p]) { - const op = json.paths[p][m]; - if (op.tags) { - op.tags = op.tags.map(t => { - const fixed = t.replace(/^crate::/, '').replace(/::/g, '-'); - return fixed; - }); - fixed++; - } - } -} - -fs.writeFileSync(openapiPath, JSON.stringify(json, null, 2)); -console.log(`Fixed tags in ${fixed} operations`); diff --git a/scripts/gen-client.js b/scripts/gen-client.js deleted file mode 100644 index a242088..0000000 --- a/scripts/gen-client.js +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Generate TypeScript axios client from openapi.json using @hey-api/openapi-ts. - * Generates into src/client. - * Post-processes: injects withCredentials: true and baseURL into the client config. - */ -const { execSync } = require('child_process'); -const fs = require('fs'); -const path = require('path'); - -const ROOT = path.join(__dirname, '..'); -const CLIENT_DIR = path.join(ROOT, 'src', 'client'); -const CLIENT_GEN = path.join(CLIENT_DIR, 'client.gen.ts'); - -const openapiTsBin = path.join(ROOT, 'node_modules/@hey-api/openapi-ts/bin/run.js'); -const openapiJson = path.join(ROOT, 'openapi.json'); - -console.log('Running @hey-api/openapi-ts...'); -try { - execSync(`node "${openapiTsBin}" -c @hey-api/client-axios -i "${openapiJson}" -o "${CLIENT_DIR}"`, { - cwd: ROOT, - stdio: 'inherit', - }); -} catch (e) { - console.error('Generator exited with code:', e.status); - process.exit(1); -} - -// Post-process: inject withCredentials and baseURL into client config -if (fs.existsSync(CLIENT_GEN)) { - let content = fs.readFileSync(CLIENT_GEN, 'utf8'); - - // Remove unused createConfig import - content = content.replace( - "import { type ClientOptions, type Config, createClient, createConfig } from './client';", - "import { type ClientOptions, type Config, createClient } from './client';" - ); - - // Replace the client creation to include withCredentials and baseURL - content = content.replace( - 'export const client = createClient(createConfig());', - `export const createClientConfig = (override?: Config): Config => { - return { - withCredentials: true, - baseURL: import.meta.env.VITE_API_BASE_URL ?? '', - ...override, - }; - }; -export const client = createClient(createClientConfig());` - ); - - fs.writeFileSync(CLIENT_GEN, content); - console.log('Updated client.gen.ts with withCredentials and baseURL'); -} - -console.log('Done.'); diff --git a/scripts/generate-changelog-data.js b/scripts/generate-changelog-data.js deleted file mode 100644 index 5e43266..0000000 --- a/scripts/generate-changelog-data.js +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Generates changelog data file for the frontend. - * Run with: node scripts/generate-changelog-data.js - */ - -const fs = require('fs'); -const path = require('path'); - -const CHANGELOG_DIR = path.join(__dirname, '..', 'changelog'); -const OUTPUT_FILE = path.join(__dirname, '..', 'src', 'data', 'changelog-data.ts'); - -const LANGUAGES = ['en', 'cn', 'de', 'fr']; - -function readFile(filePath) { - try { - return fs.readFileSync(filePath, 'utf-8'); - } catch { - return null; - } -} - -function parseMdx(content) { - const frontmatterMatch = content.match(/^---\n([\s\S]*?)\n---\n([\s\S]*)$/); - if (!frontmatterMatch) { - return { title: '', body: content }; - } - const body = frontmatterMatch[2].trim(); - const frontmatter = frontmatterMatch[1]; - const titleMatch = frontmatter.match(/title:\s*["']?([^"'\n]+)["']?/); - const title = titleMatch ? titleMatch[1].trim() : ''; - return { title, body }; -} - -// Get all unique dates -const dates = []; -const files = fs.readdirSync(CHANGELOG_DIR); -files.forEach(file => { - const match = file.match(/^(\d{4}-\d{2}-\d{2})-(\w+)\.mdx$/); - if (match) { - const date = match[1]; - const lang = match[2]; - if (!dates.includes(date)) { - dates.push(date); - } - } -}); - -// Sort dates descending -dates.sort((a, b) => new Date(b) - new Date(a)); - -// Generate data for each language -const data = {}; -LANGUAGES.forEach(lang => { - data[lang] = dates.map(date => { - const filePath = path.join(CHANGELOG_DIR, `${date}-${lang}.mdx`); - const content = readFile(filePath); - if (!content) { - return null; - } - const { title, body } = parseMdx(content); - return { - date, - title, - lang, - author: 'ZhenYi', - body, - }; - }).filter(Boolean); -}); - -// Generate TypeScript file -const tsContent = `// Auto-generated from changelog/*.mdx files -// Run: node scripts/generate-changelog-data.js - -export type ChangelogEntry = { - date: string; - title: string; - lang: string; - author: string; - body: string; -}; - -export const CHANGELOG_DATA: Record = ${JSON.stringify(data, null, 2)}; - -export const CHANGELOG_LANGUAGES = ${JSON.stringify(LANGUAGES)}; -`; - -fs.writeFileSync(OUTPUT_FILE, tsContent); -console.log(`Generated ${OUTPUT_FILE}`);