diff --git a/Cargo.lock b/Cargo.lock index 5f9b06e..41d7503 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2021,6 +2021,7 @@ dependencies = [ "db", "futures", "hmac 0.13.0", + "lazy_static", "model", "redis", "serde", diff --git a/lib/ai/agent/agent.rs b/lib/ai/agent/agent.rs index 5902500..62612cc 100644 --- a/lib/ai/agent/agent.rs +++ b/lib/ai/agent/agent.rs @@ -7,13 +7,16 @@ use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use tracing::{info, warn}; +use super::RigStreamChunk; use super::config::AgentConfig; use super::helpers::{build_input_string, check_token_budget, estimate_tokens}; -use super::hooks::{HookChain, HookLlmResponse, HookMessage, HookToolDef, ToolCallOutcome, ToolGuardrailDecision}; +use super::hooks::{ + HookChain, HookLlmResponse, HookMessage, HookToolDef, ToolCallOutcome, + ToolGuardrailDecision, +}; use super::persistence::ActiveAgentRun; use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord}; use super::subagent::run_experts; -use super::RigStreamChunk; use crate::client::AiClient; use crate::error::{AiError, AiResult}; @@ -48,9 +51,7 @@ impl RigAgent { tools: Vec>, ) -> AiResult { let (mut rx, handle) = self.run(request, tools); - tokio::spawn(async move { - while rx.recv().await.is_some() {} - }); + tokio::spawn(async move { while rx.recv().await.is_some() {} }); let result = handle.await.map_err(|_| { AiError::Response("agent task panicked".to_string()) })?; @@ -152,15 +153,24 @@ async fn execute_agent_run( // ---- SubAgent execution ---- let expert_outputs = if !request.experts.is_empty() { let run = ActiveAgentRun { - conversation_id: request.run_context.as_ref().and_then(|c| c.conversation_id), + conversation_id: request + .run_context + .as_ref() + .and_then(|c| c.conversation_id), message_id: None, - invocation_id: request.run_context.as_ref().and_then(|c| c.invocation_id), + invocation_id: request + .run_context + .as_ref() + .and_then(|c| c.invocation_id), session_id: request.run_context.as_ref().and_then(|c| c.session_id), user_id: request.run_context.as_ref().and_then(|c| c.user_id), started_at: std::time::Instant::now(), current_step: 0, }; - let realtime = request.run_context.as_ref().and_then(|c| c.realtime.as_ref()); + let realtime = request + .run_context + .as_ref() + .and_then(|c| c.realtime.as_ref()); // Notify frontend that subagents are starting. for expert in &request.experts { @@ -173,7 +183,15 @@ async fn execute_agent_run( .await; } - match run_experts(&ai_client, &agent_config, &request.experts, realtime, &run).await { + match run_experts( + &ai_client, + &agent_config, + &request.experts, + realtime, + &run, + ) + .await + { Ok(outputs) => { for out in &outputs { let _ = tx @@ -252,7 +270,10 @@ async fn execute_agent_run( Err(_elapsed) => { let _ = tx .send(RigStreamChunk::Failed { - error: format!("agent timed out after {}s", dur.as_secs()), + error: format!( + "agent timed out after {}s", + dur.as_secs() + ), }) .await; return Err(AiError::Timeout { @@ -284,7 +305,11 @@ async fn execute_agent_run( } if let Some(limit) = max_total_tokens - && check_token_budget(estimated_input_tokens, accumulated_output_chars, limit) + && check_token_budget( + estimated_input_tokens, + accumulated_output_chars, + limit, + ) { let _ = tx .send(RigStreamChunk::Failed { @@ -317,7 +342,8 @@ async fn execute_agent_run( )) => { for part in &reasoning.content { if let rig::completion::message::ReasoningContent::Text { - text, .. + text, + .. } = part { accumulated_output_chars += text.chars().count(); @@ -334,7 +360,8 @@ async fn execute_agent_run( } Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( rig::streaming::StreamedAssistantContent::ReasoningDelta { - reasoning, .. + reasoning, + .. }, )) => { accumulated_output_chars += reasoning.chars().count(); @@ -363,7 +390,9 @@ async fn execute_agent_run( let tool_args: serde_json::Value = serde_json::from_str(&args).unwrap_or_default(); - if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await { + if let Ok(Some(decision)) = + hooks.run_pre_tool_call(&tool_name, &tool_args).await + { match decision { ToolGuardrailDecision::Allow => {} ToolGuardrailDecision::Block { reason } => { @@ -390,7 +419,9 @@ async fn execute_agent_run( .send(RigStreamChunk::ToolCallFinished { tool_call_id: tool_call.id.clone(), tool_name: tool_name.clone(), - output: format!("awaiting approval: {message}"), + output: format!( + "awaiting approval: {message}" + ), error: None, }) .await; @@ -399,7 +430,9 @@ async fn execute_agent_run( name: tool_name.clone(), arguments: tool_args.clone(), output: None, - error: Some(format!("requires approval: {message}")), + error: Some(format!( + "requires approval: {message}" + )), elapsed_ms: None, }); continue; @@ -424,16 +457,22 @@ async fn execute_agent_run( }); } Ok(rig::agent::MultiTurnStreamItem::StreamUserItem( - rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, + rig::streaming::StreamedUserContent::ToolResult { + tool_result, + .. + }, )) => { - let content = - super::helpers::tool_result_content_to_string(&tool_result.content); + let content = super::helpers::tool_result_content_to_string( + &tool_result.content, + ); accumulated_output_chars += content.chars().count(); if let Some(last) = current_step_tool_calls.last_mut() && last.id == tool_result.id { - last.output = Some(serde_json::from_str(&content).unwrap_or_default()); + last.output = Some( + serde_json::from_str(&content).unwrap_or_default(), + ); } let tool_name = current_step_tool_calls @@ -464,15 +503,21 @@ async fn execute_agent_run( Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => { let usage = resp.usage(); - if !current_step_tool_calls.is_empty() || !current_step_assistant.is_empty() { + if !current_step_tool_calls.is_empty() + || !current_step_assistant.is_empty() + { let reasoning = (!current_step_reasoning.is_empty()) .then_some(std::mem::take(&mut current_step_reasoning)); steps.push(AgentStep { index: steps.len(), assistant: (!current_step_assistant.is_empty()) - .then_some(std::mem::take(&mut current_step_assistant)), + .then_some(std::mem::take( + &mut current_step_assistant, + )), reasoning_content: reasoning, - tool_calls: std::mem::take(&mut current_step_tool_calls), + tool_calls: std::mem::take( + &mut current_step_tool_calls, + ), reflection: None, }); } @@ -533,7 +578,9 @@ async fn execute_agent_run( } } - Err(AiError::Response("agent stream ended without final response".to_string())) + Err(AiError::Response( + "agent stream ended without final response".to_string(), + )) } impl Clone for HookChain { diff --git a/lib/ai/agent/compression.rs b/lib/ai/agent/compression.rs index 331141c..f2d801f 100644 --- a/lib/ai/agent/compression.rs +++ b/lib/ai/agent/compression.rs @@ -65,7 +65,10 @@ impl CompressionStrategy { self } - pub fn with_custom_instructions(mut self, instructions: impl Into) -> Self { + pub fn with_custom_instructions( + mut self, + instructions: impl Into, + ) -> Self { self.custom_instructions = Some(instructions.into()); self } @@ -91,7 +94,11 @@ pub struct CompactionResult { } impl CompactionResult { - pub fn new(summary: String, messages_compacted: usize, tokens_saved: i64) -> Self { + pub fn new( + summary: String, + messages_compacted: usize, + tokens_saved: i64, + ) -> Self { Self { summary, messages_compacted, @@ -115,7 +122,12 @@ pub fn build_compression_prompt( existing_summary: Option<&str>, messages_text: &str, ) -> String { - build_compression_prompt_with_options(existing_summary, messages_text, None, 1500) + build_compression_prompt_with_options( + existing_summary, + messages_text, + None, + 1500, + ) } /// Build the compaction prompt with custom instructions and word limit. diff --git a/lib/ai/agent/config.rs b/lib/ai/agent/config.rs index d38c3a5..17da92c 100644 --- a/lib/ai/agent/config.rs +++ b/lib/ai/agent/config.rs @@ -132,7 +132,10 @@ impl AgentConfig { self } - pub fn with_max_completion_tokens(mut self, max_completion_tokens: Option) -> Self { + pub fn with_max_completion_tokens( + mut self, + max_completion_tokens: Option, + ) -> Self { self.max_completion_tokens = max_completion_tokens; self } @@ -142,19 +145,31 @@ impl AgentConfig { self } - pub fn with_toolset_policy(mut self, enabled: Vec, disabled: Vec) -> Self { + pub fn with_toolset_policy( + mut self, + enabled: Vec, + disabled: Vec, + ) -> Self { self.enabled_toolsets = enabled; self.disabled_toolsets = disabled; self } - pub fn with_tool_policy(mut self, allowed_tools: Vec, denied_tools: Vec) -> Self { + pub fn with_tool_policy( + mut self, + allowed_tools: Vec, + denied_tools: Vec, + ) -> Self { self.allowed_tools = allowed_tools; self.denied_tools = denied_tools; self } - pub fn with_retry(mut self, max_attempts: usize, base_delay_ms: u64) -> Self { + pub fn with_retry( + mut self, + max_attempts: usize, + base_delay_ms: u64, + ) -> Self { self.retry_max_attempts = max_attempts; self.retry_base_delay_ms = base_delay_ms; self @@ -165,7 +180,10 @@ impl AgentConfig { self } - pub fn with_fallback_model(mut self, fallback_model: impl Into) -> Self { + pub fn with_fallback_model( + mut self, + fallback_model: impl Into, + ) -> Self { self.fallback_model = Some(fallback_model.into()); self } diff --git a/lib/ai/agent/error_classifier.rs b/lib/ai/agent/error_classifier.rs index 18718d3..a738837 100644 --- a/lib/ai/agent/error_classifier.rs +++ b/lib/ai/agent/error_classifier.rs @@ -44,7 +44,8 @@ impl RetryPolicy { let half = (ms as f64 * 0.25) as u64; let lo = ms.saturating_sub(half); let hi = ms.saturating_add(half); - let mix = ((attempt as u64).wrapping_mul(1_103_515_245)) % (hi - lo + 1); + let mix = + ((attempt as u64).wrapping_mul(1_103_515_245)) % (hi - lo + 1); lo + mix } else { ms @@ -58,17 +59,26 @@ impl RetryPolicy { /// /// Inspects both the HTTP status code (when available) and the error message /// content to determine the most appropriate category. -pub fn classify_error(error: &AiError, http_status: Option) -> ErrorCategory { +pub fn classify_error( + error: &AiError, + http_status: Option, +) -> ErrorCategory { // HTTP status-based classification takes precedence let from_status = match http_status { Some(429) => Some(ErrorCategory::Retryable { reason: "rate limited (HTTP 429)".to_string(), }), Some(401) | Some(403) => Some(ErrorCategory::FallbackModel { - reason: format!("authentication failed (HTTP {})", http_status.unwrap()), + reason: format!( + "authentication failed (HTTP {})", + http_status.unwrap() + ), }), Some(502) | Some(503) => Some(ErrorCategory::Overloaded { - reason: format!("provider unavailable (HTTP {})", http_status.unwrap()), + reason: format!( + "provider unavailable (HTTP {})", + http_status.unwrap() + ), }), Some(504) => Some(ErrorCategory::Timeout), Some(413) => Some(ErrorCategory::ContextWindowExceeded { @@ -90,7 +100,9 @@ pub fn classify_error(error: &AiError, http_status: Option) -> ErrorCategor // Message-based classification match error { AiError::Timeout { .. } => ErrorCategory::Timeout, - AiError::TokenBudgetExceeded { .. } => ErrorCategory::TokenBudgetExceeded, + AiError::TokenBudgetExceeded { .. } => { + ErrorCategory::TokenBudgetExceeded + } AiError::Api(msg) => classify_api_message(msg), AiError::Response(msg) => classify_response_message(msg), AiError::ModelRetriesExhausted { .. } => ErrorCategory::Fatal { @@ -107,7 +119,10 @@ fn classify_api_message(msg: &str) -> ErrorCategory { let lower = msg.to_lowercase(); // Rate limiting - if lower.contains("rate") || lower.contains("too many requests") || lower.contains("throttl") { + if lower.contains("rate") + || lower.contains("too many requests") + || lower.contains("throttl") + { return ErrorCategory::Retryable { reason: msg.to_string(), }; @@ -213,15 +228,15 @@ pub fn retry_policy_for( exponential: false, switch_to_fallback: false, }, - ErrorCategory::TokenBudgetExceeded | ErrorCategory::Cancelled | ErrorCategory::Fatal { .. } => { - RetryPolicy { - max_attempts: 0, - base_delay: Duration::from_millis(0), - jitter: false, - exponential: false, - switch_to_fallback: false, - } - } + ErrorCategory::TokenBudgetExceeded + | ErrorCategory::Cancelled + | ErrorCategory::Fatal { .. } => RetryPolicy { + max_attempts: 0, + base_delay: Duration::from_millis(0), + jitter: false, + exponential: false, + switch_to_fallback: false, + }, } } diff --git a/lib/ai/agent/events.rs b/lib/ai/agent/events.rs index 681c9df..00cd371 100644 --- a/lib/ai/agent/events.rs +++ b/lib/ai/agent/events.rs @@ -140,7 +140,9 @@ impl EventSink { } /// Subscribe to events, returns a receiver. - pub fn subscribe(&mut self) -> tokio::sync::mpsc::UnboundedReceiver { + pub fn subscribe( + &mut self, + ) -> tokio::sync::mpsc::UnboundedReceiver { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); self.senders.push(tx); rx diff --git a/lib/ai/agent/helpers.rs b/lib/ai/agent/helpers.rs index 3effec2..a0e5cf3 100644 --- a/lib/ai/agent/helpers.rs +++ b/lib/ai/agent/helpers.rs @@ -69,7 +69,9 @@ where match f().await { Ok(result) => return Ok(result), Err(e) if is_retryable(&e) && attempt + 1 < max_attempts => { - let delay = Duration::from_millis(base_delay_ms * 2u64.pow(attempt as u32)); + let delay = Duration::from_millis( + base_delay_ms * 2u64.pow(attempt as u32), + ); tracing::warn!( error = %e, attempt = attempt + 1, @@ -94,12 +96,16 @@ where fn is_retryable(error: &AiError) -> bool { matches!( error, - AiError::Api(_) | AiError::Response(_) | AiError::ModelRetriesExhausted { .. } + AiError::Api(_) + | AiError::Response(_) + | AiError::ModelRetriesExhausted { .. } ) } pub fn tool_result_content_to_string( - content: &rig::one_or_many::OneOrMany, + content: &rig::one_or_many::OneOrMany< + rig::completion::message::ToolResultContent, + >, ) -> String { use rig::completion::message::ToolResultContent; content diff --git a/lib/ai/agent/hooks.rs b/lib/ai/agent/hooks.rs index 19888d5..5fa2d28 100644 --- a/lib/ai/agent/hooks.rs +++ b/lib/ai/agent/hooks.rs @@ -51,11 +51,19 @@ pub trait AgentHook: Send + Sync { Ok(()) } - async fn on_session_end(&self, _ctx: &AgentRunContext, _success: bool) -> AiResult<()> { + async fn on_session_end( + &self, + _ctx: &AgentRunContext, + _success: bool, + ) -> AiResult<()> { Ok(()) } - async fn pre_llm_call(&self, _messages: &[HookMessage], _tools: &[HookToolDef]) -> AiResult<()> { + async fn pre_llm_call( + &self, + _messages: &[HookMessage], + _tools: &[HookToolDef], + ) -> AiResult<()> { Ok(()) } @@ -93,28 +101,42 @@ impl HookChain { self.hooks.is_empty() } - pub async fn run_session_start(&self, ctx: &AgentRunContext) -> AiResult<()> { + pub async fn run_session_start( + &self, + ctx: &AgentRunContext, + ) -> AiResult<()> { for hook in &self.hooks { hook.on_session_start(ctx).await?; } Ok(()) } - pub async fn run_session_end(&self, ctx: &AgentRunContext, success: bool) -> AiResult<()> { + pub async fn run_session_end( + &self, + ctx: &AgentRunContext, + success: bool, + ) -> AiResult<()> { for hook in &self.hooks { hook.on_session_end(ctx, success).await?; } Ok(()) } - pub async fn run_pre_llm_call(&self, messages: &[HookMessage], tools: &[HookToolDef]) -> AiResult<()> { + pub async fn run_pre_llm_call( + &self, + messages: &[HookMessage], + tools: &[HookToolDef], + ) -> AiResult<()> { for hook in &self.hooks { hook.pre_llm_call(messages, tools).await?; } Ok(()) } - pub async fn run_post_llm_call(&self, response: &HookLlmResponse) -> AiResult<()> { + pub async fn run_post_llm_call( + &self, + response: &HookLlmResponse, + ) -> AiResult<()> { for hook in &self.hooks { hook.post_llm_call(response).await?; } @@ -127,7 +149,9 @@ impl HookChain { arguments: &Value, ) -> AiResult> { for hook in &self.hooks { - if let Some(decision) = hook.pre_tool_call(tool_name, arguments).await? { + if let Some(decision) = + hook.pre_tool_call(tool_name, arguments).await? + { if !matches!(decision, ToolGuardrailDecision::Allow) { return Ok(Some(decision)); } @@ -136,7 +160,10 @@ impl HookChain { Ok(None) } - pub async fn run_post_tool_call(&self, outcome: &ToolCallOutcome) -> AiResult<()> { + pub async fn run_post_tool_call( + &self, + outcome: &ToolCallOutcome, + ) -> AiResult<()> { for hook in &self.hooks { hook.post_tool_call(outcome).await?; } diff --git a/lib/ai/agent/loop.rs b/lib/ai/agent/loop.rs index 5d36d0f..c3d952e 100644 --- a/lib/ai/agent/loop.rs +++ b/lib/ai/agent/loop.rs @@ -11,16 +11,19 @@ use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use tracing::{info, warn}; +use super::RigStreamChunk; use super::config::AgentConfig; use super::error_classifier::{ classify_error, retry_policy_for, should_switch_to_fallback, }; use super::events::{AgentEvent, EventSink}; use super::helpers::{build_input_string, estimate_tokens}; -use super::hooks::{HookChain, HookLlmResponse, HookMessage, ToolCallOutcome, ToolGuardrailDecision}; +use super::hooks::{ + HookChain, HookLlmResponse, HookMessage, ToolCallOutcome, + ToolGuardrailDecision, +}; use super::iteration_budget::IterationBudget; use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord}; -use super::RigStreamChunk; use crate::client::AiClient; use crate::error::{AiError, AiResult}; @@ -50,13 +53,13 @@ pub type FollowUpFn = Arc< >; /// Callback to decide whether the agent should stop after a turn. -pub type ShouldStopFn = Arc< - dyn Fn(&TurnContext) -> bool + Send + Sync, ->; +pub type ShouldStopFn = Arc bool + Send + Sync>; /// Callback to prepare/modify state before the next turn. pub type PrepareNextTurnFn = Arc< - dyn Fn(&TurnContext) -> Pin> + Send>> + dyn Fn( + &TurnContext, + ) -> Pin> + Send>> + Send + Sync, >; @@ -144,7 +147,10 @@ pub struct EnhancedAgent { } impl EnhancedAgent { - pub fn new(client: AiClient, loop_config: AgentLoopConfig) -> AiResult { + pub fn new( + client: AiClient, + loop_config: AgentLoopConfig, + ) -> AiResult { loop_config.config.validate()?; Ok(Self { client, @@ -270,7 +276,11 @@ async fn run_enhanced_loop( loop { // Check cancellation if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) { - let _ = tx.send(RigStreamChunk::Failed { error: "cancelled".to_string() }).await; + let _ = tx + .send(RigStreamChunk::Failed { + error: "cancelled".to_string(), + }) + .await; if let Some(sink) = &event_sink { sink.emit(AgentEvent::ErrorClassified { category: "cancelled".to_string(), @@ -279,7 +289,9 @@ async fn run_enhanced_loop( retry_delay_ms: None, }); } - return Err(AiError::Response("agent run cancelled".to_string())); + return Err(AiError::Response( + "agent run cancelled".to_string(), + )); } // Inject steering messages if any @@ -298,10 +310,12 @@ async fn run_enhanced_loop( if let Some(sink) = &event_sink { sink.emit(AgentEvent::TurnStart { turn_index }); } - let _ = tx.send(RigStreamChunk::TextDelta { - index: 0, - content: String::new(), // placeholder for turn boundary detection - }).await; + let _ = tx + .send(RigStreamChunk::TextDelta { + index: 0, + content: String::new(), // placeholder for turn boundary detection + }) + .await; // Run one LLM turn with retry let turn_result = run_single_turn( @@ -325,7 +339,9 @@ async fn run_enhanced_loop( // Collect step let tool_call_count = turn_output.tool_calls.len(); - if !turn_output.tool_calls.is_empty() || !turn_output.assistant_text.is_empty() { + if !turn_output.tool_calls.is_empty() + || !turn_output.assistant_text.is_empty() + { all_steps.push(AgentStep { index: all_steps.len(), assistant: (!turn_output.assistant_text.is_empty()) @@ -340,7 +356,9 @@ async fn run_enhanced_loop( if let Some(sink) = &event_sink { sink.emit(AgentEvent::TurnEnd { turn_index, - assistant_text: Some(turn_output.assistant_text.clone()), + assistant_text: Some( + turn_output.assistant_text.clone(), + ), tool_call_count, }); } @@ -357,7 +375,10 @@ async fn run_enhanced_loop( if let Some(stop_fn) = &should_stop { if stop_fn(&turn_ctx) { - info!(turn_index, "agent stopped by should_stop callback"); + info!( + turn_index, + "agent stopped by should_stop callback" + ); break; } } @@ -378,7 +399,8 @@ async fn run_enhanced_loop( if let Some(temp) = update.temperature { config.temperature = Some(temp); } - if let Some(max_tok) = update.max_completion_tokens { + if let Some(max_tok) = update.max_completion_tokens + { config.max_completion_tokens = Some(max_tok); } } @@ -397,14 +419,21 @@ async fn run_enhanced_loop( Err(e) => { // Error classification and retry with fallback let category = classify_error(&e, None); - let policy = retry_policy_for(&category, config.retry_max_attempts, config.retry_base_delay_ms); + let policy = retry_policy_for( + &category, + config.retry_max_attempts, + config.retry_base_delay_ms, + ); if let Some(sink) = &event_sink { sink.emit(AgentEvent::ErrorClassified { category: format!("{category:?}"), message: e.to_string(), - will_retry: policy.switch_to_fallback || policy.max_attempts > 0, - retry_delay_ms: Some(policy.base_delay.as_millis() as u64), + will_retry: policy.switch_to_fallback + || policy.max_attempts > 0, + retry_delay_ms: Some( + policy.base_delay.as_millis() as u64 + ), }); } @@ -443,16 +472,20 @@ async fn run_enhanced_loop( match retry_result { Ok(turn_output) => { - total_input_tokens += turn_output.input_tokens; - total_output_tokens += turn_output.output_tokens; + total_input_tokens += + turn_output.input_tokens; + total_output_tokens += + turn_output.output_tokens; let tc_count = turn_output.tool_calls.len(); let has_tools = tc_count > 0; - let has_text = !turn_output.assistant_text.is_empty(); + let has_text = + !turn_output.assistant_text.is_empty(); let assistant = turn_output.assistant_text; if has_tools || has_text { all_steps.push(AgentStep { index: all_steps.len(), - assistant: has_text.then_some(assistant.clone()), + assistant: has_text + .then_some(assistant.clone()), reasoning_content: None, tool_calls: turn_output.tool_calls, reflection: None, @@ -472,7 +505,9 @@ async fn run_enhanced_loop( }) .await; if let Some(ctx) = &request.run_context { - let _ = hooks.run_session_end(ctx, false).await; + let _ = hooks + .run_session_end(ctx, false) + .await; } return Err(retry_err); } @@ -582,7 +617,9 @@ async fn run_single_turn( tx: &mpsc::Sender, ) -> AiResult { if !budget.consume() { - return Err(AiError::Response("iteration budget exhausted".to_string())); + return Err(AiError::Response( + "iteration budget exhausted".to_string(), + )); } let model = client.completion_model(&config.model); @@ -674,7 +711,11 @@ async fn run_single_turn( rig::streaming::StreamedAssistantContent::Reasoning(reasoning), )) => { for part in &reasoning.content { - if let rig::completion::message::ReasoningContent::Text { text, .. } = part { + if let rig::completion::message::ReasoningContent::Text { + text, + .. + } = part + { _accumulated_output_chars += text.chars().count(); if let Some(sink) = &event_sink { sink.emit(AgentEvent::MessageThinkingDelta { @@ -693,7 +734,10 @@ async fn run_single_turn( } } Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( - rig::streaming::StreamedAssistantContent::ReasoningDelta { reasoning, .. }, + rig::streaming::StreamedAssistantContent::ReasoningDelta { + reasoning, + .. + }, )) => { _accumulated_output_chars += reasoning.chars().count(); if let Some(sink) = &event_sink { @@ -711,7 +755,10 @@ async fn run_single_turn( delta_index += 1; } Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( - rig::streaming::StreamedAssistantContent::ToolCall { tool_call, .. }, + rig::streaming::StreamedAssistantContent::ToolCall { + tool_call, + .. + }, )) => { let args = match &tool_call.function.arguments { serde_json::Value::String(s) => s.clone(), @@ -724,7 +771,9 @@ async fn run_single_turn( serde_json::from_str(&args).unwrap_or_default(); // Pre-tool-call guardrail hook - if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await { + if let Ok(Some(decision)) = + hooks.run_pre_tool_call(&tool_name, &tool_args).await + { match decision { ToolGuardrailDecision::Allow => {} ToolGuardrailDecision::Block { reason } => { @@ -761,7 +810,9 @@ async fn run_single_turn( name: tool_name.clone(), arguments: tool_args, output: None, - error: Some(format!("requires approval: {message}")), + error: Some(format!( + "requires approval: {message}" + )), elapsed_ms: None, }); continue; @@ -794,10 +845,14 @@ async fn run_single_turn( }); } Ok(rig::agent::MultiTurnStreamItem::StreamUserItem( - rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, + rig::streaming::StreamedUserContent::ToolResult { + tool_result, + .. + }, )) => { - let content = - super::helpers::tool_result_content_to_string(&tool_result.content); + let content = super::helpers::tool_result_content_to_string( + &tool_result.content, + ); _accumulated_output_chars += content.chars().count(); let tool_name = tool_calls @@ -808,14 +863,18 @@ async fn run_single_turn( if let Some(last) = tool_calls.last_mut() && last.id == tool_result.id { - last.output = Some(serde_json::from_str(&content).unwrap_or_default()); + last.output = Some( + serde_json::from_str(&content).unwrap_or_default(), + ); } if let Some(sink) = &event_sink { sink.emit(AgentEvent::ToolExecutionEnd { tool_call_id: tool_result.id.clone(), tool_name: tool_name.clone(), - output: Some(serde_json::Value::String(content.clone())), + output: Some(serde_json::Value::String( + content.clone(), + )), error: None, elapsed_ms: 0, }); @@ -872,5 +931,3 @@ async fn run_single_turn( output_tokens, }) } - - diff --git a/lib/ai/agent/prompt.rs b/lib/ai/agent/prompt.rs index 9e18379..74b9823 100644 --- a/lib/ai/agent/prompt.rs +++ b/lib/ai/agent/prompt.rs @@ -37,13 +37,12 @@ impl RigAgent { } let agent = builder.build(); - let response = agent - .prompt(&ui) - .extended_details() - .await - .map_err(|e: rig::completion::PromptError| { - AiError::Api(e.to_string()) - })?; + let response = + agent.prompt(&ui).extended_details().await.map_err( + |e: rig::completion::PromptError| { + AiError::Api(e.to_string()) + }, + )?; Ok(( response.output, diff --git a/lib/ai/agent/prompt_builder.rs b/lib/ai/agent/prompt_builder.rs index fc4a1ba..e69da25 100644 --- a/lib/ai/agent/prompt_builder.rs +++ b/lib/ai/agent/prompt_builder.rs @@ -63,8 +63,13 @@ impl SystemPromptBuilder { } /// Add a one-line tool description snippet. - pub fn tool_snippet(mut self, tool_name: impl Into, description: impl Into) -> Self { - self.tool_snippets.push((tool_name.into(), description.into())); + pub fn tool_snippet( + mut self, + tool_name: impl Into, + description: impl Into, + ) -> Self { + self.tool_snippets + .push((tool_name.into(), description.into())); self } @@ -75,7 +80,11 @@ impl SystemPromptBuilder { } /// Add a project context file (e.g., AGENTS.md content). - pub fn project_context(mut self, path: impl Into, content: impl Into) -> Self { + pub fn project_context( + mut self, + path: impl Into, + content: impl Into, + ) -> Self { self.project_contexts.push((path.into(), content.into())); self } @@ -87,13 +96,20 @@ impl SystemPromptBuilder { } /// Set a variable for {{key}} substitution. - pub fn variable(mut self, key: impl Into, value: impl Into) -> Self { + pub fn variable( + mut self, + key: impl Into, + value: impl Into, + ) -> Self { self.variables.insert(key.into(), value.into()); self } /// Set multiple variables from an iterator. - pub fn variables(mut self, vars: impl IntoIterator) -> Self { + pub fn variables( + mut self, + vars: impl IntoIterator, + ) -> Self { self.variables.extend(vars); self } @@ -105,7 +121,11 @@ impl SystemPromptBuilder { } /// Add a custom named section to the prompt. - pub fn custom_section(mut self, name: impl Into, content: impl Into) -> Self { + pub fn custom_section( + mut self, + name: impl Into, + content: impl Into, + ) -> Self { self.custom_sections.push((name.into(), content.into())); self } @@ -142,7 +162,8 @@ impl SystemPromptBuilder { // 4. Project context files if !self.project_contexts.is_empty() { let mut section = String::from("\n\n\n"); - section.push_str("Project-specific instructions and guidelines:\n\n"); + section + .push_str("Project-specific instructions and guidelines:\n\n"); for (path, content) in &self.project_contexts { section.push_str(&format!("\n{content}\n\n\n")); } diff --git a/lib/ai/agent/request.rs b/lib/ai/agent/request.rs index e9ed174..ac209ee 100644 --- a/lib/ai/agent/request.rs +++ b/lib/ai/agent/request.rs @@ -38,7 +38,9 @@ impl AgentRequest { pub fn validate(&self) -> AiResult<()> { if self.input.trim().is_empty() { - return Err(AiError::Config("agent request input is required".to_string())); + return Err(AiError::Config( + "agent request input is required".to_string(), + )); } if self.input.len() > 1_000_000 { return Err(AiError::Config( @@ -83,12 +85,18 @@ impl AgentRequest { self } - pub fn with_prefill_messages(mut self, prefill_messages: Vec) -> Self { + pub fn with_prefill_messages( + mut self, + prefill_messages: Vec, + ) -> Self { self.prefill_messages = prefill_messages; self } - pub fn with_cancellation_token(mut self, cancellation_token: CancellationToken) -> Self { + pub fn with_cancellation_token( + mut self, + cancellation_token: CancellationToken, + ) -> Self { self.cancellation_token = Some(cancellation_token); self } @@ -119,7 +127,11 @@ pub struct AgentExpert { } impl AgentExpert { - pub fn new(id: impl Into, role: impl Into, task: impl Into) -> Self { + pub fn new( + id: impl Into, + role: impl Into, + task: impl Into, + ) -> Self { Self { id: id.into(), role: role.into(), @@ -131,7 +143,10 @@ impl AgentExpert { } } - pub fn with_system_prompt(mut self, system_prompt: impl Into) -> Self { + pub fn with_system_prompt( + mut self, + system_prompt: impl Into, + ) -> Self { self.system_prompt = Some(system_prompt.into()); self } diff --git a/lib/ai/agent/session.rs b/lib/ai/agent/session.rs index 972a980..8ee1eb7 100644 --- a/lib/ai/agent/session.rs +++ b/lib/ai/agent/session.rs @@ -145,7 +145,10 @@ impl SessionEntry { } /// Create a user message entry. - pub fn user_message(parent_id: Option, content: impl Into) -> Self { + pub fn user_message( + parent_id: Option, + content: impl Into, + ) -> Self { Self::Message { id: Uuid::new_v4(), parent_id, @@ -328,7 +331,13 @@ impl Session { pub fn active_messages(&self) -> Vec<&SessionEntry> { self.active_branch() .into_iter() - .filter(|e| matches!(e, SessionEntry::Message { .. } | SessionEntry::Compaction { .. })) + .filter(|e| { + matches!( + e, + SessionEntry::Message { .. } + | SessionEntry::Compaction { .. } + ) + }) .collect() } @@ -342,11 +351,8 @@ impl Session { /// Get all leaf entries (entries with no children). pub fn leaves(&self) -> Vec<&SessionEntry> { - let parent_ids: std::collections::HashSet = self - .entries - .iter() - .filter_map(|e| e.parent_id()) - .collect(); + let parent_ids: std::collections::HashSet = + self.entries.iter().filter_map(|e| e.parent_id()).collect(); self.entries .iter() @@ -367,7 +373,9 @@ impl Session { .iter() .position(|e| e.id() == fork_entry_id) .ok_or_else(|| { - AiError::Config(format!("fork entry {fork_entry_id} not found in session")) + AiError::Config(format!( + "fork entry {fork_entry_id} not found in session" + )) })?; let mut new_session = Session::new(); @@ -445,9 +453,11 @@ fn iso_now() -> String { // Simple ISO 8601 format (UTC) let days = secs / 86400; let years = (days * 400) / 146097; - let remaining_days = days - (years * 365 + years / 4 - years / 100 + years / 400); + let remaining_days = + days - (years * 365 + years / 4 - years / 100 + years / 400); let month_days = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]; - let is_leap = (years % 4 == 0 && years % 100 != 0) || years % 400 == 0; + let is_leap = + (years % 4 == 0 && years % 100 != 0) || years % 400 == 0; let mut month = 0usize; let mut day_acc = remaining_days as i64; for (i, &md) in month_days.iter().enumerate() { @@ -487,7 +497,8 @@ mod tests { let msg1_id = msg1.id(); session.push(msg1); - let msg2 = SessionEntry::assistant_message(Some(msg1_id), "Hi there!", None); + let msg2 = + SessionEntry::assistant_message(Some(msg1_id), "Hi there!", None); session.push(msg2); assert_eq!(session.entry_count(), 2); @@ -502,7 +513,8 @@ mod tests { let msg1_id = msg1.id(); session.push(msg1); - let msg2 = SessionEntry::assistant_message(Some(msg1_id), "Reply 1", None); + let msg2 = + SessionEntry::assistant_message(Some(msg1_id), "Reply 1", None); let msg2_id = msg2.id(); session.push(msg2); @@ -524,8 +536,10 @@ mod tests { session.push(msg1); // Two children branching from root - let msg2a = SessionEntry::assistant_message(Some(msg1_id), "Branch A", None); - let msg2b = SessionEntry::assistant_message(Some(msg1_id), "Branch B", None); + let msg2a = + SessionEntry::assistant_message(Some(msg1_id), "Branch A", None); + let msg2b = + SessionEntry::assistant_message(Some(msg1_id), "Branch B", None); session.push(msg2a); session.push(msg2b); diff --git a/lib/ai/agent/subagent.rs b/lib/ai/agent/subagent.rs index f23758d..5439a78 100644 --- a/lib/ai/agent/subagent.rs +++ b/lib/ai/agent/subagent.rs @@ -31,13 +31,24 @@ pub async fn run_experts( } Err(error) => { warn!(subagent_id = %expert.id, role = %expert.role, error = %error, "subagent failed"); - let _ = publish_subagent_failed(realtime, run, expert, &error.to_string()).await; + let _ = publish_subagent_failed( + realtime, + run, + expert, + &error.to_string(), + ) + .await; failed_count += 1; } } } - debug!(total = experts.len(), ok = outputs.len(), failed = failed_count, "experts done"); + debug!( + total = experts.len(), + ok = outputs.len(), + failed = failed_count, + "experts done" + ); Ok(outputs) } @@ -53,7 +64,9 @@ async fn run_single( let rig_client = client.llm_client().clone(); let model_name = config.model.clone(); let temperature = expert.temperature.or(config.temperature); - let max_completion_tokens = expert.max_completion_tokens.or(config.max_completion_tokens); + let max_completion_tokens = expert + .max_completion_tokens + .or(config.max_completion_tokens); let retry_attempts = config.retry_max_attempts; let retry_delay_ms = config.retry_base_delay_ms; @@ -66,10 +79,8 @@ async fn run_single( let task = build_expert_task(expert); - let (output, input_tokens_usage, output_tokens_usage) = with_retry( - retry_attempts, - retry_delay_ms, - || { + let (output, input_tokens_usage, output_tokens_usage) = + with_retry(retry_attempts, retry_delay_ms, || { let rig_client = rig_client.clone(); let model_name = model_name.clone(); let prompt = prompt.clone(); @@ -85,13 +96,12 @@ async fn run_single( } let agent = builder.build(); - let response = agent - .prompt(&task) - .extended_details() - .await - .map_err(|e: rig::completion::PromptError| { - AiError::Api(e.to_string()) - })?; + let response = + agent.prompt(&task).extended_details().await.map_err( + |e: rig::completion::PromptError| { + AiError::Api(e.to_string()) + }, + )?; Ok(( response.output, @@ -99,9 +109,8 @@ async fn run_single( response.usage.output_tokens, )) } - }, - ) - .await?; + }) + .await?; let input_tokens = input_tokens_usage as i64; let output_tokens = if output_tokens_usage > 0 { @@ -150,17 +159,19 @@ async fn publish_subagent_started( config: &AgentConfig, expert: &AgentExpert, ) -> AiResult<()> { - AgentRuntime::default().publish( - realtime, - &AgentStreamEvent::SubagentStarted { - conversation_id: run.conversation_id, - message_id: run.message_id, - subagent_id: expert.id.clone(), - role: expert.role.clone(), - task: expert.task.clone(), - model: config.model.clone(), - }, - ).await + AgentRuntime::default() + .publish( + realtime, + &AgentStreamEvent::SubagentStarted { + conversation_id: run.conversation_id, + message_id: run.message_id, + subagent_id: expert.id.clone(), + role: expert.role.clone(), + task: expert.task.clone(), + model: config.model.clone(), + }, + ) + .await } async fn publish_subagent_completed( @@ -169,20 +180,22 @@ async fn publish_subagent_completed( config: &AgentConfig, output: &AgentExpertOutput, ) -> AiResult<()> { - AgentRuntime::default().publish( - realtime, - &AgentStreamEvent::SubagentCompleted { - conversation_id: run.conversation_id, - message_id: run.message_id, - subagent_id: output.id.clone(), - role: output.role.clone(), - task: output.task.clone(), - output: output.output.clone(), - input_tokens: output.input_tokens, - output_tokens: output.output_tokens, - model: config.model.clone(), - }, - ).await + AgentRuntime::default() + .publish( + realtime, + &AgentStreamEvent::SubagentCompleted { + conversation_id: run.conversation_id, + message_id: run.message_id, + subagent_id: output.id.clone(), + role: output.role.clone(), + task: output.task.clone(), + output: output.output.clone(), + input_tokens: output.input_tokens, + output_tokens: output.output_tokens, + model: config.model.clone(), + }, + ) + .await } async fn publish_subagent_failed( @@ -191,13 +204,15 @@ async fn publish_subagent_failed( expert: &AgentExpert, error: &str, ) -> AiResult<()> { - AgentRuntime::default().publish( - realtime, - &AgentStreamEvent::SubagentFailed { - conversation_id: run.conversation_id, - message_id: run.message_id, - subagent_id: expert.id.clone(), - error: error.to_string(), - }, - ).await + AgentRuntime::default() + .publish( + realtime, + &AgentStreamEvent::SubagentFailed { + conversation_id: run.conversation_id, + message_id: run.message_id, + subagent_id: expert.id.clone(), + error: error.to_string(), + }, + ) + .await } diff --git a/lib/ai/agent/tool.rs b/lib/ai/agent/tool.rs index 08d9c75..4d99507 100644 --- a/lib/ai/agent/tool.rs +++ b/lib/ai/agent/tool.rs @@ -23,7 +23,10 @@ impl RigTool where C: Clone + Send + Sync + 'static, { - pub fn new(tool: Arc>, context: Arc>) -> Self { + pub fn new( + tool: Arc>, + context: Arc>, + ) -> Self { let name = tool.name().to_string(); let description = tool.description().to_string(); let schema = tool.schema(); @@ -49,7 +52,8 @@ where fn definition<'a>( &'a self, _prompt: String, - ) -> Pin + Send + 'a>> { + ) -> Pin + Send + 'a>> + { let name = self.name.clone(); let description = self.description.clone(); let params = self.schema.clone(); @@ -67,23 +71,28 @@ where &'a self, args: String, ) -> Pin< - Box> + Send + 'a>, + Box< + dyn std::future::Future< + Output = Result, + > + Send + + 'a, + >, > { let tool = self.tool.clone(); let context = self.context.clone(); Box::pin(async move { - let args_value: Value = - serde_json::from_str(&args).map_err(rig::tool::ToolError::JsonError)?; + let args_value: Value = serde_json::from_str(&args) + .map_err(rig::tool::ToolError::JsonError)?; let mut ctx = context.lock().await; match tool.call(&mut *ctx, args_value).await { Ok(value) => serde_json::to_string(&value) .map_err(rig::tool::ToolError::JsonError), - Err(ai_err) => Err(rig::tool::ToolError::ToolCallError(Box::new( - std::io::Error::other(ai_err.to_string()), - ))), + Err(ai_err) => Err(rig::tool::ToolError::ToolCallError( + Box::new(std::io::Error::other(ai_err.to_string())), + )), } }) } @@ -112,10 +121,14 @@ where register: &crate::tool::register::ToolRegister, context: Arc>, ) -> Self { - let mut tools: Vec> = Vec::with_capacity(register.len()); + let mut tools: Vec> = + Vec::with_capacity(register.len()); for tool_arc in ®ister.tools { - tools.push(Box::new(RigTool::new(tool_arc.clone(), context.clone()))); + tools.push(Box::new(RigTool::new( + tool_arc.clone(), + context.clone(), + ))); } Self { diff --git a/lib/ai/client.rs b/lib/ai/client.rs index 25a2678..0ce401d 100644 --- a/lib/ai/client.rs +++ b/lib/ai/client.rs @@ -24,7 +24,10 @@ pub struct EndpointConfig { } impl EndpointConfig { - pub fn new(base_url: impl Into, api_key: impl Into) -> AiResult { + pub fn new( + base_url: impl Into, + api_key: impl Into, + ) -> AiResult { let config = Self { base_url: base_url.into(), api_key: api_key.into(), @@ -51,7 +54,11 @@ impl EndpointConfig { .api_key(&self.api_key) .base_url(self.base_url.trim()) .build() - .map_err(|e| AiError::Config(format!("failed to build rig OpenAI client: {e}"))) + .map_err(|e| { + AiError::Config(format!( + "failed to build rig OpenAI client: {e}" + )) + }) } } diff --git a/lib/ai/embed/client.rs b/lib/ai/embed/client.rs index f250cf4..6700dfb 100644 --- a/lib/ai/embed/client.rs +++ b/lib/ai/embed/client.rs @@ -1,7 +1,10 @@ use rig::client::EmbeddingsClient; use rig::embeddings::EmbeddingModel; -use crate::{client::AiClient, error::{AiError, AiResult}}; +use crate::{ + client::AiClient, + error::{AiError, AiResult}, +}; #[derive(Clone)] pub struct EmbedClient { @@ -23,23 +26,32 @@ impl EmbedClient { pub async fn embed_text(&self, text: String) -> AiResult> { let model = self.embedding_model(); - let mut embeddings = model.embed_texts(vec![text]) + let mut embeddings = model + .embed_texts(vec![text]) .await .map_err(|e| AiError::Api(e.to_string()))?; - embeddings.pop() + embeddings + .pop() .map(|e| e.vec.into_iter().map(|v| v as f32).collect()) - .ok_or_else(|| AiError::Response("no embedding returned".to_string())) + .ok_or_else(|| { + AiError::Response("no embedding returned".to_string()) + }) } - pub async fn embed_texts(&self, texts: Vec) -> AiResult>> { + pub async fn embed_texts( + &self, + texts: Vec, + ) -> AiResult>> { if texts.is_empty() { return Ok(Vec::new()); } let model = self.embedding_model(); - let embeddings = model.embed_texts(texts) + let embeddings = model + .embed_texts(texts) .await .map_err(|e| AiError::Api(e.to_string()))?; - Ok(embeddings.into_iter() + Ok(embeddings + .into_iter() .map(|e| e.vec.into_iter().map(|v| v as f32).collect()) .collect()) } @@ -55,11 +67,15 @@ impl EmbedClient { let mut embeddings: Vec> = Vec::with_capacity(texts.len()); for chunk in texts.chunks(batch_size) { let model = self.embedding_model(); - let chunk_embeddings = model.embed_texts(chunk.to_vec()) + let chunk_embeddings = model + .embed_texts(chunk.to_vec()) .await .map_err(|e| AiError::Api(e.to_string()))?; - embeddings.extend(chunk_embeddings.into_iter() - .map(|e| e.vec.into_iter().map(|v| v as f32).collect())); + embeddings.extend( + chunk_embeddings + .into_iter() + .map(|e| e.vec.into_iter().map(|v| v as f32).collect()), + ); } Ok(embeddings) } diff --git a/lib/ai/error.rs b/lib/ai/error.rs index f41f408..4efb1c5 100644 --- a/lib/ai/error.rs +++ b/lib/ai/error.rs @@ -24,10 +24,7 @@ pub enum AiError { Response(String), #[error("model retries exhausted after {attempts} attempts: {last_error}")] - ModelRetriesExhausted { - attempts: usize, - last_error: String, - }, + ModelRetriesExhausted { attempts: usize, last_error: String }, #[error("agent timeout after {seconds}s")] Timeout { seconds: u64 }, diff --git a/lib/ai/memory/mod.rs b/lib/ai/memory/mod.rs index 45a5c31..6e98c20 100644 --- a/lib/ai/memory/mod.rs +++ b/lib/ai/memory/mod.rs @@ -34,10 +34,7 @@ pub trait MemoryProvider: Send + Sync { ) -> AiResult> { Ok(Vec::new()) } - async fn build_context_block( - &self, - _session_id: Uuid, - ) -> AiResult { + async fn build_context_block(&self, _session_id: Uuid) -> AiResult { Ok(String::new()) } async fn setup(&self) -> AiResult<()> { diff --git a/lib/ai/rag/client.rs b/lib/ai/rag/client.rs index c4da99b..12c7785 100644 --- a/lib/ai/rag/client.rs +++ b/lib/ai/rag/client.rs @@ -42,10 +42,7 @@ impl RagClient { }) } - pub fn connect( - ai_client: &AiClient, - config: RagConfig, - ) -> AiResult { + pub fn connect(ai_client: &AiClient, config: RagConfig) -> AiResult { config.validate()?; let mut builder = Qdrant::from_url(config.url.trim()).timeout(config.timeout); @@ -132,10 +129,8 @@ impl RagClient { validate_session_id(session_id)?; validate_documents(&documents)?; - let texts: Vec = documents - .iter() - .map(|d| d.content.clone()) - .collect(); + let texts: Vec = + documents.iter().map(|d| d.content.clone()).collect(); let vectors = self .embedder .embed_texts_chunked(texts, self.config.upsert_batch_size) diff --git a/lib/ai/rag/payload.rs b/lib/ai/rag/payload.rs index d8bb6e9..8afcee5 100644 --- a/lib/ai/rag/payload.rs +++ b/lib/ai/rag/payload.rs @@ -20,8 +20,8 @@ pub(super) fn point_id(session_id: &str, document_id: &str) -> u64 { let uuid = Uuid::new_v5(&ns, key.as_bytes()); let bytes = uuid.as_bytes(); u64::from_be_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], - bytes[4], bytes[5], bytes[6], bytes[7], + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], + bytes[7], ]) } diff --git a/lib/ai/sync.rs b/lib/ai/sync.rs index 4f7dae5..cab3ee1 100644 --- a/lib/ai/sync.rs +++ b/lib/ai/sync.rs @@ -75,7 +75,9 @@ static HTTP_CLIENT: LazyLock = LazyLock::new(|| { } } #[allow(clippy::expect_used)] - builder.build().expect("failed to build reqwest HTTP client — check system TLS configuration") + builder.build().expect( + "failed to build reqwest HTTP client — check system TLS configuration", + ) }); pub async fn list_models( config: &EndpointConfig, @@ -102,12 +104,14 @@ pub async fn list_models( AiError::Response(format!("failed to list models: {}", e)) })?; - let body = resp - .text() - .await - .map_err(|e| AiError::Response(format!("failed to read models body: {}", e)))?; + let body = resp.text().await.map_err(|e| { + AiError::Response(format!("failed to read models body: {}", e)) + })?; if let Ok(parsed) = serde_json::from_str::(&body) { - debug!(count = parsed.data.len(), "parsed models in standard format"); + debug!( + count = parsed.data.len(), + "parsed models in standard format" + ); return Ok(parsed.data); } if let Ok(parsed) = serde_json::from_str::>(&body) { diff --git a/lib/ai/tool/toolset.rs b/lib/ai/tool/toolset.rs index 7e9da8e..6a199d3 100644 --- a/lib/ai/tool/toolset.rs +++ b/lib/ai/tool/toolset.rs @@ -27,7 +27,10 @@ impl Toolset { self } - pub fn with_tools(mut self, tool_names: impl IntoIterator>) -> Self { + pub fn with_tools( + mut self, + tool_names: impl IntoIterator>, + ) -> Self { self.tools.extend(tool_names.into_iter().map(Into::into)); self } @@ -36,7 +39,8 @@ impl Toolset { mut self, env_vars: impl IntoIterator>, ) -> Self { - self.requires_env.extend(env_vars.into_iter().map(Into::into)); + self.requires_env + .extend(env_vars.into_iter().map(Into::into)); self } pub fn is_available(&self) -> bool { diff --git a/lib/api/src/agent/conversation.rs b/lib/api/src/agent/conversation.rs index 1663a77..f4797a4 100644 --- a/lib/api/src/agent/conversation.rs +++ b/lib/api/src/agent/conversation.rs @@ -1,7 +1,8 @@ use actix_web::{HttpResponse, web, web::ServiceConfig}; use service::AppService; use service::agent::conversation::{ - ConversationResponse, ConversationWithSessionResponse, CreateConversation, MessageResponse, UpdateConversation, + ConversationResponse, ConversationWithSessionResponse, CreateConversation, + MessageResponse, UpdateConversation, }; use service::agent::types::{AgentRunRequest, AgentRunResponse}; use session::Session; @@ -53,8 +54,14 @@ pub async fn list_conversations( service: web::Data, path: web::Path, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_json(service.agent_conversation_list(user_id, path.into_inner()).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_conversation_list(user_id, path.into_inner()) + .await?, + ) } #[utoipa::path( @@ -70,10 +77,16 @@ pub async fn create_conversation( path: web::Path, body: web::Json, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; ok_json( service - .agent_conversation_create(user_id, path.into_inner(), body.into_inner()) + .agent_conversation_create( + user_id, + path.into_inner(), + body.into_inner(), + ) .await?, ) } @@ -94,7 +107,9 @@ pub async fn list_all_conversations( service: web::Data, query: web::Query, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; ok_json( service .agent_conversation_list_all(user_id, query.wk.as_deref()) @@ -113,8 +128,14 @@ pub async fn get_conversation( service: web::Data, path: web::Path, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_json(service.agent_conversation_get(user_id, path.into_inner()).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_conversation_get(user_id, path.into_inner()) + .await?, + ) } #[utoipa::path( @@ -130,10 +151,16 @@ pub async fn update_conversation( path: web::Path, body: web::Json, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; ok_json( service - .agent_conversation_update(user_id, path.into_inner(), body.into_inner()) + .agent_conversation_update( + user_id, + path.into_inner(), + body.into_inner(), + ) .await?, ) } @@ -149,8 +176,12 @@ pub async fn delete_conversation( service: web::Data, path: web::Path, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - service.agent_conversation_delete(user_id, path.into_inner()).await?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + service + .agent_conversation_delete(user_id, path.into_inner()) + .await?; Ok(HttpResponse::Ok().json(serde_json::json!({ "deleted": true }))) } @@ -165,8 +196,14 @@ pub async fn archive_conversation( service: web::Data, path: web::Path, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_json(service.agent_conversation_archive(user_id, path.into_inner()).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_conversation_archive(user_id, path.into_inner()) + .await?, + ) } #[utoipa::path( @@ -180,8 +217,14 @@ pub async fn unarchive_conversation( service: web::Data, path: web::Path, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_json(service.agent_conversation_unarchive(user_id, path.into_inner()).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_conversation_unarchive(user_id, path.into_inner()) + .await?, + ) } #[utoipa::path( get, path = "/api/v1/agent/conversations/{id}/messages", @@ -195,10 +238,17 @@ pub async fn list_messages( path: web::Path, query: web::Query, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; ok_json( service - .agent_message_list(user_id, path.into_inner(), query.limit, query.before) + .agent_message_list( + user_id, + path.into_inner(), + query.limit, + query.before, + ) .await?, ) } @@ -221,7 +271,9 @@ pub async fn send_message( path: web::Path, body: web::Json, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; let conversation_id = path.into_inner(); let mut req = body.into_inner(); req.conversation_id = Some(conversation_id); @@ -240,7 +292,9 @@ pub async fn stream_agent( path: web::Path, body: web::Json, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; let conversation_id = path.into_inner(); let mut req = body.into_inner(); req.conversation_id = Some(conversation_id); @@ -282,7 +336,9 @@ pub async fn fork_conversation( path: web::Path, body: web::Json, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; ok_json( service .agent_conversation_fork( diff --git a/lib/api/src/agent/session.rs b/lib/api/src/agent/session.rs index db67e26..672bd49 100644 --- a/lib/api/src/agent/session.rs +++ b/lib/api/src/agent/session.rs @@ -15,8 +15,7 @@ pub fn configure(cfg: &mut ServiceConfig) { .route(web::post().to(create_session)), ) .service( - web::resource("/sessions/search") - .route(web::get().to(search_sessions)), + web::resource("/sessions/search").route(web::get().to(search_sessions)), ) .service( web::resource("/sessions/{id}") @@ -38,7 +37,9 @@ pub async fn list_sessions( session: Session, service: web::Data, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; ok_json(service.agent_session_list(user_id).await?) } #[utoipa::path( @@ -52,8 +53,14 @@ pub async fn create_session( service: web::Data, body: web::Json, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_json(service.agent_session_create(user_id, body.into_inner()).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_session_create(user_id, body.into_inner()) + .await?, + ) } #[utoipa::path( get, path = "/api/v1/agent/sessions/{id}", @@ -66,8 +73,14 @@ pub async fn get_session( service: web::Data, path: web::Path, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_json(service.agent_session_get(user_id, path.into_inner()).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_session_get(user_id, path.into_inner()) + .await?, + ) } #[utoipa::path( patch, path = "/api/v1/agent/sessions/{id}", @@ -82,8 +95,14 @@ pub async fn update_session( path: web::Path, body: web::Json, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_json(service.agent_session_update(user_id, path.into_inner(), body.into_inner()).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_session_update(user_id, path.into_inner(), body.into_inner()) + .await?, + ) } #[utoipa::path( delete, path = "/api/v1/agent/sessions/{id}", @@ -96,8 +115,12 @@ pub async fn delete_session( service: web::Data, path: web::Path, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - service.agent_session_delete(user_id, path.into_inner()).await?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + service + .agent_session_delete(user_id, path.into_inner()) + .await?; Ok(HttpResponse::Ok().json(serde_json::json!({ "deleted": true }))) } @@ -122,7 +145,9 @@ pub async fn search_sessions( service: web::Data, query: web::Query, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; ok_json( service .agent_session_search(user_id, &query.q, query.limit) @@ -148,7 +173,9 @@ pub async fn update_session_toolsets( path: web::Path, body: web::Json, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; ok_json( service .agent_session_update_toolsets( diff --git a/lib/api/src/auth/mod.rs b/lib/api/src/auth/mod.rs index 7067c69..c5a87a1 100644 --- a/lib/api/src/auth/mod.rs +++ b/lib/api/src/auth/mod.rs @@ -37,7 +37,9 @@ pub fn configure(cfg: &mut ServiceConfig) { web::post().to(reset_pass::reset_password_verify), )), ) - .service(web::resource("/public-key").route(web::get().to(rsa::rsa))) + .service( + web::resource("/public-key").route(web::get().to(rsa::rsa)), + ) .service( web::scope("/2fa") .service( diff --git a/lib/api/src/channel/mod.rs b/lib/api/src/channel/mod.rs index 47b70e6..efe0846 100644 --- a/lib/api/src/channel/mod.rs +++ b/lib/api/src/channel/mod.rs @@ -1,9 +1,9 @@ pub mod rest; -pub mod rest_ai; pub mod rest_interact; pub mod rest_member; pub mod rest_message; pub mod rest_room; +pub mod rest_user; pub mod rest_voice; pub mod token; @@ -65,8 +65,9 @@ pub fn configure(cfg: &mut ServiceConfig, bus: ChannelBus) { .route(actix_web::web::post().to(rest_room::access_grant)), ) .service( - actix_web::web::resource("/workspaces/{workspace_id}/members") - .route(actix_web::web::get().to(rest_member::list_workspace_members)), + actix_web::web::resource("/workspaces/{workspace_id}/members").route( + actix_web::web::get().to(rest_member::list_workspace_members), + ), ) .service( actix_web::web::resource("/rooms/{room_id}/members/{user_id}") @@ -184,21 +185,8 @@ pub fn configure(cfg: &mut ServiceConfig, bus: ChannelBus) { .route(actix_web::web::post().to(rest_voice::screen_share)), ); cfg.service( - actix_web::web::resource("/rooms/{room_id}/ai/stop") - .route(actix_web::web::post().to(rest_ai::ai_stop)), - ) - .service( - actix_web::web::resource("/rooms/{room_id}/ai") - .route(actix_web::web::get().to(rest_ai::ai_list)) - .route(actix_web::web::post().to(rest_ai::ai_add)), - ) - .service( - actix_web::web::resource("/rooms/{room_id}/ai/{agent_session_id}") - .route(actix_web::web::delete().to(rest_ai::ai_remove)), - ) - .service( actix_web::web::resource("/users/summary/{username}") - .route(actix_web::web::get().to(rest_ai::user_summary)), + .route(actix_web::web::get().to(rest_user::user_summary)), ); cfg.service( actix_web::web::resource("/token") diff --git a/lib/api/src/channel/rest.rs b/lib/api/src/channel/rest.rs index 2f37e56..8a04519 100644 --- a/lib/api/src/channel/rest.rs +++ b/lib/api/src/channel/rest.rs @@ -6,13 +6,13 @@ use uuid::Uuid; use crate::error::ApiError; -pub(crate) fn extract_user(req: &HttpRequest) -> Result { +pub fn extract_user(req: &HttpRequest) -> Result { req.get_session() .user() .ok_or_else(|| ApiError(service::error::AppError::Unauthorized)) } -pub(crate) fn channel_err(e: ChannelError) -> ApiError { +pub fn channel_err(e: ChannelError) -> ApiError { ApiError(match e { ChannelError::Unauthorized | ChannelError::TokenInvalidOrExpired => { service::error::AppError::Unauthorized @@ -61,14 +61,14 @@ pub(crate) fn channel_err(e: ChannelError) -> ApiError { }) } -pub(crate) fn ok_json(event: Option) -> HttpResponse { +pub fn ok_json(event: Option) -> HttpResponse { match event { Some(e) => HttpResponse::Ok().json(e), None => HttpResponse::NoContent().finish(), } } -pub(crate) fn created_json(event: Option) -> HttpResponse { +pub fn created_json(event: Option) -> HttpResponse { match event { Some(e) => HttpResponse::Created().json(e), None => HttpResponse::NoContent().finish(), diff --git a/lib/api/src/channel/rest_ai.rs b/lib/api/src/channel/rest_ai.rs deleted file mode 100644 index a29b5cf..0000000 --- a/lib/api/src/channel/rest_ai.rs +++ /dev/null @@ -1,120 +0,0 @@ -use actix_web::{HttpRequest, HttpResponse, web}; -use channel::ChannelBus; -use channel::http::{WsHandler, WsInMessage}; -use serde::Deserialize; -use uuid::Uuid; - -use super::rest::{channel_err, created_json, extract_user, ok_json}; -use crate::error::ApiError; - -#[derive(Debug, Deserialize, utoipa::ToSchema)] -pub struct AiAddRequest { - pub agent_session: Uuid, -} - -#[utoipa::path( - get, - path = "/api/v1/ws/rooms/{room_id}/ai", - responses((status = 200, description = "AI agents in room")), - tag = "channel", -)] -pub async fn ai_list( - req: HttpRequest, - room_id: web::Path, - bus: web::Data, -) -> Result { - let user_id = extract_user(&req)?; - let msg = WsInMessage::AiList { - room: room_id.into_inner(), - }; - let result = WsHandler::handle(&bus, user_id, msg) - .await - .map_err(channel_err)?; - Ok(ok_json(result)) -} - -#[utoipa::path( - post, - path = "/api/v1/ws/rooms/{room_id}/ai", - request_body = AiAddRequest, - responses((status = 201, description = "AI agent added to room")), - tag = "channel", -)] -pub async fn ai_add( - req: HttpRequest, - room_id: web::Path, - body: web::Json, - bus: web::Data, -) -> Result { - let user_id = extract_user(&req)?; - let msg = WsInMessage::AiUpsert { - room: room_id.into_inner(), - model: body.agent_session, - }; - let result = WsHandler::handle(&bus, user_id, msg) - .await - .map_err(channel_err)?; - Ok(created_json(result)) -} - -#[utoipa::path( - delete, - path = "/api/v1/ws/rooms/{room_id}/ai/{agent_session_id}", - responses((status = 200, description = "AI agent removed from room")), - tag = "channel", -)] -pub async fn ai_remove( - req: HttpRequest, - path: web::Path<(Uuid, Uuid)>, - bus: web::Data, -) -> Result { - let user_id = extract_user(&req)?; - let (room, agent_id) = path.into_inner(); - let msg = WsInMessage::AiDelete { room, agent_id }; - let result = WsHandler::handle(&bus, user_id, msg) - .await - .map_err(channel_err)?; - Ok(ok_json(result)) -} - -#[utoipa::path( - post, - path = "/api/v1/ws/rooms/{room_id}/ai/stop", - responses((status = 204, description = "AI agent stopped")), - tag = "channel", -)] -pub async fn ai_stop( - req: HttpRequest, - room_id: web::Path, - bus: web::Data, -) -> Result { - let user_id = extract_user(&req)?; - let msg = WsInMessage::AiStop { - room: room_id.into_inner(), - }; - let result = WsHandler::handle(&bus, user_id, msg) - .await - .map_err(channel_err)?; - Ok(ok_json(result)) -} - -#[utoipa::path( - get, - path = "/api/v1/ws/users/summary/{username}", - responses((status = 200, description = "User summary")), - tag = "channel", -)] -pub async fn user_summary( - req: HttpRequest, - username: web::Path, - bus: web::Data, -) -> Result { - let user_id = extract_user(&req)?; - let msg = WsInMessage::UserSummary { - username: username.into_inner(), - }; - let result = WsHandler::handle(&bus, user_id, msg) - .await - .map_err(channel_err)?; - Ok(ok_json(result)) -} diff --git a/lib/api/src/channel/rest_member.rs b/lib/api/src/channel/rest_member.rs index 4a224f0..c6f63b1 100644 --- a/lib/api/src/channel/rest_member.rs +++ b/lib/api/src/channel/rest_member.rs @@ -360,7 +360,10 @@ pub async fn list_workspace_members( let _user_id = extract_user(&req)?; let workspace = workspace_id.into_inner(); - let members = bus.list_workspace_members(workspace).await.map_err(channel_err)?; + let members = bus + .list_workspace_members(workspace) + .await + .map_err(channel_err)?; let result: Vec = members .into_iter() .map(|(id, username, display_name, avatar_url)| RoomMember { diff --git a/lib/api/src/channel/rest_room.rs b/lib/api/src/channel/rest_room.rs index 2c969fd..09a7a34 100644 --- a/lib/api/src/channel/rest_room.rs +++ b/lib/api/src/channel/rest_room.rs @@ -13,6 +13,7 @@ pub struct RoomCreateRequest { pub room_name: String, pub public: bool, pub category: Option, + pub ai_enabled: Option, } #[derive(Debug, Deserialize, utoipa::ToSchema)] @@ -20,6 +21,7 @@ pub struct RoomUpdateRequest { pub room_name: Option, pub public: Option, pub category: Option, + pub ai_enabled: Option, } #[derive(Debug, Deserialize, utoipa::ToSchema)] @@ -49,10 +51,9 @@ pub async fn list_rooms( bus: web::Data, ) -> Result { let user_id = extract_user(&req)?; - let rooms = bus.list_user_rooms(user_id) - .await - .map_err(channel_err)?; - let categories = bus.list_user_categories(user_id) + let rooms = bus.list_user_rooms(user_id).await.map_err(channel_err)?; + let categories = bus + .list_user_categories(user_id) .await .map_err(channel_err)?; let workspace_id = if let Some(r) = rooms.first() { @@ -148,6 +149,7 @@ pub async fn room_create( room_name: body.room_name.clone(), public: body.public, category: body.category, + ai_enabled: body.ai_enabled, }; let result = WsHandler::handle(&bus, user_id, msg) .await @@ -174,6 +176,7 @@ pub async fn room_update( room_name: body.room_name.clone(), public: body.public, category: body.category, + ai_enabled: body.ai_enabled, }; let result = WsHandler::handle(&bus, user_id, msg) .await diff --git a/lib/api/src/channel/rest_user.rs b/lib/api/src/channel/rest_user.rs new file mode 100644 index 0000000..8517f81 --- /dev/null +++ b/lib/api/src/channel/rest_user.rs @@ -0,0 +1,27 @@ +use actix_web::{HttpRequest, HttpResponse, web}; +use channel::ChannelBus; +use channel::http::{WsHandler, WsInMessage}; + +use super::rest::{channel_err, extract_user, ok_json}; +use crate::error::ApiError; + +#[utoipa::path( + get, + path = "/api/v1/ws/users/summary/{username}", + responses((status = 200, description = "User summary")), + tag = "channel", +)] +pub async fn user_summary( + req: HttpRequest, + username: web::Path, + bus: web::Data, +) -> Result { + let _user_id = extract_user(&req)?; + let msg = WsInMessage::UserSummary { + username: username.into_inner(), + }; + let result = WsHandler::handle(&bus, _user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} diff --git a/lib/api/src/git/archive.rs b/lib/api/src/git/archive.rs index dd4427f..6861e31 100644 --- a/lib/api/src/git/archive.rs +++ b/lib/api/src/git/archive.rs @@ -41,7 +41,11 @@ pub async fn archive( ) -> Result { let WkRepoPath { wk, repo } = path.into_inner(); match query.format.as_str() { - "zip" => ok_json(service.git_archive_zip(&session, &wk, &repo, None).await?), - _ => ok_json(service.git_archive_tar(&session, &wk, &repo, None).await?), + "zip" => { + ok_json(service.git_archive_zip(&session, &wk, &repo, None).await?) + } + _ => { + ok_json(service.git_archive_tar(&session, &wk, &repo, None).await?) + } } } diff --git a/lib/api/src/git/blame.rs b/lib/api/src/git/blame.rs index 46d6bb1..f31120d 100644 --- a/lib/api/src/git/blame.rs +++ b/lib/api/src/git/blame.rs @@ -41,8 +41,13 @@ pub async fn blame_file( (Some(start), Some(end)) => { let data: dto::BlameFileResponseDto = service .git_blame_hunk( - &session, &wk, &repo, query.path.clone(), - query.rev.clone(), start, end, + &session, + &wk, + &repo, + query.path.clone(), + query.rev.clone(), + start, + end, ) .await? .into(); @@ -51,8 +56,12 @@ pub async fn blame_file( _ => { let data: dto::BlameFileResponseDto = service .git_blame_file( - &session, &wk, &repo, query.path.clone(), - query.rev.clone(), None, + &session, + &wk, + &repo, + query.path.clone(), + query.rev.clone(), + None, ) .await? .into(); diff --git a/lib/api/src/git/branch.rs b/lib/api/src/git/branch.rs index 4d8063d..49d6ca8 100644 --- a/lib/api/src/git/branch.rs +++ b/lib/api/src/git/branch.rs @@ -65,10 +65,8 @@ pub async fn list_branches( return ok_json(data); } if query.default_only { - let data: dto::BranchHeadResponseDto = service - .git_branch_head(&session, &wk, &repo) - .await? - .into(); + let data: dto::BranchHeadResponseDto = + service.git_branch_head(&session, &wk, &repo).await?.into(); return ok_json(data); } let data: dto::BranchListResponseDto = service @@ -180,7 +178,13 @@ pub async fn ahead_behind( ) -> Result { let WkRepoBranchPath { wk, repo, name } = path.into_inner(); let data: dto::BranchAheadBehindResponseDto = service - .git_branch_ahead_behind(&session, &wk, &repo, name, query.remote_branch.clone()) + .git_branch_ahead_behind( + &session, + &wk, + &repo, + name, + query.remote_branch.clone(), + ) .await? .into(); ok_json(data) diff --git a/lib/api/src/git/commit.rs b/lib/api/src/git/commit.rs index 099fbaa..f7b8bb2 100644 --- a/lib/api/src/git/commit.rs +++ b/lib/api/src/git/commit.rs @@ -64,10 +64,8 @@ pub async fn list_commits( return ok_json(data); } if query.refs { - let data: dto::CommitRefsResponseDto = service - .git_commit_refs(&session, &wk, &repo) - .await? - .into(); + let data: dto::CommitRefsResponseDto = + service.git_commit_refs(&session, &wk, &repo).await?.into(); return ok_json(data); } if query.summary { @@ -98,7 +96,9 @@ pub async fn commit_history( let WkRepoPath { wk, repo } = path.into_inner(); let data: dto::CommitHistoryResponseDto = service .git_commit_history( - &session, &wk, &repo, + &session, + &wk, + &repo, query.limit.unwrap_or(20), query.skip.unwrap_or(0), query.sort.unwrap_or(0), diff --git a/lib/api/src/git/commit_status.rs b/lib/api/src/git/commit_status.rs index cdf3c55..0787a4d 100644 --- a/lib/api/src/git/commit_status.rs +++ b/lib/api/src/git/commit_status.rs @@ -40,9 +40,13 @@ pub async fn list_statuses( service: web::Data, path: web::Path, ) -> Result { - ok_json(service.git_commit_status_list_by_name( - &session, &path.wk, &path.repo, &path.sha, - ).await?) + ok_json( + service + .git_commit_status_list_by_name( + &session, &path.wk, &path.repo, &path.sha, + ) + .await?, + ) } #[utoipa::path( @@ -56,9 +60,13 @@ pub async fn combined_status( service: web::Data, path: web::Path, ) -> Result { - ok_json(service.git_commit_status_combined_by_name( - &session, &path.wk, &path.repo, &path.sha, - ).await?) + ok_json( + service + .git_commit_status_combined_by_name( + &session, &path.wk, &path.repo, &path.sha, + ) + .await?, + ) } #[utoipa::path( @@ -74,8 +82,19 @@ pub async fn create_status( path: web::Path, body: web::Json, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_created(service.git_commit_status_create_by_name( - &session, user_id, &path.wk, &path.repo, &path.sha, body.into_inner(), - ).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_created( + service + .git_commit_status_create_by_name( + &session, + user_id, + &path.wk, + &path.repo, + &path.sha, + body.into_inner(), + ) + .await?, + ) } diff --git a/lib/api/src/git/compare.rs b/lib/api/src/git/compare.rs index 9a3af8c..7113dc9 100644 --- a/lib/api/src/git/compare.rs +++ b/lib/api/src/git/compare.rs @@ -33,11 +33,15 @@ pub async fn compare( ) -> Result { let (wk, repo_name, basehead) = path.into_inner(); - let (base, head) = basehead - .split_once("...") - .ok_or_else(|| ApiError(service::error::AppError::BadRequest( + let (base, head) = basehead.split_once("...").ok_or_else(|| { + ApiError(service::error::AppError::BadRequest( "basehead must be in format 'base...head'".to_string(), - )))?; + )) + })?; - ok_json(service.git_compare(&session, &wk, &repo_name, base, head).await?) + ok_json( + service + .git_compare(&session, &wk, &repo_name, base, head) + .await?, + ) } diff --git a/lib/api/src/git/contents.rs b/lib/api/src/git/contents.rs index abb0145..98682bb 100644 --- a/lib/api/src/git/contents.rs +++ b/lib/api/src/git/contents.rs @@ -34,9 +34,17 @@ pub async fn get_contents( query: web::Query, ) -> Result { let (wk, repo_name, file_path) = info.into_inner(); - ok_json(service.git_contents_get_by_name( - &session, &wk, &repo_name, &file_path, query.r#ref.as_deref(), - ).await?) + ok_json( + service + .git_contents_get_by_name( + &session, + &wk, + &repo_name, + &file_path, + query.r#ref.as_deref(), + ) + .await?, + ) } #[utoipa::path( @@ -53,9 +61,15 @@ pub async fn create_contents( body: web::Json, ) -> Result { let (wk, repo_name, file_path) = info.into_inner(); - let resp = service.git_contents_create_by_name( - &session, &wk, &repo_name, &file_path, body.into_inner(), - ).await?; + let resp = service + .git_contents_create_by_name( + &session, + &wk, + &repo_name, + &file_path, + body.into_inner(), + ) + .await?; Ok(HttpResponse::Created().json(resp)) } @@ -73,9 +87,17 @@ pub async fn update_contents( body: web::Json, ) -> Result { let (wk, repo_name, file_path) = info.into_inner(); - ok_json(service.git_contents_update_by_name( - &session, &wk, &repo_name, &file_path, body.into_inner(), - ).await?) + ok_json( + service + .git_contents_update_by_name( + &session, + &wk, + &repo_name, + &file_path, + body.into_inner(), + ) + .await?, + ) } #[derive(Deserialize, utoipa::IntoParams)] @@ -98,8 +120,16 @@ pub async fn delete_contents( query: web::Query, ) -> Result { let (wk, repo_name, file_path) = info.into_inner(); - service.git_contents_delete_by_name( - &session, &wk, &repo_name, &file_path, &query.message, &query.sha, query.branch.as_deref(), - ).await?; + service + .git_contents_delete_by_name( + &session, + &wk, + &repo_name, + &file_path, + &query.message, + &query.sha, + query.branch.as_deref(), + ) + .await?; Ok(HttpResponse::NoContent().finish()) } diff --git a/lib/api/src/git/diff.rs b/lib/api/src/git/diff.rs index 0a8d60f..fee113d 100644 --- a/lib/api/src/git/diff.rs +++ b/lib/api/src/git/diff.rs @@ -44,18 +44,34 @@ pub async fn diff( query: web::Query, ) -> Result { let WkRepoPath { wk, repo } = path.into_inner(); - if let (Some(old_tree), Some(new_tree)) = (&query.old_tree, &query.new_tree) { + if let (Some(old_tree), Some(new_tree)) = (&query.old_tree, &query.new_tree) + { let proto_resp = service - .git_diff_tree_to_tree(&session, &wk, &repo, old_tree.clone(), new_tree.clone(), None) + .git_diff_tree_to_tree( + &session, + &wk, + &repo, + old_tree.clone(), + new_tree.clone(), + None, + ) .await?; - let data: dto::DiffResultDto = proto_resp.result.unwrap_or_default().into(); + let data: dto::DiffResultDto = + proto_resp.result.unwrap_or_default().into(); return ok_json(data); } if let Some(tree_oid) = &query.tree_oid { let proto_resp = service - .git_diff_index_to_tree(&session, &wk, &repo, tree_oid.clone(), None) + .git_diff_index_to_tree( + &session, + &wk, + &repo, + tree_oid.clone(), + None, + ) .await?; - let data: dto::DiffResultDto = proto_resp.result.unwrap_or_default().into(); + let data: dto::DiffResultDto = + proto_resp.result.unwrap_or_default().into(); return ok_json(data); } let old_oid = query.old_oid.clone().unwrap_or_default(); @@ -66,21 +82,29 @@ pub async fn diff( let proto_resp = service .git_diff_stats(&session, &wk, &repo, old_oid, new_oid, None) .await?; - let data: dto::DiffStatsDto = proto_resp.result.and_then(|r| r.stats).unwrap_or_default().into(); + let data: dto::DiffStatsDto = proto_resp + .result + .and_then(|r| r.stats) + .unwrap_or_default() + .into(); ok_json(data) } "side-by-side" => { let proto_resp = service - .git_diff_patch_side_by_side(&session, &wk, &repo, old_oid, new_oid, None) + .git_diff_patch_side_by_side( + &session, &wk, &repo, old_oid, new_oid, None, + ) .await?; - let data: dto::SideBySideDiffResultDto = proto_resp.result.unwrap_or_default().into(); + let data: dto::SideBySideDiffResultDto = + proto_resp.result.unwrap_or_default().into(); ok_json(data) } _ => { let proto_resp = service .git_diff_patch(&session, &wk, &repo, old_oid, new_oid, None) .await?; - let data: dto::DiffResultDto = proto_resp.result.unwrap_or_default().into(); + let data: dto::DiffResultDto = + proto_resp.result.unwrap_or_default().into(); ok_json(data) } } diff --git a/lib/api/src/git/dto.rs b/lib/api/src/git/dto.rs index f12aa73..692977a 100644 --- a/lib/api/src/git/dto.rs +++ b/lib/api/src/git/dto.rs @@ -1,7 +1,7 @@ use base64::Engine; +use git::rpc::proto as p; use serde::{Deserialize, Serialize}; use utoipa::ToSchema; -use git::rpc::proto as p; fn oid_val(oid: Option) -> String { oid.map(|o| o.value).unwrap_or_default() @@ -430,7 +430,9 @@ impl From for BranchSummaryResponseDto { impl From for BranchHeadResponseDto { fn from(r: p::BranchHeadResponse) -> Self { - BranchHeadResponseDto { head_name: r.head_name } + BranchHeadResponseDto { + head_name: r.head_name, + } } } @@ -531,7 +533,9 @@ impl From for CommitRefsResponseDto { impl From for CommitPrefixResponseDto { fn from(r: p::CommitPrefixResponse) -> Self { - CommitPrefixResponseDto { oid: oid_opt(r.oid) } + CommitPrefixResponseDto { + oid: oid_opt(r.oid), + } } } @@ -543,13 +547,17 @@ impl From for CommitExistsResponseDto { impl From for CherryPickResponseDto { fn from(r: p::CherryPickResponse) -> Self { - CherryPickResponseDto { oid: oid_opt(r.oid) } + CherryPickResponseDto { + oid: oid_opt(r.oid), + } } } impl From for CherryPickResponseDto { fn from(r: p::CherryPickSequenceResponse) -> Self { - CherryPickResponseDto { oid: oid_opt(r.oid) } + CherryPickResponseDto { + oid: oid_opt(r.oid), + } } } @@ -646,7 +654,9 @@ impl From for BlobExistsResponseDto { impl From for BlobIsBinaryResponseDto { fn from(r: p::BlobIsBinaryResponse) -> Self { - BlobIsBinaryResponseDto { is_binary: r.is_binary } + BlobIsBinaryResponseDto { + is_binary: r.is_binary, + } } } @@ -680,31 +690,41 @@ impl From for TagListResponseDto { impl From for TagInfoResponseDto { fn from(r: p::TagInfoResponse) -> Self { - TagInfoResponseDto { tag: r.tag.map(Into::into) } + TagInfoResponseDto { + tag: r.tag.map(Into::into), + } } } impl From for TagSummaryDto { fn from(s: p::TagSummary) -> Self { - TagSummaryDto { total_count: s.total_count } + TagSummaryDto { + total_count: s.total_count, + } } } impl From for TagSummaryResponseDto { fn from(r: p::TagSummaryResponse) -> Self { - TagSummaryResponseDto { summary: r.summary.map(Into::into) } + TagSummaryResponseDto { + summary: r.summary.map(Into::into), + } } } impl From for TagInitResponseDto { fn from(r: p::TagInitResponse) -> Self { - TagInitResponseDto { oid: oid_opt(r.oid) } + TagInitResponseDto { + oid: oid_opt(r.oid), + } } } impl From for TagUpdateMessageResponseDto { fn from(r: p::TagUpdateMessageResponse) -> Self { - TagUpdateMessageResponseDto { oid: oid_opt(r.oid) } + TagUpdateMessageResponseDto { + oid: oid_opt(r.oid), + } } } @@ -843,10 +863,16 @@ impl From for DiffResultDto { impl From for SideBySideChangeTypeDto { fn from(t: p::SideBySideChangeType) -> Self { match t { - p::SideBySideChangeType::Unchanged => SideBySideChangeTypeDto::Unchanged, + p::SideBySideChangeType::Unchanged => { + SideBySideChangeTypeDto::Unchanged + } p::SideBySideChangeType::Added => SideBySideChangeTypeDto::Added, - p::SideBySideChangeType::Removed => SideBySideChangeTypeDto::Removed, - p::SideBySideChangeType::Modified => SideBySideChangeTypeDto::Modified, + p::SideBySideChangeType::Removed => { + SideBySideChangeTypeDto::Removed + } + p::SideBySideChangeType::Modified => { + SideBySideChangeTypeDto::Modified + } p::SideBySideChangeType::Empty => SideBySideChangeTypeDto::Empty, } } diff --git a/lib/api/src/git/init.rs b/lib/api/src/git/init.rs index eb13bfe..0347b97 100644 --- a/lib/api/src/git/init.rs +++ b/lib/api/src/git/init.rs @@ -1,6 +1,9 @@ use actix_web::{HttpResponse, web}; use serde::{Deserialize, Serialize}; -use service::{AppService, git::init::{CloneRepo, CreateRepo}}; +use service::{ + AppService, + git::init::{CloneRepo, CreateRepo}, +}; use session::Session; use crate::error::ApiError; diff --git a/lib/api/src/git/mod.rs b/lib/api/src/git/mod.rs index c6703ff..1af45f4 100644 --- a/lib/api/src/git/mod.rs +++ b/lib/api/src/git/mod.rs @@ -31,8 +31,7 @@ pub fn configure(cfg: &mut ServiceConfig) { .route(web::get().to(repo::list_repos)), ); cfg.service( - web::resource("/clone") - .route(web::post().to(init::clone_repo)), + web::resource("/clone").route(web::post().to(init::clone_repo)), ); cfg.service( web::resource("/{repo}") @@ -132,8 +131,7 @@ pub fn configure(cfg: &mut ServiceConfig) { .route(web::get().to(blob::blob_info)), ) .service( - web::resource("/blame") - .route(web::get().to(blame::blame_file)), + web::resource("/blame").route(web::get().to(blame::blame_file)), ) .service( web::resource("/trees/{oid}") @@ -147,10 +145,7 @@ pub fn configure(cfg: &mut ServiceConfig) { web::resource("/commits/{oid}/tree") .route(web::get().to(tree::tree_entry_by_path_from_commit)), ) - .service( - web::resource("/diff") - .route(web::get().to(diff::diff)), - ) + .service(web::resource("/diff").route(web::get().to(diff::diff))) .service( web::resource("/diff/branches") .route(web::get().to(readme::diff_branches)), @@ -195,8 +190,7 @@ pub fn configure(cfg: &mut ServiceConfig) { .route(web::get().to(readme::get_readme)), ) .service( - web::resource("/refs") - .route(web::get().to(refs::list_refs)), + web::resource("/refs").route(web::get().to(refs::list_refs)), ), ); cfg.service( diff --git a/lib/api/src/git/refs.rs b/lib/api/src/git/refs.rs index 96f95d8..4bdf1cf 100644 --- a/lib/api/src/git/refs.rs +++ b/lib/api/src/git/refs.rs @@ -34,8 +34,14 @@ pub async fn list_refs( query: web::Query, ) -> Result { if let Some(ref_name) = &query.r#ref { - let r = service.git_ref_get_by_name(&session, &path.wk, &path.repo, ref_name).await?; + let r = service + .git_ref_get_by_name(&session, &path.wk, &path.repo, ref_name) + .await?; return ok_json(vec![r]); } - ok_json(service.git_ref_list_by_name(&session, &path.wk, &path.repo).await?) + ok_json( + service + .git_ref_list_by_name(&session, &path.wk, &path.repo) + .await?, + ) } diff --git a/lib/api/src/git/release.rs b/lib/api/src/git/release.rs index 060788a..9de77ce 100644 --- a/lib/api/src/git/release.rs +++ b/lib/api/src/git/release.rs @@ -43,8 +43,14 @@ pub async fn list_releases( service: web::Data, path: web::Path, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_json(service.git_release_list_by_name(&session, user_id, &path.wk, &path.repo).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .git_release_list_by_name(&session, user_id, &path.wk, &path.repo) + .await?, + ) } #[utoipa::path( @@ -58,8 +64,16 @@ pub async fn get_release( service: web::Data, path: web::Path, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_json(service.git_release_get_by_name(&session, user_id, &path.wk, &path.repo, path.id).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .git_release_get_by_name( + &session, user_id, &path.wk, &path.repo, path.id, + ) + .await?, + ) } #[utoipa::path( @@ -74,8 +88,16 @@ pub async fn get_release_by_tag( path: web::Path<(String, String, String)>, ) -> Result { let (wk, repo_name, tag) = path.into_inner(); - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_json(service.git_release_get_by_tag_name(&session, user_id, &wk, &repo_name, &tag).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .git_release_get_by_tag_name( + &session, user_id, &wk, &repo_name, &tag, + ) + .await?, + ) } #[utoipa::path( @@ -91,8 +113,20 @@ pub async fn create_release( path: web::Path, body: web::Json, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_created(service.git_release_create_by_name(&session, user_id, &path.wk, &path.repo, body.into_inner()).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_created( + service + .git_release_create_by_name( + &session, + user_id, + &path.wk, + &path.repo, + body.into_inner(), + ) + .await?, + ) } #[utoipa::path( @@ -108,8 +142,21 @@ pub async fn update_release( path: web::Path, body: web::Json, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - ok_json(service.git_release_update_by_name(&session, user_id, &path.wk, &path.repo, path.id, body.into_inner()).await?) + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .git_release_update_by_name( + &session, + user_id, + &path.wk, + &path.repo, + path.id, + body.into_inner(), + ) + .await?, + ) } #[utoipa::path( @@ -123,8 +170,14 @@ pub async fn delete_release( service: web::Data, path: web::Path, ) -> Result { - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - service.git_release_delete_by_name(&session, user_id, &path.wk, &path.repo, path.id).await?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + service + .git_release_delete_by_name( + &session, user_id, &path.wk, &path.repo, path.id, + ) + .await?; ok_empty() } @@ -140,7 +193,13 @@ pub async fn delete_release_by_tag( path: web::Path<(String, String, String)>, ) -> Result { let (wk, repo_name, tag) = path.into_inner(); - let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; - service.git_release_delete_by_tag_name(&session, user_id, &wk, &repo_name, &tag).await?; + let user_id = session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized))?; + service + .git_release_delete_by_tag_name( + &session, user_id, &wk, &repo_name, &tag, + ) + .await?; ok_empty() } diff --git a/lib/api/src/git/tag.rs b/lib/api/src/git/tag.rs index 66e96b5..febbd9b 100644 --- a/lib/api/src/git/tag.rs +++ b/lib/api/src/git/tag.rs @@ -54,10 +54,8 @@ pub async fn list_tags( ) -> Result { let WkRepoPath { wk, repo } = path.into_inner(); if query.summary { - let data: dto::TagSummaryResponseDto = service - .git_tag_summary(&session, &wk, &repo) - .await? - .into(); + let data: dto::TagSummaryResponseDto = + service.git_tag_summary(&session, &wk, &repo).await?.into(); return ok_json(data); } let data: dto::TagListResponseDto = service @@ -117,9 +115,7 @@ pub async fn delete_tag( ) -> Result { let WkRepoTagPath { wk, repo, name } = path.into_inner(); let params = git::rpc::proto::TagDeleteParams { name }; - let _ = service - .git_tag_delete(&session, &wk, &repo, params) - .await?; + let _ = service.git_tag_delete(&session, &wk, &repo, params).await?; ok_json(serde_json::json!({})) } #[utoipa::path( @@ -146,9 +142,7 @@ pub async fn update_tag( new_name: new_name.to_string(), force: false, }; - let _ = service - .git_tag_rename(&session, &wk, &repo, params) - .await?; + let _ = service.git_tag_rename(&session, &wk, &repo, params).await?; return ok_json(serde_json::json!({})); } let message = body diff --git a/lib/api/src/git/tree.rs b/lib/api/src/git/tree.rs index ec25ffd..f45aae9 100644 --- a/lib/api/src/git/tree.rs +++ b/lib/api/src/git/tree.rs @@ -86,7 +86,13 @@ pub async fn tree_entry_by_path( ) -> Result { let WkRepoTreeSubPath { wk, repo, tree_oid } = path.into_inner(); let data: dto::TreeEntryByPathResponseDto = service - .git_tree_entry_by_path(&session, &wk, &repo, tree_oid, query.path.clone()) + .git_tree_entry_by_path( + &session, + &wk, + &repo, + tree_oid, + query.path.clone(), + ) .await? .into(); ok_json(data) @@ -112,7 +118,13 @@ pub async fn tree_entry_by_path_from_commit( ) -> Result { let WkRepoCommitPath { wk, repo, oid } = path.into_inner(); let data: dto::TreeEntryByPathResponseDto = service - .git_tree_entry_by_path_from_commit(&session, &wk, &repo, oid, query.path.clone()) + .git_tree_entry_by_path_from_commit( + &session, + &wk, + &repo, + oid, + query.path.clone(), + ) .await? .into(); ok_json(data) diff --git a/lib/api/src/lib.rs b/lib/api/src/lib.rs index df87920..dbd048d 100644 --- a/lib/api/src/lib.rs +++ b/lib/api/src/lib.rs @@ -31,29 +31,28 @@ pub fn configure(cfg: &mut ServiceConfig, channel_bus: channel::ChannelBus) { .service( web::scope("/repos") .configure(git::configure) - .configure(pull_request::configure) + .configure(pull_request::configure), ) .service( web::scope("/issues") - .configure(issues::configure) + .configure(issues::configure), ) .service( web::scope("/labels") - .configure(issues::configure_labels) + .configure(issues::configure_labels), ) .service( web::scope("/milestones") - .configure(issues::configure_milestones) - ) - ) + .configure(issues::configure_milestones), + ), + ), ) .service( web::scope("/ws") .configure(|cfg| channel::configure(cfg, channel_bus)), ) .service( - web::resource("/search") - .route(web::get().to(search::search)), - ) + web::resource("/search").route(web::get().to(search::search)), + ), ); } diff --git a/lib/api/src/openapi.rs b/lib/api/src/openapi.rs index da4109c..93797d3 100644 --- a/lib/api/src/openapi.rs +++ b/lib/api/src/openapi.rs @@ -297,11 +297,7 @@ use utoipa::openapi::security::{ crate::channel::rest_voice::voice_mute, crate::channel::rest_voice::voice_deaf, crate::channel::rest_voice::screen_share, - crate::channel::rest_ai::ai_list, - crate::channel::rest_ai::ai_add, - crate::channel::rest_ai::ai_remove, - crate::channel::rest_ai::ai_stop, - crate::channel::rest_ai::user_summary, + crate::channel::rest_user::user_summary, crate::search::search, ), modifiers(&SecurityAddon) diff --git a/lib/api/src/search.rs b/lib/api/src/search.rs index 2dca0cc..c342d5b 100644 --- a/lib/api/src/search.rs +++ b/lib/api/src/search.rs @@ -2,8 +2,8 @@ use actix_web::{HttpResponse, web}; use serde::Serialize; use utoipa::ToSchema; -use crate::error::ApiError; use crate::channel::ChannelBus; +use crate::error::ApiError; use service::AppService; use session::Session; @@ -177,14 +177,9 @@ async fn search_rooms( user_id: uuid::Uuid, q: &str, ) -> Result, ApiError> { - let rooms = bus - .list_user_rooms(user_id) - .await - .map_err(|e| { - ApiError(service::error::AppError::InternalServerError( - e.to_string(), - )) - })?; + let rooms = bus.list_user_rooms(user_id).await.map_err(|e| { + ApiError(service::error::AppError::InternalServerError(e.to_string())) + })?; let all: Vec = rooms .into_iter() diff --git a/lib/api/src/user/profile.rs b/lib/api/src/user/profile.rs index 87e287e..7f330d4 100644 --- a/lib/api/src/user/profile.rs +++ b/lib/api/src/user/profile.rs @@ -2,7 +2,9 @@ use actix_web::{HttpRequest, HttpResponse, web}; use serde::Serialize; use service::{ AppService, - user::profile::{AvatarUploadResponse, UpdateUserProfileConfig, UserProfileConfig}, + user::profile::{ + AvatarUploadResponse, UpdateUserProfileConfig, UserProfileConfig, + }, }; use session::Session; diff --git a/lib/api/src/workspace/group.rs b/lib/api/src/workspace/group.rs index f7a0d54..82c2991 100644 --- a/lib/api/src/workspace/group.rs +++ b/lib/api/src/workspace/group.rs @@ -106,12 +106,7 @@ pub async fn update_group( ) -> Result { let GroupPath { wk, group_name } = path.into_inner(); let data = service - .workspace_update_group( - &session, - &wk, - &group_name, - params.into_inner(), - ) + .workspace_update_group(&session, &wk, &group_name, params.into_inner()) .await?; ok_json(data) } diff --git a/lib/api/src/workspace/member.rs b/lib/api/src/workspace/member.rs index 2845607..e5bfaa8 100644 --- a/lib/api/src/workspace/member.rs +++ b/lib/api/src/workspace/member.rs @@ -102,12 +102,7 @@ pub async fn update_member( ) -> Result { let MemberPath { wk, username } = path.into_inner(); let data = service - .workspace_update_member( - &session, - &wk, - &username, - params.into_inner(), - ) + .workspace_update_member(&session, &wk, &username, params.into_inner()) .await?; ok_json(data) } diff --git a/lib/api/src/workspace/mod.rs b/lib/api/src/workspace/mod.rs index 0c442eb..d75ffad 100644 --- a/lib/api/src/workspace/mod.rs +++ b/lib/api/src/workspace/mod.rs @@ -6,12 +6,10 @@ pub mod workspace; use actix_web::{web, web::ServiceConfig}; pub fn configure(cfg: &mut ServiceConfig) { cfg.service( - web::resource("") - .route(web::post().to(workspace::create_workspace)), + web::resource("").route(web::post().to(workspace::create_workspace)), ); cfg.service( - web::resource("/my") - .route(web::get().to(workspace::my_workspaces)), + web::resource("/my").route(web::get().to(workspace::my_workspaces)), ); cfg.service( web::resource("/join/my-applies") @@ -64,12 +62,10 @@ pub fn configure_wk(cfg: &mut ServiceConfig) { .route(web::put().to(join::update_join_strategy)), ); cfg.service( - web::resource("/join/apply") - .route(web::post().to(join::apply_join)), + web::resource("/join/apply").route(web::post().to(join::apply_join)), ); cfg.service( - web::resource("/join/cancel") - .route(web::post().to(join::cancel_join)), + web::resource("/join/cancel").route(web::post().to(join::cancel_join)), ); cfg.service( web::resource("/join/applies") diff --git a/lib/cache/local.rs b/lib/cache/local.rs index dd3d6ad..f80cf50 100644 --- a/lib/cache/local.rs +++ b/lib/cache/local.rs @@ -26,7 +26,7 @@ impl Default for LocalCacheConfig { #[derive(Clone)] pub struct MokaCache { - pub(crate) inner: Cache, Arc<[u8]>>, + pub inner: Cache, Arc<[u8]>>, } impl MokaCache { diff --git a/lib/channel/Cargo.toml b/lib/channel/Cargo.toml index f52ccec..b0a156c 100644 --- a/lib/channel/Cargo.toml +++ b/lib/channel/Cargo.toml @@ -35,5 +35,6 @@ tokio = { workspace = true, features = ["sync", "time"] } tokio-util = { workspace = true } tracing = { workspace = true } uuid = { workspace = true, features = ["serde", "v7"] } +lazy_static = "1.5.0" [lints] workspace = true diff --git a/lib/channel/bus.rs b/lib/channel/bus.rs index 4f65536..00f0b9b 100644 --- a/lib/channel/bus.rs +++ b/lib/channel/bus.rs @@ -33,24 +33,31 @@ const ROOM_MESSAGE_EVENT: &str = "room.message"; #[derive(Clone)] pub struct ChannelBus { - pub(crate) inner: Arc, + pub inner: Arc, } -pub(crate) struct Inner { - pub(crate) db: AppDatabase, - pub(crate) cache: AppCache, - pub(crate) io: SocketIo, - pub(crate) config: ChannelBusConfig, - pub(crate) online: RwLock>>, - pub(crate) user_sync_locks: DashMap>>, - pub(crate) typing_states: DashMap<(Uuid, Uuid), (crate::event::UserInfo, crate::event::RoomInfo, tokio_util::sync::CancellationToken)>, - pub(crate) seq: SeqAllocator, - pub(crate) dedup: DeduplicationManager, - pub(crate) metrics: ChannelMetrics, - pub(crate) reconnect: ReconnectManager, - pub(crate) rate_limiter: RateLimiter, - pub(crate) csrf: CsrfProtection, - pub(crate) circuit_breaker: CircuitBreaker, +pub struct Inner { + pub db: AppDatabase, + pub cache: AppCache, + pub io: SocketIo, + pub config: ChannelBusConfig, + pub online: RwLock>>, + pub user_sync_locks: DashMap>>, + pub typing_states: DashMap< + (Uuid, Uuid), + ( + crate::event::UserInfo, + crate::event::RoomInfo, + tokio_util::sync::CancellationToken, + ), + >, + pub seq: SeqAllocator, + pub dedup: DeduplicationManager, + pub metrics: ChannelMetrics, + pub reconnect: ReconnectManager, + pub rate_limiter: RateLimiter, + pub csrf: CsrfProtection, + pub circuit_breaker: CircuitBreaker, } #[derive(Debug, Deserialize)] @@ -79,7 +86,7 @@ impl ChannelBus { &self, room: Uuid, ) -> ChannelResult { - let row = db::sqlx::query_as::<_, (String,)>( + let row = db::sqlx::query_as::<_, (String,)>( "SELECT name FROM room WHERE id = $1", ) .bind(room) @@ -132,7 +139,8 @@ impl ChannelBus { pub async fn lookup_users( &self, users: &[Uuid], - ) -> ChannelResult> { + ) -> ChannelResult> + { if users.is_empty() { return Ok(std::collections::HashMap::new()); } @@ -585,7 +593,9 @@ impl ChannelBus { Err(_) => None, }; let event = match sender { - Some(s) => ChannelEvent::message_created_with_sender(message, s), + Some(s) => { + ChannelEvent::message_created_with_sender(message, s) + } None => ChannelEvent::message_created(message), }; socket.emit(ROOM_MESSAGE_EVENT, event).await?; diff --git a/lib/channel/circuit_breaker.rs b/lib/channel/circuit_breaker.rs index e9dc288..9fab2dd 100644 --- a/lib/channel/circuit_breaker.rs +++ b/lib/channel/circuit_breaker.rs @@ -71,17 +71,15 @@ impl CircuitBreaker { let slot_reserved = { let mut state = self.inner.state.lock().await; match state.status { - STATUS_OPEN => { - match state.last_failure_time { - Some(t) if t.elapsed() > self.inner.config.timeout => { - state.status = STATUS_HALF_OPEN; - state.half_open_calls = 1; - state.success_count = 0; - true - } - _ => false, + STATUS_OPEN => match state.last_failure_time { + Some(t) if t.elapsed() > self.inner.config.timeout => { + state.status = STATUS_HALF_OPEN; + state.half_open_calls = 1; + state.success_count = 0; + true } - } + _ => false, + }, STATUS_HALF_OPEN => { if state.half_open_calls < self.inner.config.half_open_max_calls diff --git a/lib/channel/event/ai.rs b/lib/channel/event/ai.rs deleted file mode 100644 index 97fc4e2..0000000 --- a/lib/channel/event/ai.rs +++ /dev/null @@ -1,44 +0,0 @@ -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use uuid::Uuid; - -use crate::event::{AgentInfo, RoomInfo}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AiAgentJoinedService { - pub room: RoomInfo, - pub agent: AgentInfo, - pub joined_at: DateTime, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AiAgentLeftService { - pub room: RoomInfo, - pub agent: AgentInfo, - pub left_at: DateTime, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RoomAiEntry { - pub agent_session: Uuid, - pub name: String, - pub agent_kind: String, - pub model_version: Option, - pub enabled: bool, - pub auto_reply: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RoomAiListService { - pub room: RoomInfo, - pub agents: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AiAgentStatusChangedService { - pub room: RoomInfo, - pub agent: AgentInfo, - pub old_status: String, - pub new_status: String, - pub changed_at: DateTime, -} diff --git a/lib/channel/event/attachment.rs b/lib/channel/event/attachment.rs index 57f863e..9aea82f 100644 --- a/lib/channel/event/attachment.rs +++ b/lib/channel/event/attachment.rs @@ -1,6 +1,6 @@ +use crate::event::{RoomInfo, UserInfo}; use chrono::{DateTime, Utc}; use uuid::Uuid; -use crate::event::{RoomInfo, UserInfo}; use serde::{Deserialize, Serialize}; diff --git a/lib/channel/event/category.rs b/lib/channel/event/category.rs index c9c4260..0eb2c63 100644 --- a/lib/channel/event/category.rs +++ b/lib/channel/event/category.rs @@ -1,6 +1,6 @@ +use crate::event::{UserInfo, WorkspaceInfo}; use chrono::{DateTime, Utc}; use uuid::Uuid; -use crate::event::{UserInfo, WorkspaceInfo}; use serde::{Deserialize, Serialize}; diff --git a/lib/channel/event/common.rs b/lib/channel/event/common.rs index ab7a7ea..28c7507 100644 --- a/lib/channel/event/common.rs +++ b/lib/channel/event/common.rs @@ -72,21 +72,3 @@ impl WorkspaceInfo { } } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AgentInfo { - pub id: Uuid, - pub name: String, - pub agent_type: String, - pub model_name: Option, -} - -impl AgentInfo { - pub fn unknown(id: Uuid) -> Self { - Self { - id, - name: String::new(), - agent_type: String::new(), - model_name: None, - } - } -} diff --git a/lib/channel/event/dm.rs b/lib/channel/event/dm.rs deleted file mode 100644 index cbad446..0000000 --- a/lib/channel/event/dm.rs +++ /dev/null @@ -1,66 +0,0 @@ -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use uuid::Uuid; - -use crate::event::{RoomInfo, UserInfo}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum DmEventType { - Created, - Closed, - Reopened, -} - -impl DmEventType { - pub fn as_str(&self) -> &str { - match self { - Self::Created => "dm.created", - Self::Closed => "dm.closed", - Self::Reopened => "dm.reopened", - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum DmEvent { - #[serde(rename = "dm.created")] - Created(DmCreatedService), - #[serde(rename = "dm.closed")] - Closed(DmClosedService), - #[serde(rename = "dm.reopened")] - Reopened(DmReopenedService), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DmCreatedService { - pub room: RoomInfo, - pub initiator: UserInfo, - pub recipient: UserInfo, - pub created_at: DateTime, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DmClosedService { - pub room: RoomInfo, - pub closed_by: UserInfo, - pub closed_at: DateTime, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DmReopenedService { - pub room: RoomInfo, - pub reopened_by: UserInfo, - pub reopened_at: DateTime, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DmCreateClient { - pub recipient: Uuid, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DmCloseClient { - pub room: Uuid, -} diff --git a/lib/channel/event/mod.rs b/lib/channel/event/mod.rs index 6aea10c..b346cfe 100644 --- a/lib/channel/event/mod.rs +++ b/lib/channel/event/mod.rs @@ -1,10 +1,8 @@ -pub mod ai; pub mod attachment; pub mod ban; pub mod category; pub mod common; pub mod conversation; -pub mod dm; pub mod draft; pub mod forward; pub mod invite; @@ -22,7 +20,7 @@ pub mod thread; pub mod voice; pub mod workspace; -pub use common::{AgentInfo, RoomInfo, UserInfo, WorkspaceInfo}; +pub use common::{RoomInfo, UserInfo, WorkspaceInfo}; use model::room::RoomMessageModel; use serde::{Deserialize, Serialize}; @@ -37,8 +35,6 @@ pub enum ChannelEventType { ReactionCreated, ReactionDeleted, MessageRead, - DmCreated, - DmClosed, ConversationUpdated, Custom(String), } @@ -52,8 +48,6 @@ impl ChannelEventType { Self::ReactionCreated => "reaction.created", Self::ReactionDeleted => "reaction.deleted", Self::MessageRead => "message.read", - Self::DmCreated => "dm.created", - Self::DmClosed => "dm.closed", Self::ConversationUpdated => "conversation.updated", Self::Custom(value) => value, } diff --git a/lib/channel/event/pin.rs b/lib/channel/event/pin.rs index dcf3c0d..da5ee40 100644 --- a/lib/channel/event/pin.rs +++ b/lib/channel/event/pin.rs @@ -1,6 +1,6 @@ use chrono::{DateTime, Utc}; -use uuid::Uuid; use serde::{Deserialize, Serialize}; +use uuid::Uuid; use crate::event::{RoomInfo, UserInfo}; diff --git a/lib/channel/event/reaction.rs b/lib/channel/event/reaction.rs index a618777..2e2cd21 100644 --- a/lib/channel/event/reaction.rs +++ b/lib/channel/event/reaction.rs @@ -1,6 +1,6 @@ use chrono::{DateTime, Utc}; -use uuid::Uuid; use serde::{Deserialize, Serialize}; +use uuid::Uuid; use crate::event::{RoomInfo, UserInfo}; diff --git a/lib/channel/event/rooms.rs b/lib/channel/event/rooms.rs index 34b5cbf..7c9d355 100644 --- a/lib/channel/event/rooms.rs +++ b/lib/channel/event/rooms.rs @@ -13,7 +13,6 @@ pub enum RoomEventType { TopicUpdated, SettingsUpdated, Moved, - AiUpdated, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -27,8 +26,6 @@ pub enum RoomEvent { Renamed(RoomRenamedService), #[serde(rename = "room.moved")] Moved(RoomMovedService), - #[serde(rename = "room.ai_updated")] - AiUpdated(RoomAiUpdatedService), #[serde(rename = "room.topic_updated")] TopicUpdated(RoomTopicUpdatedService), #[serde(rename = "room.settings_updated")] @@ -73,18 +70,6 @@ pub struct RoomMovedService { pub moved_at: DateTime, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RoomAiUpdatedService { - pub room: RoomInfo, - pub workspace: WorkspaceInfo, - pub model: Uuid, - pub model_name: String, - pub version: i64, - pub agent_type: String, - pub updated_by: UserInfo, - pub updated_at: DateTime, -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RoomTopicUpdatedService { pub room: RoomInfo, @@ -112,6 +97,7 @@ pub struct RoomCreateClient { pub room_name: String, pub public: bool, pub category: Option, + pub ai_enabled: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -123,6 +109,7 @@ pub struct RoomUpdateClient { pub slowmode_seconds: Option, pub nsfw: Option, pub default_auto_archive_duration: Option, + pub ai_enabled: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/lib/channel/event/search.rs b/lib/channel/event/search.rs index 3240117..886fe6d 100644 --- a/lib/channel/event/search.rs +++ b/lib/channel/event/search.rs @@ -1,6 +1,6 @@ use chrono::{DateTime, Utc}; -use uuid::Uuid; use serde::{Deserialize, Serialize}; +use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, message::MessageNewService}; diff --git a/lib/channel/http/handler/ai.rs b/lib/channel/http/handler/ai.rs deleted file mode 100644 index 1e0be2e..0000000 --- a/lib/channel/http/handler/ai.rs +++ /dev/null @@ -1,159 +0,0 @@ -use chrono::Utc; -use uuid::Uuid; - -use crate::event::{AgentInfo, RoomInfo, ai}; -use crate::{ChannelBus, ChannelError, ChannelResult}; - -use super::WsOutEvent; -use super::WsHandler; - -impl WsHandler { - pub(super) async fn ai_list( - bus: &ChannelBus, - user_id: Uuid, - room: Uuid, - ) -> ChannelResult> { - Self::ensure_room_access(bus, user_id, room).await?; - let rows = db::sqlx::query_as::<_, (Uuid, Option, Option, Option, bool, bool)>( - "SELECT ra.agent_session, s.name, s.agent_kind, s.model_version, ra.enabled, ra.auto_reply \ - FROM room_ai ra \ - LEFT JOIN agent_session s ON s.id = ra.agent_session AND s.deleted_at IS NULL \ - WHERE ra.room = $1", - ) - .bind(room) - .fetch_all(bus.inner.db.reader()) - .await?; - - let agents = rows - .into_iter() - .filter_map(|(agent_session, name, agent_kind, model_version, enabled, auto_reply)| { - name.map(|n| ai::RoomAiEntry { - agent_session, - name: n, - agent_kind: agent_kind.unwrap_or_default(), - model_version, - enabled, - auto_reply, - }) - }) - .collect(); - - let ai_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - Ok(Some(WsOutEvent::AiAgentList { - room: ai_room.clone(), - data: ai::RoomAiListService { - room: ai_room, - agents, - }, - })) - } - - pub(super) async fn ai_upsert( - bus: &ChannelBus, - user_id: Uuid, - room: Uuid, - model: Uuid, - ) -> ChannelResult> { - Self::ensure_room_access(bus, user_id, room).await?; - let session = db::sqlx::query_as::<_, model::agent::AgentSessionModel>( - "SELECT id, \"user\", wk, name, description, agent_kind, model_version, \ - system_prompt, temperature, max_output_tokens, enabled, created_by, \ - created_at, updated_at, deleted_at \ - FROM agent_session WHERE id = $1 AND deleted_at IS NULL", - ) - .bind(model) - .fetch_one(bus.inner.db.reader()) - .await - .map_err(|e| match e { - db::sqlx::Error::RowNotFound => ChannelError::RoomNotFound, - other => ChannelError::Database(other), - })?; - db::sqlx::query_as::<_, model::room::RoomAiModel>( - "INSERT INTO room_ai (room, agent_session, enabled, auto_reply, created_by, created_at, updated_at) \ - VALUES ($1, $2, true, false, $3, now(), now()) \ - ON CONFLICT (room, agent_session) DO UPDATE SET enabled = true, updated_at = now() \ - RETURNING room, agent_session, enabled, auto_reply, created_by, created_at, updated_at", - ) - .bind(room) - .bind(model) - .bind(user_id) - .fetch_one(bus.inner.db.writer()) - .await?; - let ai_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let data = ai::AiAgentJoinedService { - room: ai_room, - agent: AgentInfo { - id: model, - name: session.name.clone(), - agent_type: session.agent_kind.clone(), - model_name: None, - }, - joined_at: Utc::now(), - }; - bus.publish_room_event(room, "ai.agent_joined", &data) - .await?; - - Ok(Some(WsOutEvent::AiAgentJoined { room: data.room.clone(), data })) - } - - pub(super) async fn ai_delete( - bus: &ChannelBus, - user_id: Uuid, - room: Uuid, - agent_id: Uuid, - ) -> ChannelResult> { - Self::ensure_room_access(bus, user_id, room).await?; - let session = db::sqlx::query_as::<_, model::agent::AgentSessionModel>( - "SELECT id, \"user\", wk, name, description, agent_kind, model_version, \ - system_prompt, temperature, max_output_tokens, enabled, created_by, \ - created_at, updated_at, deleted_at \ - FROM agent_session WHERE id = $1 AND deleted_at IS NULL", - ) - .bind(agent_id) - .fetch_optional(bus.inner.db.reader()) - .await?; - - let result = db::sqlx::query( - "DELETE FROM room_ai WHERE room = $1 AND agent_session = $2", - ) - .bind(room) - .bind(agent_id) - .execute(bus.inner.db.writer()) - .await?; - - if result.rows_affected() == 0 { - return Err(ChannelError::RoomNotFound); - } - let ai_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let agent_info = session.map(|s| AgentInfo { - id: s.id, - name: s.name, - agent_type: s.agent_kind, - model_name: None, - }).unwrap_or_else(|| AgentInfo::unknown(agent_id)); - - let data = ai::AiAgentLeftService { - room: ai_room, - agent: agent_info, - left_at: Utc::now(), - }; - bus.publish_room_event(room, "ai.agent_left", &data).await?; - - Ok(Some(WsOutEvent::AiAgentLeft { room: data.room.clone(), data })) - } - - pub(super) async fn ai_stop( - bus: &ChannelBus, - user_id: Uuid, - room: Uuid, - ) -> ChannelResult> { - Self::ensure_room_access(bus, user_id, room).await?; - bus.publish_room_event( - room, - "ai.stop", - &serde_json::json!({"stopped_by": user_id}), - ) - .await?; - Ok(None) - } -} diff --git a/lib/channel/http/handler/ban.rs b/lib/channel/http/handler/ban.rs index 975a7a7..c6a7aee 100644 --- a/lib/channel/http/handler/ban.rs +++ b/lib/channel/http/handler/ban.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::event::{UserInfo, WorkspaceInfo, ban}; use crate::{ChannelBus, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn ban_create( @@ -36,9 +36,18 @@ impl WsHandler { }); bus.inner.cache.set(&ban_key, &ban_data).await?; let data = ban::BannedService { - workspace: bus.lookup_workspace(workspace).await.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), - user: bus.lookup_user(user).await.unwrap_or_else(|_| UserInfo::unknown(user)), - banned_by: bus.lookup_user(_user_id).await.unwrap_or_else(|_| UserInfo::unknown(_user_id)), + workspace: bus + .lookup_workspace(workspace) + .await + .unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), + user: bus + .lookup_user(user) + .await + .unwrap_or_else(|_| UserInfo::unknown(user)), + banned_by: bus + .lookup_user(_user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(_user_id)), reason, expires_at: _expires_at, banned_at: Utc::now(), @@ -64,9 +73,18 @@ impl WsHandler { let ban_key = format!("ban:{}:{}:{}", workspace, _user_id, user); bus.inner.cache.remove(&ban_key).await?; let data = ban::UnbannedService { - workspace: bus.lookup_workspace(workspace).await.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), - user: bus.lookup_user(user).await.unwrap_or_else(|_| UserInfo::unknown(user)), - unbanned_by: bus.lookup_user(_user_id).await.unwrap_or_else(|_| UserInfo::unknown(_user_id)), + workspace: bus + .lookup_workspace(workspace) + .await + .unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), + user: bus + .lookup_user(user) + .await + .unwrap_or_else(|_| UserInfo::unknown(user)), + unbanned_by: bus + .lookup_user(_user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(_user_id)), unbanned_at: Utc::now(), }; bus.workspace_changed(workspace).await?; diff --git a/lib/channel/http/handler/category.rs b/lib/channel/http/handler/category.rs index 017609f..5a3fa8a 100644 --- a/lib/channel/http/handler/category.rs +++ b/lib/channel/http/handler/category.rs @@ -5,8 +5,8 @@ use crate::event::{UserInfo, WorkspaceInfo, category}; use crate::{ChannelBus, ChannelError, ChannelResult}; use super::MAX_CATEGORY_NAME_LEN; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn category_create( @@ -17,7 +17,9 @@ impl WsHandler { position: Option, ) -> ChannelResult> { if name.is_empty() || name.len() > MAX_CATEGORY_NAME_LEN { - return Err(ChannelError::Validation("invalid category name".into())); + return Err(ChannelError::Validation( + "invalid category name".into(), + )); } Self::ensure_workspace_member(bus, user_id, workspace).await?; let row = db::sqlx::query_as::<_, model::room::RoomCategoryModel>( @@ -95,7 +97,10 @@ impl WsHandler { updated_at: Utc::now(), }; bus.workspace_changed(old.wk).await?; - Ok(Some(WsOutEvent::CategoryUpdated { workspace: data.project.clone(), data })) + Ok(Some(WsOutEvent::CategoryUpdated { + workspace: data.project.clone(), + data, + })) } pub(super) async fn category_delete( @@ -133,6 +138,9 @@ impl WsHandler { deleted_at: Utc::now(), }; bus.workspace_changed(row.wk).await?; - Ok(Some(WsOutEvent::CategoryDeleted { workspace: cd_workspace, data })) + Ok(Some(WsOutEvent::CategoryDeleted { + workspace: cd_workspace, + data, + })) } } diff --git a/lib/channel/http/handler/conversation.rs b/lib/channel/http/handler/conversation.rs index 47ce248..1f281be 100644 --- a/lib/channel/http/handler/conversation.rs +++ b/lib/channel/http/handler/conversation.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, conversation}; use crate::{ChannelBus, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn conversation_pin( @@ -29,10 +29,14 @@ impl WsHandler { .execute(bus.inner.db.writer()) .await?; - let room_info = - bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let user_info = - bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let room_info = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let user_info = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); if pin { let data = conversation::ConversationPinnedService { @@ -40,7 +44,8 @@ impl WsHandler { room: room_info.clone(), pinned_at: now, }; - bus.emit_to_user(user_id, "conversation.pinned", &data).await?; + bus.emit_to_user(user_id, "conversation.pinned", &data) + .await?; Ok(Some(WsOutEvent::ConversationPinned { room: room_info, data, @@ -51,7 +56,8 @@ impl WsHandler { room: room_info.clone(), unpinned_at: now, }; - bus.emit_to_user(user_id, "conversation.unpinned", &data).await?; + bus.emit_to_user(user_id, "conversation.unpinned", &data) + .await?; Ok(Some(WsOutEvent::ConversationUnpinned { room: room_info, data, @@ -80,10 +86,14 @@ impl WsHandler { .execute(bus.inner.db.writer()) .await?; - let room_info = - bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let user_info = - bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let room_info = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let user_info = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); if mute { let data = conversation::ConversationMutedService { @@ -91,7 +101,8 @@ impl WsHandler { room: room_info.clone(), muted_at: now, }; - bus.emit_to_user(user_id, "conversation.muted", &data).await?; + bus.emit_to_user(user_id, "conversation.muted", &data) + .await?; Ok(Some(WsOutEvent::ConversationMuted { room: room_info, data, @@ -102,7 +113,8 @@ impl WsHandler { room: room_info.clone(), unmuted_at: now, }; - bus.emit_to_user(user_id, "conversation.unmuted", &data).await?; + bus.emit_to_user(user_id, "conversation.unmuted", &data) + .await?; Ok(Some(WsOutEvent::ConversationUnmuted { room: room_info, data, @@ -116,7 +128,8 @@ impl WsHandler { notify_level: String, ) -> ChannelResult> { Self::ensure_room_access(bus, user_id, room).await?; - let valid = matches!(notify_level.as_str(), "all" | "mentions" | "none"); + let valid = + matches!(notify_level.as_str(), "all" | "mentions" | "none"); if !valid { return Err(crate::ChannelError::Internal( "notify_level must be 'all', 'mentions', or 'none'".to_string(), @@ -147,10 +160,14 @@ impl WsHandler { .execute(bus.inner.db.writer()) .await?; - let room_info = - bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let user_info = - bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let room_info = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let user_info = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = conversation::ConversationNotifyLevelChangedService { user: user_info, @@ -208,7 +225,16 @@ impl WsHandler { let summaries: Vec = rows .into_iter() .map( - |(id, name, room_type, is_pinned, is_muted, notify_level, last_read_seq, max_seq)| { + |( + id, + name, + room_type, + is_pinned, + is_muted, + notify_level, + last_read_seq, + max_seq, + )| { let unread = (max_seq - last_read_seq).max(0); conversation::ConversationSummary { room: id, diff --git a/lib/channel/http/handler/dm.rs b/lib/channel/http/handler/dm.rs deleted file mode 100644 index 680a5ee..0000000 --- a/lib/channel/http/handler/dm.rs +++ /dev/null @@ -1,248 +0,0 @@ -use chrono::Utc; -use uuid::Uuid; - -use crate::event::{RoomInfo, UserInfo, dm}; -use crate::{ChannelBus, ChannelResult}; - -use super::WsOutEvent; -use super::WsHandler; - -impl WsHandler { - pub(super) async fn dm_create( - bus: &ChannelBus, - user_id: Uuid, - recipient: Uuid, - ) -> ChannelResult> { - if user_id == recipient { - return Err(crate::ChannelError::Internal( - "cannot create DM with yourself".to_string(), - )); - } - let recipient_exists: Option<(Uuid,)> = db::sqlx::query_as( - "SELECT id FROM \"user\" WHERE id = $1", - ) - .bind(recipient) - .fetch_optional(bus.inner.db.reader()) - .await?; - if recipient_exists.is_none() { - return Err(crate::ChannelError::UserNotFound); - } - let (initiator, other) = if user_id < recipient { - (user_id, recipient) - } else { - (recipient, user_id) - }; - let existing: Option<(Uuid, Uuid, bool)> = db::sqlx::query_as( - "SELECT room, initiator, is_closed FROM dm_conversation \ - WHERE initiator = $1 AND recipient = $2", - ) - .bind(initiator) - .bind(other) - .fetch_optional(bus.inner.db.reader()) - .await?; - - let now = Utc::now(); - - let (room_id, is_reopen) = if let Some((room, _, is_closed)) = existing { - if is_closed { - db::sqlx::query( - "UPDATE dm_conversation SET is_closed = false, closed_at = NULL, \ - updated_at = now() WHERE initiator = $1 AND recipient = $2", - ) - .bind(initiator) - .bind(other) - .execute(bus.inner.db.writer()) - .await?; - db::sqlx::query( - "UPDATE room SET is_archived = false, updated_at = now() WHERE id = $1", - ) - .bind(room) - .execute(bus.inner.db.writer()) - .await?; - - (room, true) - } else { - (room, false) - } - } else { - let shared_wk: Option<(Uuid,)> = db::sqlx::query_as( - "SELECT wm1.wk FROM wk_member wm1 \ - INNER JOIN wk_member wm2 ON wm2.wk = wm1.wk \ - WHERE wm1.\"user\" = $1 AND wm1.leave_at IS NULL \ - AND wm2.\"user\" = $2 AND wm2.leave_at IS NULL \ - LIMIT 1", - ) - .bind(user_id) - .bind(recipient) - .fetch_optional(bus.inner.db.reader()) - .await?; - - let wk = shared_wk.map(|r| r.0).unwrap_or_else(|| { - Uuid::nil() - }); - let room_id = Uuid::new_v4(); - db::sqlx::query( - "INSERT INTO room (id, wk, name, topic, room_type, position, is_private, \ - created_by, created_at, updated_at) \ - VALUES ($1, $2, $3, NULL, 'DM', 0, true, $4, now(), now())", - ) - .bind(room_id) - .bind(wk) - .bind(format!("dm-{}", &room_id.to_string()[..8])) - .bind(user_id) - .execute(bus.inner.db.writer()) - .await?; - db::sqlx::query( - "INSERT INTO dm_conversation (room, initiator, recipient, created_at, updated_at) \ - VALUES ($1, $2, $3, now(), now()) \ - ON CONFLICT (initiator, recipient) DO NOTHING", - ) - .bind(room_id) - .bind(initiator) - .bind(other) - .execute(bus.inner.db.writer()) - .await?; - for uid in &[user_id, recipient] { - db::sqlx::query( - "INSERT INTO room_permission_overwrite \ - (room, target_type, target_id, allow_mask, deny_mask, created_at) \ - VALUES ($1, 'user', $2, 0, 0, now()) \ - ON CONFLICT DO NOTHING", - ) - .bind(room_id) - .bind(uid) - .execute(bus.inner.db.writer()) - .await?; - } - - (room_id, false) - }; - let _ = crate::rooms::refresh_user_rooms_cache( - &bus.inner.db, - &bus.inner.cache, - &bus.inner.config, - user_id, - ) - .await; - let _ = crate::rooms::refresh_user_rooms_cache( - &bus.inner.db, - &bus.inner.cache, - &bus.inner.config, - recipient, - ) - .await; - - let room_info = - bus.lookup_room(room_id).await.unwrap_or_else(|_| RoomInfo::unknown(room_id)); - let initiator_info = - bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); - let recipient_info = - bus.lookup_user(recipient).await.unwrap_or_else(|_| UserInfo::unknown(recipient)); - - if is_reopen { - let data = dm::DmReopenedService { - room: room_info.clone(), - reopened_by: initiator_info, - reopened_at: now, - }; - bus.emit_to_user(user_id, "dm.reopened", &data).await?; - bus.emit_to_user(recipient, "dm.reopened", &data).await?; - Ok(Some(WsOutEvent::DmReopened { - room: room_info, - data, - })) - } else { - let data = dm::DmCreatedService { - room: room_info.clone(), - initiator: initiator_info, - recipient: recipient_info, - created_at: now, - }; - bus.emit_to_user(user_id, "dm.created", &data).await?; - bus.emit_to_user(recipient, "dm.created", &data).await?; - Ok(Some(WsOutEvent::DmCreated { - room: room_info, - data, - })) - } - } - pub(super) async fn dm_close( - bus: &ChannelBus, - user_id: Uuid, - room: Uuid, - ) -> ChannelResult> { - let now = Utc::now(); - - let result = db::sqlx::query( - "UPDATE dm_conversation SET is_closed = true, closed_at = $1, updated_at = $1 \ - WHERE room = $2 AND (initiator = $3 OR recipient = $3) AND is_closed = false", - ) - .bind(now) - .bind(room) - .bind(user_id) - .execute(bus.inner.db.writer()) - .await?; - - if result.rows_affected() == 0 { - return Ok(None); - } - - let room_info = - bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let closed_by = - bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); - - let data = dm::DmClosedService { - room: room_info.clone(), - closed_by, - closed_at: now, - }; - bus.publish_room_event(room, "dm.closed", &data).await?; - Ok(Some(WsOutEvent::DmClosed { - room: room_info, - data, - })) - } - pub(super) async fn dm_list( - bus: &ChannelBus, - user_id: Uuid, - ) -> ChannelResult> { - let rows = db::sqlx::query_as::<_, (Uuid, Uuid, Uuid, chrono::DateTime)>( - "SELECT dc.room, dc.initiator, dc.recipient, dc.created_at \ - FROM dm_conversation dc \ - INNER JOIN room r ON r.id = dc.room \ - WHERE (dc.initiator = $1 OR dc.recipient = $1) \ - AND dc.is_closed = false \ - AND r.deleted_at IS NULL \ - ORDER BY dc.updated_at DESC", - ) - .bind(user_id) - .fetch_all(bus.inner.db.reader()) - .await?; - - let mut results = Vec::with_capacity(rows.len()); - for (room_id, initiator_id, recipient_id, created_at) in rows { - let room_info = bus - .lookup_room(room_id) - .await - .unwrap_or_else(|_| RoomInfo::unknown(room_id)); - let initiator_info = bus - .lookup_user(initiator_id) - .await - .unwrap_or_else(|_| UserInfo::unknown(initiator_id)); - let recipient_info = bus - .lookup_user(recipient_id) - .await - .unwrap_or_else(|_| UserInfo::unknown(recipient_id)); - - results.push(dm::DmCreatedService { - room: room_info, - initiator: initiator_info, - recipient: recipient_info, - created_at, - }); - } - - Ok(Some(WsOutEvent::DmList { data: results })) - } -} diff --git a/lib/channel/http/handler/draft.rs b/lib/channel/http/handler/draft.rs index 74e432b..0d65e77 100644 --- a/lib/channel/http/handler/draft.rs +++ b/lib/channel/http/handler/draft.rs @@ -4,9 +4,9 @@ use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, draft}; use crate::{ChannelBus, ChannelError, ChannelResult}; -use super::{MAX_TEXT_LEN}; -use super::WsOutEvent; +use super::MAX_TEXT_LEN; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn draft_save( @@ -21,15 +21,24 @@ impl WsHandler { } let key = format!("draft:{}:{}", user_id, room); bus.inner.cache.set(&key, &content).await?; - let ds_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let ds_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let ds_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let ds_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = draft::DraftSavedService { user: ds_user, room: ds_room, content, saved_at: Utc::now(), }; - Ok(Some(WsOutEvent::DraftSaved { room: data.room.clone(), data })) + Ok(Some(WsOutEvent::DraftSaved { + room: data.room.clone(), + data, + })) } pub(super) async fn draft_clear( @@ -40,13 +49,22 @@ impl WsHandler { Self::ensure_room_access(bus, user_id, room).await?; let key = format!("draft:{}:{}", user_id, room); bus.inner.cache.remove(&key).await?; - let dc_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let dc_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let dc_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let dc_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = draft::DraftClearedService { user: dc_user, room: dc_room, cleared_at: Utc::now(), }; - Ok(Some(WsOutEvent::DraftCleared { room: data.room.clone(), data })) + Ok(Some(WsOutEvent::DraftCleared { + room: data.room.clone(), + data, + })) } } diff --git a/lib/channel/http/handler/forward.rs b/lib/channel/http/handler/forward.rs index 5fb6085..6b9f5e7 100644 --- a/lib/channel/http/handler/forward.rs +++ b/lib/channel/http/handler/forward.rs @@ -3,8 +3,8 @@ use uuid::Uuid; use crate::event::{RoomInfo, forward}; use crate::{ChannelBus, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn message_forward( @@ -63,11 +63,8 @@ impl WsHandler { let fwd_content_type = row.content_type.clone(); let fwd_created_at = row.created_at; - bus.publish_room_message( - row, - Some(bus.lookup_user(user_id).await?), - ) - .await?; + bus.publish_room_message(row, Some(bus.lookup_user(user_id).await?)) + .await?; let data = forward::MessageForwardedService { id: fwd_id, diff --git a/lib/channel/http/handler/helpers.rs b/lib/channel/http/handler/helpers.rs index 7a487d2..944e4e4 100644 --- a/lib/channel/http/handler/helpers.rs +++ b/lib/channel/http/handler/helpers.rs @@ -59,10 +59,13 @@ impl WsHandler { .fetch_all(bus.inner.db.reader()) .await?; - let user_ids: Vec = rows.iter().map(|(_, _, user)| *user).collect(); + let user_ids: Vec = + rows.iter().map(|(_, _, user)| *user).collect(); let users = bus.lookup_users(&user_ids).await.unwrap_or_default(); - let mut grouped: HashMap> = - HashMap::new(); + let mut grouped: HashMap< + Uuid, + HashMap, + > = HashMap::new(); for (message_id, emoji, reactor) in rows { let group = grouped diff --git a/lib/channel/http/handler/invite.rs b/lib/channel/http/handler/invite.rs index cde04e0..3d5e29b 100644 --- a/lib/channel/http/handler/invite.rs +++ b/lib/channel/http/handler/invite.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, WorkspaceInfo, invite}; use crate::{ChannelBus, ChannelError, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn invite_create( @@ -29,16 +29,29 @@ impl WsHandler { "expires_at": _expires_at, }); bus.inner.cache.set(&id_key, &meta.to_string()).await?; - bus.inner.cache.set(&code_key, &invite_id.to_string()).await?; + bus.inner + .cache + .set(&code_key, &invite_id.to_string()) + .await?; let inv_room = match _room { - Some(r) => Some(bus.lookup_room(r).await.unwrap_or_else(|_| RoomInfo::unknown(r))), + Some(r) => Some( + bus.lookup_room(r) + .await + .unwrap_or_else(|_| RoomInfo::unknown(r)), + ), None => None, }; let data = invite::InviteCreatedService { id: invite_id, - workspace: bus.lookup_workspace(workspace).await.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), + workspace: bus + .lookup_workspace(workspace) + .await + .unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), room: inv_room, - inviter: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + inviter: bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)), invitee: None, code, max_uses: _max_uses, @@ -54,7 +67,8 @@ impl WsHandler { code: String, ) -> ChannelResult> { let code_key = format!("invite:code:{}", code); - let invite_id_str: Option = bus.inner.cache.get(&code_key).await?; + let invite_id_str: Option = + bus.inner.cache.get(&code_key).await?; let invite_id = invite_id_str .as_deref() .and_then(|s| Uuid::parse_str(s).ok()) @@ -90,9 +104,15 @@ impl WsHandler { bus.inner.cache.remove(&id_key).await?; let data = invite::InviteAcceptedService { id: Uuid::now_v7(), - workspace: bus.lookup_workspace(wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(wk)), + workspace: bus + .lookup_workspace(wk) + .await + .unwrap_or_else(|_| WorkspaceInfo::unknown(wk)), room: None, - user: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + user: bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)), accepted_at: Utc::now(), }; bus.workspace_changed(wk).await?; diff --git a/lib/channel/http/handler/message.rs b/lib/channel/http/handler/message.rs index a23e6a6..74db7c0 100644 --- a/lib/channel/http/handler/message.rs +++ b/lib/channel/http/handler/message.rs @@ -7,9 +7,9 @@ use crate::{ pagination::{MessagePagination, PaginationDirection, PaginationParams}, }; -use super::{MAX_MESSAGES_PER_REQUEST, MAX_TEXT_LEN}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; +use super::{MAX_MESSAGES_PER_REQUEST, MAX_TEXT_LEN}; impl WsHandler { /// Count non-deleted sibling replies to the same parent message. @@ -124,16 +124,21 @@ impl WsHandler { // ── Auto-thread logic ────────────────────────────────────────── let mut events: Vec = Vec::new(); - let effective_thread: Option = if let Some(ref parent_id) = in_reply_to { + let effective_thread: Option = if let Some(ref parent_id) = + in_reply_to + { if thread.is_some() { thread } else { - let existing = Self::find_thread_in_chain(bus, *parent_id).await?; + let existing = + Self::find_thread_in_chain(bus, *parent_id).await?; if let Some(tid) = existing { Some(tid) } else { - let sibling_count = Self::count_sibling_replies(bus, *parent_id).await?; - let (root_id, root_seq, chain_depth) = Self::reply_chain_info(bus, *parent_id).await?; + let sibling_count = + Self::count_sibling_replies(bus, *parent_id).await?; + let (root_id, root_seq, chain_depth) = + Self::reply_chain_info(bus, *parent_id).await?; let should_create = sibling_count >= 3 || chain_depth >= 5; if should_create { @@ -152,11 +157,20 @@ impl WsHandler { .await?; let new_thread_id = thread_row.id; - Self::attach_chain_to_thread(bus, *parent_id, new_thread_id).await?; + Self::attach_chain_to_thread( + bus, + *parent_id, + new_thread_id, + ) + .await?; - let tc_room = bus.lookup_room(room).await + let tc_room = bus + .lookup_room(room) + .await .unwrap_or_else(|_| RoomInfo::unknown(room)); - let created_by = bus.lookup_user(user_id).await + let created_by = bus + .lookup_user(user_id) + .await .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = thread::ThreadCreatedService { id: new_thread_id, @@ -166,7 +180,8 @@ impl WsHandler { participants: serde_json::Value::Null, created_at: thread_row.created_at, }; - bus.publish_room_event(room, "thread.created", &data).await?; + bus.publish_room_event(room, "thread.created", &data) + .await?; events.push(WsOutEvent::ThreadCreated { room: data.room.clone(), data, @@ -227,11 +242,11 @@ impl WsHandler { .fetch_one(bus.inner.db.writer()) .await?; - bus.publish_room_message( - row.clone(), - Some(sender), - ).await?; - let msg_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + bus.publish_room_message(row.clone(), Some(sender)).await?; + let msg_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); events.push(WsOutEvent::MessageNew { room: msg_room.clone(), data: message::MessageNewService { @@ -254,7 +269,9 @@ impl WsHandler { }, }); - Ok(events.into_iter().find(|e| matches!(e, WsOutEvent::MessageNew { .. }))) + Ok(events + .into_iter() + .find(|e| matches!(e, WsOutEvent::MessageNew { .. }))) } pub(super) async fn message_update( @@ -302,8 +319,14 @@ impl WsHandler { .execute(bus.inner.db.writer()) .await?; - let sender = bus.lookup_user(row.author).await.unwrap_or_else(|_| UserInfo::unknown(row.author)); - let room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); + let sender = bus + .lookup_user(row.author) + .await + .unwrap_or_else(|_| UserInfo::unknown(row.author)); + let room = bus + .lookup_room(row.room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(row.room)); let data = message::MessageEditedService { id: row.id, seq: row.seq, @@ -353,8 +376,14 @@ impl WsHandler { .bind(message_id) .fetch_one(bus.inner.db.writer()) .await?; - let revoked_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); - let room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); + let revoked_by = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); + let room = bus + .lookup_room(row.room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(row.room)); let data = message::MessageRevokedService { id: row.id, seq: row.seq, @@ -418,20 +447,28 @@ impl WsHandler { .await?; let mut page_messages = page.messages; - if before_seq.is_some() || (before_seq.is_none() && after_seq.is_none()) { + if before_seq.is_some() || (before_seq.is_none() && after_seq.is_none()) + { page_messages.reverse(); } - let message_ids: Vec = page_messages.iter().map(|m| m.id).collect(); - let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) - .await - .unwrap_or_default(); + let message_ids: Vec = + page_messages.iter().map(|m| m.id).collect(); + let reactions = + Self::reaction_groups_for_messages(bus, user_id, &message_ids) + .await + .unwrap_or_default(); - let list_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let list_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); let mut messages: Vec = Vec::with_capacity(page_messages.len()); for m in page_messages { - let sender = bus.lookup_user(m.sender_id).await + let sender = bus + .lookup_user(m.sender_id) + .await .unwrap_or_else(|_| UserInfo::unknown(m.sender_id)); messages.push(message::MessageNewService { id: m.id, @@ -534,10 +571,14 @@ impl WsHandler { let author_ids: Vec = rows.iter().map(|r| r.author).collect(); let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); let message_ids: Vec = rows.iter().map(|r| r.id).collect(); - let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) + let reactions = + Self::reaction_groups_for_messages(bus, user_id, &message_ids) + .await + .unwrap_or_default(); + let around_room = bus + .lookup_room(room) .await - .unwrap_or_default(); - let around_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + .unwrap_or_else(|_| RoomInfo::unknown(room)); let messages = rows .into_iter() .map(|r| { @@ -561,7 +602,10 @@ impl WsHandler { thinking_content: None, thinking_is_chunked: None, send_at: r.created_at, - reactions: reactions.get(&r.id).cloned().unwrap_or_default(), + reactions: reactions + .get(&r.id) + .cloned() + .unwrap_or_default(), } }) .collect::>(); @@ -591,13 +635,19 @@ impl WsHandler { .reconnect .get_missed_messages(room, after_seq) .await?; - let author_ids: Vec = messages.iter().map(|m| m.sender_id).collect(); - let message_ids: Vec = messages.iter().map(|m| m.message_id).collect(); + let author_ids: Vec = + messages.iter().map(|m| m.sender_id).collect(); + let message_ids: Vec = + messages.iter().map(|m| m.message_id).collect(); let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); - let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) + let reactions = + Self::reaction_groups_for_messages(bus, user_id, &message_ids) + .await + .unwrap_or_default(); + let missed_room = bus + .lookup_room(room) .await - .unwrap_or_default(); - let missed_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + .unwrap_or_else(|_| RoomInfo::unknown(room)); let messages = messages .into_iter() .take(limit) @@ -622,7 +672,10 @@ impl WsHandler { thinking_content: None, thinking_is_chunked: None, send_at: m.send_at, - reactions: reactions.get(&m.message_id).cloned().unwrap_or_default(), + reactions: reactions + .get(&m.message_id) + .cloned() + .unwrap_or_default(), } }) .collect::>(); diff --git a/lib/channel/http/handler/message_read.rs b/lib/channel/http/handler/message_read.rs index 5d849b6..5cade03 100644 --- a/lib/channel/http/handler/message_read.rs +++ b/lib/channel/http/handler/message_read.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, message_read}; use crate::{ChannelBus, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn message_mark_read( @@ -60,10 +60,14 @@ impl WsHandler { .await?; } - let room_info = - bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let reader_info = - bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let room_info = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let reader_info = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = message_read::MessageReadBatchService { room: room_info.clone(), @@ -72,7 +76,8 @@ impl WsHandler { reader: reader_info, read_at: now, }; - bus.publish_room_event(room, "message.read_batch", &data).await?; + bus.publish_room_event(room, "message.read_batch", &data) + .await?; Ok(Some(WsOutEvent::MessageReadBatch { room: room_info, data, diff --git a/lib/channel/http/handler/mod.rs b/lib/channel/http/handler/mod.rs index 2ba6cc6..095bd8b 100644 --- a/lib/channel/http/handler/mod.rs +++ b/lib/channel/http/handler/mod.rs @@ -12,27 +12,25 @@ pub(crate) const MAX_CATEGORY_NAME_LEN: usize = 50; mod helpers; -mod subscription; -mod message; -mod room; -mod category; -mod reaction; -mod thread; -mod pin; -mod draft; -mod notification; -mod presence; -mod invite; mod ban; -mod voice; -mod ai; -mod search; -mod user; +mod category; mod conversation; -mod dm; +mod draft; mod forward; +mod invite; +mod message; mod message_read; +mod notification; +mod pin; +mod presence; +mod reaction; +mod room; +mod search; mod star; +mod subscription; +mod thread; +mod user; +mod voice; pub struct WsHandler; @@ -73,8 +71,14 @@ impl WsHandler { ) .await } - WsInMessage::MessageAround { room, seq, limit, thread } => { - Self::message_around(bus, user_id, room, seq, limit, thread).await + WsInMessage::MessageAround { + room, + seq, + limit, + thread, + } => { + Self::message_around(bus, user_id, room, seq, limit, thread) + .await } WsInMessage::MessageCreate { room, @@ -108,9 +112,11 @@ impl WsHandler { room_name, public, category, + ai_enabled, } => { Self::room_create( bus, user_id, workspace, room_name, public, category, + ai_enabled, ) .await } @@ -119,7 +125,13 @@ impl WsHandler { room_name, public, category, - } => Self::room_update(bus, user_id, room, room_name, public, category).await, + ai_enabled, + } => { + Self::room_update( + bus, user_id, room, room_name, public, category, ai_enabled, + ) + .await + } WsInMessage::RoomDelete { room } => { Self::room_delete(bus, user_id, room).await } @@ -127,7 +139,10 @@ impl WsHandler { workspace, name, position, - } => Self::category_create(bus, user_id, workspace, name, position).await, + } => { + Self::category_create(bus, user_id, workspace, name, position) + .await + } WsInMessage::CategoryUpdate { id, name, position } => { Self::category_update(bus, user_id, id, name, position).await } @@ -160,7 +175,11 @@ impl WsHandler { dnd_end_hour, } => { Self::dnd_update( - bus, user_id, room, do_not_disturb, dnd_start_hour, + bus, + user_id, + room, + do_not_disturb, + dnd_start_hour, dnd_end_hour, ) .await @@ -205,7 +224,8 @@ impl WsHandler { Self::notification_mark_read(bus, user_id, id).await } WsInMessage::NotificationMarkAllRead { workspace_id } => { - Self::notification_mark_all_read(bus, user_id, workspace_id).await + Self::notification_mark_all_read(bus, user_id, workspace_id) + .await } WsInMessage::NotificationArchive { id } => { Self::notification_archive(bus, user_id, id).await @@ -218,8 +238,10 @@ impl WsHandler { text, expires_at, } => { - Self::custom_status_update(bus, user_id, emoji, text, expires_at) - .await + Self::custom_status_update( + bus, user_id, emoji, text, expires_at, + ) + .await } WsInMessage::InviteCreate { workspace, @@ -227,8 +249,10 @@ impl WsHandler { max_uses, expires_at, } => { - Self::invite_create(bus, user_id, workspace, room, max_uses, expires_at) - .await + Self::invite_create( + bus, user_id, workspace, room, max_uses, expires_at, + ) + .await } WsInMessage::InviteAccept { code } => { Self::invite_accept(bus, user_id, code).await @@ -242,8 +266,10 @@ impl WsHandler { reason, expires_at, } => { - Self::ban_create(bus, user_id, workspace, user, reason, expires_at) - .await + Self::ban_create( + bus, user_id, workspace, user, reason, expires_at, + ) + .await } WsInMessage::BanRemove { workspace, user } => { Self::ban_remove(bus, user_id, workspace, user).await @@ -263,18 +289,6 @@ impl WsHandler { WsInMessage::ScreenShare { room, start } => { Self::screen_share(bus, user_id, room, start).await } - WsInMessage::AiList { room } => { - Self::ai_list(bus, user_id, room).await - } - WsInMessage::AiUpsert { room, model } => { - Self::ai_upsert(bus, user_id, room, model).await - } - WsInMessage::AiDelete { room, agent_id } => { - Self::ai_delete(bus, user_id, room, agent_id).await - } - WsInMessage::AiStop { room } => { - Self::ai_stop(bus, user_id, room).await - } WsInMessage::UserSummary { username } => { Self::user_summary(bus, username).await } @@ -291,29 +305,20 @@ impl WsHandler { WsInMessage::ConversationMute { room, mute } => { Self::conversation_mute(bus, user_id, room, mute).await } - WsInMessage::ConversationNotifyLevel { - room, - notify_level, - } => { + WsInMessage::ConversationNotifyLevel { room, notify_level } => { Self::conversation_notify_level( - bus, user_id, room, notify_level, + bus, + user_id, + room, + notify_level, ) .await } WsInMessage::ConversationList => { Self::conversation_list(bus, user_id).await } - WsInMessage::DmCreate { recipient } => { - Self::dm_create(bus, user_id, recipient).await - } - WsInMessage::DmClose { room } => { - Self::dm_close(bus, user_id, room).await - } - WsInMessage::DmList => Self::dm_list(bus, user_id).await, - WsInMessage::MessageMarkRead { - room, - message_ids, - } => { + + WsInMessage::MessageMarkRead { room, message_ids } => { Self::message_mark_read(bus, user_id, room, message_ids).await } WsInMessage::MessageGetReaders { message_id } => { @@ -331,8 +336,13 @@ impl WsHandler { source_message_id, target_room, } => { - Self::message_forward(bus, user_id, source_message_id, target_room) - .await + Self::message_forward( + bus, + user_id, + source_message_id, + target_room, + ) + .await } } } diff --git a/lib/channel/http/handler/notification.rs b/lib/channel/http/handler/notification.rs index 6ee34f8..07670af 100644 --- a/lib/channel/http/handler/notification.rs +++ b/lib/channel/http/handler/notification.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::event::{UserInfo, notify}; use crate::{ChannelBus, ChannelError, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn notification_mark_read( @@ -24,7 +24,10 @@ impl WsHandler { if result.rows_affected() == 0 { return Err(ChannelError::RoomNotFound); } - let nr_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let nr_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = notify::NotifyReadService { id, user: nr_user, diff --git a/lib/channel/http/handler/pin.rs b/lib/channel/http/handler/pin.rs index 0f916b2..cf3e324 100644 --- a/lib/channel/http/handler/pin.rs +++ b/lib/channel/http/handler/pin.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, pin}; use crate::{ChannelBus, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn pin_add( @@ -35,8 +35,14 @@ impl WsHandler { .bind(message) .execute(bus.inner.db.writer()) .await?; - let pa_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let pinned_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let pa_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let pinned_by = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = pin::PinAddedService { room: pa_room, message, @@ -44,7 +50,10 @@ impl WsHandler { pinned_at: Utc::now(), }; bus.publish_room_event(room, "pin.added", &data).await?; - Ok(Some(WsOutEvent::PinAdded { room: data.room.clone(), data })) + Ok(Some(WsOutEvent::PinAdded { + room: data.room.clone(), + data, + })) } pub(super) async fn pin_remove( @@ -69,8 +78,14 @@ impl WsHandler { .bind(message) .execute(bus.inner.db.writer()) .await?; - let pr_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let removed_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let pr_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let removed_by = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = pin::PinRemovedService { room: pr_room, message, @@ -78,6 +93,9 @@ impl WsHandler { removed_at: Utc::now(), }; bus.publish_room_event(room, "pin.removed", &data).await?; - Ok(Some(WsOutEvent::PinRemoved { room: data.room.clone(), data })) + Ok(Some(WsOutEvent::PinRemoved { + room: data.room.clone(), + data, + })) } } diff --git a/lib/channel/http/handler/presence.rs b/lib/channel/http/handler/presence.rs index 3ddfaab..f95a992 100644 --- a/lib/channel/http/handler/presence.rs +++ b/lib/channel/http/handler/presence.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, member, presence}; use crate::{ChannelBus, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn dnd_update( @@ -27,8 +27,14 @@ impl WsHandler { "dnd_end_hour": end_hour, }); bus.inner.cache.set(&key, &dnd_data).await?; - let dnd_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let dnd_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let dnd_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let dnd_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = member::DndUpdatedService { room: dnd_room, user: dnd_user, @@ -36,7 +42,8 @@ impl WsHandler { dnd_start_hour: start_hour, dnd_end_hour: end_hour, }; - bus.publish_room_event(room, "member.dnd_updated", &data).await?; + bus.publish_room_event(room, "member.dnd_updated", &data) + .await?; Ok(None) } @@ -45,7 +52,10 @@ impl WsHandler { user_id: Uuid, status: presence::UserPresenceStatus, ) -> ChannelResult> { - let pc_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let pc_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = presence::PresenceChangedService { user: pc_user, project: None, @@ -60,7 +70,8 @@ impl WsHandler { ) .await?; for room in rooms { - bus.publish_room_event(room, "presence.changed", &data).await?; + bus.publish_room_event(room, "presence.changed", &data) + .await?; } Ok(Some(WsOutEvent::PresenceChanged { data })) } @@ -72,7 +83,10 @@ impl WsHandler { text: Option, expires_at: Option>, ) -> ChannelResult> { - let cs_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let cs_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = presence::CustomStatusUpdatedService { user: cs_user, emoji, @@ -87,7 +101,8 @@ impl WsHandler { ) .await?; for room in rooms { - bus.publish_room_event(room, "custom_status.updated", &data).await?; + bus.publish_room_event(room, "custom_status.updated", &data) + .await?; } Ok(Some(WsOutEvent::CustomStatusUpdated { data })) } diff --git a/lib/channel/http/handler/reaction.rs b/lib/channel/http/handler/reaction.rs index 786cd50..28aba06 100644 --- a/lib/channel/http/handler/reaction.rs +++ b/lib/channel/http/handler/reaction.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, reaction}; use crate::{ChannelBus, ChannelError, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn reaction_add( @@ -39,7 +39,10 @@ impl WsHandler { .lookup_user(user_id) .await .unwrap_or_else(|_| UserInfo::unknown(user_id)); - let rct_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let rct_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); let data = reaction::ReactionAddedService { id: Uuid::now_v7(), room: rct_room, @@ -48,8 +51,12 @@ impl WsHandler { emoji, created_at: Utc::now(), }; - bus.publish_room_event(room, "reaction.added", &data).await?; - Ok(Some(WsOutEvent::ReactionAdded { room: data.room.clone(), data })) + bus.publish_room_event(room, "reaction.added", &data) + .await?; + Ok(Some(WsOutEvent::ReactionAdded { + room: data.room.clone(), + data, + })) } pub(super) async fn reaction_remove( @@ -76,7 +83,10 @@ impl WsHandler { .lookup_user(user_id) .await .unwrap_or_else(|_| UserInfo::unknown(user_id)); - let rct_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let rct_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); let data = reaction::ReactionRemovedService { id: Uuid::now_v7(), room: rct_room, @@ -85,7 +95,11 @@ impl WsHandler { emoji, removed_at: Utc::now(), }; - bus.publish_room_event(room, "reaction.removed", &data).await?; - Ok(Some(WsOutEvent::ReactionRemoved { room: data.room.clone(), data })) + bus.publish_room_event(room, "reaction.removed", &data) + .await?; + Ok(Some(WsOutEvent::ReactionRemoved { + room: data.room.clone(), + data, + })) } } diff --git a/lib/channel/http/handler/room.rs b/lib/channel/http/handler/room.rs index f1ca72a..e880027 100644 --- a/lib/channel/http/handler/room.rs +++ b/lib/channel/http/handler/room.rs @@ -4,9 +4,9 @@ use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, WorkspaceInfo, member, rooms}; use crate::{ChannelBus, ChannelError, ChannelResult}; -use super::{MAX_ROOM_NAME_LEN}; -use super::WsOutEvent; +use super::MAX_ROOM_NAME_LEN; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn room_get( @@ -17,7 +17,7 @@ impl WsHandler { Self::ensure_room_access(bus, user_id, room).await?; let row = db::sqlx::query_as::<_, model::room::RoomModel>( "SELECT id, wk, parent, name, topic, room_type, position, \ - is_private, is_archived, created_by, created_at, updated_at, deleted_at \ + is_private, is_archived, ai_enabled, created_by, created_at, updated_at, deleted_at \ FROM room WHERE id = $1 AND deleted_at IS NULL", ) .bind(room) @@ -33,6 +33,7 @@ impl WsHandler { "room_type": row.room_type, "is_private": row.is_private, "is_archived": row.is_archived, + "ai_enabled": row.ai_enabled, "parent": row.parent, "created_by": row.created_by, "created_at": row.created_at, @@ -47,22 +48,25 @@ impl WsHandler { room_name: String, public: bool, category: Option, + ai_enabled: Option, ) -> ChannelResult> { if room_name.is_empty() || room_name.len() > MAX_ROOM_NAME_LEN { return Err(ChannelError::Validation("invalid room name".into())); } Self::ensure_workspace_member(bus, user_id, workspace).await?; let is_private = !public; + let ai = ai_enabled.unwrap_or(false); let row = db::sqlx::query_as::<_, model::room::RoomModel>( - "INSERT INTO room (wk, parent, name, room_type, is_private, created_by, created_at, updated_at) \ - VALUES ($1, $2, $3, 'channel', $4, $5, now(), now()) \ + "INSERT INTO room (wk, parent, name, room_type, is_private, ai_enabled, created_by, created_at, updated_at) \ + VALUES ($1, $2, $3, 'channel', $4, $5, $6, now(), now()) \ RETURNING id, wk, parent, name, topic, room_type, position, \ - is_private, is_archived, created_by, created_at, updated_at, deleted_at", + is_private, is_archived, ai_enabled, created_by, created_at, updated_at, deleted_at", ) .bind(workspace) .bind(category) .bind(&room_name) .bind(is_private) + .bind(ai) .bind(user_id) .fetch_one(bus.inner.db.writer()) .await?; @@ -77,15 +81,25 @@ impl WsHandler { .await?; let data = rooms::RoomCreatedService { room: RoomInfo::from_model(&row), - workspace: bus.lookup_workspace(workspace).await.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), + workspace: bus + .lookup_workspace(workspace) + .await + .unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), public, category, - created_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + created_by: bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)), created_at: row.created_at, }; - bus.publish_room_event(row.id, "room.created", &data).await?; + bus.publish_room_event(row.id, "room.created", &data) + .await?; bus.room_changed(row.id).await?; - Ok(Some(WsOutEvent::RoomCreated { room: data.room.clone(), data })) + Ok(Some(WsOutEvent::RoomCreated { + room: data.room.clone(), + data, + })) } pub(super) async fn room_update( @@ -95,56 +109,74 @@ impl WsHandler { room_name: Option, public: Option, category: Option, + ai_enabled: Option, ) -> ChannelResult> { Self::ensure_room_access(bus, user_id, room).await?; let old = db::sqlx::query_as::<_, model::room::RoomModel>( "SELECT id, wk, parent, name, topic, room_type, position, \ - is_private, is_archived, created_by, created_at, updated_at, deleted_at \ + is_private, is_archived, ai_enabled, created_by, created_at, updated_at, deleted_at \ FROM room WHERE id = $1 AND deleted_at IS NULL", ) .bind(room) .fetch_one(bus.inner.db.reader()) .await?; let new_name = room_name.unwrap_or(old.name.clone()); - let new_private = - public.map(|p| !p).unwrap_or(old.is_private); + let new_private = public.map(|p| !p).unwrap_or(old.is_private); let new_category = category.or(old.parent); + let new_ai = ai_enabled.unwrap_or(old.ai_enabled); let row = db::sqlx::query_as::<_, model::room::RoomModel>( - "UPDATE room SET name = $2, is_private = $3, parent = $4, updated_at = now() \ + "UPDATE room SET name = $2, is_private = $3, parent = $4, ai_enabled = $5, updated_at = now() \ WHERE id = $1 AND deleted_at IS NULL \ RETURNING id, wk, parent, name, topic, room_type, position, \ - is_private, is_archived, created_by, created_at, updated_at, deleted_at", + is_private, is_archived, ai_enabled, created_by, created_at, updated_at, deleted_at", ) .bind(room) .bind(&new_name) .bind(new_private) .bind(new_category) + .bind(new_ai) .fetch_one(bus.inner.db.writer()) .await?; let mut renamed = false; if new_name != old.name { let data = rooms::RoomRenamedService { room: RoomInfo::from_model(&row), - workspace: bus.lookup_workspace(row.wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), + workspace: bus + .lookup_workspace(row.wk) + .await + .unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), old_name: old.name.clone(), new_name: new_name, - renamed_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + renamed_by: bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)), renamed_at: Utc::now(), }; bus.publish_room_event(room, "room.renamed", &data).await?; renamed = true; } - if new_private != old.is_private || new_category != old.parent { + if new_private != old.is_private + || new_category != old.parent + || new_ai != old.ai_enabled + { let data = rooms::RoomSettingsUpdatedService { room: RoomInfo::from_model(&row), - workspace: bus.lookup_workspace(row.wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), + workspace: bus + .lookup_workspace(row.wk) + .await + .unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), slowmode_seconds: None, nsfw: false, default_auto_archive_duration: None, - updated_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + updated_by: bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)), updated_at: Utc::now(), }; - bus.publish_room_event(room, "room.settings_updated", &data).await?; + bus.publish_room_event(room, "room.settings_updated", &data) + .await?; } bus.room_changed(room).await?; if renamed { @@ -152,15 +184,30 @@ impl WsHandler { room: RoomInfo::from_model(&row), data: rooms::RoomRenamedService { room: RoomInfo::from_model(&row), - workspace: bus.lookup_workspace(row.wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), + workspace: bus + .lookup_workspace(row.wk) + .await + .unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), old_name: old.name, new_name: row.name, - renamed_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + renamed_by: bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)), renamed_at: Utc::now(), }, })); } - Ok(None) + Ok(Some(WsOutEvent::Response { + request_id: Uuid::nil(), + data: serde_json::json!({ + "id": row.id, + "name": row.name, + "is_private": row.is_private, + "ai_enabled": row.ai_enabled, + "parent": row.parent, + }), + })) } pub(super) async fn room_delete( @@ -191,13 +238,22 @@ impl WsHandler { .await?; let data = rooms::RoomDeletedService { room: RoomInfo::from_model(&row), - workspace: bus.lookup_workspace(row.wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), - deleted_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + workspace: bus + .lookup_workspace(row.wk) + .await + .unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), + deleted_by: bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)), deleted_at: Utc::now(), }; bus.publish_room_event(room, "room.deleted", &data).await?; bus.room_changed(room).await?; - Ok(Some(WsOutEvent::RoomDeleted { room: data.room.clone(), data })) + Ok(Some(WsOutEvent::RoomDeleted { + room: data.room.clone(), + data, + })) } pub(super) async fn access_grant( @@ -217,8 +273,14 @@ impl WsHandler { .bind(target_user) .execute(bus.inner.db.writer()) .await?; - let mj_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let mj_user = bus.lookup_user(target_user).await.unwrap_or_else(|_| UserInfo::unknown(target_user)); + let mj_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let mj_user = bus + .lookup_user(target_user) + .await + .unwrap_or_else(|_| UserInfo::unknown(target_user)); let data = member::MemberJoinedService { room: mj_room, user: mj_user, @@ -227,7 +289,10 @@ impl WsHandler { }; bus.publish_room_event(room, "member.joined", &data).await?; bus.room_changed(room).await?; - Ok(Some(WsOutEvent::MemberJoined { room: data.room.clone(), data })) + Ok(Some(WsOutEvent::MemberJoined { + room: data.room.clone(), + data, + })) } pub(super) async fn access_revoke( @@ -245,17 +310,30 @@ impl WsHandler { .bind(target_user) .execute(bus.inner.db.writer()) .await?; - let mr_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let mr_target = bus.lookup_user(target_user).await.unwrap_or_else(|_| UserInfo::unknown(target_user)); - let mr_remover = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let mr_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let mr_target = bus + .lookup_user(target_user) + .await + .unwrap_or_else(|_| UserInfo::unknown(target_user)); + let mr_remover = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = member::MemberRemovedService { room: mr_room, user: mr_target, removed_by: mr_remover, removed_at: Utc::now(), }; - bus.publish_room_event(room, "member.removed", &data).await?; + bus.publish_room_event(room, "member.removed", &data) + .await?; bus.room_changed(room).await?; - Ok(Some(WsOutEvent::MemberRemoved { room: data.room.clone(), data })) + Ok(Some(WsOutEvent::MemberRemoved { + room: data.room.clone(), + data, + })) } } diff --git a/lib/channel/http/handler/search.rs b/lib/channel/http/handler/search.rs index 64cb239..bf1ea40 100644 --- a/lib/channel/http/handler/search.rs +++ b/lib/channel/http/handler/search.rs @@ -6,8 +6,8 @@ use crate::{ search::{SearchEngine, SearchQuery}, }; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn search( @@ -36,18 +36,27 @@ impl WsHandler { }) .await?; - let author_ids: Vec = result.hits.iter().map(|h| h.sender_id).collect(); - let message_ids: Vec = result.hits.iter().map(|h| h.message_id).collect(); + let author_ids: Vec = + result.hits.iter().map(|h| h.sender_id).collect(); + let message_ids: Vec = + result.hits.iter().map(|h| h.message_id).collect(); let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); - let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) - .await - .unwrap_or_default(); + let reactions = + Self::reaction_groups_for_messages(bus, user_id, &message_ids) + .await + .unwrap_or_default(); let search_room = match room { - Some(r) => Some(bus.lookup_room(r).await.unwrap_or_else(|_| RoomInfo::unknown(r))), + Some(r) => Some( + bus.lookup_room(r) + .await + .unwrap_or_else(|_| RoomInfo::unknown(r)), + ), None => None, }; - let search_msg_room = search_room.clone().unwrap_or_else(|| RoomInfo::unknown(room.unwrap_or_default())); + let search_msg_room = search_room + .clone() + .unwrap_or_else(|| RoomInfo::unknown(room.unwrap_or_default())); let data = crate::event::search::SearchResultService { q, room: search_room, @@ -60,26 +69,30 @@ impl WsHandler { .cloned() .unwrap_or_else(|| UserInfo::unknown(h.sender_id)); crate::event::search::SearchMessageHitService { - message: crate::event::message::MessageNewService { - id: h.message_id, - seq: 0, - room: search_msg_room.clone(), - sender_type: "user".to_string(), - sender, - thread: None, - in_reply_to: None, - content: h.content.clone(), - content_type: "text".to_string(), - pinned: false, - system_type: None, - metadata: serde_json::Value::Null, - thinking_content: None, - thinking_is_chunked: None, - send_at: h.send_at, - reactions: reactions.get(&h.message_id).cloned().unwrap_or_default(), - }, - highlighted_content: h.highlighted, - }}) + message: crate::event::message::MessageNewService { + id: h.message_id, + seq: 0, + room: search_msg_room.clone(), + sender_type: "user".to_string(), + sender, + thread: None, + in_reply_to: None, + content: h.content.clone(), + content_type: "text".to_string(), + pinned: false, + system_type: None, + metadata: serde_json::Value::Null, + thinking_content: None, + thinking_is_chunked: None, + send_at: h.send_at, + reactions: reactions + .get(&h.message_id) + .cloned() + .unwrap_or_default(), + }, + highlighted_content: h.highlighted, + } + }) .collect(), total: result.total as i64, took_ms: 0, diff --git a/lib/channel/http/handler/star.rs b/lib/channel/http/handler/star.rs index 1df491d..9f5faa3 100644 --- a/lib/channel/http/handler/star.rs +++ b/lib/channel/http/handler/star.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, star}; use crate::{ChannelBus, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn message_star( @@ -18,10 +18,14 @@ impl WsHandler { Self::ensure_room_access(bus, user_id, room).await?; Self::ensure_message_in_room(bus, room, message).await?; - let room_info = - bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let user_info = - bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let room_info = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let user_info = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); if do_star { let result = db::sqlx::query( @@ -76,7 +80,8 @@ impl WsHandler { unstarred_by: user_info, unstarred_at: Utc::now(), }; - bus.emit_to_user(user_id, "message.unstarred", &data).await?; + bus.emit_to_user(user_id, "message.unstarred", &data) + .await?; Ok(Some(WsOutEvent::MessageUnstarred { room: room_info, data, @@ -123,7 +128,17 @@ impl WsHandler { let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); let mut entries = Vec::with_capacity(rows.len()); - for (_star_id, msg_id, seq, content, content_type, author_id, starred_at, sent_at) in rows { + for ( + _star_id, + msg_id, + seq, + content, + content_type, + author_id, + starred_at, + sent_at, + ) in rows + { let msg_room_row: Option<(Uuid,)> = db::sqlx::query_as( "SELECT room FROM room_message WHERE id = $1", ) diff --git a/lib/channel/http/handler/subscription.rs b/lib/channel/http/handler/subscription.rs index c5a64ca..e09d4a0 100644 --- a/lib/channel/http/handler/subscription.rs +++ b/lib/channel/http/handler/subscription.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, member}; use crate::{ChannelBus, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn subscribe( @@ -36,10 +36,18 @@ impl WsHandler { let key = (room, user_id); if action == "start" { - let ty_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let ty_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let ty_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let ty_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let already_typing = bus.inner.typing_states.contains_key(&key); - if let Some((_, (_, _, old_cancel))) = bus.inner.typing_states.remove(&key) { + if let Some((_, (_, _, old_cancel))) = + bus.inner.typing_states.remove(&key) + { old_cancel.cancel(); } @@ -48,13 +56,18 @@ impl WsHandler { let bus_clone = bus.clone(); let user_clone = ty_user.clone(); let room_clone = ty_room.clone(); - bus.inner.typing_states.insert(key, (ty_user.clone(), ty_room.clone(), cancel)); + bus.inner + .typing_states + .insert(key, (ty_user.clone(), ty_room.clone(), cancel)); tokio::spawn(async move { tokio::time::sleep(std::time::Duration::from_secs(10)).await; if cancel_clone.is_cancelled() { return; } - bus_clone.inner.typing_states.remove(&(room_clone.id, user_clone.id)); + bus_clone + .inner + .typing_states + .remove(&(room_clone.id, user_clone.id)); let room_id = room_clone.id; let stop_data = member::TypingStopService { room: room_clone, @@ -62,7 +75,9 @@ impl WsHandler { sender_type: "user".to_string(), stopped_at: Utc::now(), }; - let _ = bus_clone.publish_room_event(room_id, "typing.stop", &stop_data).await; + let _ = bus_clone + .publish_room_event(room_id, "typing.stop", &stop_data) + .await; }); if !already_typing { let data = member::TypingStartService { @@ -72,16 +87,27 @@ impl WsHandler { started_at: Utc::now(), }; bus.publish_room_event(room, "typing.start", &data).await?; - return Ok(Some(WsOutEvent::TypingStart { room: data.room.clone(), data })); + return Ok(Some(WsOutEvent::TypingStart { + room: data.room.clone(), + data, + })); } Ok(None) } else { - if let Some((_, (_, _, cancel))) = bus.inner.typing_states.remove(&key) { + if let Some((_, (_, _, cancel))) = + bus.inner.typing_states.remove(&key) + { cancel.cancel(); } - let ty_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let ty_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let ty_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let ty_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = member::TypingStopService { room: ty_room, @@ -90,7 +116,10 @@ impl WsHandler { stopped_at: Utc::now(), }; bus.publish_room_event(room, "typing.stop", &data).await?; - Ok(Some(WsOutEvent::TypingStop { room: data.room.clone(), data })) + Ok(Some(WsOutEvent::TypingStop { + room: data.room.clone(), + data, + })) } } @@ -117,8 +146,14 @@ impl WsHandler { .bind(last_read_seq) .execute(bus.inner.db.writer()) .await?; - let rr_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let rr_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let rr_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let rr_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = member::ReadReceiptService { room: rr_room.clone(), user: rr_user, diff --git a/lib/channel/http/handler/thread.rs b/lib/channel/http/handler/thread.rs index a3f4b95..11c6bcb 100644 --- a/lib/channel/http/handler/thread.rs +++ b/lib/channel/http/handler/thread.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, thread}; use crate::{ChannelBus, ChannelError, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; /// Helper struct for thread_list JOIN query result #[derive(db::sqlx::FromRow)] @@ -48,9 +48,13 @@ impl WsHandler { let mut items = Vec::new(); for row in rows { - let tc_room = bus.lookup_room(row.room).await + let tc_room = bus + .lookup_room(row.room) + .await .unwrap_or_else(|_| RoomInfo::unknown(row.room)); - let created_by = bus.lookup_user(row.created_by).await + let created_by = bus + .lookup_user(row.created_by) + .await .unwrap_or_else(|_| UserInfo::unknown(row.created_by)); // Get last message preview let preview: Option<(String,)> = db::sqlx::query_as( @@ -110,8 +114,14 @@ impl WsHandler { .bind(user_id) .fetch_one(bus.inner.db.writer()) .await?; - let tc_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let created_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let tc_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let created_by = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = thread::ThreadCreatedService { id: row.id, room: tc_room, @@ -120,8 +130,12 @@ impl WsHandler { participants: serde_json::Value::Null, created_at: row.created_at, }; - bus.publish_room_event(room, "thread.created", &data).await?; - Ok(Some(WsOutEvent::ThreadCreated { room: data.room.clone(), data })) + bus.publish_room_event(room, "thread.created", &data) + .await?; + Ok(Some(WsOutEvent::ThreadCreated { + room: data.room.clone(), + data, + })) } pub(super) async fn thread_resolve( @@ -129,13 +143,12 @@ impl WsHandler { user_id: Uuid, thread_id: Uuid, ) -> ChannelResult> { - let existing: (Uuid,) = db::sqlx::query_as( - "SELECT room FROM room_thread WHERE id = $1", - ) - .bind(thread_id) - .fetch_optional(bus.inner.db.reader()) - .await? - .ok_or(ChannelError::RoomNotFound)?; + let existing: (Uuid,) = + db::sqlx::query_as("SELECT room FROM room_thread WHERE id = $1") + .bind(thread_id) + .fetch_optional(bus.inner.db.reader()) + .await? + .ok_or(ChannelError::RoomNotFound)?; Self::ensure_room_access(bus, user_id, existing.0).await?; let row = db::sqlx::query_as::<_, model::room::RoomThreadModel>( "UPDATE room_thread SET locked = true, updated_at = now() \ @@ -146,16 +159,26 @@ impl WsHandler { .bind(thread_id) .fetch_one(bus.inner.db.writer()) .await?; - let tr_room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); - let resolved_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let tr_room = bus + .lookup_room(row.room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(row.room)); + let resolved_by = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = thread::ThreadResolvedService { id: row.id, room: tr_room, resolved_by, resolved_at: Utc::now(), }; - bus.publish_room_event(row.room, "thread.resolved", &data).await?; - Ok(Some(WsOutEvent::ThreadResolved { room: data.room.clone(), data })) + bus.publish_room_event(row.room, "thread.resolved", &data) + .await?; + Ok(Some(WsOutEvent::ThreadResolved { + room: data.room.clone(), + data, + })) } pub(super) async fn thread_archive( @@ -163,13 +186,12 @@ impl WsHandler { user_id: Uuid, thread_id: Uuid, ) -> ChannelResult> { - let existing: (Uuid,) = db::sqlx::query_as( - "SELECT room FROM room_thread WHERE id = $1", - ) - .bind(thread_id) - .fetch_optional(bus.inner.db.reader()) - .await? - .ok_or(ChannelError::RoomNotFound)?; + let existing: (Uuid,) = + db::sqlx::query_as("SELECT room FROM room_thread WHERE id = $1") + .bind(thread_id) + .fetch_optional(bus.inner.db.reader()) + .await? + .ok_or(ChannelError::RoomNotFound)?; Self::ensure_room_access(bus, user_id, existing.0).await?; let row = db::sqlx::query_as::<_, model::room::RoomThreadModel>( "UPDATE room_thread SET archived = true, archived_at = now(), updated_at = now() \ @@ -180,15 +202,25 @@ impl WsHandler { .bind(thread_id) .fetch_one(bus.inner.db.writer()) .await?; - let ta_room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); - let archived_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let ta_room = bus + .lookup_room(row.room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(row.room)); + let archived_by = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = thread::ThreadArchivedService { id: row.id, room: ta_room, archived_by, archived_at: Utc::now(), }; - bus.publish_room_event(row.room, "thread.archived", &data).await?; - Ok(Some(WsOutEvent::ThreadArchived { room: data.room.clone(), data })) + bus.publish_room_event(row.room, "thread.archived", &data) + .await?; + Ok(Some(WsOutEvent::ThreadArchived { + room: data.room.clone(), + data, + })) } } diff --git a/lib/channel/http/handler/user.rs b/lib/channel/http/handler/user.rs index 37068ab..274bee3 100644 --- a/lib/channel/http/handler/user.rs +++ b/lib/channel/http/handler/user.rs @@ -2,8 +2,8 @@ use uuid::Uuid; use crate::{ChannelBus, ChannelError, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn user_summary( diff --git a/lib/channel/http/handler/voice.rs b/lib/channel/http/handler/voice.rs index 403f402..b245e91 100644 --- a/lib/channel/http/handler/voice.rs +++ b/lib/channel/http/handler/voice.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, voice}; use crate::{ChannelBus, ChannelResult}; -use super::WsOutEvent; use super::WsHandler; +use super::WsOutEvent; impl WsHandler { pub(super) async fn voice_join( @@ -14,8 +14,14 @@ impl WsHandler { room: Uuid, ) -> ChannelResult> { Self::ensure_room_access(bus, user_id, room).await?; - let vj_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let vj_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let vj_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let vj_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = voice::VoiceChannelJoinedService { room: vj_room, workspace: None, @@ -25,8 +31,12 @@ impl WsHandler { video: false, joined_at: Utc::now(), }; - bus.publish_room_event(room, "voice.channel_joined", &data).await?; - Ok(Some(WsOutEvent::VoiceChannelJoined { room: data.room.clone(), data })) + bus.publish_room_event(room, "voice.channel_joined", &data) + .await?; + Ok(Some(WsOutEvent::VoiceChannelJoined { + room: data.room.clone(), + data, + })) } pub(super) async fn voice_leave( @@ -35,16 +45,26 @@ impl WsHandler { room: Uuid, ) -> ChannelResult> { Self::ensure_room_access(bus, user_id, room).await?; - let vl_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); - let vl_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let vl_room = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let vl_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = voice::VoiceChannelLeftService { room: vl_room, workspace: None, user: vl_user, left_at: Utc::now(), }; - bus.publish_room_event(room, "voice.channel_left", &data).await?; - Ok(Some(WsOutEvent::VoiceChannelLeft { room: data.room.clone(), data })) + bus.publish_room_event(room, "voice.channel_left", &data) + .await?; + Ok(Some(WsOutEvent::VoiceChannelLeft { + room: data.room.clone(), + data, + })) } pub(super) async fn voice_mute( diff --git a/lib/channel/http/out_event.rs b/lib/channel/http/out_event.rs index 05fee98..1cabb64 100644 --- a/lib/channel/http/out_event.rs +++ b/lib/channel/http/out_event.rs @@ -2,10 +2,9 @@ use serde::Serialize; use uuid::Uuid; use crate::event::{ - RoomInfo, WorkspaceInfo, - ai, attachment, ban, category, conversation, dm, draft, forward, invite, - member, message, message_read, notify, pin, presence, reaction, rooms, - search, star, thread, voice, workspace, + RoomInfo, WorkspaceInfo, attachment, ban, category, conversation, draft, + forward, invite, member, message, message_read, notify, pin, presence, + reaction, rooms, search, star, thread, voice, workspace, }; #[derive(Debug, Clone, Serialize)] @@ -188,22 +187,6 @@ pub enum WsOutEvent { UserUnbanned { data: ban::UnbannedService, }, - AiAgentJoined { - room: RoomInfo, - data: ai::AiAgentJoinedService, - }, - AiAgentLeft { - room: RoomInfo, - data: ai::AiAgentLeftService, - }, - AiAgentList { - room: RoomInfo, - data: ai::RoomAiListService, - }, - AiAgentStatusChanged { - room: RoomInfo, - data: ai::AiAgentStatusChangedService, - }, VoiceChannelJoined { room: RoomInfo, data: voice::VoiceChannelJoinedService, @@ -235,21 +218,6 @@ pub enum WsOutEvent { ConversationList { data: Vec, }, - DmCreated { - room: RoomInfo, - data: dm::DmCreatedService, - }, - DmClosed { - room: RoomInfo, - data: dm::DmClosedService, - }, - DmReopened { - room: RoomInfo, - data: dm::DmReopenedService, - }, - DmList { - data: Vec, - }, MessageRead { room: RoomInfo, data: message_read::MessageReadService, diff --git a/lib/channel/http/types.rs b/lib/channel/http/types.rs index cf63640..a82efe2 100644 --- a/lib/channel/http/types.rs +++ b/lib/channel/http/types.rs @@ -60,12 +60,14 @@ pub enum WsInMessage { room_name: String, public: bool, category: Option, + ai_enabled: Option, }, RoomUpdate { room: Uuid, room_name: Option, public: Option, category: Option, + ai_enabled: Option, }, RoomDelete { room: Uuid, @@ -212,20 +214,6 @@ pub enum WsInMessage { room: Uuid, start: bool, }, - AiList { - room: Uuid, - }, - AiUpsert { - room: Uuid, - model: Uuid, - }, - AiDelete { - room: Uuid, - agent_id: Uuid, - }, - AiStop { - room: Uuid, - }, UserSummary { username: String, }, @@ -242,13 +230,6 @@ pub enum WsInMessage { notify_level: String, }, ConversationList, - DmCreate { - recipient: Uuid, - }, - DmClose { - room: Uuid, - }, - DmList, MessageMarkRead { room: Uuid, message_ids: Vec, @@ -312,14 +293,9 @@ impl WsInMessage { VoiceMute, VoiceDeaf, ScreenShare, - AiList, - AiUpsert, - AiDelete, - AiStop, ConversationPin, ConversationMute, ConversationNotifyLevel, - DmClose, MessageMarkRead, MessageStar, ) diff --git a/lib/channel/http/ws.rs b/lib/channel/http/ws.rs index 3fc92d9..2f34e4e 100644 --- a/lib/channel/http/ws.rs +++ b/lib/channel/http/ws.rs @@ -47,11 +47,7 @@ async fn handle_inbound(bus: &ChannelBus, socket: &Socket, data: EventPayload) { let parsed = payload; let text = serde_json::to_string(payload).unwrap_or_default(); - if parsed - .get("type") - .and_then(|t| t.as_str()) - == Some("ping") - { + if parsed.get("type").and_then(|t| t.as_str()) == Some("ping") { let pong = WsOutEvent::Pong { protocol_version: super::types::WS_PROTOCOL_VERSION, }; @@ -115,7 +111,8 @@ async fn handle_inbound(bus: &ChannelBus, socket: &Socket, data: EventPayload) { code: 400, error: "parse_error".to_string(), message: e.to_string(), - }).unwrap_or_default(), + }) + .unwrap_or_default(), }; send_event(socket, &err_resp).await; } else { @@ -126,7 +123,8 @@ async fn handle_inbound(bus: &ChannelBus, socket: &Socket, data: EventPayload) { error: "parse_error".to_string(), message: e.to_string(), }, - ).await; + ) + .await; } } } diff --git a/lib/channel/lib.rs b/lib/channel/lib.rs index d0e7c26..21baf07 100644 --- a/lib/channel/lib.rs +++ b/lib/channel/lib.rs @@ -17,6 +17,7 @@ mod security; mod seq; mod token; +use crate::event::UserInfo; pub use ack::{AckRequest, AckResponse, AckStatus, AckTracker, MessageAck}; pub use bus::ChannelBus; pub use cdn::{CdnManager, CdnStoredFile}; @@ -37,3 +38,14 @@ pub use seq::SeqAllocator; pub use token::{ ChannelAccessToken, ChannelTokenApply, ChannelTokenContext, TOKEN_TTL_SECS, }; + +use uuid::Uuid; + +lazy_static::lazy_static! { + pub static ref REDPADA: UserInfo = UserInfo { + id: Uuid::nil(), + username: "RedPanda".to_string(), + display_name: "RedPanda".to_string(), + avatar_url: "".to_string(), + }; +} diff --git a/lib/channel/reconnect.rs b/lib/channel/reconnect.rs index 0b56e4e..e4d3c59 100644 --- a/lib/channel/reconnect.rs +++ b/lib/channel/reconnect.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use model::room::RoomMessageModel; use serde::{Deserialize, Serialize}; -use crate::rooms::RM_COLUMNS; use crate::ChannelResult; +use crate::rooms::RM_COLUMNS; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ClientState { diff --git a/lib/channel/rooms.rs b/lib/channel/rooms.rs index b079ac8..ac02d67 100644 --- a/lib/channel/rooms.rs +++ b/lib/channel/rooms.rs @@ -6,8 +6,7 @@ use uuid::Uuid; use crate::{ChannelBusConfig, ChannelResult}; -pub(crate) const RM_COLUMNS: &str = - "id, room, seq, thread, parent, author, content, content_type, pinned, \ +pub(crate) const RM_COLUMNS: &str = "id, room, seq, thread, parent, author, content, content_type, pinned, \ system_type, metadata, edited_at, created_at, updated_at, deleted_at"; pub(crate) fn room_socket_name(room: Uuid) -> String { @@ -24,6 +23,7 @@ pub struct RoomListItem { pub topic: Option, pub room_type: String, pub is_private: bool, + pub ai_enabled: bool, pub category: Option, pub workspace_id: Uuid, } @@ -44,8 +44,20 @@ pub async fn user_rooms_for_api( return Ok(Vec::new()); } - let rows = sqlx::query_as::<_, (Uuid, String, Option, String, bool, Option, Uuid)>( - "SELECT id, name, topic, room_type, is_private, parent, wk \ + let rows = sqlx::query_as::< + _, + ( + Uuid, + String, + Option, + String, + bool, + bool, + Option, + Uuid, + ), + >( + "SELECT id, name, topic, room_type, is_private, ai_enabled, parent, wk \ FROM room \ WHERE id = ANY($1) AND deleted_at IS NULL AND is_archived = false \ ORDER BY name", @@ -56,15 +68,27 @@ pub async fn user_rooms_for_api( Ok(rows .into_iter() - .map(|(id, name, topic, room_type, is_private, category, workspace_id)| RoomListItem { - id, - name, - topic, - room_type, - is_private, - category, - workspace_id, - }) + .map( + |( + id, + name, + topic, + room_type, + is_private, + ai_enabled, + category, + workspace_id, + )| RoomListItem { + id, + name, + topic, + room_type, + is_private, + ai_enabled, + category, + workspace_id, + }, + ) .collect()) } pub async fn user_categories_for_api( @@ -179,14 +203,14 @@ pub(crate) async fn catchup_messages( room: Uuid, after_seq: i64, ) -> ChannelResult> { - let rows = sqlx::query_as::<_, RoomMessageModel>( - db::sqlx::AssertSqlSafe(format!( + let rows = sqlx::query_as::<_, RoomMessageModel>(db::sqlx::AssertSqlSafe( + format!( "SELECT {RM_COLUMNS} FROM room_message \ WHERE room = $1 AND seq > $2 AND deleted_at IS NULL \ ORDER BY seq ASC \ LIMIT $3" - )), - ) + ), + )) .bind(room) .bind(after_seq) .bind(config.catchup_limit) diff --git a/lib/channel/security.rs b/lib/channel/security.rs index 7c2e297..97a59eb 100644 --- a/lib/channel/security.rs +++ b/lib/channel/security.rs @@ -142,7 +142,7 @@ return 0 } } -pub(crate) fn require_cluster( +pub fn require_cluster( cache: &cache::AppCache, ) -> ChannelResult<&cache::ClusterCache> { cache diff --git a/lib/channel/seq.rs b/lib/channel/seq.rs index 129b417..0526ae9 100644 --- a/lib/channel/seq.rs +++ b/lib/channel/seq.rs @@ -128,7 +128,12 @@ impl SeqAllocator { } if state .next - .compare_exchange_weak(current, current + 1, Ordering::AcqRel, Ordering::Acquire) + .compare_exchange_weak( + current, + current + 1, + Ordering::AcqRel, + Ordering::Acquire, + ) .is_ok() { return Some(current); diff --git a/lib/channel/token.rs b/lib/channel/token.rs index 1ad4240..83b5cc3 100644 --- a/lib/channel/token.rs +++ b/lib/channel/token.rs @@ -184,7 +184,9 @@ impl ChannelBus { .arg(&session_key) .query_async(&mut conn) .await - .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + .map_err(|e| { + ChannelError::Cache(cache::CacheError::Redis(e)) + })?; let device_id = hash_data .get("device_id") @@ -252,9 +254,8 @@ impl ChannelBus { created_at, }; let new_token_bytes = new_payload.encode(&signing_key)?; - let new_access_token = - base64::engine::general_purpose::URL_SAFE_NO_PAD - .encode(&new_token_bytes); + let new_access_token = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(&new_token_bytes); let new_session_key = self.session_hash_key(&payload.user_id, created_at); diff --git a/lib/config/cache.rs b/lib/config/cache.rs index eee340a..ab79acb 100644 --- a/lib/config/cache.rs +++ b/lib/config/cache.rs @@ -44,11 +44,7 @@ impl AppConfig { self.parse_env("APP_CACHE_CLUSTER_WRITE_THROUGH", true) } - pub(crate) fn parse_env( - &self, - key: &str, - default: T, - ) -> anyhow::Result + pub fn parse_env(&self, key: &str, default: T) -> anyhow::Result where T: std::str::FromStr, T::Err: std::error::Error + Send + Sync + 'static, @@ -59,7 +55,7 @@ impl AppConfig { } } - pub(crate) fn parse_duration_secs( + pub fn parse_duration_secs( &self, key: &str, default_secs: u64, @@ -67,7 +63,7 @@ impl AppConfig { Ok(Duration::from_secs(self.parse_env(key, default_secs)?)) } - pub(crate) fn parse_optional_duration_secs( + pub fn parse_optional_duration_secs( &self, key: &str, default_secs: Option, diff --git a/lib/config/channel.rs b/lib/config/channel.rs new file mode 100644 index 0000000..5aa283a --- /dev/null +++ b/lib/config/channel.rs @@ -0,0 +1,70 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn channel_ai_basic_url(&self) -> anyhow::Result { + if let Some(url) = self.env.get("APP_CHANNEL_AI_BASIC_URL") { + return Ok(url.to_string()); + } + Err(anyhow::anyhow!("APP_CHANNEL_AI_BASIC_URL not found")) + } + + pub fn channel_ai_api_key(&self) -> anyhow::Result { + if let Some(api_key) = self.env.get("APP_CHANNEL_AI_API_KEY") { + return Ok(api_key.to_string()); + } + Err(anyhow::anyhow!("APP_CHANNEL_AI_API_KEY not found")) + } + + pub fn channel_ai_model(&self) -> anyhow::Result { + if let Some(model) = self.env.get("APP_CHANNEL_AI_MODEL") { + return Ok(model.to_string()); + } + Err(anyhow::anyhow!("APP_CHANNEL_AI_MODEL not found")) + } + + pub fn channel_ai_context_length(&self) -> anyhow::Result { + if let Some(v) = self.env.get("APP_CHANNEL_AI_CONTEXT_LENGTH") { + return Ok(v.parse::()?); + } + Err(anyhow::anyhow!("APP_CHANNEL_AI_CONTEXT_LENGTH not found")) + } + + pub fn channel_ai_max_tokens(&self) -> anyhow::Result { + if let Some(v) = self.env.get("APP_CHANNEL_AI_MAX_TOKENS") { + return Ok(v.parse::()?); + } + Err(anyhow::anyhow!("APP_CHANNEL_AI_MAX_TOKENS not found")) + } + + pub fn channel_ai_temperature(&self) -> anyhow::Result { + if let Some(v) = self.env.get("APP_CHANNEL_AI_TEMPERATURE") { + return Ok(v.parse::()?); + } + Err(anyhow::anyhow!("APP_CHANNEL_AI_TEMPERATURE not found")) + } + + pub fn channel_ai_top_p(&self) -> anyhow::Result { + if let Some(v) = self.env.get("APP_CHANNEL_AI_TOP_P") { + return Ok(v.parse::()?); + } + Err(anyhow::anyhow!("APP_CHANNEL_AI_TOP_P not found")) + } + + pub fn channel_ai_stop(&self) -> Option { + self.env.get("APP_CHANNEL_AI_STOP").cloned() + } + + pub fn channel_ai_max_retries(&self) -> anyhow::Result { + if let Some(v) = self.env.get("APP_CHANNEL_AI_MAX_RETRIES") { + return Ok(v.parse::()?); + } + Err(anyhow::anyhow!("APP_CHANNEL_AI_MAX_RETRIES not found")) + } + + pub fn channel_ai_retry_delay_secs(&self) -> anyhow::Result { + if let Some(v) = self.env.get("APP_CHANNEL_AI_RETRY_DELAY_SECS") { + return Ok(v.parse::()?); + } + Err(anyhow::anyhow!("APP_CHANNEL_AI_RETRY_DELAY_SECS not found")) + } +} diff --git a/lib/config/git.rs b/lib/config/git.rs index c6f05a5..25bba5a 100644 --- a/lib/config/git.rs +++ b/lib/config/git.rs @@ -27,7 +27,11 @@ impl AppConfig { return Ok(root.to_string()); } let base = std::env::current_dir()?; - Ok(base.join("data").join("repos").to_string_lossy().to_string()) + Ok(base + .join("data") + .join("repos") + .to_string_lossy() + .to_string()) } pub fn gitsync_health_port(&self) -> u16 { diff --git a/lib/config/lib.rs b/lib/config/lib.rs index 0d2598a..5426616 100644 --- a/lib/config/lib.rs +++ b/lib/config/lib.rs @@ -42,6 +42,7 @@ pub mod app; pub mod auth; pub mod avatar; pub mod cache; +pub mod channel; pub mod database; pub mod domain; pub mod embed; diff --git a/lib/db/transaction.rs b/lib/db/transaction.rs index 1f3c5a5..e234f0c 100644 --- a/lib/db/transaction.rs +++ b/lib/db/transaction.rs @@ -4,7 +4,7 @@ use sqlx::{ }; pub struct AppTransaction<'a> { - pub(crate) inner: Transaction<'a, Postgres>, + pub inner: Transaction<'a, Postgres>, } impl<'a> AppTransaction<'a> { diff --git a/lib/git/bare.rs b/lib/git/bare.rs index 98f223c..564a7c4 100644 --- a/lib/git/bare.rs +++ b/lib/git/bare.rs @@ -37,7 +37,10 @@ impl GitBare { } Ok(()) } - pub fn last_commits_for_paths(&self, paths: &[String]) -> GitResult>> { + pub fn last_commits_for_paths( + &self, + paths: &[String], + ) -> GitResult>> { use gix::traverse::commit::simple::CommitTimeOrder; if paths.is_empty() { @@ -45,17 +48,24 @@ impl GitBare { } let repo = self.gix_repo()?; - let head_id = repo.head_id() - .map_err(|e| crate::errors::GitError::Internal(format!("no HEAD: {e}")))?; + let head_id = repo.head_id().map_err(|e| { + crate::errors::GitError::Internal(format!("no HEAD: {e}")) + })?; - let walk = repo.rev_walk(vec![head_id.detach()]) - .sorting(gix::revision::walk::Sorting::ByCommitTime(CommitTimeOrder::NewestFirst)) + let walk = repo + .rev_walk(vec![head_id.detach()]) + .sorting(gix::revision::walk::Sorting::ByCommitTime( + CommitTimeOrder::NewestFirst, + )) .first_parent_only() .all() - .map_err(|e| crate::errors::GitError::Internal(format!("rev_walk: {e}")))?; + .map_err(|e| { + crate::errors::GitError::Internal(format!("rev_walk: {e}")) + })?; let mut result: Vec> = vec![None; paths.len()]; - let mut remaining: std::collections::HashSet = (0..paths.len()).collect(); + let mut remaining: std::collections::HashSet = + (0..paths.len()).collect(); for walk_item in walk { if remaining.is_empty() { @@ -99,24 +109,29 @@ impl GitBare { Err(_) => continue, }; - let changed: std::collections::HashSet = changes.iter() - .map(|c| c.location().to_string()) - .collect(); + let changed: std::collections::HashSet = + changes.iter().map(|c| c.location().to_string()).collect(); let is_root = parent_tree.is_none(); let author_sig = match decoded.author() { Ok(s) => s, Err(_) => continue, }; - let time = author_sig.time().unwrap_or(gix::date::Time { seconds: 0, offset: 0 }); - let msg = commit.message_raw() + let time = author_sig.time().unwrap_or(gix::date::Time { + seconds: 0, + offset: 0, + }); + let msg = commit + .message_raw() .map(|r| r.to_string().trim_end_matches('\n').to_string()) .unwrap_or_default(); let summary = msg.lines().next().unwrap_or("").to_string(); - let matched: Vec = remaining.iter().copied().filter(|&idx| { - is_root || changed.contains(&paths[idx]) - }).collect(); + let matched: Vec = remaining + .iter() + .copied() + .filter(|&idx| is_root || changed.contains(&paths[idx])) + .collect(); for idx in matched { result[idx] = Some(LastCommitInfo { diff --git a/lib/git/cmd/commit/commit_create.rs b/lib/git/cmd/commit/commit_create.rs index 8fc7bde..8fdcb14 100644 --- a/lib/git/cmd/commit/commit_create.rs +++ b/lib/git/cmd/commit/commit_create.rs @@ -25,8 +25,8 @@ pub struct CreateCommitParams { #[derive(Debug, Clone)] struct TreeEntry { - mode: String, // e.g. "100644" - kind: String, // "blob" or "tree" + mode: String, // e.g. "100644" + kind: String, // "blob" or "tree" oid: ObjectId, path: String, } @@ -42,12 +42,14 @@ impl GitBare { let repo = self.gix_repo()?; // 1. Write each file as a blob and collect entries - let mut new_entries: Vec = Vec::with_capacity(params.files.len()); + let mut new_entries: Vec = + Vec::with_capacity(params.files.len()); for fc in ¶ms.files { - let blob_upload_result = self.blob_upload(crate::cmd::blob::BlobUploadParams { - blob: fc.content.clone(), - path: fc.path.clone(), - })?; + let blob_upload_result = + self.blob_upload(crate::cmd::blob::BlobUploadParams { + blob: fc.content.clone(), + path: fc.path.clone(), + })?; new_entries.push(TreeEntry { mode: "100644".to_string(), kind: "blob".to_string(), @@ -64,7 +66,8 @@ impl GitBare { match repo.find_reference(&branch_ref) { Ok(r) => { - let oid = r.into_fully_peeled_id() + let oid = r + .into_fully_peeled_id() .map_err(|e| GitError::Gix(e.to_string()))? .detach(); let oid_str = oid.to_hex().to_string(); @@ -106,10 +109,8 @@ impl GitBare { .as_secs() as i64; let timestamp = crate::cmd::parse::format_git_timestamp(now, 0); - let mut commit_tree_args = vec![ - "commit-tree".to_string(), - tree_oid.as_str().to_string(), - ]; + let mut commit_tree_args = + vec!["commit-tree".to_string(), tree_oid.as_str().to_string()]; if let Some(parent) = &parent_oid { commit_tree_args.push("-p".to_string()); commit_tree_args.push(parent.as_str().to_string()); @@ -120,11 +121,23 @@ impl GitBare { let commit_output = self.git_command_with( GitCommandParams::new(commit_tree_args) .with_stdin(params.message.as_bytes().to_vec()) - .with_env("GIT_AUTHOR_NAME".to_string(), params.author_name.clone()) - .with_env("GIT_AUTHOR_EMAIL".to_string(), params.author_email.clone()) + .with_env( + "GIT_AUTHOR_NAME".to_string(), + params.author_name.clone(), + ) + .with_env( + "GIT_AUTHOR_EMAIL".to_string(), + params.author_email.clone(), + ) .with_env("GIT_AUTHOR_DATE".to_string(), timestamp.clone()) - .with_env("GIT_COMMITTER_NAME".to_string(), params.committer_name.clone()) - .with_env("GIT_COMMITTER_EMAIL".to_string(), params.committer_email.clone()) + .with_env( + "GIT_COMMITTER_NAME".to_string(), + params.committer_name.clone(), + ) + .with_env( + "GIT_COMMITTER_EMAIL".to_string(), + params.committer_email.clone(), + ) .with_env("GIT_COMMITTER_DATE".to_string(), timestamp), )?; @@ -209,15 +222,16 @@ impl GitBare { for entry in entries { input.push_str(&format!( "{} {} {}\t{}\n", - entry.mode, entry.kind, entry.oid.as_str(), entry.path + entry.mode, + entry.kind, + entry.oid.as_str(), + entry.path )); } let output = self.git_command_with( - GitCommandParams::new(vec![ - "mktree".to_string(), - ]) - .with_stdin(input.into_bytes()), + GitCommandParams::new(vec!["mktree".to_string()]) + .with_stdin(input.into_bytes()), )?; if !output.success { diff --git a/lib/git/cmd/commit/commit_walker.rs b/lib/git/cmd/commit/commit_walker.rs index 0d7629d..13dde3b 100644 --- a/lib/git/cmd/commit/commit_walker.rs +++ b/lib/git/cmd/commit/commit_walker.rs @@ -45,7 +45,9 @@ impl GitBare { params: CommitWalkParams, ) -> GitResult> { let repo = self.gix_repo()?; - let tips: Vec = if let Some(ref branch_name) = params.branch { + let tips: Vec = if let Some(ref branch_name) = + params.branch + { if branch_name.is_empty() { vec![repo.head_id()?.detach()] } else { @@ -54,7 +56,9 @@ impl GitBare { Some(reference) => { let target = reference.target(); let target_id = target.try_id().ok_or_else(|| { - GitError::Internal(format!("branch '{branch_name}' has no direct target")) + GitError::Internal(format!( + "branch '{branch_name}' has no direct target" + )) })?; vec![target_id.to_owned()] } @@ -69,7 +73,9 @@ impl GitBare { vec![target_id.to_owned()] } None => { - return Err(GitError::RefNotFound(branch_name.clone())); + return Err(GitError::RefNotFound( + branch_name.clone(), + )); } } } diff --git a/lib/git/rpc/blame.rs b/lib/git/rpc/blame.rs index 17e9004..e537405 100644 --- a/lib/git/rpc/blame.rs +++ b/lib/git/rpc/blame.rs @@ -29,9 +29,14 @@ impl p::blame_service_server::BlameService for BlameServiceImpl { let repo_id = inner.repo_id.clone(); let rev_str = inner.rev.unwrap_or_default(); let path = inner.path.clone(); - let cache_key = format!("git:rpc:cache:blame:file:{}:{}:{}", repo_id, path, rev_str); + let cache_key = format!( + "git:rpc:cache:blame:file:{}:{}:{}", + repo_id, path, rev_str + ); - if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + if let Ok(Some(cached)) = + self.cache.get::(&cache_key).await + { return Ok(Response::new(cached)); } diff --git a/lib/git/rpc/blob.rs b/lib/git/rpc/blob.rs index cd40309..b213c88 100644 --- a/lib/git/rpc/blob.rs +++ b/lib/git/rpc/blob.rs @@ -27,9 +27,12 @@ impl p::blob_service_server::BlobService for BlobServiceImpl { let repo_id = inner.repo_id.clone(); let oid_str = inner.id.clone().map(|o| o.value).unwrap_or_default(); let path = inner.path.clone(); - let cache_key = format!("git:rpc:cache:blob:load:{}:{}:{}", repo_id, oid_str, path); + let cache_key = + format!("git:rpc:cache:blob:load:{}:{}:{}", repo_id, oid_str, path); - if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + if let Ok(Some(cached)) = + self.cache.get::(&cache_key).await + { return Ok(Response::new(cached)); } @@ -89,9 +92,12 @@ impl p::blob_service_server::BlobService for BlobServiceImpl { let inner = req.into_inner(); let repo_id = inner.repo_id.clone(); let oid_str = inner.id.clone().map(|o| o.value).unwrap_or_default(); - let cache_key = format!("git:rpc:cache:blob:binary:{}:{}", repo_id, oid_str); + let cache_key = + format!("git:rpc:cache:blob:binary:{}:{}", repo_id, oid_str); - if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + if let Ok(Some(cached)) = + self.cache.get::(&cache_key).await + { return Ok(Response::new(cached)); } diff --git a/lib/git/rpc/commit.rs b/lib/git/rpc/commit.rs index 83edad2..461d532 100644 --- a/lib/git/rpc/commit.rs +++ b/lib/git/rpc/commit.rs @@ -46,10 +46,16 @@ impl p::commit_service_server::CommitService for CommitServiceImpl { let repo_id = inner.repo_id.clone(); let cache_key = format!( "git:rpc:cache:commit:history:{}:{}:{}:{}:{}", - repo_id, inner.limit, inner.skip, inner.sort, inner.branch.as_deref().unwrap_or("") + repo_id, + inner.limit, + inner.skip, + inner.sort, + inner.branch.as_deref().unwrap_or("") ); - if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + if let Ok(Some(cached)) = + self.cache.get::(&cache_key).await + { return Ok(Response::new(cached)); } @@ -57,7 +63,11 @@ impl p::commit_service_server::CommitService for CommitServiceImpl { let params = crate::cmd::commit::CommitWalkParams { start_oids: vec![], hide_oids: vec![], - limit: if inner.limit > 0 { Some(inner.limit as usize) } else { None }, + limit: if inner.limit > 0 { + Some(inner.limit as usize) + } else { + None + }, skip: inner.skip as usize, first_parent: false, sort: inner.sort.into(), @@ -86,7 +96,11 @@ impl p::commit_service_server::CommitService for CommitServiceImpl { let params = crate::cmd::commit::CommitWalkParams { start_oids: vec![], hide_oids: vec![], - limit: if inner.limit > 0 { Some(inner.limit as usize) } else { None }, + limit: if inner.limit > 0 { + Some(inner.limit as usize) + } else { + None + }, skip: inner.skip as usize, first_parent: false, sort: inner.sort.into(), @@ -119,7 +133,9 @@ impl p::commit_service_server::CommitService for CommitServiceImpl { let repo_id = inner.repo_id.clone(); let cache_key = format!("git:rpc:cache:commit:summary:{}", repo_id); - if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + if let Ok(Some(cached)) = + self.cache.get::(&cache_key).await + { return Ok(Response::new(cached)); } @@ -128,7 +144,9 @@ impl p::commit_service_server::CommitService for CommitServiceImpl { .await .map_err(spawn_blocking_error)? .map_err(to_status)?; - let resp = p::CommitSummaryResponse { summary: Some(result.into()) }; + let resp = p::CommitSummaryResponse { + summary: Some(result.into()), + }; let _ = self.cache.set(&cache_key, &resp).await; Ok(Response::new(resp)) } @@ -159,7 +177,8 @@ impl p::commit_service_server::CommitService for CommitServiceImpl { let result = tokio::task::spawn_blocking(move || { let repo = bare.gix_repo()?; let head_id = repo.head_id()?.detach(); - let oid = crate::cmd::oid::ObjectId::new(head_id.to_hex().to_string()); + let oid = + crate::cmd::oid::ObjectId::new(head_id.to_hex().to_string()); bare.commit_refs(oid) }) .await @@ -219,7 +238,9 @@ impl p::commit_service_server::CommitService for CommitServiceImpl { .map_err(to_status)?; if let Ok(repo_uid) = Uuid::parse_str(&repo_id) { - self.sync.send(crate::sync::RepoReceiveSyncTask { repo_uid }).await; + self.sync + .send(crate::sync::RepoReceiveSyncTask { repo_uid }) + .await; } Ok(Response::new(p::CherryPickResponse { @@ -243,7 +264,9 @@ impl p::commit_service_server::CommitService for CommitServiceImpl { .map_err(to_status)?; if let Ok(repo_uid) = Uuid::parse_str(&repo_id) { - self.sync.send(crate::sync::RepoReceiveSyncTask { repo_uid }).await; + self.sync + .send(crate::sync::RepoReceiveSyncTask { repo_uid }) + .await; } Ok(Response::new(p::CherryPickSequenceResponse { @@ -274,14 +297,17 @@ impl p::commit_service_server::CommitService for CommitServiceImpl { }) .collect(), }; - let result = tokio::task::spawn_blocking(move || bare.commit_create(params)) - .await - .map_err(spawn_blocking_error)? - .map_err(to_status)?; + let result = + tokio::task::spawn_blocking(move || bare.commit_create(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; // Trigger sync after write if let Ok(repo_uid) = Uuid::parse_str(&repo_id) { - self.sync.send(crate::sync::RepoReceiveSyncTask { repo_uid }).await; + self.sync + .send(crate::sync::RepoReceiveSyncTask { repo_uid }) + .await; } Ok(Response::new(p::CreateCommitResponse { diff --git a/lib/git/rpc/diff.rs b/lib/git/rpc/diff.rs index 9141b3b..d22c92c 100644 --- a/lib/git/rpc/diff.rs +++ b/lib/git/rpc/diff.rs @@ -25,11 +25,18 @@ impl p::diff_service_server::DiffService for DiffServiceImpl { ) -> Result, Status> { let inner = req.into_inner(); let repo_id = inner.repo_id.clone(); - let old_str = inner.old_oid.clone().map(|o| o.value).unwrap_or_default(); - let new_str = inner.new_oid.clone().map(|o| o.value).unwrap_or_default(); - let cache_key = format!("git:rpc:cache:diff:stats:{}:{}:{}", repo_id, old_str, new_str); + let old_str = + inner.old_oid.clone().map(|o| o.value).unwrap_or_default(); + let new_str = + inner.new_oid.clone().map(|o| o.value).unwrap_or_default(); + let cache_key = format!( + "git:rpc:cache:diff:stats:{}:{}:{}", + repo_id, old_str, new_str + ); - if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + if let Ok(Some(cached)) = + self.cache.get::(&cache_key).await + { return Ok(Response::new(cached)); } @@ -45,8 +52,13 @@ impl p::diff_service_server::DiffService for DiffServiceImpl { .await .map_err(spawn_blocking_error)? .map_err(to_status)?; - let result = crate::cmd::diff::DiffResult { stats, deltas: vec![] }; - let resp = p::DiffStatsResponse { result: Some(result.into()) }; + let result = crate::cmd::diff::DiffResult { + stats, + deltas: vec![], + }; + let resp = p::DiffStatsResponse { + result: Some(result.into()), + }; let _ = self.cache.set(&cache_key, &resp).await; Ok(Response::new(resp)) } @@ -57,11 +69,18 @@ impl p::diff_service_server::DiffService for DiffServiceImpl { ) -> Result, Status> { let inner = req.into_inner(); let repo_id = inner.repo_id.clone(); - let old_str = inner.old_oid.clone().map(|o| o.value).unwrap_or_default(); - let new_str = inner.new_oid.clone().map(|o| o.value).unwrap_or_default(); - let cache_key = format!("git:rpc:cache:diff:patch:{}:{}:{}", repo_id, old_str, new_str); + let old_str = + inner.old_oid.clone().map(|o| o.value).unwrap_or_default(); + let new_str = + inner.new_oid.clone().map(|o| o.value).unwrap_or_default(); + let cache_key = format!( + "git:rpc:cache:diff:patch:{}:{}:{}", + repo_id, old_str, new_str + ); - if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + if let Ok(Some(cached)) = + self.cache.get::(&cache_key).await + { return Ok(Response::new(cached)); } @@ -77,7 +96,9 @@ impl p::diff_service_server::DiffService for DiffServiceImpl { .await .map_err(spawn_blocking_error)? .map_err(to_status)?; - let resp = p::DiffPatchResponse { result: Some(result.into()) }; + let resp = p::DiffPatchResponse { + result: Some(result.into()), + }; let _ = self.cache.set(&cache_key, &resp).await; Ok(Response::new(resp)) } diff --git a/lib/git/rpc/init.rs b/lib/git/rpc/init.rs index ac8b7a0..4df7e82 100644 --- a/lib/git/rpc/init.rs +++ b/lib/git/rpc/init.rs @@ -56,10 +56,12 @@ impl p::init_service_server::InitService for InitServiceImpl { let inner = req.into_inner(); let bare = self.registry.get(&inner.repo_id).await?; let branch_name = inner.branch_name; - tokio::task::spawn_blocking(move || bare.set_default_branch(&branch_name)) - .await - .map_err(spawn_blocking_error)? - .map_err(to_status)?; + tokio::task::spawn_blocking(move || { + bare.set_default_branch(&branch_name) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; Ok(Response::new(p::SetDefaultBranchResponse {})) } diff --git a/lib/git/rpc/server.rs b/lib/git/rpc/server.rs index 3dcbeb9..f48a79f 100644 --- a/lib/git/rpc/server.rs +++ b/lib/git/rpc/server.rs @@ -5,7 +5,6 @@ use db::database::AppDatabase; use tonic::transport::Server; use uuid::Uuid; -use crate::sync::ReceiveSyncService; use crate::rpc::{ archive::ArchiveServiceImpl, blame::BlameServiceImpl, @@ -33,6 +32,7 @@ use crate::rpc::{ tag::TagServiceImpl, tree::TreeServiceImpl, }; +use crate::sync::ReceiveSyncService; type RepoId = Uuid; @@ -44,7 +44,12 @@ pub struct GitServer { } impl GitServer { - pub fn new(addr: SocketAddr, db: AppDatabase, cache: AppCache, sync: ReceiveSyncService) -> Self { + pub fn new( + addr: SocketAddr, + db: AppDatabase, + cache: AppCache, + sync: ReceiveSyncService, + ) -> Self { Self { addr, cache: cache.clone(), @@ -127,7 +132,12 @@ pub struct GitServerBuilder { } impl GitServerBuilder { - pub fn new(addr: SocketAddr, db: AppDatabase, cache: AppCache, sync: ReceiveSyncService) -> Self { + pub fn new( + addr: SocketAddr, + db: AppDatabase, + cache: AppCache, + sync: ReceiveSyncService, + ) -> Self { Self { addr, db, diff --git a/lib/git/rpc/tree.rs b/lib/git/rpc/tree.rs index 2ab875f..e2e0ac2 100644 --- a/lib/git/rpc/tree.rs +++ b/lib/git/rpc/tree.rs @@ -31,12 +31,18 @@ impl p::tree_service_server::TreeService for TreeServiceImpl { let entries = tokio::task::spawn_blocking(move || { use crate::errors::GitError; bare.tree_entries(oid) - .map_err(|e| GitError::Internal(format!("tree_entries: {e}"))) - .map(|e| e.into_iter().map(Into::into).collect::>()) + .map_err(|e| { + GitError::Internal(format!("tree_entries: {e}")) + }) + .map(|e| { + e.into_iter() + .map(Into::into) + .collect::>() + }) }) - .await - .map_err(spawn_blocking_error)? - .map_err(to_status)?; + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; return Ok(Response::new(p::TreeEntriesResponse { entries })); } let cache_key = format!( @@ -44,8 +50,12 @@ impl p::tree_service_server::TreeService for TreeServiceImpl { repo_id, oid_val, base_path ); - if let Ok(Some(cached)) = self.cache.get::>(&cache_key).await { - return Ok(Response::new(p::TreeEntriesResponse { entries: cached })); + if let Ok(Some(cached)) = + self.cache.get::>(&cache_key).await + { + return Ok(Response::new(p::TreeEntriesResponse { + entries: cached, + })); } let bare = self.registry.get(&repo_id).await?; @@ -59,11 +69,13 @@ impl p::tree_service_server::TreeService for TreeServiceImpl { use crate::errors::GitError; bare.tree_entries(oid) .map_err(|e| GitError::Internal(format!("tree_entries: {e}"))) - .map(|e| e.into_iter().map(Into::into).collect::>()) + .map(|e| { + e.into_iter().map(Into::into).collect::>() + }) }) - .await - .map_err(spawn_blocking_error)? - .map_err(to_status)?; + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; let mut response = entries.clone(); for entry in &mut response { @@ -75,15 +87,19 @@ impl p::tree_service_server::TreeService for TreeServiceImpl { tokio::task::spawn(async move { let enriched = tokio::task::spawn_blocking(move || { - let paths: Vec = entries.iter().map(|e| { - if base_path_bg.is_empty() { - e.name.clone() - } else { - format!("{}/{}", base_path_bg, e.name) - } - }).collect(); + let paths: Vec = entries + .iter() + .map(|e| { + if base_path_bg.is_empty() { + e.name.clone() + } else { + format!("{}/{}", base_path_bg, e.name) + } + }) + .collect(); - let last_commits = match bare_bg.last_commits_for_paths(&paths) { + let last_commits = match bare_bg.last_commits_for_paths(&paths) + { Ok(lc) => lc, Err(_) => return entries, }; @@ -100,7 +116,8 @@ impl p::tree_service_server::TreeService for TreeServiceImpl { } } out - }).await; + }) + .await; if let Ok(entries) = enriched { let _ = cache_bg.set(&cache_key_bg, &entries).await; diff --git a/lib/git/ssh/mod.rs b/lib/git/ssh/mod.rs index 1e554f8..94b3b4f 100644 --- a/lib/git/ssh/mod.rs +++ b/lib/git/ssh/mod.rs @@ -60,8 +60,10 @@ impl SSHHandle { tracing::info!("Loading SSH private key from file: {}", key_file); - let private_key_pem = std::fs::read_to_string(&key_file) - .with_context(|| format!("Failed to read SSH private key file: {}", key_file))?; + let private_key_pem = + std::fs::read_to_string(&key_file).with_context(|| { + format!("Failed to read SSH private key file: {}", key_file) + })?; let private_key = russh::keys::decode_secret_key(&private_key_pem, None) .or_else(|e| { diff --git a/lib/git/sync/branch.rs b/lib/git/sync/branch.rs index dbf9c1b..68766c8 100644 --- a/lib/git/sync/branch.rs +++ b/lib/git/sync/branch.rs @@ -13,22 +13,29 @@ pub struct BranchTip { } pub fn collect_branch_tips(bare: &GitBare) -> Result, GitError> { let repo = bare.gix_repo()?; - let refs = repo.references() - .map_err(|e| GitError::Internal(format!("failed to open references: {}", e)))?; - let iter = refs.all() - .map_err(|e| GitError::Internal(format!("failed to iterate refs: {}", e)))?; + let refs = repo.references().map_err(|e| { + GitError::Internal(format!("failed to open references: {}", e)) + })?; + let iter = refs.all().map_err(|e| { + GitError::Internal(format!("failed to iterate refs: {}", e)) + })?; let mut branches = Vec::new(); for ref_result in iter { - let reference = ref_result - .map_err(|e| GitError::Internal(format!("ref iteration error: {}", e)))?; + let reference = ref_result.map_err(|e| { + GitError::Internal(format!("ref iteration error: {}", e)) + })?; let full_name = reference.name().as_bstr().to_string(); if !full_name.starts_with("refs/heads/") { continue; } - let target_oid = reference.target().try_id() + let target_oid = reference + .target() + .try_id() .map(|id| id.to_hex().to_string()) - .ok_or_else(|| GitError::Internal("ref has no direct target".to_string()))?; + .ok_or_else(|| { + GitError::Internal("ref has no direct target".to_string()) + })?; let shorthand = reference.name().shorten().to_string(); branches.push(BranchTip { name: full_name, diff --git a/lib/git/sync/commit.rs b/lib/git/sync/commit.rs index fde549f..b476101 100644 --- a/lib/git/sync/commit.rs +++ b/lib/git/sync/commit.rs @@ -17,32 +17,42 @@ pub async fn sync_commits( let existing_oids: Vec = sqlx::query_scalar::<_, String>( "SELECT sha FROM repo_commit WHERE repo = $1", ) - .bind(repo_id) - .fetch_all(pool) - .await - .map_err(|e| GitError::Internal(format!("failed to query commits: {}", e)))?; + .bind(repo_id) + .fetch_all(pool) + .await + .map_err(|e| { + GitError::Internal(format!("failed to query commits: {}", e)) + })?; let existing_set: HashSet = existing_oids.into_iter().collect(); - let head_id = repo.head_id() - .map_err(|e| GitError::Internal(format!("failed to resolve HEAD: {}", e)))? + let head_id = repo + .head_id() + .map_err(|e| { + GitError::Internal(format!("failed to resolve HEAD: {}", e)) + })? .detach(); let tips = { - let refs = repo.references() - .map_err(|e| GitError::Internal(format!("failed to open references: {}", e)))?; - let iter = refs.all() - .map_err(|e| GitError::Internal(format!("failed to iterate refs: {}", e)))?; + let refs = repo.references().map_err(|e| { + GitError::Internal(format!("failed to open references: {}", e)) + })?; + let iter = refs.all().map_err(|e| { + GitError::Internal(format!("failed to iterate refs: {}", e)) + })?; let mut tips = vec![head_id]; for ref_result in iter { - let reference = ref_result - .map_err(|e| GitError::Internal(format!("ref iteration error: {}", e)))?; + let reference = ref_result.map_err(|e| { + GitError::Internal(format!("ref iteration error: {}", e)) + })?; let name = reference.name().as_bstr().to_string(); if !name.starts_with("refs/heads/") { continue; } if let Some(target_id) = reference.target().try_id() { let hex = target_id.to_hex().to_string(); - if let Ok(gix_id) = gix::hash::ObjectId::from_hex(hex.as_bytes()) { + if let Ok(gix_id) = + gix::hash::ObjectId::from_hex(hex.as_bytes()) + { tips.push(gix_id); } } @@ -50,16 +60,20 @@ pub async fn sync_commits( tips }; - let platform = repo.rev_walk(tips) - .sorting(gix::revision::walk::Sorting::ByCommitTime( + let platform = repo.rev_walk(tips).sorting( + gix::revision::walk::Sorting::ByCommitTime( gix::traverse::commit::simple::CommitTimeOrder::NewestFirst, - )); - let walk = platform.all() + ), + ); + let walk = platform + .all() .map_err(|e| GitError::Internal(format!("rev_walk failed: {}", e)))?; let mut new_commits: Vec = Vec::new(); for info in walk { - let info = info.map_err(|e| GitError::Internal(format!("walk step error: {}", e)))?; + let info = info.map_err(|e| { + GitError::Internal(format!("walk step error: {}", e)) + })?; let hex = info.id().detach().to_hex().to_string(); if !existing_set.contains(&hex) { new_commits.push(info.id().detach()); @@ -80,7 +94,8 @@ pub async fn sync_commits( .map_err(|e| GitError::Internal(format!("failed to query repo_committer: {}", e)))?; for model in &existing_committers { - committer_map.insert(model.email.clone(), (model.id, model.name.clone())); + committer_map + .insert(model.email.clone(), (model.id, model.name.clone())); } let email_map = resolve_user_ids(db, &committer_map).await?; @@ -90,25 +105,48 @@ pub async fn sync_commits( for gix_id in &new_commits { let hex_oid = gix_id.to_hex().to_string(); let oid = ObjectId::new(&hex_oid); - let commit_meta = bare.commit_info(oid) - .map_err(|e| GitError::Internal(format!("commit_info failed for {}: {}", hex_oid, e)))?; + let commit_meta = bare.commit_info(oid).map_err(|e| { + GitError::Internal(format!( + "commit_info failed for {}: {}", + hex_oid, e + )) + })?; let author_committer_id = ensure_committer( - &mut committer_map, pool, repo_id, &commit_meta.author.email, - &commit_meta.author.name, &email_map, now, - ).await?; + &mut committer_map, + pool, + repo_id, + &commit_meta.author.email, + &commit_meta.author.name, + &email_map, + now, + ) + .await?; let committer_committer_id = ensure_committer( - &mut committer_map, pool, repo_id, &commit_meta.committer.email, - &commit_meta.committer.name, &email_map, now, - ).await?; + &mut committer_map, + pool, + repo_id, + &commit_meta.committer.email, + &commit_meta.committer.name, + &email_map, + now, + ) + .await?; - let parent_shas = commit_meta.parent_ids + let parent_shas = commit_meta + .parent_ids .iter() .map(|p| p.as_str()) .collect::>() .join("."); - let authored_at = git_time_to_datetime(commit_meta.author.time_secs, commit_meta.author.offset_minutes); - let committed_at = git_time_to_datetime(commit_meta.committer.time_secs, commit_meta.committer.offset_minutes); + let authored_at = git_time_to_datetime( + commit_meta.author.time_secs, + commit_meta.author.offset_minutes, + ); + let committed_at = git_time_to_datetime( + commit_meta.committer.time_secs, + commit_meta.committer.offset_minutes, + ); let new_id = Uuid::new_v4(); sqlx::query( diff --git a/lib/git/sync/consumer.rs b/lib/git/sync/consumer.rs index 66dd820..6b95373 100644 --- a/lib/git/sync/consumer.rs +++ b/lib/git/sync/consumer.rs @@ -75,10 +75,7 @@ impl SyncConsumer { if result > 0 { Some(()) } else { None } } - pub(crate) fn queue_key_for_task_type( - &self, - task_type: &TaskType, - ) -> String { + pub fn queue_key_for_task_type(&self, task_type: &TaskType) -> String { let prefix = &self.service.redis_prefix; match task_type { TaskType::Sync => format!("{prefix}:sync"), diff --git a/lib/git/sync/language.rs b/lib/git/sync/language.rs index f82e73d..1bf5a5b 100644 --- a/lib/git/sync/language.rs +++ b/lib/git/sync/language.rs @@ -66,14 +66,19 @@ fn language_from_filename(name: &str) -> Option<&str> { _ => None, } } -fn collect_language_stats(bare: &GitBare) -> Result, GitError> { +fn collect_language_stats( + bare: &GitBare, +) -> Result, GitError> { let repo = bare.gix_repo()?; - let head_id = repo.head_id() - .map_err(|e| GitError::Internal(format!("failed to resolve HEAD: {}", e)))?; - let commit = repo.find_commit(head_id.detach()) - .map_err(|e| GitError::Internal(format!("failed to find HEAD commit: {}", e)))?; - let decoded = commit.decode() - .map_err(|e| GitError::Internal(format!("failed to decode commit: {}", e)))?; + let head_id = repo.head_id().map_err(|e| { + GitError::Internal(format!("failed to resolve HEAD: {}", e)) + })?; + let commit = repo.find_commit(head_id.detach()).map_err(|e| { + GitError::Internal(format!("failed to find HEAD commit: {}", e)) + })?; + let decoded = commit.decode().map_err(|e| { + GitError::Internal(format!("failed to decode commit: {}", e)) + })?; let tree_oid = ObjectId::new(decoded.tree().to_hex().to_string()); let mut stats: HashMap = HashMap::new(); @@ -99,11 +104,10 @@ fn walk_tree( continue; } - let language = language_from_filename(&entry.name) - .or_else(|| { - let ext = entry.name.rsplit('.').next().unwrap_or(""); - language_from_extension(ext) - }); + let language = language_from_filename(&entry.name).or_else(|| { + let ext = entry.name.rsplit('.').next().unwrap_or(""); + language_from_extension(ext) + }); if let Some(lang) = language { let size = blob_size(bare, &entry.oid)?; @@ -114,10 +118,12 @@ fn walk_tree( } fn blob_size(bare: &GitBare, oid: &ObjectId) -> Result { let repo = bare.gix_repo()?; - let gix_id: gix::hash::ObjectId = oid.try_into() + let gix_id: gix::hash::ObjectId = oid + .try_into() .map_err(|e| GitError::Internal(format!("invalid oid: {}", e)))?; - let header = repo.find_header(gix_id) - .map_err(|e| GitError::Internal(format!("blob header not found: {}", e)))?; + let header = repo.find_header(gix_id).map_err(|e| { + GitError::Internal(format!("blob header not found: {}", e)) + })?; Ok(header.size() as u64) } pub async fn sync_languages( @@ -133,15 +139,17 @@ pub async fn sync_languages( let total_bytes: u64 = stats.values().sum(); let pool = db.writer(); - let mut tx = pool.begin() - .await - .map_err(|e| GitError::Internal(format!("failed to begin tx: {}", e)))?; + let mut tx = pool.begin().await.map_err(|e| { + GitError::Internal(format!("failed to begin tx: {}", e)) + })?; sqlx::query("DELETE FROM repo_language WHERE repo = $1") .bind(repo_id) .execute(&mut *tx) .await - .map_err(|e| GitError::Internal(format!("failed to delete repo_language: {}", e)))?; + .map_err(|e| { + GitError::Internal(format!("failed to delete repo_language: {}", e)) + })?; for (language, bytes) in &stats { let percentage = if total_bytes > 0 { @@ -161,9 +169,9 @@ pub async fn sync_languages( .map_err(|e| GitError::Internal(format!("failed to insert repo_language: {}", e)))?; } - tx.commit() - .await - .map_err(|e| GitError::Internal(format!("failed to commit tx: {}", e)))?; + tx.commit().await.map_err(|e| { + GitError::Internal(format!("failed to commit tx: {}", e)) + })?; tracing::info!( repo_id = %repo_id, diff --git a/lib/git/sync/mod.rs b/lib/git/sync/mod.rs index 5906924..5b8e23c 100644 --- a/lib/git/sync/mod.rs +++ b/lib/git/sync/mod.rs @@ -56,7 +56,7 @@ impl ReceiveSyncService { Some((work_items.len() + queued_before + 1, total)) } - pub(crate) fn push_queue_keys(repo_uid: uuid::Uuid) -> (String, String) { + pub fn push_queue_keys(repo_uid: uuid::Uuid) -> (String, String) { let hash_tag = format!("{{push:{}}}", repo_uid); ( format!("git:{}:queue", hash_tag), diff --git a/lib/git/sync/tag.rs b/lib/git/sync/tag.rs index f69ea29..aae5443 100644 --- a/lib/git/sync/tag.rs +++ b/lib/git/sync/tag.rs @@ -12,22 +12,29 @@ pub struct TagTip { } pub fn collect_tag_tips(bare: &GitBare) -> Result, GitError> { let repo = bare.gix_repo()?; - let refs = repo.references() - .map_err(|e| GitError::Internal(format!("failed to open references: {}", e)))?; - let iter = refs.all() - .map_err(|e| GitError::Internal(format!("failed to iterate refs: {}", e)))?; + let refs = repo.references().map_err(|e| { + GitError::Internal(format!("failed to open references: {}", e)) + })?; + let iter = refs.all().map_err(|e| { + GitError::Internal(format!("failed to iterate refs: {}", e)) + })?; let mut tags = Vec::new(); for ref_result in iter { - let reference = ref_result - .map_err(|e| GitError::Internal(format!("ref iteration error: {}", e)))?; + let reference = ref_result.map_err(|e| { + GitError::Internal(format!("ref iteration error: {}", e)) + })?; let full_name = reference.name().as_bstr().to_string(); if !full_name.starts_with("refs/tags/") { continue; } - let target_oid = reference.target().try_id() + let target_oid = reference + .target() + .try_id() .map(|id| id.to_hex().to_string()) - .ok_or_else(|| GitError::Internal("ref has no direct target".to_string()))?; + .ok_or_else(|| { + GitError::Internal("ref has no direct target".to_string()) + })?; let short_name = reference.name().shorten().to_string(); tags.push(TagTip { name: short_name, diff --git a/lib/git/sync/worker.rs b/lib/git/sync/worker.rs index 2eea2cc..60ed476 100644 --- a/lib/git/sync/worker.rs +++ b/lib/git/sync/worker.rs @@ -243,7 +243,8 @@ impl SyncWorker { } if let Err(e) = - crate::sync::language::sync_languages(&self.db, &bare, repo_id).await + crate::sync::language::sync_languages(&self.db, &bare, repo_id) + .await { tracing::error!(error = %e, repo_id = %repo_id, "sync_languages failed"); } diff --git a/lib/migrate/sql/room/dm_conversation_down_01.sql b/lib/migrate/sql/room/dm_conversation_down_01.sql deleted file mode 100644 index 30e712e..0000000 --- a/lib/migrate/sql/room/dm_conversation_down_01.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE IF EXISTS dm_conversation; diff --git a/lib/migrate/sql/room/dm_conversation_up_01.sql b/lib/migrate/sql/room/dm_conversation_up_01.sql deleted file mode 100644 index 47e3551..0000000 --- a/lib/migrate/sql/room/dm_conversation_up_01.sql +++ /dev/null @@ -1,17 +0,0 @@ --- depends_on: room -CREATE TABLE IF NOT EXISTS dm_conversation ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - room UUID NOT NULL REFERENCES room(id) ON DELETE CASCADE, - initiator UUID NOT NULL, - recipient UUID NOT NULL, - is_closed BOOLEAN NOT NULL DEFAULT FALSE, - closed_at TIMESTAMPTZ, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - UNIQUE (initiator, recipient), - CHECK (initiator < recipient) -); - -CREATE INDEX IF NOT EXISTS idx_dm_conversation_room ON dm_conversation (room); -CREATE INDEX IF NOT EXISTS idx_dm_conversation_initiator ON dm_conversation (initiator); -CREATE INDEX IF NOT EXISTS idx_dm_conversation_recipient ON dm_conversation (recipient); diff --git a/lib/migrate/sql/room/room_ai_down_01.sql b/lib/migrate/sql/room/room_ai_down_01.sql deleted file mode 100644 index 2ecc847..0000000 --- a/lib/migrate/sql/room/room_ai_down_01.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE IF EXISTS room_ai CASCADE; diff --git a/lib/migrate/sql/room/room_ai_up_01.sql b/lib/migrate/sql/room/room_ai_up_01.sql deleted file mode 100644 index 90f2704..0000000 --- a/lib/migrate/sql/room/room_ai_up_01.sql +++ /dev/null @@ -1,10 +0,0 @@ -CREATE TABLE IF NOT EXISTS room_ai ( - room UUID NOT NULL REFERENCES room(id), - agent_session UUID NOT NULL, - enabled BOOLEAN NOT NULL DEFAULT FALSE, - auto_reply BOOLEAN NOT NULL DEFAULT FALSE, - created_by UUID NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - PRIMARY KEY (room, agent_session) -); diff --git a/lib/migrate/sql/room/room_down_02.sql b/lib/migrate/sql/room/room_down_02.sql new file mode 100644 index 0000000..c72bc6a --- /dev/null +++ b/lib/migrate/sql/room/room_down_02.sql @@ -0,0 +1 @@ +ALTER TABLE room DROP COLUMN IF EXISTS ai_enabled; diff --git a/lib/migrate/sql/room/room_up_02.sql b/lib/migrate/sql/room/room_up_02.sql new file mode 100644 index 0000000..9a28030 --- /dev/null +++ b/lib/migrate/sql/room/room_up_02.sql @@ -0,0 +1 @@ +ALTER TABLE room ADD COLUMN IF NOT EXISTS ai_enabled BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/lib/migrate/src/main.rs b/lib/migrate/src/main.rs index ad2867b..82dca78 100644 --- a/lib/migrate/src/main.rs +++ b/lib/migrate/src/main.rs @@ -35,8 +35,12 @@ struct Migration { impl Ord for Migration { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - (&self.domain, &self.table, self.version, &self.direction) - .cmp(&(&other.domain, &other.table, other.version, &other.direction)) + (&self.domain, &self.table, self.version, &self.direction).cmp(&( + &other.domain, + &other.table, + other.version, + &other.direction, + )) } } @@ -222,13 +226,12 @@ fn parse_depends_on(content: &str) -> Vec { .lines() .filter_map(|line| { let line = line.trim(); - line.strip_prefix("-- depends_on:") - .map(|deps| { - deps.split(',') - .map(|d| d.trim().to_string()) - .filter(|d| !d.is_empty()) - .collect::>() - }) + line.strip_prefix("-- depends_on:").map(|deps| { + deps.split(',') + .map(|d| d.trim().to_string()) + .filter(|d| !d.is_empty()) + .collect::>() + }) }) .flatten() .collect() @@ -254,9 +257,8 @@ fn topo_sort(migrations: &mut [Migration]) -> Result<()> { } } - let mut queue: VecDeque = (0..n) - .filter(|&i| in_degree[i] == 0) - .collect(); + let mut queue: VecDeque = + (0..n).filter(|&i| in_degree[i] == 0).collect(); let mut order = Vec::with_capacity(n); while let Some(i) = queue.pop_front() { diff --git a/lib/model/repos/mod.rs b/lib/model/repos/mod.rs index f142237..d54992f 100644 --- a/lib/model/repos/mod.rs +++ b/lib/model/repos/mod.rs @@ -25,6 +25,7 @@ pub mod repo_webhook_delivery; pub use repo::RepoModel; pub use repo_audit_log::RepoAuditLogModel; pub use repo_commit::RepoCommitModel; +pub use repo_commit_status::RepoCommitStatusModel; pub use repo_committer::RepoCommitterModel; pub use repo_deploy_key::RepoDeployKeyModel; pub use repo_fork::RepoForkModel; @@ -38,7 +39,6 @@ pub use repo_protect::RepoProtectModel; pub use repo_ref::RepoRefModel; pub use repo_release::RepoReleaseModel; pub use repo_release_asset::RepoReleaseAssetModel; -pub use repo_commit_status::RepoCommitStatusModel; pub use repo_secret::RepoSecretModel; pub use repo_star::RepoStarModel; pub use repo_topic::RepoTopicModel; diff --git a/lib/model/room/dm_conversation.rs b/lib/model/room/dm_conversation.rs deleted file mode 100644 index 2e37242..0000000 --- a/lib/model/room/dm_conversation.rs +++ /dev/null @@ -1,15 +0,0 @@ -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use sqlx::FromRow; -use uuid::Uuid; -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] -pub struct DmConversationModel { - pub id: Uuid, - pub room: Uuid, - pub initiator: Uuid, - pub recipient: Uuid, - pub is_closed: bool, - pub closed_at: Option>, - pub created_at: DateTime, - pub updated_at: DateTime, -} diff --git a/lib/model/room/mod.rs b/lib/model/room/mod.rs index 2b511c8..6312b96 100644 --- a/lib/model/room/mod.rs +++ b/lib/model/room/mod.rs @@ -1,7 +1,9 @@ +pub mod message_read; +pub mod message_star; pub mod room; -pub mod room_ai; pub mod room_attachments; pub mod room_categories; +pub mod room_mention; pub mod room_message; pub mod room_message_edit_history; pub mod room_permission_overwrite; @@ -9,14 +11,11 @@ pub mod room_pins; pub mod room_reactions; pub mod room_server_label; pub mod room_threads; -pub mod room_mention; pub mod user_room_state; -pub mod dm_conversation; -pub mod message_read; -pub mod message_star; +pub use message_read::MessageReadModel; +pub use message_star::MessageStarModel; pub use room::RoomModel; -pub use room_ai::RoomAiModel; pub use room_attachments::RoomAttachmentModel; pub use room_categories::RoomCategoryModel; pub use room_mention::RoomMentionModel; @@ -28,6 +27,3 @@ pub use room_reactions::RoomReactionModel; pub use room_server_label::RoomServerLabelModel; pub use room_threads::RoomThreadModel; pub use user_room_state::UserRoomStateModel; -pub use dm_conversation::DmConversationModel; -pub use message_read::MessageReadModel; -pub use message_star::MessageStarModel; diff --git a/lib/model/room/room.rs b/lib/model/room/room.rs index 8d629b5..17e62ad 100644 --- a/lib/model/room/room.rs +++ b/lib/model/room/room.rs @@ -14,6 +14,7 @@ pub struct RoomModel { pub position: i32, pub is_private: bool, pub is_archived: bool, + pub ai_enabled: bool, pub created_by: Uuid, pub created_at: DateTime, pub updated_at: DateTime, diff --git a/lib/model/room/room_ai.rs b/lib/model/room/room_ai.rs deleted file mode 100644 index 69eea18..0000000 --- a/lib/model/room/room_ai.rs +++ /dev/null @@ -1,15 +0,0 @@ -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use sqlx::FromRow; -use uuid::Uuid; - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] -pub struct RoomAiModel { - pub room: Uuid, - pub agent_session: Uuid, - pub enabled: bool, - pub auto_reply: bool, - pub created_by: Uuid, - pub created_at: DateTime, - pub updated_at: DateTime, -} diff --git a/lib/queue/producer.rs b/lib/queue/producer.rs index 6e770b9..1d65d27 100644 --- a/lib/queue/producer.rs +++ b/lib/queue/producer.rs @@ -60,7 +60,7 @@ impl NatsProducer { } } -pub(crate) async fn connect_jetstream( +pub async fn connect_jetstream( config: &AppConfig, ) -> anyhow::Result { let client = match config.nats_token() { @@ -75,7 +75,7 @@ pub(crate) async fn connect_jetstream( Ok(jetstream::new(client)) } -pub(crate) async fn ensure_stream( +pub async fn ensure_stream( config: &AppConfig, jetstream: &jetstream::Context, ) -> anyhow::Result { diff --git a/lib/service/agent/billing.rs b/lib/service/agent/billing.rs index cb7fc2c..3a3b30d 100644 --- a/lib/service/agent/billing.rs +++ b/lib/service/agent/billing.rs @@ -4,11 +4,11 @@ use rust_decimal::Decimal as RustDecimal; use uuid::Uuid; use super::types::{BillingRecord, BillingTarget, SessionContext}; -use crate::error::AppError; use crate::AppService; +use crate::error::AppError; impl AppService { - pub(crate) async fn agent_calculate_cost( + pub async fn agent_calculate_cost( &self, model_version_id: Uuid, input_tokens: i64, @@ -35,27 +35,32 @@ impl AppService { + (output_decimal * output_price / million); Ok(Some((cost, "USD".to_string()))) } - pub(crate) async fn agent_deduct_billing( + pub async fn agent_deduct_billing( &self, ctx: &SessionContext, cost: RustDecimal, ) -> Result<(), AppError> { match ctx.billing_target { BillingTarget::User => { - let user_id = ctx - .user_id - .ok_or_else(|| AppError::BadRequest("user billing target requires user_id".to_string()))?; + let user_id = ctx.user_id.ok_or_else(|| { + AppError::BadRequest( + "user billing target requires user_id".to_string(), + ) + })?; self.deduct_user_balance(user_id, cost).await } BillingTarget::Workspace => { - let wk_id = ctx - .workspace_id - .ok_or_else(|| AppError::BadRequest("workspace billing target requires workspace_id".to_string()))?; + let wk_id = ctx.workspace_id.ok_or_else(|| { + AppError::BadRequest( + "workspace billing target requires workspace_id" + .to_string(), + ) + })?; self.deduct_workspace_balance(wk_id, cost).await } } } - pub(crate) async fn agent_record_usage( + pub async fn agent_record_usage( &self, record: &BillingRecord, ) -> Result<(), AppError> { @@ -88,7 +93,7 @@ impl AppService { Ok(()) } - pub(crate) async fn agent_record_invocation( + pub async fn agent_record_invocation( &self, invocation_id: Uuid, session_id: Uuid, @@ -120,7 +125,7 @@ impl AppService { Ok(()) } - pub(crate) async fn agent_record_tool_call( + pub async fn agent_record_tool_call( &self, invocation_id: Uuid, session_id: Uuid, @@ -175,7 +180,10 @@ impl AppService { Ok(()) => return Ok(()), Err(AppError::TxnError) if attempt < MAX_RETRIES - 1 => { let backoff_ms = 10 * (1 << attempt); - tokio::time::sleep(tokio::time::Duration::from_millis(backoff_ms)).await; + tokio::time::sleep(tokio::time::Duration::from_millis( + backoff_ms, + )) + .await; continue; } Err(e) => return Err(e), @@ -250,7 +258,10 @@ impl AppService { Ok(()) => return Ok(()), Err(AppError::TxnError) if attempt < MAX_RETRIES - 1 => { let backoff_ms = 10 * (1 << attempt); - tokio::time::sleep(tokio::time::Duration::from_millis(backoff_ms)).await; + tokio::time::sleep(tokio::time::Duration::from_millis( + backoff_ms, + )) + .await; continue; } Err(e) => return Err(e), diff --git a/lib/service/agent/compaction.rs b/lib/service/agent/compaction.rs index b63bb00..8ca9112 100644 --- a/lib/service/agent/compaction.rs +++ b/lib/service/agent/compaction.rs @@ -1,13 +1,13 @@ use ai::agent::AgentConfig; use ai::agent::RigAgent; -use ai::client::AiClient; use ai::agent::request::AgentRequest; +use ai::client::AiClient; use db::sqlx; use tracing::{info, warn}; use uuid::Uuid; -use crate::error::AppError; use crate::AppService; +use crate::error::AppError; const COMPACTION_SYSTEM_PROMPT: &str = r#"You are a conversation context compaction assistant. @@ -25,7 +25,7 @@ const COMPACTION_TRIGGER_CHARS: usize = 80_000; const RECENT_MESSAGES_TO_KEEP: usize = 10; impl AppService { - pub(crate) async fn agent_maybe_compact( + pub async fn agent_maybe_compact( &self, ai_client: &AiClient, model_name: &str, @@ -71,7 +71,9 @@ impl AppService { body.push_str("\n"); body.push_str(prev); body.push_str("\n\n\n"); - body.push_str("Merge the previous summary with the new messages below:\n\n"); + body.push_str( + "Merge the previous summary with the new messages below:\n\n", + ); } for (_, role, content) in older { body.push_str(&format!("[{}]: {}\n\n", role, content)); diff --git a/lib/service/agent/config.rs b/lib/service/agent/config.rs index 0420280..a44b97c 100644 --- a/lib/service/agent/config.rs +++ b/lib/service/agent/config.rs @@ -9,11 +9,11 @@ use model::{ use uuid::Uuid; use super::types::SessionContext; -use crate::error::AppError; use crate::AppService; +use crate::error::AppError; impl AppService { - pub(crate) async fn agent_session_context( + pub async fn agent_session_context( &self, session_id: Uuid, user_id: Uuid, @@ -35,16 +35,16 @@ impl AppService { .ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?; if let Some(wk_id) = session.wk { - let _ = self - .workspace_require_member(wk_id, user_id) - .await?; + let _ = self.workspace_require_member(wk_id, user_id).await?; } else if Some(user_id) != session.user { return Err(AppError::PermissionDenied); } - let model_version_id = session - .model_version - .ok_or_else(|| AppError::BadRequest("agent session has no model_version".to_string()))?; + let model_version_id = session.model_version.ok_or_else(|| { + AppError::BadRequest( + "agent session has no model_version".to_string(), + ) + })?; let version = self.resolve_model_version(model_version_id).await?; @@ -58,10 +58,9 @@ impl AppService { session_id, user_id: session.user, workspace_id: session.wk, - system_prompt: self.build_system_prompt_with_context( - &session, - user_id, - ).await, + system_prompt: self + .build_system_prompt_with_context(&session, user_id) + .await, model_version_id: version.id, provider_model_name: version.provider_model_name, temperature: session.temperature, @@ -76,7 +75,7 @@ impl AppService { billing_target, }) } - pub(crate) async fn agent_build_ai_client( + pub async fn agent_build_ai_client( &self, model_version_id: Uuid, ) -> Result { @@ -107,24 +106,24 @@ impl AppService { let base_url = provider .base_url .unwrap_or_else(|| self.config.ai_basic_url().unwrap_or_default()); - let api_key = self - .config - .ai_api_key() - .map_err(|e| AppError::InternalServerError(format!("AI API key: {e}")))?; + let api_key = self.config.ai_api_key().map_err(|e| { + AppError::InternalServerError(format!("AI API key: {e}")) + })?; - let embed_base_url = self - .config - .get_embed_model_base_url() - .map_err(|e| AppError::InternalServerError(format!("embed base url: {e}")))?; - let embed_api_key = self - .config - .get_embed_model_api_key() - .map_err(|e| AppError::InternalServerError(format!("embed api key: {e}")))?; + let embed_base_url = + self.config.get_embed_model_base_url().map_err(|e| { + AppError::InternalServerError(format!("embed base url: {e}")) + })?; + let embed_api_key = + self.config.get_embed_model_api_key().map_err(|e| { + AppError::InternalServerError(format!("embed api key: {e}")) + })?; let llm_config = EndpointConfig::new(&base_url, &api_key) .map_err(|e| AppError::InternalServerError(e.to_string()))?; - let embed_endpoint = EndpointConfig::new(&embed_base_url, &embed_api_key) - .map_err(|e| AppError::InternalServerError(e.to_string()))?; + let embed_endpoint = + EndpointConfig::new(&embed_base_url, &embed_api_key) + .map_err(|e| AppError::InternalServerError(e.to_string()))?; let embed_config = EmbedConfig::new( embed_endpoint, self.config @@ -139,9 +138,10 @@ impl AppService { let client_config = AiClientConfig::new(llm_config, embed_config) .map_err(|e| AppError::InternalServerError(e.to_string()))?; - AiClient::new(client_config).map_err(|e| AppError::InternalServerError(e.to_string())) + AiClient::new(client_config) + .map_err(|e| AppError::InternalServerError(e.to_string())) } - pub(crate) fn agent_build_config( + pub fn agent_build_config( &self, ctx: &SessionContext, max_steps_override: Option, @@ -154,8 +154,9 @@ impl AppService { config.model = ctx.provider_model_name.clone(); config.system_prompt = ctx.system_prompt.clone(); if let Some(ref vars_json) = ctx.variables_json { - if let Ok(vars) = - serde_json::from_str::>(vars_json) + if let Ok(vars) = serde_json::from_str::< + serde_json::Map, + >(vars_json) { if !vars.is_empty() { let mut prompt = config.system_prompt.clone(); @@ -178,7 +179,8 @@ impl AppService { serde_json::Value::String(s) => s.clone(), other => other.to_string(), }; - prompt.push_str(&format!("- {}: {}\n", key, val_str)); + prompt + .push_str(&format!("- {}: {}\n", key, val_str)); } prompt.push_str(""); } @@ -302,7 +304,7 @@ impl AppService { Ok(version) } - pub(crate) async fn agent_resolve_pricing( + pub async fn agent_resolve_pricing( &self, model_version_id: Uuid, ) -> Result<(Option, Option), AppError> { @@ -330,21 +332,20 @@ impl AppService { session: &model::agent::AgentSessionModel, user_id: Uuid, ) -> String { - let base = session - .system_prompt - .clone() - .unwrap_or_else(|| ai::agent::config::default_system_prompt().to_string()); + let base = session.system_prompt.clone().unwrap_or_else(|| { + ai::agent::config::default_system_prompt().to_string() + }); let mut context_section = String::new(); // Workspace context if let Some(wk_id) = session.wk { - let wk: Option<(String,)> = sqlx::query_as( - "SELECT name FROM workspace WHERE id = $1") - .bind(wk_id) - .fetch_optional(self.db.reader()) - .await - .unwrap_or(None); + let wk: Option<(String,)> = + sqlx::query_as("SELECT name FROM workspace WHERE id = $1") + .bind(wk_id) + .fetch_optional(self.db.reader()) + .await + .unwrap_or(None); if let Some((wk_name,)) = wk { context_section.push_str(&format!( "- You are operating in workspace \"{wk_name}\" (id: {wk_id}).\n" @@ -358,26 +359,36 @@ impl AppService { // User context if let Some(session_user_id) = session.user { let u: Option<(String, String)> = sqlx::query_as( - "SELECT display_name, username FROM \"user\" WHERE id = $1") - .bind(session_user_id) - .fetch_optional(self.db.reader()) - .await - .unwrap_or(None); + "SELECT display_name, username FROM \"user\" WHERE id = $1", + ) + .bind(session_user_id) + .fetch_optional(self.db.reader()) + .await + .unwrap_or(None); if let Some((display_name, username)) = u { - let name = if display_name.is_empty() { &username } else { &display_name }; + let name = if display_name.is_empty() { + &username + } else { + &display_name + }; context_section.push_str(&format!( "- The current user is {name} (username: {username}, id: {session_user_id}).\n" )); } } else { let u: Option<(String, String)> = sqlx::query_as( - "SELECT display_name, username FROM \"user\" WHERE id = $1") - .bind(user_id) - .fetch_optional(self.db.reader()) - .await - .unwrap_or(None); + "SELECT display_name, username FROM \"user\" WHERE id = $1", + ) + .bind(user_id) + .fetch_optional(self.db.reader()) + .await + .unwrap_or(None); if let Some((display_name, username)) = u { - let name = if display_name.is_empty() { &username } else { &display_name }; + let name = if display_name.is_empty() { + &username + } else { + &display_name + }; context_section.push_str(&format!( "- The current user is {name} (username: {username}, id: {user_id}).\n" )); diff --git a/lib/service/agent/context.rs b/lib/service/agent/context.rs index 86d7b32..ab48aff 100644 --- a/lib/service/agent/context.rs +++ b/lib/service/agent/context.rs @@ -1,27 +1,23 @@ use std::time::Duration; use ai::{ - agent::request::{ - AgentContextChunk, AgentMessage, AgentRequest, - }, + agent::request::{AgentContextChunk, AgentMessage, AgentRequest}, client::AiClient, - rag::{ - RagClient, RagConfig, RagDocument, - }, + rag::{RagClient, RagConfig, RagDocument}, }; use db::sqlx; use model::repos::RepoModel; use uuid::Uuid; use super::types::SessionContext; -use crate::error::AppError; use crate::AppService; +use crate::error::AppError; const MAX_HISTORY_MESSAGES: u32 = 50; const MAX_HISTORY_CHARS: usize = 500_000; const MAX_HISTORY_ESTIMATED_TOKENS: u64 = 64_000; impl AppService { - pub(crate) async fn agent_build_request( + pub async fn agent_build_request( &self, ai_client: &AiClient, ctx: &SessionContext, @@ -51,9 +47,8 @@ impl AppService { ))); } - let messages = self - .agent_load_conversation_messages(conv_id) - .await?; + let messages = + self.agent_load_conversation_messages(conv_id).await?; all_messages.extend(messages); request = request.with_messages(all_messages); @@ -61,7 +56,8 @@ impl AppService { let kb_context = self .agent_load_knowledge_base(ai_client, ctx, &input) .await?; - let (memories_text, _memory_rows) = self.agent_load_memories(ctx.session_id).await?; + let (memories_text, _memory_rows) = + self.agent_load_memories(ctx.session_id).await?; let mut all_context = kb_context; if !memories_text.is_empty() { @@ -89,7 +85,7 @@ impl AppService { Ok(request) } - pub(crate) async fn agent_load_conversation_messages( + pub async fn agent_load_conversation_messages( &self, conversation_id: Uuid, ) -> Result, AppError> { @@ -146,7 +142,9 @@ impl AppService { }) .sum(); let mut trimmed_for_tokens = 0usize; - while estimated_tokens > MAX_HISTORY_ESTIMATED_TOKENS && !result.is_empty() { + while estimated_tokens > MAX_HISTORY_ESTIMATED_TOKENS + && !result.is_empty() + { let removed = result.remove(0); estimated_tokens -= match &removed { AgentMessage::User(c) | AgentMessage::Assistant(c) => { @@ -165,7 +163,7 @@ impl AppService { Ok(result) } - pub(crate) async fn agent_load_knowledge_base( + pub async fn agent_load_knowledge_base( &self, ai_client: &AiClient, ctx: &SessionContext, @@ -203,13 +201,12 @@ impl AppService { .get_embed_model_dimensions() .map_err(|e| AppError::InternalServerError(e.to_string()))?; - let rag_config = RagConfig::new(qdrant_url, "agent_knowledge", vector_size) - .map_err(|e| AppError::InternalServerError(e.to_string()))? - .with_api_key( - self.config - .qdrant_api_key() - .map_err(|e| AppError::InternalServerError(e.to_string()))?, - ); + let rag_config = + RagConfig::new(qdrant_url, "agent_knowledge", vector_size) + .map_err(|e| AppError::InternalServerError(e.to_string()))? + .with_api_key(self.config.qdrant_api_key().map_err(|e| { + AppError::InternalServerError(e.to_string()) + })?); let rag = RagClient::connect(ai_client, rag_config) .map_err(|e| AppError::InternalServerError(e.to_string()))?; @@ -219,13 +216,15 @@ impl AppService { match rag.search_session(&session_key, query).await { Ok(hits) => { for hit in hits { - all_hits.push(AgentContextChunk::from(ai::rag::RagSearchHit { - id: hit.id, - session_id: hit.session_id, - score: hit.score, - content: hit.content, - metadata: hit.metadata, - })); + all_hits.push(AgentContextChunk::from( + ai::rag::RagSearchHit { + id: hit.id, + session_id: hit.session_id, + score: hit.score, + content: hit.content, + metadata: hit.metadata, + }, + )); } } Err(e) => { @@ -242,7 +241,7 @@ impl AppService { } /// Parse `@[repo:name:label]` mentions from the input and resolve them to /// AgentContextChunks with repo metadata from the database. - pub(crate) async fn agent_resolve_mentioned_repos( + pub async fn agent_resolve_mentioned_repos( &self, workspace_id: Uuid, input: &str, @@ -307,7 +306,7 @@ impl AppService { } #[allow(dead_code)] - pub(crate) async fn agent_upsert_knowledge( + pub async fn agent_upsert_knowledge( &self, ai_client: &AiClient, kb_id: Uuid, @@ -322,13 +321,12 @@ impl AppService { .get_embed_model_dimensions() .map_err(|e| AppError::InternalServerError(e.to_string()))?; - let rag_config = RagConfig::new(qdrant_url, "agent_knowledge", vector_size) - .map_err(|e| AppError::InternalServerError(e.to_string()))? - .with_api_key( - self.config - .qdrant_api_key() - .map_err(|e| AppError::InternalServerError(e.to_string()))?, - ); + let rag_config = + RagConfig::new(qdrant_url, "agent_knowledge", vector_size) + .map_err(|e| AppError::InternalServerError(e.to_string()))? + .with_api_key(self.config.qdrant_api_key().map_err(|e| { + AppError::InternalServerError(e.to_string()) + })?); let rag = RagClient::connect(ai_client, rag_config) .map_err(|e| AppError::InternalServerError(e.to_string()))?; @@ -393,10 +391,7 @@ fn extract_repo_mentions(input: &str) -> Vec { /// Format a RepoModel into a concise context string for the AI. fn format_repo_context(repo: &RepoModel) -> String { - let mut s = format!( - "Repository: {} (id: {})\n", - repo.name, repo.id - ); + let mut s = format!("Repository: {} (id: {})\n", repo.name, repo.id); if let Some(ref desc) = repo.description { if !desc.trim().is_empty() { s.push_str(&format!("Description: {}\n", desc.trim())); @@ -430,7 +425,8 @@ mod tests { #[test] fn test_extract_repo_mentions_multiple() { - let input = "compare @[repo:backend:backend] with @[repo:frontend:frontend]"; + let input = + "compare @[repo:backend:backend] with @[repo:frontend:frontend]"; let names = extract_repo_mentions(input); assert_eq!(names, vec!["backend", "frontend"]); } diff --git a/lib/service/agent/conversation.rs b/lib/service/agent/conversation.rs index 7c5fdd8..ce402e8 100644 --- a/lib/service/agent/conversation.rs +++ b/lib/service/agent/conversation.rs @@ -5,8 +5,8 @@ use serde::{Deserialize, Serialize}; use utoipa::ToSchema; use uuid::Uuid; -use crate::error::AppError; use crate::AppService; +use crate::error::AppError; #[derive(Debug, Clone, Deserialize, ToSchema)] pub struct CreateConversation { @@ -46,8 +46,15 @@ impl From for ToolCallResponse { Self { id: m.tool_call_id.unwrap_or_default(), name: m.tool_name, - arguments: m.arguments.as_deref().and_then(|s| serde_json::from_str(s).ok()).unwrap_or_default(), - output: m.result.as_deref().and_then(|s| serde_json::from_str(s).ok()), + arguments: m + .arguments + .as_deref() + .and_then(|s| serde_json::from_str(s).ok()) + .unwrap_or_default(), + output: m + .result + .as_deref() + .and_then(|s| serde_json::from_str(s).ok()), error: m.error, status: m.status, elapsed_ms: m.latency_ms, @@ -116,7 +123,7 @@ impl From for ConversationWithSessionResponse { } impl AppService { - pub(crate) async fn agent_require_conversation_access( + pub async fn agent_require_conversation_access( &self, user_id: Uuid, conversation_id: Uuid, @@ -141,7 +148,9 @@ impl AppService { .fetch_optional(self.db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))? - .ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?; + .ok_or_else(|| { + AppError::NotFound("agent session not found".to_string()) + })?; let (session_user, session_wk) = session; if session_user != Some(user_id) { @@ -171,7 +180,9 @@ impl AppService { .fetch_optional(self.db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))? - .ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?; + .ok_or_else(|| { + AppError::NotFound("agent session not found".to_string()) + })?; if session.0 != Some(user_id) { if let Some(wk) = session.1 { @@ -345,7 +356,9 @@ impl AppService { .map_err(|e| AppError::DatabaseError(e.to_string()))?; if rows.rows_affected() == 0 { - return Err(AppError::NotFound("conversation not found".to_string())); + return Err(AppError::NotFound( + "conversation not found".to_string(), + )); } Ok(()) } @@ -464,8 +477,10 @@ impl AppService { }; // Group tool calls by message_id. - let mut tool_calls_by_message: std::collections::HashMap> = - std::collections::HashMap::new(); + let mut tool_calls_by_message: std::collections::HashMap< + Uuid, + Vec, + > = std::collections::HashMap::new(); for log in tool_call_logs { if let Some(msg_id) = log.message { tool_calls_by_message @@ -479,9 +494,8 @@ impl AppService { .into_iter() .map(|row| { let mut msg: MessageResponse = row.into(); - msg.tool_calls = tool_calls_by_message - .remove(&msg.id) - .unwrap_or_default(); + msg.tool_calls = + tool_calls_by_message.remove(&msg.id).unwrap_or_default(); msg }) .collect(); diff --git a/lib/service/agent/git_tools/blame.rs b/lib/service/agent/git_tools/blame.rs index 1384ede..5fe6108 100644 --- a/lib/service/agent/git_tools/blame.rs +++ b/lib/service/agent/git_tools/blame.rs @@ -3,26 +3,34 @@ use ai::tool::tools::FunctionCall; use async_trait::async_trait; use git::rpc::proto as p; use git::rpc::proto::blame_service_client::BlameServiceClient; -use serde_json::{json, Value}; +use serde_json::{Value, json}; -use super::helpers::{arg_str, arg_opt_str, git_ctx, require_repo_member, rpc_err}; +use super::helpers::{ + arg_opt_str, arg_str, git_ctx, require_repo_member, rpc_err, +}; use crate::agent::run::AppAgentContext; pub struct GitBlameTool; impl GitBlameTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitBlameTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitBlameTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_blame" } + fn name(&self) -> &'static str { + "git_blame" + } fn description(&self) -> &'static str { "Blame a file to see which commits authored each line range." @@ -43,46 +51,79 @@ impl FunctionCall for GitBlameTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let path = arg_str(&args, "path")?; let rev = arg_opt_str(&args, "rev").map(String::from); - let start_line = args.get("start_line").and_then(|v| v.as_u64()).map(|v| v as u32); - let end_line = args.get("end_line").and_then(|v| v.as_u64()).map(|v| v as u32); + let start_line = args + .get("start_line") + .and_then(|v| v.as_u64()) + .map(|v| v as u32); + let end_line = args + .get("end_line") + .and_then(|v| v.as_u64()) + .map(|v| v as u32); - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = BlameServiceClient::new(git.channel.clone()); if let (Some(start), Some(end)) = (start_line, end_line) { let resp = client .blame_lines(p::BlameLinesRequest { - repo_id: repo.id.to_string(), path: path.to_string(), rev, - start_line: start, end_line: end, + repo_id: repo.id.to_string(), + path: path.to_string(), + rev, + start_line: start, + end_line: end, }) - .await.map_err(rpc_err)?.into_inner(); + .await + .map_err(rpc_err)? + .into_inner(); - let lines: Vec = resp.lines.iter().map(|l| json!({ - "line_no": l.line_no, - "content": l.content, - "commit_oid": l.commit_oid.as_ref().map(|o| &o.value), - })).collect(); + let lines: Vec = resp + .lines + .iter() + .map(|l| { + json!({ + "line_no": l.line_no, + "content": l.content, + "commit_oid": l.commit_oid.as_ref().map(|o| &o.value), + }) + }) + .collect(); Ok(json!({ "lines": lines, "count": lines.len() })) } else { let resp = client .blame_file(p::BlameFileRequest { - repo_id: repo.id.to_string(), path: path.to_string(), rev, options: None, + repo_id: repo.id.to_string(), + path: path.to_string(), + rev, + options: None, }) - .await.map_err(rpc_err)?.into_inner(); + .await + .map_err(rpc_err)? + .into_inner(); - let hunks: Vec = resp.hunks.iter().map(|h| json!({ - "commit_oid": h.commit_oid.as_ref().map(|o| &o.value), - "final_start_line": h.final_start_line, - "final_lines": h.final_lines, - })).collect(); + let hunks: Vec = resp + .hunks + .iter() + .map(|h| { + json!({ + "commit_oid": h.commit_oid.as_ref().map(|o| &o.value), + "final_start_line": h.final_start_line, + "final_lines": h.final_lines, + }) + }) + .collect(); Ok(json!({ "hunks": hunks, "count": hunks.len() })) } diff --git a/lib/service/agent/git_tools/branch.rs b/lib/service/agent/git_tools/branch.rs index f80f58b..d2ca3aa 100644 --- a/lib/service/agent/git_tools/branch.rs +++ b/lib/service/agent/git_tools/branch.rs @@ -3,7 +3,7 @@ use ai::tool::tools::FunctionCall; use async_trait::async_trait; use git::rpc::proto as p; use git::rpc::proto::branch_service_client::BranchServiceClient; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use super::helpers::{arg_str, git_ctx, require_repo_member, rpc_err}; use crate::agent::run::AppAgentContext; @@ -11,18 +11,24 @@ use crate::agent::run::AppAgentContext; pub struct GitBranchListTool; impl GitBranchListTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitBranchListTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitBranchListTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_branch_list" } + fn name(&self) -> &'static str { + "git_branch_list" + } fn description(&self) -> &'static str { "List all branches in a repository with their HEAD commit OID." @@ -39,16 +45,23 @@ impl FunctionCall for GitBranchListTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = BranchServiceClient::new(git.channel.clone()); let resp = client - .branch_list(p::BranchListRequest { repo_id: repo.id.to_string() }) + .branch_list(p::BranchListRequest { + repo_id: repo.id.to_string(), + }) .await .map_err(rpc_err)? .into_inner(); @@ -67,18 +80,24 @@ impl FunctionCall for GitBranchListTool { pub struct GitBranchInfoTool; impl GitBranchInfoTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitBranchInfoTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitBranchInfoTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_branch_info" } + fn name(&self) -> &'static str { + "git_branch_info" + } fn description(&self) -> &'static str { "Get detailed information about a single branch, including its HEAD OID and upstream." @@ -96,13 +115,18 @@ impl FunctionCall for GitBranchInfoTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let branch = arg_str(&args, "branch")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = BranchServiceClient::new(git.channel.clone()); let resp = client @@ -114,7 +138,9 @@ impl FunctionCall for GitBranchInfoTool { .map_err(rpc_err)? .into_inner(); - let b = resp.branch.ok_or_else(|| AiError::Config(format!("branch '{branch}' not found")))?; + let b = resp.branch.ok_or_else(|| { + AiError::Config(format!("branch '{branch}' not found")) + })?; Ok(json!({ "name": b.name, "oid": b.oid.as_ref().map(|o| &o.value), @@ -128,18 +154,24 @@ impl FunctionCall for GitBranchInfoTool { pub struct GitBranchAheadBehindTool; impl GitBranchAheadBehindTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitBranchAheadBehindTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitBranchAheadBehindTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_branch_ahead_behind" } + fn name(&self) -> &'static str { + "git_branch_ahead_behind" + } fn description(&self) -> &'static str { "Compare a local branch with its remote tracking branch. Returns commits ahead and behind." @@ -158,14 +190,19 @@ impl FunctionCall for GitBranchAheadBehindTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let local_branch = arg_str(&args, "local_branch")?; let remote_branch = arg_str(&args, "remote_branch")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = BranchServiceClient::new(git.channel.clone()); let resp = client @@ -185,18 +222,24 @@ impl FunctionCall for GitBranchAheadBehindTool { pub struct GitBranchDeleteTool; impl GitBranchDeleteTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitBranchDeleteTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitBranchDeleteTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_branch_delete" } + fn name(&self) -> &'static str { + "git_branch_delete" + } fn description(&self) -> &'static str { "Delete a branch from the repository. Requires write access." @@ -215,14 +258,20 @@ impl FunctionCall for GitBranchDeleteTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let name = arg_str(&args, "name")?; - let force = args.get("force").and_then(|v| v.as_bool()).unwrap_or(false); + let force = + args.get("force").and_then(|v| v.as_bool()).unwrap_or(false); - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = BranchServiceClient::new(git.channel.clone()); client @@ -243,18 +292,24 @@ impl FunctionCall for GitBranchDeleteTool { pub struct GitCreateBranchTool; impl GitCreateBranchTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitCreateBranchTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitCreateBranchTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_create_branch" } + fn name(&self) -> &'static str { + "git_create_branch" + } fn description(&self) -> &'static str { "Create a new branch in a repository. Requires write access." @@ -274,15 +329,21 @@ impl FunctionCall for GitCreateBranchTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let name = arg_str(&args, "name")?; let oid = arg_str(&args, "oid")?; - let force = args.get("force").and_then(|v| v.as_bool()).unwrap_or(false); + let force = + args.get("force").and_then(|v| v.as_bool()).unwrap_or(false); - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = BranchServiceClient::new(git.channel.clone()); client @@ -290,7 +351,9 @@ impl FunctionCall for GitCreateBranchTool { repo_id: repo.id.to_string(), params: Some(p::BranchForkParams { name: name.to_string(), - oid: Some(p::ObjectId { value: oid.to_string() }), + oid: Some(p::ObjectId { + value: oid.to_string(), + }), force, }), }) diff --git a/lib/service/agent/git_tools/commit.rs b/lib/service/agent/git_tools/commit.rs index 0eb232b..2cff9cb 100644 --- a/lib/service/agent/git_tools/commit.rs +++ b/lib/service/agent/git_tools/commit.rs @@ -3,9 +3,11 @@ use ai::tool::tools::FunctionCall; use async_trait::async_trait; use git::rpc::proto as p; use git::rpc::proto::commit_service_client::CommitServiceClient; -use serde_json::{json, Value}; +use serde_json::{Value, json}; -use super::helpers::{arg_str, arg_opt_str, arg_u64, git_ctx, require_repo_member, rpc_err}; +use super::helpers::{ + arg_opt_str, arg_str, arg_u64, git_ctx, require_repo_member, rpc_err, +}; use crate::agent::run::AppAgentContext; pub struct GitCommitHistoryTool; @@ -63,7 +65,11 @@ impl FunctionCall for GitCommitHistoryTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; @@ -71,7 +77,8 @@ impl FunctionCall for GitCommitHistoryTool { let limit = arg_u64(&args, "limit", 20).min(100); let skip = arg_u64(&args, "skip", 0); - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = CommitServiceClient::new(git.channel.clone()); let resp = client @@ -143,19 +150,26 @@ impl FunctionCall for GitCommitInfoTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let oid = arg_str(&args, "oid")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = CommitServiceClient::new(git.channel.clone()); let resp = client .commit_info(p::CommitInfoRequest { repo_id: repo.id.to_string(), - oid: Some(p::ObjectId { value: oid.to_string() }), + oid: Some(p::ObjectId { + value: oid.to_string(), + }), }) .await .map_err(rpc_err)? @@ -164,7 +178,8 @@ impl FunctionCall for GitCommitInfoTool { let c = resp .commit .ok_or_else(|| AiError::Response("commit not found".to_string()))?; - let parent_ids: Vec = c.parent_ids.iter().map(|o| o.value.clone()).collect(); + let parent_ids: Vec = + c.parent_ids.iter().map(|o| o.value.clone()).collect(); Ok(json!({ "oid": c.oid.as_ref().map(|o| &o.value), @@ -180,22 +195,27 @@ impl FunctionCall for GitCommitInfoTool { } } - pub struct GitCommitExistsTool; impl GitCommitExistsTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitCommitExistsTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitCommitExistsTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_commit_exists" } + fn name(&self) -> &'static str { + "git_commit_exists" + } fn description(&self) -> &'static str { "Check whether a specific commit OID exists in the repository." @@ -213,19 +233,26 @@ impl FunctionCall for GitCommitExistsTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let oid = arg_str(&args, "oid")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = CommitServiceClient::new(git.channel.clone()); let resp = client .commit_exists(p::CommitExistsRequest { repo_id: repo.id.to_string(), - oid: Some(p::ObjectId { value: oid.to_string() }), + oid: Some(p::ObjectId { + value: oid.to_string(), + }), }) .await .map_err(rpc_err)? @@ -235,23 +262,27 @@ impl FunctionCall for GitCommitExistsTool { } } - - pub struct GitCherryPickTool; impl GitCherryPickTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitCherryPickTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitCherryPickTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_cherry_pick" } + fn name(&self) -> &'static str { + "git_cherry_pick" + } fn description(&self) -> &'static str { "Cherry-pick a commit onto the current branch. Requires write access." @@ -271,22 +302,32 @@ impl FunctionCall for GitCherryPickTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let oid = arg_str(&args, "oid")?; - let message = args.get("message").and_then(|v| v.as_str()).map(String::from); + let message = args + .get("message") + .and_then(|v| v.as_str()) + .map(String::from); let update_ref = arg_opt_str(&args, "update_ref").map(String::from); - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = CommitServiceClient::new(git.channel.clone()); let resp = client .cherry_pick(p::CherryPickRequest { repo_id: repo.id.to_string(), params: Some(p::CommitCherryPickParams { - cherrypick_oid: Some(p::ObjectId { value: oid.to_string() }), + cherrypick_oid: Some(p::ObjectId { + value: oid.to_string(), + }), message, update_ref, ..Default::default() @@ -303,7 +344,6 @@ impl FunctionCall for GitCherryPickTool { } } - pub struct GitCommitCreateTool; impl GitCommitCreateTool { @@ -373,7 +413,11 @@ impl FunctionCall for GitCommitCreateTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; @@ -383,18 +427,27 @@ impl FunctionCall for GitCommitCreateTool { let files_val = args .get("files") .and_then(|v| v.as_array()) - .ok_or_else(|| AiError::Config("'files' must be an array of {path, content} objects".to_string()))?; + .ok_or_else(|| { + AiError::Config( + "'files' must be an array of {path, content} objects" + .to_string(), + ) + })?; let mut file_changes: Vec = Vec::new(); for f in files_val { - let path = f - .get("path") - .and_then(|v| v.as_str()) - .ok_or_else(|| AiError::Config("each file must have a 'path' field".to_string()))?; - let content = f - .get("content") - .and_then(|v| v.as_str()) - .ok_or_else(|| AiError::Config("each file must have a 'content' field".to_string()))?; + let path = + f.get("path").and_then(|v| v.as_str()).ok_or_else(|| { + AiError::Config( + "each file must have a 'path' field".to_string(), + ) + })?; + let content = + f.get("content").and_then(|v| v.as_str()).ok_or_else(|| { + AiError::Config( + "each file must have a 'content' field".to_string(), + ) + })?; file_changes.push(super::helpers::FileChange { path: path.to_string(), content: content.as_bytes().to_vec(), @@ -402,10 +455,13 @@ impl FunctionCall for GitCommitCreateTool { } if file_changes.is_empty() { - return Err(AiError::Config("'files' array must not be empty".to_string())); + return Err(AiError::Config( + "'files' array must not be empty".to_string(), + )); } - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = CommitServiceClient::new(git.channel.clone()); let resp = client diff --git a/lib/service/agent/git_tools/diff.rs b/lib/service/agent/git_tools/diff.rs index f5c12d4..3b36982 100644 --- a/lib/service/agent/git_tools/diff.rs +++ b/lib/service/agent/git_tools/diff.rs @@ -3,7 +3,7 @@ use ai::tool::tools::FunctionCall; use async_trait::async_trait; use git::rpc::proto as p; use git::rpc::proto::diff_service_client::DiffServiceClient; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use super::helpers::{arg_str, git_ctx, require_repo_member, rpc_err}; use crate::agent::run::AppAgentContext; @@ -11,18 +11,24 @@ use crate::agent::run::AppAgentContext; pub struct GitDiffStatsTool; impl GitDiffStatsTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitDiffStatsTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitDiffStatsTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_diff_stats" } + fn name(&self) -> &'static str { + "git_diff_stats" + } fn description(&self) -> &'static str { "Get diff statistics between two commits: files changed, insertions, deletions." @@ -41,29 +47,42 @@ impl FunctionCall for GitDiffStatsTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let old_oid = arg_str(&args, "old_oid")?; let new_oid = arg_str(&args, "new_oid")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = DiffServiceClient::new(git.channel.clone()); let resp = client .diff_stats(p::DiffStatsRequest { repo_id: repo.id.to_string(), - old_oid: Some(p::ObjectId { value: old_oid.to_string() }), - new_oid: Some(p::ObjectId { value: new_oid.to_string() }), + old_oid: Some(p::ObjectId { + value: old_oid.to_string(), + }), + new_oid: Some(p::ObjectId { + value: new_oid.to_string(), + }), options: None, }) .await .map_err(rpc_err)? .into_inner(); - let result = resp.result.ok_or_else(|| AiError::Response("no diff result".to_string()))?; - let stats = result.stats.ok_or_else(|| AiError::Response("no stats".to_string()))?; + let result = resp + .result + .ok_or_else(|| AiError::Response("no diff result".to_string()))?; + let stats = result + .stats + .ok_or_else(|| AiError::Response("no stats".to_string()))?; let files: Vec = result.deltas.iter().map(|d| { let status = match p::DiffDeltaStatus::try_from(d.status) { @@ -90,18 +109,24 @@ impl FunctionCall for GitDiffStatsTool { pub struct GitDiffPatchTool; impl GitDiffPatchTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitDiffPatchTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitDiffPatchTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_diff_patch" } + fn name(&self) -> &'static str { + "git_diff_patch" + } fn description(&self) -> &'static str { "Get the full diff (unified format) between two commits, including line-level changes." @@ -121,22 +146,34 @@ impl FunctionCall for GitDiffPatchTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let old_oid = arg_str(&args, "old_oid")?; let new_oid = arg_str(&args, "new_oid")?; - let ctx_lines = args.get("context_lines").and_then(|v| v.as_u64()).unwrap_or(3) as u32; + let ctx_lines = args + .get("context_lines") + .and_then(|v| v.as_u64()) + .unwrap_or(3) as u32; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = DiffServiceClient::new(git.channel.clone()); let resp = client .diff_patch(p::DiffPatchRequest { repo_id: repo.id.to_string(), - old_oid: Some(p::ObjectId { value: old_oid.to_string() }), - new_oid: Some(p::ObjectId { value: new_oid.to_string() }), + old_oid: Some(p::ObjectId { + value: old_oid.to_string(), + }), + new_oid: Some(p::ObjectId { + value: new_oid.to_string(), + }), options: Some(p::DiffOptions { context_lines: ctx_lines, ..Default::default() @@ -146,8 +183,12 @@ impl FunctionCall for GitDiffPatchTool { .map_err(rpc_err)? .into_inner(); - let result = resp.result.ok_or_else(|| AiError::Response("no diff result".to_string()))?; - let stats = result.stats.ok_or_else(|| AiError::Response("no stats".to_string()))?; + let result = resp + .result + .ok_or_else(|| AiError::Response("no diff result".to_string()))?; + let stats = result + .stats + .ok_or_else(|| AiError::Response("no stats".to_string()))?; let mut patch_text = String::new(); for delta in &result.deltas { @@ -158,15 +199,29 @@ impl FunctionCall for GitDiffPatchTool { Ok(p::DiffDeltaStatus::Renamed) => "renamed", _ => "unknown", }; - let old = delta.old_file.as_ref().and_then(|f| f.path.as_deref()).unwrap_or("unknown"); - let new = delta.new_file.as_ref().and_then(|f| f.path.as_deref()).unwrap_or("unknown"); - patch_text.push_str(&format!("--- {}\n+++ {}\n@@ status: {status} @@\n", old, new)); + let old = delta + .old_file + .as_ref() + .and_then(|f| f.path.as_deref()) + .unwrap_or("unknown"); + let new = delta + .new_file + .as_ref() + .and_then(|f| f.path.as_deref()) + .unwrap_or("unknown"); + patch_text.push_str(&format!( + "--- {}\n+++ {}\n@@ status: {status} @@\n", + old, new + )); for hunk in &delta.hunks { patch_text.push_str(&hunk.header); patch_text.push('\n'); for line in &delta.lines { - patch_text.push_str(&format!("{}{}\n", line.origin, line.content)); + patch_text.push_str(&format!( + "{}{}\n", + line.origin, line.content + )); } patch_text.push('\n'); } diff --git a/lib/service/agent/git_tools/helpers.rs b/lib/service/agent/git_tools/helpers.rs index e0843fb..820c99e 100644 --- a/lib/service/agent/git_tools/helpers.rs +++ b/lib/service/agent/git_tools/helpers.rs @@ -44,14 +44,17 @@ pub(super) async fn require_repo_member( workspace_name: &str, repo_name: &str, ) -> AiResult { - let wk_id: Uuid = sqlx::query_scalar( - "SELECT id FROM workspace WHERE name = $1", - ) - .bind(workspace_name) - .fetch_optional(git.db.reader()) - .await - .map_err(AiError::Database)? - .ok_or_else(|| AiError::Config(format!("workspace '{workspace_name}' not found")))?; + let wk_id: Uuid = + sqlx::query_scalar("SELECT id FROM workspace WHERE name = $1") + .bind(workspace_name) + .fetch_optional(git.db.reader()) + .await + .map_err(AiError::Database)? + .ok_or_else(|| { + AiError::Config(format!( + "workspace '{workspace_name}' not found" + )) + })?; let is_member: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM wk_member \ @@ -86,9 +89,11 @@ pub(super) async fn require_repo_member( } pub(super) fn git_ctx(ctx: &AppAgentContext) -> AiResult<&GitAgentContext> { - ctx.git - .as_ref() - .ok_or_else(|| AiError::Config("git tools are not available in this session".to_string())) + ctx.git.as_ref().ok_or_else(|| { + AiError::Config( + "git tools are not available in this session".to_string(), + ) + }) } pub(super) fn rpc_err(status: tonic::Status) -> AiError { @@ -96,9 +101,9 @@ pub(super) fn rpc_err(status: tonic::Status) -> AiError { } pub(super) fn arg_str<'a>(args: &'a Value, key: &str) -> AiResult<&'a str> { - args.get(key) - .and_then(|v| v.as_str()) - .ok_or_else(|| AiError::Config(format!("'{key}' parameter is required"))) + args.get(key).and_then(|v| v.as_str()).ok_or_else(|| { + AiError::Config(format!("'{key}' parameter is required")) + }) } pub(super) fn arg_opt_str<'a>(args: &'a Value, key: &str) -> Option<&'a str> { @@ -106,7 +111,5 @@ pub(super) fn arg_opt_str<'a>(args: &'a Value, key: &str) -> Option<&'a str> { } pub(super) fn arg_u64(args: &Value, key: &str, default: u64) -> u64 { - args.get(key) - .and_then(|v| v.as_u64()) - .unwrap_or(default) + args.get(key).and_then(|v| v.as_u64()).unwrap_or(default) } diff --git a/lib/service/agent/git_tools/merge.rs b/lib/service/agent/git_tools/merge.rs index 1626a5b..7e7e0c5 100644 --- a/lib/service/agent/git_tools/merge.rs +++ b/lib/service/agent/git_tools/merge.rs @@ -3,7 +3,7 @@ use ai::tool::tools::FunctionCall; use async_trait::async_trait; use git::rpc::proto as p; use git::rpc::proto::merge_service_client::MergeServiceClient; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use super::helpers::{arg_str, git_ctx, require_repo_member, rpc_err}; use crate::agent::run::AppAgentContext; @@ -11,18 +11,24 @@ use crate::agent::run::AppAgentContext; pub struct GitMergeBaseTool; impl GitMergeBaseTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitMergeBaseTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitMergeBaseTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_merge_base" } + fn name(&self) -> &'static str { + "git_merge_base" + } fn description(&self) -> &'static str { "Find the common ancestor (merge base) of two commits." @@ -41,21 +47,30 @@ impl FunctionCall for GitMergeBaseTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let oid_a = arg_str(&args, "oid_a")?; let oid_b = arg_str(&args, "oid_b")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = MergeServiceClient::new(git.channel.clone()); let resp = client .merge_base(p::MergeBaseRequest { repo_id: repo.id.to_string(), - oid_a: Some(p::ObjectId { value: oid_a.to_string() }), - oid_b: Some(p::ObjectId { value: oid_b.to_string() }), + oid_a: Some(p::ObjectId { + value: oid_a.to_string(), + }), + oid_b: Some(p::ObjectId { + value: oid_b.to_string(), + }), }) .await .map_err(rpc_err)? @@ -68,18 +83,24 @@ impl FunctionCall for GitMergeBaseTool { pub struct GitMergeAnalysisTool; impl GitMergeAnalysisTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitMergeAnalysisTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitMergeAnalysisTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_merge_analysis" } + fn name(&self) -> &'static str { + "git_merge_analysis" + } fn description(&self) -> &'static str { "Analyze whether two commits can be merged (fast-forward, normal, up-to-date, etc)." @@ -98,28 +119,41 @@ impl FunctionCall for GitMergeAnalysisTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let oid_a = arg_str(&args, "oid_a")?; let oid_b = arg_str(&args, "oid_b")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = MergeServiceClient::new(git.channel.clone()); let resp = client .merge_analysis(p::MergeAnalysisRequest { repo_id: repo.id.to_string(), - oid_a: Some(p::ObjectId { value: oid_a.to_string() }), - oid_b: Some(p::ObjectId { value: oid_b.to_string() }), + oid_a: Some(p::ObjectId { + value: oid_a.to_string(), + }), + oid_b: Some(p::ObjectId { + value: oid_b.to_string(), + }), }) .await .map_err(rpc_err)? .into_inner(); - let analysis = resp.analysis.ok_or_else(|| AiError::Response("no analysis".to_string()))?; - let pref = resp.preference.ok_or_else(|| AiError::Response("no preference".to_string()))?; + let analysis = resp + .analysis + .ok_or_else(|| AiError::Response("no analysis".to_string()))?; + let pref = resp + .preference + .ok_or_else(|| AiError::Response("no preference".to_string()))?; // Determine overall status let status = if analysis.is_up_to_date { @@ -155,18 +189,24 @@ impl FunctionCall for GitMergeAnalysisTool { pub struct GitMergeIsConflictedTool; impl GitMergeIsConflictedTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitMergeIsConflictedTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitMergeIsConflictedTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_merge_is_conflicted" } + fn name(&self) -> &'static str { + "git_merge_is_conflicted" + } fn description(&self) -> &'static str { "Check if the repository is currently in a conflicted merge state." @@ -183,12 +223,17 @@ impl FunctionCall for GitMergeIsConflictedTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = MergeServiceClient::new(git.channel.clone()); let resp = client diff --git a/lib/service/agent/git_tools/tag.rs b/lib/service/agent/git_tools/tag.rs index 81c9960..91a37e6 100644 --- a/lib/service/agent/git_tools/tag.rs +++ b/lib/service/agent/git_tools/tag.rs @@ -3,29 +3,38 @@ use ai::tool::tools::FunctionCall; use async_trait::async_trait; use git::rpc::proto as p; use git::rpc::proto::tag_service_client::TagServiceClient; -use serde_json::{json, Value}; +use serde_json::{Value, json}; -use super::helpers::{arg_str, arg_opt_str, git_ctx, require_repo_member, rpc_err}; +use super::helpers::{ + arg_opt_str, arg_str, git_ctx, require_repo_member, rpc_err, +}; use crate::agent::run::AppAgentContext; - pub struct GitTagListTool; impl GitTagListTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitTagListTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitTagListTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_tag_list" } + fn name(&self) -> &'static str { + "git_tag_list" + } - fn description(&self) -> &'static str { "List all tags in a repository." } + fn description(&self) -> &'static str { + "List all tags in a repository." + } fn schema(&self) -> Value { json!({ @@ -38,51 +47,71 @@ impl FunctionCall for GitTagListTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = TagServiceClient::new(git.channel.clone()); let resp = client - .tag_list(p::TagListRequest { repo_id: repo.id.to_string() }) + .tag_list(p::TagListRequest { + repo_id: repo.id.to_string(), + }) .await .map_err(rpc_err)? .into_inner(); - let tags: Vec = resp.tags.iter().map(|t| json!({ - "name": t.name, - "oid": t.oid.as_ref().map(|o| &o.value), - "target": t.target.as_ref().map(|o| &o.value), - "is_annotated": t.is_annotated, - "message": t.message, - "tagger": t.tagger, - })).collect(); + let tags: Vec = resp + .tags + .iter() + .map(|t| { + json!({ + "name": t.name, + "oid": t.oid.as_ref().map(|o| &o.value), + "target": t.target.as_ref().map(|o| &o.value), + "is_annotated": t.is_annotated, + "message": t.message, + "tagger": t.tagger, + }) + }) + .collect(); Ok(json!({ "tags": tags, "count": tags.len() })) } } - pub struct GitCreateTagTool; impl GitCreateTagTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitCreateTagTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitCreateTagTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_create_tag" } + fn name(&self) -> &'static str { + "git_create_tag" + } - fn description(&self) -> &'static str { "Create a new tag pointing at a commit." } + fn description(&self) -> &'static str { + "Create a new tag pointing at a commit." + } fn schema(&self) -> Value { json!({ @@ -99,16 +128,22 @@ impl FunctionCall for GitCreateTagTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let name = arg_str(&args, "name")?; let target_oid = arg_str(&args, "target_oid")?; let message = arg_opt_str(&args, "message").map(String::from); - let force = args.get("force").and_then(|v| v.as_bool()).unwrap_or(false); + let force = + args.get("force").and_then(|v| v.as_bool()).unwrap_or(false); - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = TagServiceClient::new(git.channel.clone()); let resp = client @@ -116,7 +151,9 @@ impl FunctionCall for GitCreateTagTool { repo_id: repo.id.to_string(), params: Some(p::TagInitParams { name: name.to_string(), - target: Some(p::ObjectId { value: target_oid.to_string() }), + target: Some(p::ObjectId { + value: target_oid.to_string(), + }), message, tagger: None, force, @@ -131,24 +168,31 @@ impl FunctionCall for GitCreateTagTool { } } - pub struct GitTagInfoTool; impl GitTagInfoTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitTagInfoTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitTagInfoTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_tag_info" } + fn name(&self) -> &'static str { + "git_tag_info" + } - fn description(&self) -> &'static str { "Get detailed information about a specific tag." } + fn description(&self) -> &'static str { + "Get detailed information about a specific tag." + } fn schema(&self) -> Value { json!({ @@ -162,22 +206,32 @@ impl FunctionCall for GitTagInfoTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let name = arg_str(&args, "name")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = TagServiceClient::new(git.channel.clone()); let resp = client - .tag_info(p::TagInfoRequest { repo_id: repo.id.to_string(), name: name.to_string() }) + .tag_info(p::TagInfoRequest { + repo_id: repo.id.to_string(), + name: name.to_string(), + }) .await .map_err(rpc_err)? .into_inner(); - let t = resp.tag.ok_or_else(|| AiError::Config(format!("tag '{name}' not found")))?; + let t = resp.tag.ok_or_else(|| { + AiError::Config(format!("tag '{name}' not found")) + })?; Ok(json!({ "name": t.name, "oid": t.oid.as_ref().map(|o| &o.value), @@ -190,24 +244,31 @@ impl FunctionCall for GitTagInfoTool { } } - pub struct GitDeleteTagTool; impl GitDeleteTagTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitDeleteTagTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitDeleteTagTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_delete_tag" } + fn name(&self) -> &'static str { + "git_delete_tag" + } - fn description(&self) -> &'static str { "Delete a tag from the repository. Requires write access." } + fn description(&self) -> &'static str { + "Delete a tag from the repository. Requires write access." + } fn schema(&self) -> Value { json!({ @@ -221,19 +282,26 @@ impl FunctionCall for GitDeleteTagTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let name = arg_str(&args, "name")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut client = TagServiceClient::new(git.channel.clone()); client .tag_delete(p::TagDeleteRequest { repo_id: repo.id.to_string(), - params: Some(p::TagDeleteParams { name: name.to_string() }), + params: Some(p::TagDeleteParams { + name: name.to_string(), + }), }) .await .map_err(rpc_err)?; diff --git a/lib/service/agent/git_tools/tree.rs b/lib/service/agent/git_tools/tree.rs index a651caf..8e555c1 100644 --- a/lib/service/agent/git_tools/tree.rs +++ b/lib/service/agent/git_tools/tree.rs @@ -5,27 +5,34 @@ use git::rpc::proto as p; use git::rpc::proto::blob_service_client::BlobServiceClient; use git::rpc::proto::commit_service_client::CommitServiceClient; use git::rpc::proto::tree_service_client::TreeServiceClient; -use serde_json::{json, Value}; +use serde_json::{Value, json}; -use super::helpers::{arg_str, arg_opt_str, git_ctx, require_repo_member, rpc_err}; +use super::helpers::{ + arg_opt_str, arg_str, git_ctx, require_repo_member, rpc_err, +}; use crate::agent::run::AppAgentContext; - pub struct GitTreeEntriesTool; impl GitTreeEntriesTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitTreeEntriesTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitTreeEntriesTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_tree_entries" } + fn name(&self) -> &'static str { + "git_tree_entries" + } fn description(&self) -> &'static str { "List files and subdirectories at a given path in a commit's tree. Use this to explore repo structure." @@ -44,29 +51,36 @@ impl FunctionCall for GitTreeEntriesTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let commit_oid = arg_str(&args, "commit_oid")?; let path = arg_opt_str(&args, "path").unwrap_or(""); - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut commit_client = CommitServiceClient::new(git.channel.clone()); let commit_resp = commit_client .commit_info(p::CommitInfoRequest { repo_id: repo.id.to_string(), - oid: Some(p::ObjectId { value: commit_oid.to_string() }), + oid: Some(p::ObjectId { + value: commit_oid.to_string(), + }), }) .await .map_err(rpc_err)? .into_inner(); - let tree_oid = commit_resp - .commit - .and_then(|c| c.tree_id) - .ok_or_else(|| AiError::Response("commit has no tree".to_string()))?; + let tree_oid = + commit_resp.commit.and_then(|c| c.tree_id).ok_or_else(|| { + AiError::Response("commit has no tree".to_string()) + })?; let mut client = TreeServiceClient::new(git.channel.clone()); let resp = client @@ -80,43 +94,51 @@ impl FunctionCall for GitTreeEntriesTool { .map_err(rpc_err)? .into_inner(); - let entries: Vec = resp.entries.iter().map(|e| { - let kind = match p::TreeKind::try_from(e.kind) { - Ok(p::TreeKind::Blob) => "file", - Ok(p::TreeKind::Tree) => "dir", - Ok(p::TreeKind::LfsPointer) => "lfs", - _ => "unknown", - }; - json!({ - "name": e.name, - "oid": e.oid.as_ref().map(|o| &o.value), - "kind": kind, - "is_binary": e.is_binary, - "is_lfs": e.is_lfs, + let entries: Vec = resp + .entries + .iter() + .map(|e| { + let kind = match p::TreeKind::try_from(e.kind) { + Ok(p::TreeKind::Blob) => "file", + Ok(p::TreeKind::Tree) => "dir", + Ok(p::TreeKind::LfsPointer) => "lfs", + _ => "unknown", + }; + json!({ + "name": e.name, + "oid": e.oid.as_ref().map(|o| &o.value), + "kind": kind, + "is_binary": e.is_binary, + "is_lfs": e.is_lfs, + }) }) - }).collect(); + .collect(); Ok(json!({ "entries": entries, "count": entries.len() })) } } - - pub struct GitFileContentTool; impl GitFileContentTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for GitFileContentTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for GitFileContentTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "git_file_content" } + fn name(&self) -> &'static str { + "git_file_content" + } fn description(&self) -> &'static str { "Read the content of a file at a given path from a specific commit. Returns the file content as text." @@ -135,32 +157,43 @@ impl FunctionCall for GitFileContentTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let repo_name = arg_str(&args, "repo")?; let commit_oid = arg_str(&args, "commit_oid")?; let path = arg_str(&args, "path")?; - let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + let repo = + require_repo_member(git, ctx.user_id, workspace, repo_name).await?; let mut tree_client = TreeServiceClient::new(git.channel.clone()); let entry_resp = tree_client - .tree_entry_by_path_from_commit(p::TreeEntryByPathFromCommitRequest { - repo_id: repo.id.to_string(), - commit_oid: Some(p::ObjectId { value: commit_oid.to_string() }), - path: path.to_string(), - }) + .tree_entry_by_path_from_commit( + p::TreeEntryByPathFromCommitRequest { + repo_id: repo.id.to_string(), + commit_oid: Some(p::ObjectId { + value: commit_oid.to_string(), + }), + path: path.to_string(), + }, + ) .await .map_err(rpc_err)? .into_inner(); - let entry = entry_resp - .entry - .ok_or_else(|| AiError::Config(format!("file not found: {path}")))?; + let entry = entry_resp.entry.ok_or_else(|| { + AiError::Config(format!("file not found: {path}")) + })?; if entry.kind == p::TreeKind::Tree as i32 { - return Err(AiError::Config(format!("'{path}' is a directory, not a file"))); + return Err(AiError::Config(format!( + "'{path}' is a directory, not a file" + ))); } let blob_oid = entry diff --git a/lib/service/agent/issue_tools/helpers.rs b/lib/service/agent/issue_tools/helpers.rs index b6e835d..d98ce8b 100644 --- a/lib/service/agent/issue_tools/helpers.rs +++ b/lib/service/agent/issue_tools/helpers.rs @@ -18,14 +18,17 @@ pub(super) async fn require_workspace_member( user_id: Uuid, workspace_name: &str, ) -> AiResult { - let wk_id: Uuid = sqlx::query_scalar( - "SELECT id FROM workspace WHERE name = $1", - ) - .bind(workspace_name) - .fetch_optional(git.db.reader()) - .await - .map_err(AiError::Database)? - .ok_or_else(|| AiError::Config(format!("workspace '{workspace_name}' not found")))?; + let wk_id: Uuid = + sqlx::query_scalar("SELECT id FROM workspace WHERE name = $1") + .bind(workspace_name) + .fetch_optional(git.db.reader()) + .await + .map_err(AiError::Database)? + .ok_or_else(|| { + AiError::Config(format!( + "workspace '{workspace_name}' not found" + )) + })?; let is_member: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM wk_member \ @@ -47,13 +50,15 @@ pub(super) async fn require_workspace_member( } pub(super) fn git_ctx(ctx: &AppAgentContext) -> AiResult<&GitAgentContext> { - ctx.git - .as_ref() - .ok_or_else(|| AiError::Config("issue tools are not available in this session".to_string())) + ctx.git.as_ref().ok_or_else(|| { + AiError::Config( + "issue tools are not available in this session".to_string(), + ) + }) } pub(super) fn arg_str<'a>(args: &'a Value, key: &str) -> AiResult<&'a str> { - args.get(key) - .and_then(|v| v.as_str()) - .ok_or_else(|| AiError::Config(format!("'{key}' parameter is required"))) + args.get(key).and_then(|v| v.as_str()).ok_or_else(|| { + AiError::Config(format!("'{key}' parameter is required")) + }) } diff --git a/lib/service/agent/issue_tools/issue.rs b/lib/service/agent/issue_tools/issue.rs index 6cb3769..5d54507 100644 --- a/lib/service/agent/issue_tools/issue.rs +++ b/lib/service/agent/issue_tools/issue.rs @@ -2,7 +2,7 @@ use ai::error::{AiError, AiResult}; use ai::tool::tools::FunctionCall; use async_trait::async_trait; use db::sqlx; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use super::helpers::{arg_str, git_ctx, require_workspace_member}; use crate::agent::run::AppAgentContext; @@ -10,18 +10,24 @@ use crate::agent::run::AppAgentContext; pub struct IssueListTool; impl IssueListTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for IssueListTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for IssueListTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "issue_list" } + fn name(&self) -> &'static str { + "issue_list" + } fn description(&self) -> &'static str { "List issues in a workspace with optional filters: state (open/closed), priority, label, milestone, assignee." @@ -43,40 +49,63 @@ impl FunctionCall for IssueListTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; - let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; - let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(20).min(100); + let wk_id = + require_workspace_member(git, ctx.user_id, workspace).await?; + let limit = args + .get("limit") + .and_then(|v| v.as_i64()) + .unwrap_or(20) + .min(100); - let mut conditions = vec!["i.wk = $1".to_string(), "i.deleted_at IS NULL".to_string()]; + let mut conditions = + vec!["i.wk = $1".to_string(), "i.deleted_at IS NULL".to_string()]; let mut params: Vec = vec![wk_id.to_string()]; let mut idx = 2i32; - for (arg, col) in [ - ("state", "i.state"), - ("priority", "i.priority"), - ] { - if let Some(v) = args.get(arg).and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + for (arg, col) in [("state", "i.state"), ("priority", "i.priority")] { + if let Some(v) = args + .get(arg) + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + { conditions.push(format!("{col} = ${idx}")); params.push(v.to_string()); idx += 1; } } - if let Some(v) = args.get("label").and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + if let Some(v) = args + .get("label") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + { conditions.push(format!("EXISTS(SELECT 1 FROM issue_label il INNER JOIN label l ON l.id = il.label WHERE il.issue = i.id AND l.name = ${idx})")); params.push(v.to_string()); idx += 1; } - if let Some(v) = args.get("milestone").and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + if let Some(v) = args + .get("milestone") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + { conditions.push(format!("EXISTS(SELECT 1 FROM issue_milestone im INNER JOIN milestone m ON m.id = im.milestone WHERE im.issue = i.id AND m.title = ${idx})")); params.push(v.to_string()); idx += 1; } - if let Some(v) = args.get("assignee").and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + if let Some(v) = args + .get("assignee") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + { conditions.push(format!("EXISTS(SELECT 1 FROM issue_assignee ia INNER JOIN \"user\" u ON u.id = ia.\"user\" WHERE ia.issue = i.id AND u.username = ${idx})")); params.push(v.to_string()); idx += 1; @@ -90,14 +119,18 @@ impl FunctionCall for IssueListTool { ORDER BY i.created_at DESC LIMIT ${idx}", ); - let mut q = sqlx::query_as::<_, IssueRow>(db::sqlx::AssertSqlSafe(query)); + let mut q = + sqlx::query_as::<_, IssueRow>(db::sqlx::AssertSqlSafe(query)); q = q.bind(wk_id); for i in 1..params.len() { q = q.bind(¶ms[i]); } q = q.bind(limit); - let rows = q.fetch_all(git.db.reader()).await.map_err(AiError::Database)?; + let rows = q + .fetch_all(git.db.reader()) + .await + .map_err(AiError::Database)?; let issues: Vec = rows.iter().map(|r| json!({ "number": r.number, @@ -127,18 +160,24 @@ struct IssueRow { pub struct IssueGetTool; impl IssueGetTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for IssueGetTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for IssueGetTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "issue_get" } + fn name(&self) -> &'static str { + "issue_get" + } fn description(&self) -> &'static str { "Get full details of a single issue by its number." @@ -155,12 +194,19 @@ impl FunctionCall for IssueGetTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; - let number = args.get("number").and_then(|v| v.as_i64()) - .ok_or_else(|| AiError::Config("'number' parameter is required".to_string()))?; - let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + let number = + args.get("number").and_then(|v| v.as_i64()).ok_or_else(|| { + AiError::Config("'number' parameter is required".to_string()) + })?; + let wk_id = + require_workspace_member(git, ctx.user_id, workspace).await?; let row = sqlx::query_as::<_, IssueRow>( "SELECT number, title, body, state, priority, \ @@ -176,7 +222,9 @@ impl FunctionCall for IssueGetTool { // Load labels #[derive(sqlx::FromRow)] - struct LabelRow { name: String } + struct LabelRow { + name: String, + } let labels: Vec = sqlx::query_as::<_, LabelRow>( "SELECT l.name FROM label l \ @@ -189,7 +237,9 @@ impl FunctionCall for IssueGetTool { // Load assignees #[derive(sqlx::FromRow)] - struct AssigneeRow { username: String } + struct AssigneeRow { + username: String, + } let assignees: Vec = sqlx::query_as::<_, AssigneeRow>( "SELECT u.username FROM \"user\" u \ @@ -218,18 +268,24 @@ impl FunctionCall for IssueGetTool { pub struct IssueCommentsTool; impl IssueCommentsTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for IssueCommentsTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for IssueCommentsTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "issue_comments" } + fn name(&self) -> &'static str { + "issue_comments" + } fn description(&self) -> &'static str { "List comments on an issue, ordered by time." @@ -247,13 +303,24 @@ impl FunctionCall for IssueCommentsTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; - let number = args.get("number").and_then(|v| v.as_i64()) - .ok_or_else(|| AiError::Config("'number' parameter is required".to_string()))?; - let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; - let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(50).min(200); + let number = + args.get("number").and_then(|v| v.as_i64()).ok_or_else(|| { + AiError::Config("'number' parameter is required".to_string()) + })?; + let wk_id = + require_workspace_member(git, ctx.user_id, workspace).await?; + let limit = args + .get("limit") + .and_then(|v| v.as_i64()) + .unwrap_or(50) + .min(200); #[derive(sqlx::FromRow)] struct CommentRow { @@ -273,31 +340,44 @@ impl FunctionCall for IssueCommentsTool { .bind(wk_id).bind(number).bind(limit) .fetch_all(git.db.reader()).await.map_err(AiError::Database)?; - let comments: Vec = rows.iter().map(|r| json!({ - "author": r.username, - "body": r.body, - "created_at": r.created_at.to_rfc3339(), - })).collect(); + let comments: Vec = rows + .iter() + .map(|r| { + json!({ + "author": r.username, + "body": r.body, + "created_at": r.created_at.to_rfc3339(), + }) + }) + .collect(); - Ok(json!({ "issue_number": number, "comments": comments, "count": comments.len() })) + Ok( + json!({ "issue_number": number, "comments": comments, "count": comments.len() }), + ) } } pub struct IssueEventsTool; impl IssueEventsTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for IssueEventsTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for IssueEventsTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "issue_events" } + fn name(&self) -> &'static str { + "issue_events" + } fn description(&self) -> &'static str { "List the timeline of events for an issue (created, commented, closed, labeled, etc)." @@ -314,12 +394,19 @@ impl FunctionCall for IssueEventsTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; - let number = args.get("number").and_then(|v| v.as_i64()) - .ok_or_else(|| AiError::Config("'number' parameter is required".to_string()))?; - let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + let number = + args.get("number").and_then(|v| v.as_i64()).ok_or_else(|| { + AiError::Config("'number' parameter is required".to_string()) + })?; + let wk_id = + require_workspace_member(git, ctx.user_id, workspace).await?; #[derive(sqlx::FromRow)] struct EventRow { @@ -340,14 +427,21 @@ impl FunctionCall for IssueEventsTool { .bind(wk_id).bind(number) .fetch_all(git.db.reader()).await.map_err(AiError::Database)?; - let events: Vec = rows.iter().map(|r| json!({ - "event": r.event, - "actor": r.username, - "from": r.from_value, - "to": r.to_value, - "created_at": r.created_at.to_rfc3339(), - })).collect(); + let events: Vec = rows + .iter() + .map(|r| { + json!({ + "event": r.event, + "actor": r.username, + "from": r.from_value, + "to": r.to_value, + "created_at": r.created_at.to_rfc3339(), + }) + }) + .collect(); - Ok(json!({ "issue_number": number, "events": events, "count": events.len() })) + Ok( + json!({ "issue_number": number, "events": events, "count": events.len() }), + ) } } diff --git a/lib/service/agent/memory.rs b/lib/service/agent/memory.rs index 91da680..8f76950 100644 --- a/lib/service/agent/memory.rs +++ b/lib/service/agent/memory.rs @@ -5,13 +5,13 @@ use ai::{ use async_trait::async_trait; use chrono::Utc; use db::sqlx; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use tracing::info; use uuid::Uuid; use super::run::AppAgentContext; -use crate::error::AppError; use crate::AppService; +use crate::error::AppError; pub struct SaveMemoryTool; impl SaveMemoryTool { @@ -62,15 +62,15 @@ impl FunctionCall for SaveMemoryTool { context: &mut Self::Context, args: Value, ) -> AiResult { - let key = args - .get("key") - .and_then(|v| v.as_str()) - .ok_or_else(|| AiError::Config("key parameter is required".to_string()))?; + let key = + args.get("key").and_then(|v| v.as_str()).ok_or_else(|| { + AiError::Config("key parameter is required".to_string()) + })?; - let value = args - .get("value") - .and_then(|v| v.as_str()) - .ok_or_else(|| AiError::Config("value parameter is required".to_string()))?; + let value = + args.get("value").and_then(|v| v.as_str()).ok_or_else(|| { + AiError::Config("value parameter is required".to_string()) + })?; let importance = args .get("importance") @@ -126,10 +126,7 @@ impl FunctionCall for RecallMemoryTool { _context: &mut Self::Context, args: Value, ) -> AiResult { - let query = args - .get("query") - .and_then(|v| v.as_str()) - .unwrap_or(""); + let query = args.get("query").and_then(|v| v.as_str()).unwrap_or(""); if query.is_empty() { return Ok(json!({ @@ -153,7 +150,7 @@ pub struct PendingMemory { } impl AppService { - pub(crate) async fn agent_load_memories( + pub async fn agent_load_memories( &self, session_id: Uuid, ) -> Result<(String, Vec<(Uuid, String, String, i32)>), AppError> { @@ -173,7 +170,8 @@ impl AppService { return Ok((String::new(), rows)); } - let mut formatted = String::from("Long-term memories for this session:\n"); + let mut formatted = + String::from("Long-term memories for this session:\n"); for (_, key, value, importance) in &rows { formatted.push_str(&format!( "- [{}] {} (importance: {})\n", @@ -183,7 +181,7 @@ impl AppService { Ok((formatted, rows)) } - pub(crate) async fn agent_persist_memories( + pub async fn agent_persist_memories( &self, session_id: Uuid, memories: &[PendingMemory], @@ -221,7 +219,7 @@ impl AppService { Ok(()) } #[allow(dead_code)] - pub(crate) async fn agent_touch_memories( + pub async fn agent_touch_memories( &self, memory_ids: &[Uuid], ) -> Result<(), AppError> { diff --git a/lib/service/agent/persistence.rs b/lib/service/agent/persistence.rs index 8429b40..1295a0f 100644 --- a/lib/service/agent/persistence.rs +++ b/lib/service/agent/persistence.rs @@ -2,9 +2,12 @@ use chrono::Utc; use db::sqlx; use uuid::Uuid; -use super::types::{AgentCostInfo, AgentStepInfo, AgentToolCallInfo, BillingRecord, SessionContext}; -use crate::error::AppError; +use super::types::{ + AgentCostInfo, AgentStepInfo, AgentToolCallInfo, BillingRecord, + SessionContext, +}; use crate::AppService; +use crate::error::AppError; impl AppService { pub(super) async fn persist_user_message( @@ -67,7 +70,11 @@ impl AppService { output_tokens: i64, ) -> Result, AppError> { let cost_result = self - .agent_calculate_cost(ctx.model_version_id, input_tokens, output_tokens) + .agent_calculate_cost( + ctx.model_version_id, + input_tokens, + output_tokens, + ) .await?; let (cost, currency) = match cost_result { @@ -160,7 +167,9 @@ impl AppService { } #[allow(dead_code)] -pub(super) fn step_info_from_agent(step: ai::agent::AgentStep) -> AgentStepInfo { +pub(super) fn step_info_from_agent( + step: ai::agent::AgentStep, +) -> AgentStepInfo { AgentStepInfo { index: step.index, assistant: step.assistant, @@ -174,7 +183,9 @@ pub(super) fn step_info_from_agent(step: ai::agent::AgentStep) -> AgentStepInfo } #[allow(dead_code)] -pub(super) fn tool_call_info_from_record(record: ai::agent::ToolCallRecord) -> AgentToolCallInfo { +pub(super) fn tool_call_info_from_record( + record: ai::agent::ToolCallRecord, +) -> AgentToolCallInfo { AgentToolCallInfo { id: record.id, name: record.name, diff --git a/lib/service/agent/run.rs b/lib/service/agent/run.rs index 7edca60..e6dd7c5 100644 --- a/lib/service/agent/run.rs +++ b/lib/service/agent/run.rs @@ -1,21 +1,16 @@ use std::sync::Arc; use std::time::Duration; -use ai::{ - agent::RigAgent, - tool::register::ToolRegister, -}; +use ai::{agent::RigAgent, tool::register::ToolRegister}; use cache::AppCache; use db::AppDatabase; use tonic::transport::Channel; use tracing::{info, warn}; use uuid::Uuid; -use super::types::{ - AgentRunRequest, AgentRunResponse, AgentUsageInfo, -}; -use crate::error::AppError; +use super::types::{AgentRunRequest, AgentRunResponse, AgentUsageInfo}; use crate::AppService; +use crate::error::AppError; #[derive(Clone)] pub struct GitAgentContext { @@ -42,9 +37,9 @@ impl AppService { ) -> Result { let ctx = self.agent_session_context(req.session_id, user_id).await?; - let conversation_id = req - .conversation_id - .ok_or_else(|| AppError::BadRequest("conversation_id is required".to_string()))?; + let conversation_id = req.conversation_id.ok_or_else(|| { + AppError::BadRequest("conversation_id is required".to_string()) + })?; let conversation = self .agent_require_conversation_access(user_id, conversation_id) .await?; @@ -54,23 +49,25 @@ impl AppService { )); } - let ai_client = self - .agent_build_ai_client(ctx.model_version_id) - .await?; + let ai_client = + self.agent_build_ai_client(ctx.model_version_id).await?; - let agent_config = self.agent_build_config( - &ctx, - req.max_steps, - ); + let agent_config = self.agent_build_config(&ctx, req.max_steps); - self.agent_maybe_compact(&ai_client, &ctx.provider_model_name, conversation_id) - .await - .unwrap_or_else(|e| { - warn!(error = %e, "compaction check failed, continuing"); - }); + self.agent_maybe_compact( + &ai_client, + &ctx.provider_model_name, + conversation_id, + ) + .await + .unwrap_or_else(|e| { + warn!(error = %e, "compaction check failed, continuing"); + }); let mut tools: ToolRegister = ToolRegister::new(); - if conversation.title == "New Chat" || conversation.title.trim().is_empty() { + if conversation.title == "New Chat" + || conversation.title.trim().is_empty() + { tools.register(super::tools::SetTitleTool::new()); } tools.register(super::memory::SaveMemoryTool::new()); @@ -99,7 +96,8 @@ impl AppService { }; let shared_ctx = Arc::new(tokio::sync::Mutex::new(agent_ctx)); - let mut tool_set = ai::agent::RigToolSet::from_register(&tools, shared_ctx); + let mut tool_set = + ai::agent::RigToolSet::from_register(&tools, shared_ctx); let rig_tools = tool_set.take_tools(); let agent = RigAgent::new(ai_client.clone(), agent_config) @@ -156,8 +154,8 @@ impl AppService { ) .await?; - let cost_info = - self.persist_billing_and_deduct( + let cost_info = self + .persist_billing_and_deduct( &ctx, invocation_id, 0, // input_tokens not tracked in chat() mode @@ -176,15 +174,16 @@ impl AppService { ) .await?; - self.update_conversation_timestamp(conversation_id) - .await?; + self.update_conversation_timestamp(conversation_id).await?; let title = agent_ctx .pending_title .filter(|t| !t.trim().is_empty()) .or_else(|| { // Only auto-set title from input when still default. - if conversation.title == "New Chat" || conversation.title.trim().is_empty() { + if conversation.title == "New Chat" + || conversation.title.trim().is_empty() + { let first_line = req.input.lines().next().unwrap_or(&req.input); let truncated: String = @@ -203,10 +202,7 @@ impl AppService { if let Some(new_title) = &title { if let Err(e) = self - .update_conversation_title( - conversation_id, - new_title, - ) + .update_conversation_title(conversation_id, new_title) .await { warn!( diff --git a/lib/service/agent/session.rs b/lib/service/agent/session.rs index f1e15fb..3e4261b 100644 --- a/lib/service/agent/session.rs +++ b/lib/service/agent/session.rs @@ -5,8 +5,8 @@ use serde::{Deserialize, Serialize}; use utoipa::ToSchema; use uuid::Uuid; -use crate::error::AppError; use crate::AppService; +use crate::error::AppError; #[derive(Debug, Clone, Deserialize, ToSchema)] pub struct CreateAgentSession { @@ -85,10 +85,8 @@ impl AppService { params: CreateAgentSession, ) -> Result { let wk_uuid: Option = if let Some(ref wk_name) = params.wk { - let wk = crate::AppService::workspace_resolve( - &*self, wk_name, - ) - .await?; + let wk = + crate::AppService::workspace_resolve(&*self, wk_name).await?; let _ = crate::AppService::workspace_require_member( &*self, wk.id, user_id, ) @@ -100,9 +98,13 @@ impl AppService { let id = Uuid::now_v7(); let now = Utc::now(); - let visibility = params.visibility.unwrap_or_else(|| "private".to_string()); + let visibility = + params.visibility.unwrap_or_else(|| "private".to_string()); let kb_ids = params.knowledge_base_ids.map(|ids| { - ids.iter().map(|id| id.to_string()).collect::>().join(",") + ids.iter() + .map(|id| id.to_string()) + .collect::>() + .join(",") }); let row = sqlx::query_as::<_, AgentSessionModel>( @@ -242,18 +244,28 @@ impl AppService { let description = params.description.or(existing.description); let system_prompt = params.system_prompt.or(existing.system_prompt); let temperature = params.temperature.or(existing.temperature); - let max_output_tokens = params.max_output_tokens.or(existing.max_output_tokens); + let max_output_tokens = + params.max_output_tokens.or(existing.max_output_tokens); let model_version = params.model_version.or(existing.model_version); let tool_policy = params.tool_policy.or(existing.tool_policy); let toolset_json = params.toolset_json.or(existing.toolset_json); - let memory_provider = params.memory_provider.or(existing.memory_provider); - let memory_provider_config = params.memory_provider_config.or(existing.memory_provider_config); - let iteration_budget = params.iteration_budget.or(existing.iteration_budget); + let memory_provider = + params.memory_provider.or(existing.memory_provider); + let memory_provider_config = params + .memory_provider_config + .or(existing.memory_provider_config); + let iteration_budget = + params.iteration_budget.or(existing.iteration_budget); let visibility = params.visibility.unwrap_or(existing.visibility); let enabled = params.enabled.unwrap_or(existing.enabled); let kb_ids = params .knowledge_base_ids - .map(|ids| ids.iter().map(|id| id.to_string()).collect::>().join(",")) + .map(|ids| { + ids.iter() + .map(|id| id.to_string()) + .collect::>() + .join(",") + }) .or(existing.knowledge_base_ids); let variables = params.variables.or(existing.variables); @@ -427,7 +439,9 @@ impl AppService { current.insert( "disabled".to_string(), serde_json::Value::Array( - dis.into_iter().map(serde_json::Value::String).collect(), + dis.into_iter() + .map(serde_json::Value::String) + .collect(), ), ); } diff --git a/lib/service/agent/sse.rs b/lib/service/agent/sse.rs index e7bcfcf..530c2ce 100644 --- a/lib/service/agent/sse.rs +++ b/lib/service/agent/sse.rs @@ -2,15 +2,15 @@ use std::sync::Arc; use ai::agent::{RigAgent, RigStreamChunk, RigToolSet}; use ai::tool::register::ToolRegister; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use tokio::sync::mpsc; use tracing::{error, info, warn}; use uuid::Uuid; use super::run::AppAgentContext; use super::types::AgentRunRequest; -use crate::error::AppError; use crate::AppService; +use crate::error::AppError; impl AppService { pub async fn agent_run_streaming( @@ -19,9 +19,9 @@ impl AppService { req: AgentRunRequest, ) -> Result, AppError> { let ctx = self.agent_session_context(req.session_id, user_id).await?; - let conversation_id = req - .conversation_id - .ok_or_else(|| AppError::BadRequest("conversation_id is required".to_string()))?; + let conversation_id = req.conversation_id.ok_or_else(|| { + AppError::BadRequest("conversation_id is required".to_string()) + })?; let conversation = self .agent_require_conversation_access(user_id, conversation_id) .await?; @@ -31,17 +31,24 @@ impl AppService { )); } - let ai_client = self.agent_build_ai_client(ctx.model_version_id).await?; + let ai_client = + self.agent_build_ai_client(ctx.model_version_id).await?; let agent_config = self.agent_build_config(&ctx, req.max_steps); - self.agent_maybe_compact(&ai_client, &ctx.provider_model_name, conversation_id) - .await - .unwrap_or_else(|e| { - warn!(error = %e, "compaction check failed, continuing"); - }); + self.agent_maybe_compact( + &ai_client, + &ctx.provider_model_name, + conversation_id, + ) + .await + .unwrap_or_else(|e| { + warn!(error = %e, "compaction check failed, continuing"); + }); let mut tools: ToolRegister = ToolRegister::new(); - if conversation.title == "New Chat" || conversation.title.trim().is_empty() { + if conversation.title == "New Chat" + || conversation.title.trim().is_empty() + { tools.register(super::tools::SetTitleTool::new()); } tools.register(super::memory::SaveMemoryTool::new()); @@ -84,10 +91,14 @@ impl AppService { "agent sse stream starting" ); - if let Err(e) = self.cache.set::( - &format!("agent:stream:active:{}", conversation_id), - &invocation_id, - ).await { + if let Err(e) = self + .cache + .set::( + &format!("agent:stream:active:{}", conversation_id), + &invocation_id, + ) + .await + { warn!(error = %e, "agent sse: failed to mark stream active"); } @@ -97,8 +108,11 @@ impl AppService { { Ok(id) => Some(id), Err(e) => { - let _ = tx.send(super::persistence::stream_error("failed to persist user message")); - let _ = self.cache + let _ = tx.send(super::persistence::stream_error( + "failed to persist user message", + )); + let _ = self + .cache .remove(&format!("agent:stream:active:{}", conversation_id)) .await; return Err(e); @@ -128,11 +142,14 @@ impl AppService { tokio::spawn(async move { let mut tracer = super::trace::TraceAccumulator::new( - trace_svc, invocation_id, conversation_id, + trace_svc, + invocation_id, + conversation_id, ); let mut phase: &str = "think"; while let Some(chunk) = chunk_rx.recv().await { - let (new_phase, sse_event) = process_chunk_with_phase(&chunk, phase, &mut tracer).await; + let (new_phase, sse_event) = + process_chunk_with_phase(&chunk, phase, &mut tracer).await; if new_phase != phase { phase = new_phase; let _ = tx.send(phase_sse(phase)); @@ -149,10 +166,13 @@ impl AppService { { Ok(Ok(inner)) => inner, Ok(Err(e)) => Err(ai::error::AiError::Response(e.to_string())), - Err(_) => Err(ai::error::AiError::Timeout { seconds: timeout_secs }), + Err(_) => Err(ai::error::AiError::Timeout { + seconds: timeout_secs, + }), }; - let _ = self_clone.cache + let _ = self_clone + .cache .remove(&format!("agent:stream:active:{}", conversation_id)) .await; @@ -166,7 +186,11 @@ impl AppService { .iter() .filter_map(|step| step.reasoning_content.clone()) .collect(); - if collected.is_empty() { None } else { Some(collected.join("\n\n")) } + if collected.is_empty() { + None + } else { + Some(collected.join("\n\n")) + } }; match self_clone @@ -184,85 +208,159 @@ impl AppService { for tc in &step.tool_calls { let _ = self_clone .agent_record_tool_call( - invocation_id, ctx_clone.session_id, - Some(conversation_id), Some(msg_id), - &tc.id, &tc.name, + invocation_id, + ctx_clone.session_id, + Some(conversation_id), + Some(msg_id), + &tc.id, + &tc.name, Some(&tc.arguments.to_string()), - tc.output.as_ref().map(|v| v.to_string()).as_deref(), + tc.output + .as_ref() + .map(|v| v.to_string()) + .as_deref(), tc.error.as_deref(), - if tc.error.is_some() { "error" } else { "success" }, + if tc.error.is_some() { + "error" + } else { + "success" + }, tc.elapsed_ms, ) .await; } } - let _ = self_clone.persist_billing_and_deduct( - &ctx_clone, invocation_id, - result.input_tokens, result.output_tokens, - ).await; + let _ = self_clone + .persist_billing_and_deduct( + &ctx_clone, + invocation_id, + result.input_tokens, + result.output_tokens, + ) + .await; - let _ = self_clone.agent_record_invocation( - invocation_id, ctx_clone.session_id, - Some(conversation_id), Some(msg_id), - ctx_clone.model_version_id, "completed", None, - ).await; + let _ = self_clone + .agent_record_invocation( + invocation_id, + ctx_clone.session_id, + Some(conversation_id), + Some(msg_id), + ctx_clone.model_version_id, + "completed", + None, + ) + .await; - let _ = self_clone.update_conversation_timestamp(conversation_id).await; + let _ = self_clone + .update_conversation_timestamp(conversation_id) + .await; - let title = agent_ctx.pending_title + let title = agent_ctx + .pending_title .filter(|t| !t.trim().is_empty()) .or_else(|| { // Only auto-set title from input when still default. - if conversation.title == "New Chat" || conversation.title.trim().is_empty() { - let first_line = first_input.lines().next().unwrap_or(&first_input); - let truncated: String = first_line.chars().take(50).collect(); - if truncated.trim().is_empty() { None } - else { Some(if first_line.len() > 50 { format!("{}…", truncated.trim_end()) } else { truncated.trim().to_string() }) } + if conversation.title == "New Chat" + || conversation.title.trim().is_empty() + { + let first_line = first_input + .lines() + .next() + .unwrap_or(&first_input); + let truncated: String = first_line + .chars() + .take(50) + .collect(); + if truncated.trim().is_empty() { + None + } else { + Some(if first_line.len() > 50 { + format!( + "{}…", + truncated.trim_end() + ) + } else { + truncated.trim().to_string() + }) + } } else { None } }); if let Some(new_title) = &title { - if self_clone.update_conversation_title(conversation_id, new_title).await.is_ok() { + if self_clone + .update_conversation_title( + conversation_id, + new_title, + ) + .await + .is_ok() + { let title_event = serde_json::json!({ "type": "title_updated", "conversation_id": conversation_id.to_string(), "title": new_title, }); - let _ = tx.send(format!("data: {}\n\n", title_event)); + let _ = tx.send(format!( + "data: {}\n\n", + title_event + )); } } if !agent_ctx.pending_memories.is_empty() { - let _ = self_clone.agent_persist_memories( - ctx_clone.session_id, &agent_ctx.pending_memories, - ).await; + let _ = self_clone + .agent_persist_memories( + ctx_clone.session_id, + &agent_ctx.pending_memories, + ) + .await; } - let _ = tx.send(done_sse_with_phase(msg_id, &result.output, "summarize")); + let _ = tx.send(done_sse_with_phase( + msg_id, + &result.output, + "summarize", + )); info!(invocation_id = %invocation_id, message_id = %msg_id, "agent sse stream completed"); } Err(e) => { error!(error = %e, "sse: failed to persist assistant message"); - let _ = tx.send(super::persistence::stream_error("persistence failed")); + let _ = tx.send(super::persistence::stream_error( + "persistence failed", + )); } } } Err(e) => { warn!(invocation_id = %invocation_id, error = %e, "agent sse stream failed"); - let _ = tx.send(super::persistence::stream_error(&e.to_string())); + let _ = tx + .send(super::persistence::stream_error(&e.to_string())); let error_content = format!( - "I encountered an error while processing your request: {}", e + "I encountered an error while processing your request: {}", + e ); - let _ = self_clone.persist_assistant_message( - conversation_id, ctx_clone.session_id, &error_content, None, invocation_id, - ).await; - let _ = self_clone.agent_record_invocation( - invocation_id, ctx_clone.session_id, - Some(conversation_id), user_message_id, - ctx_clone.model_version_id, "failed", Some(&e.to_string()), - ).await; + let _ = self_clone + .persist_assistant_message( + conversation_id, + ctx_clone.session_id, + &error_content, + None, + invocation_id, + ) + .await; + let _ = self_clone + .agent_record_invocation( + invocation_id, + ctx_clone.session_id, + Some(conversation_id), + user_message_id, + ctx_clone.model_version_id, + "failed", + Some(&e.to_string()), + ) + .await; } } }); @@ -284,27 +382,59 @@ async fn process_chunk_with_phase( tracer.feed_text(content).await; ("answer", Some(format_chunk_sse(chunk))) } - RigStreamChunk::ToolCallStarted { tool_call_id, tool_name, arguments } => { - let args_val: Value = serde_json::from_str(arguments).unwrap_or(Value::Null); - tracer.feed_tool_call(tool_call_id, tool_name, &args_val).await; + RigStreamChunk::ToolCallStarted { + tool_call_id, + tool_name, + arguments, + } => { + let args_val: Value = + serde_json::from_str(arguments).unwrap_or(Value::Null); + tracer + .feed_tool_call(tool_call_id, tool_name, &args_val) + .await; ("act", Some(format_chunk_sse(chunk))) } - RigStreamChunk::ToolCallFinished { tool_call_id, tool_name, output, error } => { + RigStreamChunk::ToolCallFinished { + tool_call_id, + tool_name, + output, + error, + } => { let out_val: Option = match output { o if o.is_empty() => None, o => serde_json::from_str(o).ok(), }; - tracer.feed_tool_result(tool_call_id, tool_name, out_val.as_ref(), error.as_deref(), 0).await; + tracer + .feed_tool_result( + tool_call_id, + tool_name, + out_val.as_ref(), + error.as_deref(), + 0, + ) + .await; ("act", Some(format_chunk_sse(chunk))) } - RigStreamChunk::Final { content, input_tokens, output_tokens } => { - tracer.finish(content, *input_tokens as i64, *output_tokens as i64).await; + RigStreamChunk::Final { + content, + input_tokens, + output_tokens, + } => { + tracer + .finish(content, *input_tokens as i64, *output_tokens as i64) + .await; ("summarize", Some(format_chunk_sse(chunk))) } RigStreamChunk::Failed { .. } => ("summarize", None), - RigStreamChunk::SubagentStarted { .. } => ("act", Some(format_chunk_sse(chunk))), - RigStreamChunk::SubagentCompleted { .. } => ("act", Some(format_chunk_sse(chunk))), - RigStreamChunk::SubagentFailed { .. } => ("summarize", Some(format_chunk_sse(chunk))), + RigStreamChunk::SubagentStarted { .. } => { + ("act", Some(format_chunk_sse(chunk))) + } + RigStreamChunk::SubagentCompleted { .. } => { + ("act", Some(format_chunk_sse(chunk))) + } + RigStreamChunk::SubagentFailed { .. } => { + ("summarize", Some(format_chunk_sse(chunk))) + } } } @@ -316,26 +446,44 @@ fn format_chunk_sse(chunk: &RigStreamChunk) -> String { RigStreamChunk::Thinking { index, content } => json!({ "type": "thinking", "index": index, "content": content, }), - RigStreamChunk::ToolCallStarted { tool_call_id, tool_name, arguments } => json!({ + RigStreamChunk::ToolCallStarted { + tool_call_id, + tool_name, + arguments, + } => json!({ "type": "tool_call_started", "tool_call_id": tool_call_id, "tool_name": tool_name, "arguments": arguments, }), - RigStreamChunk::ToolCallFinished { tool_call_id, tool_name, output, error } => json!({ + RigStreamChunk::ToolCallFinished { + tool_call_id, + tool_name, + output, + error, + } => json!({ "type": "tool_call_finished", "tool_call_id": tool_call_id, "tool_name": tool_name, "output": output, "error": error, }), - RigStreamChunk::SubagentStarted { subagent_id, role, task } => json!({ + RigStreamChunk::SubagentStarted { + subagent_id, + role, + task, + } => json!({ "type": "subagent_started", "subagent_id": subagent_id, "role": role, "task": task, }), - RigStreamChunk::SubagentCompleted { subagent_id, role, task, output } => json!({ + RigStreamChunk::SubagentCompleted { + subagent_id, + role, + task, + output, + } => json!({ "type": "subagent_completed", "subagent_id": subagent_id, "role": role, @@ -346,7 +494,9 @@ fn format_chunk_sse(chunk: &RigStreamChunk) -> String { "type": "subagent_failed", "error": error, }), - RigStreamChunk::Final { .. } | RigStreamChunk::Failed { .. } => return String::new(), + RigStreamChunk::Final { .. } | RigStreamChunk::Failed { .. } => { + return String::new(); + } }; format!("data: {}\n\n", payload) } diff --git a/lib/service/agent/tools.rs b/lib/service/agent/tools.rs index 6aa4668..e1f50ca 100644 --- a/lib/service/agent/tools.rs +++ b/lib/service/agent/tools.rs @@ -1,7 +1,7 @@ use ai::error::{AiError, AiResult}; use ai::tool::tools::FunctionCall; use async_trait::async_trait; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use super::run::AppAgentContext; pub struct SetTitleTool; @@ -44,10 +44,10 @@ impl FunctionCall for SetTitleTool { context: &mut Self::Context, args: Value, ) -> AiResult { - let title = args - .get("title") - .and_then(|v| v.as_str()) - .ok_or_else(|| AiError::Config("title parameter is required".to_string()))?; + let title = + args.get("title").and_then(|v| v.as_str()).ok_or_else(|| { + AiError::Config("title parameter is required".to_string()) + })?; let title = title.trim(); if title.is_empty() { diff --git a/lib/service/agent/trace.rs b/lib/service/agent/trace.rs index 975acd1..45c8733 100644 --- a/lib/service/agent/trace.rs +++ b/lib/service/agent/trace.rs @@ -1,11 +1,11 @@ use chrono::Utc; use db::sqlx; use model::agent::AgentTraceModel; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use uuid::Uuid; -use crate::error::AppError; use crate::AppService; +use crate::error::AppError; pub struct TraceContext { pub invocation_id: Uuid, @@ -88,7 +88,8 @@ impl AppService { .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; - let conversation_id = rows.first().map(|r| r.conversation).unwrap_or(Uuid::nil()); + let conversation_id = + rows.first().map(|r| r.conversation).unwrap_or(Uuid::nil()); Ok(TraceReplay { invocation_id, @@ -126,8 +127,10 @@ impl AppService { .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; - let mut grouped: std::collections::BTreeMap> = - std::collections::BTreeMap::new(); + let mut grouped: std::collections::BTreeMap< + Uuid, + Vec, + > = std::collections::BTreeMap::new(); for row in rows { grouped.entry(row.invocation).or_default().push(row); } @@ -169,9 +172,16 @@ pub struct TraceAccumulator { } impl TraceAccumulator { - pub fn new(svc: AppService, invocation_id: Uuid, conversation_id: Uuid) -> Self { + pub fn new( + svc: AppService, + invocation_id: Uuid, + conversation_id: Uuid, + ) -> Self { Self { - ctx: TraceContext { invocation_id, conversation_id }, + ctx: TraceContext { + invocation_id, + conversation_id, + }, seq: 0, think_buf: String::new(), answer_buf: String::new(), @@ -191,10 +201,16 @@ impl TraceAccumulator { self.flush_think().await; } self.answer_buf.push_str(chunk); - self.answer_tokens += (chunk.chars().count() as f64 / 2.5).ceil() as i64; + self.answer_tokens += + (chunk.chars().count() as f64 / 2.5).ceil() as i64; } - pub async fn feed_tool_call(&mut self, tool_call_id: &str, tool_name: &str, args: &Value) { + pub async fn feed_tool_call( + &mut self, + tool_call_id: &str, + tool_name: &str, + args: &Value, + ) { if !self.answer_buf.is_empty() { self.flush_answer().await; } @@ -208,49 +224,83 @@ impl TraceAccumulator { self.seq += 1; } - pub async fn feed_tool_result(&mut self, tool_call_id: &str, tool_name: &str, - output: Option<&Value>, error: Option<&str>, elapsed_ms: i64) { - let _ = self.svc.trace_record( - &self.ctx, self.seq, "act", - None, - None, - Some(&json!({ - "tool_call_id": tool_call_id, - "name": tool_name, - "output": output, - "error": error, - "elapsed_ms": elapsed_ms, - })), - None, None, None, - ).await; + pub async fn feed_tool_result( + &mut self, + tool_call_id: &str, + tool_name: &str, + output: Option<&Value>, + error: Option<&str>, + elapsed_ms: i64, + ) { + let _ = self + .svc + .trace_record( + &self.ctx, + self.seq, + "act", + None, + None, + Some(&json!({ + "tool_call_id": tool_call_id, + "name": tool_name, + "output": output, + "error": error, + "elapsed_ms": elapsed_ms, + })), + None, + None, + None, + ) + .await; self.seq += 1; } - pub async fn finish(&mut self, output: &str, input_tokens: i64, output_tokens: i64) { + pub async fn finish( + &mut self, + output: &str, + input_tokens: i64, + output_tokens: i64, + ) { if !self.think_buf.is_empty() { self.flush_think().await; } if !self.answer_buf.is_empty() { self.flush_answer().await; } - let _ = self.svc.trace_record( - &self.ctx, self.seq, "summarize", - Some(output), - None, None, - Some(input_tokens), Some(output_tokens), - None, - ).await; + let _ = self + .svc + .trace_record( + &self.ctx, + self.seq, + "summarize", + Some(output), + None, + None, + Some(input_tokens), + Some(output_tokens), + None, + ) + .await; } async fn flush_think(&mut self) { let content = std::mem::take(&mut self.think_buf); let tokens = self.think_tokens; self.think_tokens = 0; - let _ = self.svc.trace_record( - &self.ctx, self.seq, "think", - Some(&content), None, None, - Some(tokens), None, None, - ).await; + let _ = self + .svc + .trace_record( + &self.ctx, + self.seq, + "think", + Some(&content), + None, + None, + Some(tokens), + None, + None, + ) + .await; self.seq += 1; } @@ -258,11 +308,20 @@ impl TraceAccumulator { let content = std::mem::take(&mut self.answer_buf); let tokens = self.answer_tokens; self.answer_tokens = 0; - let _ = self.svc.trace_record( - &self.ctx, self.seq, "answer", - Some(&content), None, None, - None, Some(tokens), None, - ).await; + let _ = self + .svc + .trace_record( + &self.ctx, + self.seq, + "answer", + Some(&content), + None, + None, + None, + Some(tokens), + None, + ) + .await; self.seq += 1; } } diff --git a/lib/service/agent/types.rs b/lib/service/agent/types.rs index ff29692..e61d58e 100644 --- a/lib/service/agent/types.rs +++ b/lib/service/agent/types.rs @@ -56,7 +56,7 @@ pub struct AgentCostInfo { pub currency: String, } #[derive(Debug, Clone)] -pub(crate) struct SessionContext { +pub struct SessionContext { pub session_id: Uuid, pub user_id: Option, pub workspace_id: Option, @@ -77,7 +77,7 @@ pub(crate) struct SessionContext { pub billing_target: BillingTarget, } #[derive(Debug, Clone)] -pub(crate) struct BillingRecord { +pub struct BillingRecord { pub invocation_id: Uuid, pub session_id: Uuid, pub model_version_id: Uuid, @@ -94,7 +94,7 @@ pub(crate) struct BillingRecord { } #[derive(Debug, Clone)] #[allow(dead_code)] -pub(crate) struct RunPersistState { +pub struct RunPersistState { pub message_id: Uuid, pub conversation_id: Uuid, pub invocation_id: Uuid, diff --git a/lib/service/agent/workspace_tools/helpers.rs b/lib/service/agent/workspace_tools/helpers.rs index 16a5691..20146c2 100644 --- a/lib/service/agent/workspace_tools/helpers.rs +++ b/lib/service/agent/workspace_tools/helpers.rs @@ -18,14 +18,17 @@ pub(super) async fn require_workspace_member( user_id: Uuid, workspace_name: &str, ) -> AiResult { - let wk_id: Uuid = sqlx::query_scalar( - "SELECT id FROM workspace WHERE name = $1", - ) - .bind(workspace_name) - .fetch_optional(git.db.reader()) - .await - .map_err(AiError::Database)? - .ok_or_else(|| AiError::Config(format!("workspace '{workspace_name}' not found")))?; + let wk_id: Uuid = + sqlx::query_scalar("SELECT id FROM workspace WHERE name = $1") + .bind(workspace_name) + .fetch_optional(git.db.reader()) + .await + .map_err(AiError::Database)? + .ok_or_else(|| { + AiError::Config(format!( + "workspace '{workspace_name}' not found" + )) + })?; let is_member: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM wk_member \ @@ -47,13 +50,15 @@ pub(super) async fn require_workspace_member( } pub(super) fn git_ctx(ctx: &AppAgentContext) -> AiResult<&GitAgentContext> { - ctx.git - .as_ref() - .ok_or_else(|| AiError::Config("workspace tools are not available in this session".to_string())) + ctx.git.as_ref().ok_or_else(|| { + AiError::Config( + "workspace tools are not available in this session".to_string(), + ) + }) } pub(super) fn arg_str<'a>(args: &'a Value, key: &str) -> AiResult<&'a str> { - args.get(key) - .and_then(|v| v.as_str()) - .ok_or_else(|| AiError::Config(format!("'{key}' parameter is required"))) + args.get(key).and_then(|v| v.as_str()).ok_or_else(|| { + AiError::Config(format!("'{key}' parameter is required")) + }) } diff --git a/lib/service/agent/workspace_tools/workspace.rs b/lib/service/agent/workspace_tools/workspace.rs index 9cd5ac6..76581f6 100644 --- a/lib/service/agent/workspace_tools/workspace.rs +++ b/lib/service/agent/workspace_tools/workspace.rs @@ -2,7 +2,7 @@ use ai::error::{AiError, AiResult}; use ai::tool::tools::FunctionCall; use async_trait::async_trait; use db::sqlx; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use uuid::Uuid; use super::helpers::{arg_str, git_ctx, require_workspace_member}; @@ -11,18 +11,24 @@ use crate::agent::run::AppAgentContext; pub struct WorkspaceInfoTool; impl WorkspaceInfoTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for WorkspaceInfoTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for WorkspaceInfoTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "workspace_info" } + fn name(&self) -> &'static str { + "workspace_info" + } fn description(&self) -> &'static str { "Get information about a workspace: name, description, avatar." @@ -38,10 +44,15 @@ impl FunctionCall for WorkspaceInfoTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; - let _wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + let _wk_id = + require_workspace_member(git, ctx.user_id, workspace).await?; let row = sqlx::query_as::<_, (String, String, String, chrono::DateTime)>( "SELECT name, description, avatar_url, created_at FROM workspace WHERE name = $1", @@ -73,18 +84,24 @@ impl FunctionCall for WorkspaceInfoTool { pub struct WorkspaceMembersTool; impl WorkspaceMembersTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for WorkspaceMembersTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for WorkspaceMembersTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "workspace_members" } + fn name(&self) -> &'static str { + "workspace_members" + } fn description(&self) -> &'static str { "List all members of a workspace with their roles." @@ -101,11 +118,20 @@ impl FunctionCall for WorkspaceMembersTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; - let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; - let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(50).min(200); + let wk_id = + require_workspace_member(git, ctx.user_id, workspace).await?; + let limit = args + .get("limit") + .and_then(|v| v.as_i64()) + .unwrap_or(50) + .min(200); #[derive(sqlx::FromRow)] struct MemberRow { @@ -142,18 +168,24 @@ impl FunctionCall for WorkspaceMembersTool { pub struct WorkspaceGroupsTool; impl WorkspaceGroupsTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for WorkspaceGroupsTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for WorkspaceGroupsTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "workspace_groups" } + fn name(&self) -> &'static str { + "workspace_groups" + } fn description(&self) -> &'static str { "List all user groups in a workspace." @@ -169,10 +201,15 @@ impl FunctionCall for WorkspaceGroupsTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; - let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + let wk_id = + require_workspace_member(git, ctx.user_id, workspace).await?; #[derive(sqlx::FromRow)] struct GroupRow { @@ -215,18 +252,24 @@ impl FunctionCall for WorkspaceGroupsTool { pub struct WorkspaceGroupMembersTool; impl WorkspaceGroupMembersTool { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } impl Default for WorkspaceGroupMembersTool { - fn default() -> Self { Self::new() } + fn default() -> Self { + Self::new() + } } #[async_trait] impl FunctionCall for WorkspaceGroupMembersTool { type Context = AppAgentContext; - fn name(&self) -> &'static str { "workspace_group_members" } + fn name(&self) -> &'static str { + "workspace_group_members" + } fn description(&self) -> &'static str { "List members of a specific workspace group." @@ -243,11 +286,16 @@ impl FunctionCall for WorkspaceGroupMembersTool { }) } - async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + async fn call( + &self, + ctx: &mut AppAgentContext, + args: Value, + ) -> AiResult { let git = git_ctx(ctx)?; let workspace = arg_str(&args, "workspace")?; let group_name = arg_str(&args, "group_name")?; - let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + let wk_id = + require_workspace_member(git, ctx.user_id, workspace).await?; #[derive(sqlx::FromRow)] struct MemberRow { @@ -268,11 +316,18 @@ impl FunctionCall for WorkspaceGroupMembersTool { .await .map_err(AiError::Database)?; - let members: Vec = rows.iter().map(|r| json!({ - "username": r.username, - "display_name": r.display_name, - })).collect(); + let members: Vec = rows + .iter() + .map(|r| { + json!({ + "username": r.username, + "display_name": r.display_name, + }) + }) + .collect(); - Ok(json!({ "group": group_name, "members": members, "count": members.len() })) + Ok( + json!({ "group": group_name, "members": members, "count": members.len() }), + ) } } diff --git a/lib/service/ai/card.rs b/lib/service/ai/card.rs index 22574c2..62a55dd 100644 --- a/lib/service/ai/card.rs +++ b/lib/service/ai/card.rs @@ -15,7 +15,7 @@ impl AppService { self.ai_card_get_inner(model_id).await } - pub(crate) async fn ai_card_get_inner( + pub async fn ai_card_get_inner( &self, model_id: uuid::Uuid, ) -> Result, AppError> { diff --git a/lib/service/ai/like.rs b/lib/service/ai/like.rs index 3e1b645..ba7172e 100644 --- a/lib/service/ai/like.rs +++ b/lib/service/ai/like.rs @@ -33,7 +33,7 @@ impl AppService { Ok(results) } - pub(crate) async fn ai_like_count_inner( + pub async fn ai_like_count_inner( &self, model_id: uuid::Uuid, ) -> Result { diff --git a/lib/service/ai/mod.rs b/lib/service/ai/mod.rs index 3361e6c..08caf67 100644 --- a/lib/service/ai/mod.rs +++ b/lib/service/ai/mod.rs @@ -15,7 +15,7 @@ use session::Session; use uuid::Uuid; impl AppService { - pub(crate) async fn ai_require_login( + pub async fn ai_require_login( &self, ctx: &Session, ) -> Result { diff --git a/lib/service/ai/model.rs b/lib/service/ai/model.rs index b495b56..dbaa770 100644 --- a/lib/service/ai/model.rs +++ b/lib/service/ai/model.rs @@ -144,7 +144,7 @@ impl AppService { }) } - pub(crate) async fn ai_provider_by_id( + pub async fn ai_provider_by_id( &self, id: uuid::Uuid, ) -> Result { diff --git a/lib/service/ai/sync.rs b/lib/service/ai/sync.rs index ecc1c1e..5cfa3e4 100644 --- a/lib/service/ai/sync.rs +++ b/lib/service/ai/sync.rs @@ -1,7 +1,7 @@ use std::time::Duration; -use ai::sync::{UpstreamModel, UpstreamPricing}; use ai::client::EndpointConfig; +use ai::sync::{UpstreamModel, UpstreamPricing}; use chrono::Utc; use db::sqlx::{self, types::Decimal}; use model::ai::{AiModelModel, AiModelVersionModel, AiProviderModel}; @@ -279,7 +279,10 @@ async fn upsert_pricing( let Some(p) = pricing else { return Ok(PricingResult::Skipped); }; - let input_million: Option = p.prompt.as_deref().and_then(parse_token_price_decimal) + let input_million: Option = p + .prompt + .as_deref() + .and_then(parse_token_price_decimal) .map(|per_token| per_token * Decimal::from(1_000_000u64)) .or_else(|| { p.input @@ -287,7 +290,10 @@ async fn upsert_pricing( .map(|v| Decimal::try_from(v).unwrap_or_default()) }); - let output_million: Option = p.completion.as_deref().and_then(parse_token_price_decimal) + let output_million: Option = p + .completion + .as_deref() + .and_then(parse_token_price_decimal) .map(|per_token| per_token * Decimal::from(1_000_000u64)) .or_else(|| { p.output @@ -295,7 +301,8 @@ async fn upsert_pricing( .map(|v| Decimal::try_from(v).unwrap_or_default()) }); - let cache_input: Option = p.cache_read + let cache_input: Option = p + .cache_read .filter(|v| *v > 0.0) .map(|v| Decimal::try_from(v).unwrap_or_default()); @@ -371,7 +378,9 @@ async fn disable_all_models(db: &db::AppDatabase) -> Result { Ok(result.rows_affected() as i64) } -async fn deactivate_orphaned_models(db: &db::AppDatabase) -> Result { +async fn deactivate_orphaned_models( + db: &db::AppDatabase, +) -> Result { let now = Utc::now(); sqlx::query( "UPDATE ai_model_version SET enabled = false, updated_at = $1 \ @@ -425,24 +434,25 @@ async fn sync_models_from_upstream( } }; - let (model_record, _is_new) = match upsert_model(db, provider.id, model).await { - Ok((m, created)) => { - if created { - models_created += 1; - } else { - models_updated += 1; + let (model_record, _is_new) = + match upsert_model(db, provider.id, model).await { + Ok((m, created)) => { + if created { + models_created += 1; + } else { + models_updated += 1; + } + (m, created) } - (m, created) - } - Err(e) => { - tracing::warn!( - model = %model.id, - error = %e, - "sync: upsert_model error" - ); - continue; - } - }; + Err(e) => { + tracing::warn!( + model = %model.id, + error = %e, + "sync: upsert_model error" + ); + continue; + } + }; let (version_record, version_is_new) = match upsert_version(db, model_record.id, &model.id).await { @@ -460,7 +470,9 @@ async fn sync_models_from_upstream( versions_created += 1; } - match upsert_pricing(db, version_record.id, model.pricing.as_ref()).await { + match upsert_pricing(db, version_record.id, model.pricing.as_ref()) + .await + { Ok(PricingResult::Created) => pricing_created += 1, Ok(PricingResult::Updated) => pricing_updated += 1, Ok(PricingResult::Skipped) => {} @@ -488,11 +500,15 @@ async fn sync_models_from_upstream( } impl AppService { - pub async fn sync_upstream_models(&self) -> Result { - let api_key = self - .config - .ai_api_key() - .map_err(|e| AppError::InternalServerError(format!("AI API key not configured: {}", e)))?; + pub async fn sync_upstream_models( + &self, + ) -> Result { + let api_key = self.config.ai_api_key().map_err(|e| { + AppError::InternalServerError(format!( + "AI API key not configured: {}", + e + )) + })?; let base_url = self.config.ai_basic_url().unwrap_or_default(); @@ -524,7 +540,9 @@ impl AppService { } } -pub fn spawn_model_sync_loop(service: AppService) -> tokio::task::JoinHandle<()> { +pub fn spawn_model_sync_loop( + service: AppService, +) -> tokio::task::JoinHandle<()> { let db = service.db.clone(); let config = service.config.clone(); diff --git a/lib/service/ai/tag.rs b/lib/service/ai/tag.rs index 70d9f4b..60ae83e 100644 --- a/lib/service/ai/tag.rs +++ b/lib/service/ai/tag.rs @@ -14,7 +14,7 @@ impl AppService { self.ai_tag_list_inner(model_id).await } - pub(crate) async fn ai_tag_list_inner( + pub async fn ai_tag_list_inner( &self, model_id: uuid::Uuid, ) -> Result, AppError> { diff --git a/lib/service/ai/version.rs b/lib/service/ai/version.rs index 1f6ac52..e6571f9 100644 --- a/lib/service/ai/version.rs +++ b/lib/service/ai/version.rs @@ -16,7 +16,7 @@ impl AppService { self.ai_version_list_inner(model_id).await } - pub(crate) async fn ai_version_list_inner( + pub async fn ai_version_list_inner( &self, model_id: uuid::Uuid, ) -> Result, AppError> { diff --git a/lib/service/auth/login.rs b/lib/service/auth/login.rs index 903e3bc..d8135fe 100644 --- a/lib/service/auth/login.rs +++ b/lib/service/auth/login.rs @@ -102,7 +102,7 @@ impl AppService { Ok(()) } - pub(crate) async fn auth_find_user_by_username( + pub async fn auth_find_user_by_username( &self, username: &str, ) -> Result { @@ -118,7 +118,7 @@ impl AppService { .ok_or(AppError::UserNotFound) } - pub(crate) async fn auth_find_user_by_email( + pub async fn auth_find_user_by_email( &self, email: &str, ) -> Result { @@ -136,7 +136,7 @@ impl AppService { .ok_or(AppError::UserNotFound) } - pub(crate) async fn auth_find_user_by_uid( + pub async fn auth_find_user_by_uid( &self, uid: uuid::Uuid, ) -> Result { @@ -152,9 +152,7 @@ impl AppService { .ok_or(AppError::UserNotFound) } - pub(crate) fn validate_password_strength( - password: &str, - ) -> Result<(), AppError> { + pub fn validate_password_strength(password: &str) -> Result<(), AppError> { if password.len() < 8 { return Err(AppError::PasswordTooWeak); } diff --git a/lib/service/auth/me.rs b/lib/service/auth/me.rs index 55e3beb..8f6f7af 100644 --- a/lib/service/auth/me.rs +++ b/lib/service/auth/me.rs @@ -18,7 +18,8 @@ impl AppService { pub async fn auth_me(&self, ctx: Session) -> Result { let user_id = ctx.user().ok_or(AppError::Unauthorized)?; let user = self.auth_find_user_by_uid(user_id).await?; - let unread = self.unread_notifications_count(user_id).await.unwrap_or(0); + let unread = + self.unread_notifications_count(user_id).await.unwrap_or(0); Ok(ContextMe { id: user.id, diff --git a/lib/service/auth/totp.rs b/lib/service/auth/totp.rs index 03bd0c4..7a0ff0a 100644 --- a/lib/service/auth/totp.rs +++ b/lib/service/auth/totp.rs @@ -150,7 +150,7 @@ impl AppService { Ok(()) } - pub(crate) async fn auth_2fa_verify( + pub async fn auth_2fa_verify( &self, user_uid: Uuid, code: &str, @@ -164,7 +164,7 @@ impl AppService { self.verify_2fa_or_backup_code(&two_fa, code).await } - pub(crate) async fn auth_2fa_status_by_uid( + pub async fn auth_2fa_status_by_uid( &self, user_uid: Uuid, ) -> Result { diff --git a/lib/service/git/commit_status.rs b/lib/service/git/commit_status.rs index 36f38e6..16f2124 100644 --- a/lib/service/git/commit_status.rs +++ b/lib/service/git/commit_status.rs @@ -4,8 +4,8 @@ use model::repos::RepoCommitStatusModel; use session::Session; use uuid::Uuid; -use crate::error::AppError; use crate::AppService; +use crate::error::AppError; #[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] pub struct CommitStatusResponse { @@ -79,9 +79,12 @@ impl AppService { commit_sha: &str, params: CreateCommitStatus, ) -> Result { - if !["pending", "success", "failure", "error"].contains(¶ms.state.as_str()) { + if !["pending", "success", "failure", "error"] + .contains(¶ms.state.as_str()) + { return Err(AppError::BadRequest( - "state must be one of: pending, success, failure, error".to_string(), + "state must be one of: pending, success, failure, error" + .to_string(), )); } @@ -112,17 +115,38 @@ impl AppService { } impl AppService { - pub async fn git_commit_status_list_by_name(&self, ctx: &Session, wk: &str, repo: &str, sha: &str) -> Result, AppError> { + pub async fn git_commit_status_list_by_name( + &self, + ctx: &Session, + wk: &str, + repo: &str, + sha: &str, + ) -> Result, AppError> { let repo = self.git_require_member(ctx, wk, repo).await?; self.git_commit_status_list(repo.id, sha).await } - pub async fn git_commit_status_combined_by_name(&self, ctx: &Session, wk: &str, repo: &str, sha: &str) -> Result { + pub async fn git_commit_status_combined_by_name( + &self, + ctx: &Session, + wk: &str, + repo: &str, + sha: &str, + ) -> Result { let repo = self.git_require_member(ctx, wk, repo).await?; self.git_commit_status_combined(repo.id, sha).await } - pub async fn git_commit_status_create_by_name(&self, ctx: &Session, user_id: Uuid, wk: &str, repo: &str, sha: &str, params: CreateCommitStatus) -> Result { + pub async fn git_commit_status_create_by_name( + &self, + ctx: &Session, + user_id: Uuid, + wk: &str, + repo: &str, + sha: &str, + params: CreateCommitStatus, + ) -> Result { let repo = self.git_require_member(ctx, wk, repo).await?; - self.git_commit_status_create(repo.id, user_id, sha, params).await + self.git_commit_status_create(repo.id, user_id, sha, params) + .await } } @@ -144,8 +168,14 @@ fn combined_state(statuses: &[CommitStatusResponse]) -> String { return "pending".to_string(); } let has = |s: &str| statuses.iter().any(|st| st.state == s); - (if has("error") { "error" } - else if has("failure") { "failure" } - else if has("pending") { "pending" } - else { "success" }).to_string() + (if has("error") { + "error" + } else if has("failure") { + "failure" + } else if has("pending") { + "pending" + } else { + "success" + }) + .to_string() } diff --git a/lib/service/git/compare.rs b/lib/service/git/compare.rs index aa4b8b1..f64bcf8 100644 --- a/lib/service/git/compare.rs +++ b/lib/service/git/compare.rs @@ -37,7 +37,9 @@ impl AppService { let mut client = CommitServiceClient::new(self.git.clone()); fn oid(s: &str) -> p::ObjectId { - p::ObjectId { value: s.to_string() } + p::ObjectId { + value: s.to_string(), + } } let base_info = client @@ -45,41 +47,60 @@ impl AppService { repo_id: repo.id.to_string(), oid: Some(oid(base)), })) - .await.map_err(rpc_err)?.into_inner(); + .await + .map_err(rpc_err)? + .into_inner(); let head_info = client .commit_info(tonic::Request::new(p::CommitInfoRequest { repo_id: repo.id.to_string(), oid: Some(oid(head)), })) - .await.map_err(rpc_err)?.into_inner(); + .await + .map_err(rpc_err)? + .into_inner(); let history = client .commit_history(tonic::Request::new(p::CommitHistoryRequest { repo_id: repo.id.to_string(), - limit: 250, skip: 0, sort: 0, + limit: 250, + skip: 0, + sort: 0, branch: Some(format!("{base}..{head}")), })) - .await.map_err(rpc_err)?.into_inner(); + .await + .map_err(rpc_err)? + .into_inner(); - let commits: Vec = history.commits.into_iter().map(|c| { - let author_name = c.author.as_ref().map(|a| a.name.clone()); - let author_email = c.author.as_ref().map(|a| a.email.clone()); - CompareCommit { - sha: c.oid.map(|o| o.value).unwrap_or_default(), - message: c.summary, - author_name, - author_email, - } - }).collect(); + let commits: Vec = history + .commits + .into_iter() + .map(|c| { + let author_name = c.author.as_ref().map(|a| a.name.clone()); + let author_email = c.author.as_ref().map(|a| a.email.clone()); + CompareCommit { + sha: c.oid.map(|o| o.value).unwrap_or_default(), + message: c.summary, + author_name, + author_email, + } + }) + .collect(); let diff = crate::AppService::git_diff_stats( - self, ctx, wk_name, repo_name, - base.to_string(), head.to_string(), None, - ).await?; + self, + ctx, + wk_name, + repo_name, + base.to_string(), + head.to_string(), + None, + ) + .await?; let stats = diff.result.and_then(|r| r.stats); - let files_changed = stats.as_ref().map(|s| s.files_changed).unwrap_or(0); + let files_changed = + stats.as_ref().map(|s| s.files_changed).unwrap_or(0); let insertions = stats.as_ref().map(|s| s.insertions).unwrap_or(0); let deletions = stats.as_ref().map(|s| s.deletions).unwrap_or(0); @@ -107,8 +128,11 @@ fn cmt(c: Option) -> CompareCommit { author_name, author_email, } - }).unwrap_or_else(|| CompareCommit { - sha: String::new(), message: String::new(), - author_name: None, author_email: None, + }) + .unwrap_or_else(|| CompareCommit { + sha: String::new(), + message: String::new(), + author_name: None, + author_email: None, }) } diff --git a/lib/service/git/contents.rs b/lib/service/git/contents.rs index 552c248..c6e7769 100644 --- a/lib/service/git/contents.rs +++ b/lib/service/git/contents.rs @@ -1,8 +1,5 @@ use db::sqlx; -use git::rpc::{ - proto as p, - proto::blob_service_client::BlobServiceClient, -}; +use git::rpc::{proto as p, proto::blob_service_client::BlobServiceClient}; use session::Session; use crate::{AppService, error::AppError, git::rpc_err}; @@ -34,21 +31,52 @@ pub struct UpdateContent { } impl AppService { - pub async fn git_contents_get_by_name(&self, ctx: &Session, wk: &str, repo: &str, path: &str, ref_name: Option<&str>) -> Result { + pub async fn git_contents_get_by_name( + &self, + ctx: &Session, + wk: &str, + repo: &str, + path: &str, + ref_name: Option<&str>, + ) -> Result { let _ = self.git_require_member(ctx, wk, repo).await?; self.git_contents_get(ctx, wk, repo, path, ref_name).await } - pub async fn git_contents_create_by_name(&self, ctx: &Session, wk: &str, repo: &str, path: &str, params: CreateContent) -> Result { + pub async fn git_contents_create_by_name( + &self, + ctx: &Session, + wk: &str, + repo: &str, + path: &str, + params: CreateContent, + ) -> Result { let _ = self.git_require_member(ctx, wk, repo).await?; self.git_contents_create(ctx, wk, repo, path, params).await } - pub async fn git_contents_update_by_name(&self, ctx: &Session, wk: &str, repo: &str, path: &str, params: UpdateContent) -> Result { + pub async fn git_contents_update_by_name( + &self, + ctx: &Session, + wk: &str, + repo: &str, + path: &str, + params: UpdateContent, + ) -> Result { let _ = self.git_require_member(ctx, wk, repo).await?; self.git_contents_update(ctx, wk, repo, path, params).await } - pub async fn git_contents_delete_by_name(&self, ctx: &Session, wk: &str, repo: &str, path: &str, msg: &str, sha: &str, branch: Option<&str>) -> Result<(), AppError> { + pub async fn git_contents_delete_by_name( + &self, + ctx: &Session, + wk: &str, + repo: &str, + path: &str, + msg: &str, + sha: &str, + branch: Option<&str>, + ) -> Result<(), AppError> { let _ = self.git_require_member(ctx, wk, repo).await?; - self.git_contents_delete(ctx, wk, repo, path, msg, sha, branch).await + self.git_contents_delete(ctx, wk, repo, path, msg, sha, branch) + .await } } @@ -64,7 +92,9 @@ impl AppService { let repo = self.git_require_member(ctx, wk_name, repo_name).await?; let mut blob_client = BlobServiceClient::new(self.git.clone()); - let empty_oid = p::ObjectId { value: String::new() }; + let empty_oid = p::ObjectId { + value: String::new(), + }; let resp = blob_client .blob_load(tonic::Request::new(p::BlobLoadRequest { @@ -72,14 +102,18 @@ impl AppService { id: Some(empty_oid.clone()), path: path.to_string(), })) - .await.map_err(rpc_err)?.into_inner(); + .await + .map_err(rpc_err)? + .into_inner(); let is_binary = blob_client .blob_is_binary(tonic::Request::new(p::BlobIsBinaryRequest { repo_id: repo.id.to_string(), id: Some(empty_oid.clone()), })) - .await.map(|r| r.into_inner().is_binary).unwrap_or(false); + .await + .map(|r| r.into_inner().is_binary) + .unwrap_or(false); let size_resp = blob_client .blob_size(tonic::Request::new(p::BlobSizeRequest { @@ -87,7 +121,9 @@ impl AppService { id: Some(empty_oid), path: String::new(), })) - .await.map_err(rpc_err)?.into_inner(); + .await + .map_err(rpc_err)? + .into_inner(); let blob_data = resp.blob; let size = size_resp.size as i64; @@ -109,7 +145,11 @@ impl AppService { } pub async fn git_contents_create( - &self, ctx: &Session, wk_name: &str, repo_name: &str, path: &str, + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + path: &str, params: CreateContent, ) -> Result { let repo = self.git_require_member(ctx, wk_name, repo_name).await?; @@ -127,15 +167,23 @@ impl AppService { let display = user_model.display_name.clone(); let username = user_model.username.clone(); - let author_name = if display.is_empty() { username.clone() } else { display }; + let author_name = if display.is_empty() { + username.clone() + } else { + display + }; let file_size = params.content.len() as i64; let content_bytes = params.content.clone().into_bytes(); - let mut client = p::commit_service_client::CommitServiceClient::new(self.git.clone()); + let mut client = p::commit_service_client::CommitServiceClient::new( + self.git.clone(), + ); let resp = client .create_commit(tonic::Request::new(p::CreateCommitRequest { repo_id: repo.id.to_string(), - branch: params.branch.unwrap_or_else(|| repo.default_branch.clone()), + branch: params + .branch + .unwrap_or_else(|| repo.default_branch.clone()), message: params.message, author_name: author_name.clone(), author_email: format!("{username}@gitdata.ai"), @@ -162,17 +210,31 @@ impl AppService { } pub async fn git_contents_update( - &self, _ctx: &Session, _wk: &str, _repo: &str, _path: &str, + &self, + _ctx: &Session, + _wk: &str, + _repo: &str, + _path: &str, _params: UpdateContent, ) -> Result { - Err(AppError::InternalServerError("contents update not yet implemented".to_string())) + Err(AppError::InternalServerError( + "contents update not yet implemented".to_string(), + )) } pub async fn git_contents_delete( - &self, _ctx: &Session, _wk: &str, _repo: &str, _path: &str, - _message: &str, _sha: &str, _branch: Option<&str>, + &self, + _ctx: &Session, + _wk: &str, + _repo: &str, + _path: &str, + _message: &str, + _sha: &str, + _branch: Option<&str>, ) -> Result<(), AppError> { - Err(AppError::InternalServerError("contents delete not yet implemented".to_string())) + Err(AppError::InternalServerError( + "contents delete not yet implemented".to_string(), + )) } } diff --git a/lib/service/git/init.rs b/lib/service/git/init.rs index 9140d2c..d8b6dcd 100644 --- a/lib/service/git/init.rs +++ b/lib/service/git/init.rs @@ -135,12 +135,16 @@ impl AppService { let name = params.name.trim(); if name.is_empty() { - return Err(AppError::BadRequest("repo name is required".to_string())); + return Err(AppError::BadRequest( + "repo name is required".to_string(), + )); } let source_url = params.source_url.trim(); if source_url.is_empty() { - return Err(AppError::BadRequest("source URL is required".to_string())); + return Err(AppError::BadRequest( + "source URL is required".to_string(), + )); } let existing = sqlx::query_scalar::<_, bool>( diff --git a/lib/service/git/mod.rs b/lib/service/git/mod.rs index cfe1078..5fb92dd 100644 --- a/lib/service/git/mod.rs +++ b/lib/service/git/mod.rs @@ -30,12 +30,12 @@ use session::Session; use crate::{AppService, error::AppError, session_user}; impl AppService { - pub(crate) async fn queue_sync(&self, repo_uid: uuid::Uuid) { + pub async fn queue_sync(&self, repo_uid: uuid::Uuid) { let sync_service = ReceiveSyncService::new(self.redis_pool.clone()); sync_service.send(RepoReceiveSyncTask { repo_uid }).await; } - pub(crate) async fn repo_resolve( + pub async fn repo_resolve( &self, wk_id: uuid::Uuid, repo_name: &str, @@ -53,7 +53,7 @@ impl AppService { .ok_or(AppError::RepoNotFound) } - pub(crate) async fn git_require_member( + pub async fn git_require_member( &self, ctx: &Session, wk_name: &str, @@ -65,7 +65,7 @@ impl AppService { self.repo_resolve(wk.id, repo_name).await } - pub(crate) async fn git_require_admin( + pub async fn git_require_admin( &self, ctx: &Session, wk_name: &str, @@ -78,7 +78,7 @@ impl AppService { } } -pub(crate) fn rpc_err(status: tonic::Status) -> AppError { +pub fn rpc_err(status: tonic::Status) -> AppError { match status.code() { tonic::Code::NotFound => { AppError::NotFound(status.message().to_string()) diff --git a/lib/service/git/protect.rs b/lib/service/git/protect.rs index 7bd369e..a309455 100644 --- a/lib/service/git/protect.rs +++ b/lib/service/git/protect.rs @@ -25,7 +25,7 @@ pub struct ProtectResponse { pub updated_at: chrono::DateTime, } -pub(crate) fn protect_response(p: RepoProtectModel) -> ProtectResponse { +pub fn protect_response(p: RepoProtectModel) -> ProtectResponse { ProtectResponse { id: p.id, repo: p.repo, diff --git a/lib/service/git/readme.rs b/lib/service/git/readme.rs index d7e252c..f199976 100644 --- a/lib/service/git/readme.rs +++ b/lib/service/git/readme.rs @@ -30,17 +30,12 @@ impl AppService { for name in &readme_names { match self - .git_tree_entry_by_path_from_commit_for_readme( - &repo.id, - name, - ) + .git_tree_entry_by_path_from_commit_for_readme(&repo.id, name) .await? { Some((content, oid)) => { return self - .git_blob_load_for_readme( - &repo, &content, &oid, - ) + .git_blob_load_for_readme(&repo, &content, &oid) .await; } None => continue, @@ -56,13 +51,15 @@ impl AppService { repo_id: &uuid::Uuid, readme_name: &str, ) -> Result, AppError> { + use crate::git::rpc_err; use git::rpc::proto as p; use git::rpc::proto::tree_service_client::TreeServiceClient; - use crate::git::rpc_err; let mut client = TreeServiceClient::new(self.git.clone()); let mut commit_client = - git::rpc::proto::commit_service_client::CommitServiceClient::new(self.git.clone()); + git::rpc::proto::commit_service_client::CommitServiceClient::new( + self.git.clone(), + ); let summary_resp = commit_client .commit_summary(tonic::Request::new(p::CommitSummaryRequest { repo_id: repo_id.to_string(), @@ -82,11 +79,15 @@ impl AppService { }; let resp = client - .tree_entry_by_path(tonic::Request::new(p::TreeEntryByPathRequest { - repo_id: repo_id.to_string(), - tree_oid: Some(p::ObjectId { value: tree_id.clone() }), - path: readme_name.to_string(), - })) + .tree_entry_by_path(tonic::Request::new( + p::TreeEntryByPathRequest { + repo_id: repo_id.to_string(), + tree_oid: Some(p::ObjectId { + value: tree_id.clone(), + }), + path: readme_name.to_string(), + }, + )) .await .map_err(rpc_err)? .into_inner(); @@ -109,15 +110,17 @@ impl AppService { _path: &str, oid: &str, ) -> Result, AppError> { + use crate::git::rpc_err; use git::rpc::proto as p; use git::rpc::proto::blob_service_client::BlobServiceClient; - use crate::git::rpc_err; let mut client = BlobServiceClient::new(self.git.clone()); let resp = client .blob_load(tonic::Request::new(p::BlobLoadRequest { repo_id: repo.id.to_string(), - id: Some(p::ObjectId { value: oid.to_string() }), + id: Some(p::ObjectId { + value: oid.to_string(), + }), path: String::new(), })) .await @@ -129,11 +132,11 @@ impl AppService { return Ok(None); } - let html = comrak::markdown_to_html(&content, &comrak::ComrakOptions::default()); + let html = comrak::markdown_to_html( + &content, + &comrak::ComrakOptions::default(), + ); - Ok(Some(super::readme::ReadmeDto { - content, - html, - })) + Ok(Some(super::readme::ReadmeDto { content, html })) } } diff --git a/lib/service/git/refs.rs b/lib/service/git/refs.rs index e861daa..5da6e35 100644 --- a/lib/service/git/refs.rs +++ b/lib/service/git/refs.rs @@ -3,8 +3,8 @@ use model::repos::RepoRefModel; use session::Session; use uuid::Uuid; -use crate::error::AppError; use crate::AppService; +use crate::error::AppError; #[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] pub struct GitRefResponse { @@ -16,11 +16,22 @@ pub struct GitRefResponse { } impl AppService { - pub async fn git_ref_list_by_name(&self, ctx: &Session, wk: &str, repo: &str) -> Result, AppError> { + pub async fn git_ref_list_by_name( + &self, + ctx: &Session, + wk: &str, + repo: &str, + ) -> Result, AppError> { let repo = self.git_require_member(ctx, wk, repo).await?; self.git_ref_list(repo.id).await } - pub async fn git_ref_get_by_name(&self, ctx: &Session, wk: &str, repo: &str, ref_name: &str) -> Result { + pub async fn git_ref_get_by_name( + &self, + ctx: &Session, + wk: &str, + repo: &str, + ref_name: &str, + ) -> Result { let repo = self.git_require_member(ctx, wk, repo).await?; self.git_ref_get(repo.id, ref_name).await } @@ -40,13 +51,16 @@ impl AppService { .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; - Ok(refs.into_iter().map(|r| GitRefResponse { - name: r.name, - kind: r.kind, - target_sha: r.target_sha, - is_default: r.is_default, - is_protected: r.is_protected, - }).collect()) + Ok(refs + .into_iter() + .map(|r| GitRefResponse { + name: r.name, + kind: r.kind, + target_sha: r.target_sha, + is_default: r.is_default, + is_protected: r.is_protected, + }) + .collect()) } pub async fn git_ref_get( diff --git a/lib/service/git/release.rs b/lib/service/git/release.rs index 967d8b9..1ff3094 100644 --- a/lib/service/git/release.rs +++ b/lib/service/git/release.rs @@ -1,13 +1,11 @@ use chrono::Utc; use db::sqlx; -use model::repos::{ - RepoReleaseAssetModel, RepoReleaseModel, -}; +use model::repos::{RepoReleaseAssetModel, RepoReleaseModel}; use session::Session; use uuid::Uuid; -use crate::error::AppError; use crate::AppService; +use crate::error::AppError; #[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] pub struct ReleaseResponse { @@ -132,7 +130,11 @@ impl AppService { let published_at = if params.draft { None } else { Some(now) }; let target = if let Some(ref sha) = params.target_commit_sha { - if !sha.trim().is_empty() { sha.clone() } else { self.default_branch_sha(repo_id).await? } + if !sha.trim().is_empty() { + sha.clone() + } else { + self.default_branch_sha(repo_id).await? + } } else { self.default_branch_sha(repo_id).await? }; @@ -177,7 +179,11 @@ impl AppService { let draft = params.draft.unwrap_or(existing.draft); let prerelease = params.prerelease.unwrap_or(existing.prerelease); - let published_at = if draft { None } else { existing.published_at.or(Some(now)) }; + let published_at = if draft { + None + } else { + existing.published_at.or(Some(now)) + }; let r = sqlx::query_as::<_, RepoReleaseModel>( "UPDATE repo_release SET tag_name=$1, name=$2, body=$3, draft=$4, \ @@ -208,14 +214,13 @@ impl AppService { repo_id: Uuid, release_id: Uuid, ) -> Result<(), AppError> { - let rows = sqlx::query( - "DELETE FROM repo_release WHERE id = $1 AND repo = $2", - ) - .bind(release_id) - .bind(repo_id) - .execute(self.db.writer()) - .await - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let rows = + sqlx::query("DELETE FROM repo_release WHERE id = $1 AND repo = $2") + .bind(release_id) + .bind(repo_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; if rows.rows_affected() == 0 { return Err(AppError::NotFound("release not found".to_string())); } @@ -235,8 +240,9 @@ impl AppService { .fetch_optional(self.db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; - let release_id = release_id - .ok_or_else(|| AppError::NotFound("release not found".to_string()))?; + let release_id = release_id.ok_or_else(|| { + AppError::NotFound("release not found".to_string()) + })?; self.git_release_delete(repo_id, release_id).await } @@ -254,14 +260,17 @@ impl AppService { .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; - Ok(assets.into_iter().map(|a| ReleaseAssetResponse { - id: a.id, - name: a.name, - content_type: a.content_type, - size: a.size, - download_count: a.download_count, - created_at: a.created_at, - }).collect()) + Ok(assets + .into_iter() + .map(|a| ReleaseAssetResponse { + id: a.id, + name: a.name, + content_type: a.content_type, + size: a.size, + download_count: a.download_count, + created_at: a.created_at, + }) + .collect()) } pub async fn git_release_asset_create( @@ -325,38 +334,90 @@ impl AppService { } impl AppService { - pub async fn git_release_list_by_name(&self, ctx: &Session, _user_id: Uuid, wk: &str, repo: &str) -> Result, AppError> { + pub async fn git_release_list_by_name( + &self, + ctx: &Session, + _user_id: Uuid, + wk: &str, + repo: &str, + ) -> Result, AppError> { let repo = self.git_require_member(ctx, wk, repo).await?; self.git_release_list(repo.id).await } - pub async fn git_release_get_by_name(&self, ctx: &Session, _user_id: Uuid, wk: &str, repo: &str, id: Uuid) -> Result { + pub async fn git_release_get_by_name( + &self, + ctx: &Session, + _user_id: Uuid, + wk: &str, + repo: &str, + id: Uuid, + ) -> Result { let repo = self.git_require_member(ctx, wk, repo).await?; self.git_release_get(repo.id, id).await } - pub async fn git_release_get_by_tag_name(&self, ctx: &Session, _user_id: Uuid, wk: &str, repo: &str, tag: &str) -> Result { + pub async fn git_release_get_by_tag_name( + &self, + ctx: &Session, + _user_id: Uuid, + wk: &str, + repo: &str, + tag: &str, + ) -> Result { let repo = self.git_require_member(ctx, wk, repo).await?; self.git_release_get_by_tag(repo.id, tag).await } - pub async fn git_release_create_by_name(&self, ctx: &Session, user_id: Uuid, wk: &str, repo: &str, params: CreateRelease) -> Result { + pub async fn git_release_create_by_name( + &self, + ctx: &Session, + user_id: Uuid, + wk: &str, + repo: &str, + params: CreateRelease, + ) -> Result { let repo = self.git_require_member(ctx, wk, repo).await?; self.git_release_create(ctx, repo.id, user_id, params).await } - pub async fn git_release_update_by_name(&self, ctx: &Session, _user_id: Uuid, wk: &str, repo: &str, id: Uuid, params: UpdateRelease) -> Result { + pub async fn git_release_update_by_name( + &self, + ctx: &Session, + _user_id: Uuid, + wk: &str, + repo: &str, + id: Uuid, + params: UpdateRelease, + ) -> Result { let repo = self.git_require_member(ctx, wk, repo).await?; self.git_release_update(repo.id, id, params).await } - pub async fn git_release_delete_by_name(&self, ctx: &Session, _user_id: Uuid, wk: &str, repo: &str, id: Uuid) -> Result<(), AppError> { + pub async fn git_release_delete_by_name( + &self, + ctx: &Session, + _user_id: Uuid, + wk: &str, + repo: &str, + id: Uuid, + ) -> Result<(), AppError> { let repo = self.git_require_member(ctx, wk, repo).await?; self.git_release_delete(repo.id, id).await } - pub async fn git_release_delete_by_tag_name(&self, ctx: &Session, _user_id: Uuid, wk: &str, repo: &str, tag: &str) -> Result<(), AppError> { + pub async fn git_release_delete_by_tag_name( + &self, + ctx: &Session, + _user_id: Uuid, + wk: &str, + repo: &str, + tag: &str, + ) -> Result<(), AppError> { let repo = self.git_require_member(ctx, wk, repo).await?; self.git_release_delete_by_tag(repo.id, tag).await } } impl AppService { - async fn default_branch_sha(&self, repo_id: Uuid) -> Result { + async fn default_branch_sha( + &self, + repo_id: Uuid, + ) -> Result { sqlx::query_scalar("SELECT target_sha FROM repo_ref WHERE repo = $1 AND is_default = true") .bind(repo_id) .fetch_optional(self.db.reader()) diff --git a/lib/service/git/repo.rs b/lib/service/git/repo.rs index 4315b68..902746b 100644 --- a/lib/service/git/repo.rs +++ b/lib/service/git/repo.rs @@ -4,7 +4,9 @@ use model::repos::{RepoModel, RepoTopicModel}; use serde::{Deserialize, Serialize}; use session::Session; -use crate::{AppService, Pagination, error::AppError, git::rpc_err, session_user}; +use crate::{ + AppService, Pagination, error::AppError, git::rpc_err, session_user, +}; #[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] pub struct RepoResponse { @@ -27,7 +29,7 @@ pub struct RepoResponse { pub updated_at: chrono::DateTime, } -pub(crate) fn repo_response(repo: RepoModel) -> RepoResponse { +pub fn repo_response(repo: RepoModel) -> RepoResponse { RepoResponse { id: repo.id, name: repo.name, @@ -421,7 +423,7 @@ impl AppService { wk_name: &str, ) -> Result)>, AppError> { let wk = sqlx::query_as::<_, (uuid::Uuid,)>( - "SELECT id FROM workspace WHERE name = $1" + "SELECT id FROM workspace WHERE name = $1", ) .bind(wk_name) .fetch_optional(self.db.reader()) diff --git a/lib/service/git/star.rs b/lib/service/git/star.rs index 089b034..d542c2f 100644 --- a/lib/service/git/star.rs +++ b/lib/service/git/star.rs @@ -24,13 +24,12 @@ impl AppService { .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; - let count: (i64,) = sqlx::query_as( - "SELECT COUNT(*) FROM repo_star WHERE repo = $1", - ) - .bind(repo.id) - .fetch_one(self.db.reader()) - .await - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let count: (i64,) = + sqlx::query_as("SELECT COUNT(*) FROM repo_star WHERE repo = $1") + .bind(repo.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok(serde_json::json!({ "starred": true, "count": count.0 })) } @@ -44,22 +43,19 @@ impl AppService { let user_uid = session_user(ctx)?; let repo = self.git_require_member(ctx, wk_name, repo_name).await?; - sqlx::query( - "DELETE FROM repo_star WHERE repo = $1 AND \"user\" = $2", - ) - .bind(repo.id) - .bind(user_uid) - .execute(self.db.writer()) - .await - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + sqlx::query("DELETE FROM repo_star WHERE repo = $1 AND \"user\" = $2") + .bind(repo.id) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; - let count: (i64,) = sqlx::query_as( - "SELECT COUNT(*) FROM repo_star WHERE repo = $1", - ) - .bind(repo.id) - .fetch_one(self.db.reader()) - .await - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let count: (i64,) = + sqlx::query_as("SELECT COUNT(*) FROM repo_star WHERE repo = $1") + .bind(repo.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok(serde_json::json!({ "starred": false, "count": count.0 })) } @@ -82,13 +78,12 @@ impl AppService { .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; - let count: (i64,) = sqlx::query_as( - "SELECT COUNT(*) FROM repo_star WHERE repo = $1", - ) - .bind(repo.id) - .fetch_one(self.db.reader()) - .await - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let count: (i64,) = + sqlx::query_as("SELECT COUNT(*) FROM repo_star WHERE repo = $1") + .bind(repo.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok(serde_json::json!({ "starred": exists.0, diff --git a/lib/service/git/watch.rs b/lib/service/git/watch.rs index b0dd850..d40574b 100644 --- a/lib/service/git/watch.rs +++ b/lib/service/git/watch.rs @@ -29,15 +29,16 @@ impl AppService { .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; - let count: (i64,) = sqlx::query_as( - "SELECT COUNT(*) FROM repo_watch WHERE repo = $1", - ) - .bind(repo.id) - .fetch_one(self.db.reader()) - .await - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let count: (i64,) = + sqlx::query_as("SELECT COUNT(*) FROM repo_watch WHERE repo = $1") + .bind(repo.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; - Ok(serde_json::json!({ "watching": true, "count": count.0, "level": level })) + Ok( + serde_json::json!({ "watching": true, "count": count.0, "level": level }), + ) } pub async fn git_repo_unwatch( @@ -49,22 +50,19 @@ impl AppService { let user_uid = session_user(ctx)?; let repo = self.git_require_member(ctx, wk_name, repo_name).await?; - sqlx::query( - "DELETE FROM repo_watch WHERE repo = $1 AND \"user\" = $2", - ) - .bind(repo.id) - .bind(user_uid) - .execute(self.db.writer()) - .await - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + sqlx::query("DELETE FROM repo_watch WHERE repo = $1 AND \"user\" = $2") + .bind(repo.id) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; - let count: (i64,) = sqlx::query_as( - "SELECT COUNT(*) FROM repo_watch WHERE repo = $1", - ) - .bind(repo.id) - .fetch_one(self.db.reader()) - .await - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let count: (i64,) = + sqlx::query_as("SELECT COUNT(*) FROM repo_watch WHERE repo = $1") + .bind(repo.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok(serde_json::json!({ "watching": false, "count": count.0 })) } @@ -88,13 +86,12 @@ impl AppService { .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; - let count: (i64,) = sqlx::query_as( - "SELECT COUNT(*) FROM repo_watch WHERE repo = $1", - ) - .bind(repo.id) - .fetch_one(self.db.reader()) - .await - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let count: (i64,) = + sqlx::query_as("SELECT COUNT(*) FROM repo_watch WHERE repo = $1") + .bind(repo.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok(serde_json::json!({ "watching": watch.as_ref().map(|(w, _)| *w).unwrap_or(false), diff --git a/lib/service/git/webhook.rs b/lib/service/git/webhook.rs index 4cf068c..a8ea001 100644 --- a/lib/service/git/webhook.rs +++ b/lib/service/git/webhook.rs @@ -38,7 +38,7 @@ pub struct WebhookResponse { pub updated_at: chrono::DateTime, } -pub(crate) fn webhook_response(w: RepoWebhookModel) -> WebhookResponse { +pub fn webhook_response(w: RepoWebhookModel) -> WebhookResponse { WebhookResponse { id: w.id, repo: w.repo, @@ -71,7 +71,7 @@ pub struct WebhookDeliveryResponse { pub created_at: chrono::DateTime, } -pub(crate) fn delivery_response( +pub fn delivery_response( d: RepoWebhookDeliveryModel, ) -> WebhookDeliveryResponse { WebhookDeliveryResponse { diff --git a/lib/service/issues/assignee.rs b/lib/service/issues/assignee.rs index 7898b18..4a060f3 100644 --- a/lib/service/issues/assignee.rs +++ b/lib/service/issues/assignee.rs @@ -96,7 +96,7 @@ impl AppService { self.issue_assignees(issue.id).await } - pub(crate) async fn issue_assignees( + pub async fn issue_assignees( &self, issue_id: uuid::Uuid, ) -> Result, AppError> { diff --git a/lib/service/issues/binding.rs b/lib/service/issues/binding.rs index 88529a1..753a431 100644 --- a/lib/service/issues/binding.rs +++ b/lib/service/issues/binding.rs @@ -218,7 +218,7 @@ impl AppService { self.issue_pull_requests(issue.id).await } - pub(crate) async fn issue_repos( + pub async fn issue_repos( &self, issue_id: uuid::Uuid, ) -> Result, AppError> { @@ -238,7 +238,7 @@ impl AppService { Ok(repos.into_iter().map(issue_repo_response).collect()) } - pub(crate) async fn issue_pull_requests( + pub async fn issue_pull_requests( &self, issue_id: uuid::Uuid, ) -> Result, AppError> { diff --git a/lib/service/issues/issue.rs b/lib/service/issues/issue.rs index b82e96e..f305c97 100644 --- a/lib/service/issues/issue.rs +++ b/lib/service/issues/issue.rs @@ -402,7 +402,7 @@ impl AppService { Ok(()) } - pub(crate) async fn issue_resolve( + pub async fn issue_resolve( &self, wk_id: uuid::Uuid, number: i64, @@ -459,7 +459,7 @@ impl AppService { }) } - pub(crate) async fn users_find_by_id( + pub async fn users_find_by_id( &self, uid: uuid::Uuid, ) -> Result { @@ -481,7 +481,7 @@ impl AppService { wk_name: &str, ) -> Result, AppError> { let wk = sqlx::query_as::<_, (uuid::Uuid,)>( - "SELECT id FROM workspace WHERE name = $1" + "SELECT id FROM workspace WHERE name = $1", ) .bind(wk_name) .fetch_optional(self.db.reader()) diff --git a/lib/service/issues/label.rs b/lib/service/issues/label.rs index e6fb740..f11ac19 100644 --- a/lib/service/issues/label.rs +++ b/lib/service/issues/label.rs @@ -265,7 +265,7 @@ impl AppService { self.issue_labels(issue.id).await } - pub(crate) async fn issue_labels( + pub async fn issue_labels( &self, issue_id: uuid::Uuid, ) -> Result, AppError> { diff --git a/lib/service/issues/milestone.rs b/lib/service/issues/milestone.rs index d5d7ab9..cdfa016 100644 --- a/lib/service/issues/milestone.rs +++ b/lib/service/issues/milestone.rs @@ -282,7 +282,7 @@ impl AppService { Ok(()) } - pub(crate) async fn issue_milestone( + pub async fn issue_milestone( &self, issue_id: uuid::Uuid, ) -> Result, AppError> { diff --git a/lib/service/issues/reaction.rs b/lib/service/issues/reaction.rs index f440d98..1bf2d52 100644 --- a/lib/service/issues/reaction.rs +++ b/lib/service/issues/reaction.rs @@ -143,7 +143,7 @@ impl AppService { self.issue_reactions_for(issue.id, Some(comment_id)).await } - pub(crate) async fn issue_reactions_for( + pub async fn issue_reactions_for( &self, issue_id: uuid::Uuid, comment_id: Option, diff --git a/lib/service/issues/types.rs b/lib/service/issues/types.rs index f6547fd..db502f1 100644 --- a/lib/service/issues/types.rs +++ b/lib/service/issues/types.rs @@ -107,7 +107,7 @@ pub struct IssueReactionResponse { pub created_at: chrono::DateTime, } -pub(crate) fn issue_author(user: UserModel) -> IssueAuthor { +pub fn issue_author(user: UserModel) -> IssueAuthor { IssueAuthor { username: user.username, display_name: non_empty(user.display_name), @@ -115,7 +115,7 @@ pub(crate) fn issue_author(user: UserModel) -> IssueAuthor { } } -pub(crate) fn label_response(label: LabelModel) -> LabelResponse { +pub fn label_response(label: LabelModel) -> LabelResponse { LabelResponse { id: label.id, name: label.name, @@ -124,9 +124,7 @@ pub(crate) fn label_response(label: LabelModel) -> LabelResponse { } } -pub(crate) fn milestone_response( - milestone: MilestoneModel, -) -> MilestoneResponse { +pub fn milestone_response(milestone: MilestoneModel) -> MilestoneResponse { MilestoneResponse { id: milestone.id, title: milestone.title, @@ -136,16 +134,14 @@ pub(crate) fn milestone_response( } } -pub(crate) fn issue_repo_response(repo: RepoModel) -> IssueRepoResponse { +pub fn issue_repo_response(repo: RepoModel) -> IssueRepoResponse { IssueRepoResponse { id: repo.id, name: repo.name, } } -pub(crate) fn issue_pr_response( - pr: PullRequestModel, -) -> IssuePullRequestResponse { +pub fn issue_pr_response(pr: PullRequestModel) -> IssuePullRequestResponse { IssuePullRequestResponse { id: pr.id, number: pr.number, diff --git a/lib/service/lib.rs b/lib/service/lib.rs index 7a19d43..c6a5fef 100644 --- a/lib/service/lib.rs +++ b/lib/service/lib.rs @@ -20,15 +20,15 @@ pub mod pull_request; pub mod user; pub mod users; pub mod workspace; -pub(crate) fn session_user(ctx: &Session) -> Result { +pub fn session_user(ctx: &Session) -> Result { ctx.user().ok_or(AppError::Unauthorized) } -pub(crate) fn non_empty(value: String) -> Option { +pub fn non_empty(value: String) -> Option { if value.is_empty() { None } else { Some(value) } } -pub(crate) fn constant_time_eq(a: &str, b: &str) -> bool { +pub fn constant_time_eq(a: &str, b: &str) -> bool { if a.len() != b.len() { return false; } diff --git a/lib/service/pull_request/assignee.rs b/lib/service/pull_request/assignee.rs index a94a015..6bf7f0f 100644 --- a/lib/service/pull_request/assignee.rs +++ b/lib/service/pull_request/assignee.rs @@ -86,7 +86,7 @@ impl AppService { self.pr_assignees_list(pr.id).await } - pub(crate) async fn pr_assignees_list( + pub async fn pr_assignees_list( &self, pr_id: uuid::Uuid, ) -> Result, AppError> { diff --git a/lib/service/pull_request/label.rs b/lib/service/pull_request/label.rs index fb78cae..609f9c4 100644 --- a/lib/service/pull_request/label.rs +++ b/lib/service/pull_request/label.rs @@ -77,7 +77,7 @@ impl AppService { self.pr_labels(pr.id).await } - pub(crate) async fn pr_labels( + pub async fn pr_labels( &self, pr_id: uuid::Uuid, ) -> Result, AppError> { diff --git a/lib/service/pull_request/mod.rs b/lib/service/pull_request/mod.rs index ef5472b..452a15a 100644 --- a/lib/service/pull_request/mod.rs +++ b/lib/service/pull_request/mod.rs @@ -15,7 +15,7 @@ use session::Session; use crate::{AppService, error::AppError, git::rpc_err}; impl AppService { - pub(crate) async fn pr_resolve( + pub async fn pr_resolve( &self, repo_id: uuid::Uuid, number: i64, @@ -34,7 +34,7 @@ impl AppService { .ok_or(AppError::PullRequestNotFound) } - pub(crate) async fn pr_resolve_repo( + pub async fn pr_resolve_repo( &self, ctx: &Session, wk_name: &str, @@ -44,7 +44,7 @@ impl AppService { Ok((repo.id, repo)) } - pub(crate) async fn pr_resolve_repo_admin( + pub async fn pr_resolve_repo_admin( &self, ctx: &Session, wk_name: &str, @@ -54,7 +54,7 @@ impl AppService { Ok((repo.id, repo)) } - pub(crate) async fn branch_head_sha( + pub async fn branch_head_sha( &self, repo_id: uuid::Uuid, branch: &str, diff --git a/lib/service/pull_request/pull_request.rs b/lib/service/pull_request/pull_request.rs index 7c2df1a..535a59a 100644 --- a/lib/service/pull_request/pull_request.rs +++ b/lib/service/pull_request/pull_request.rs @@ -215,10 +215,13 @@ impl AppService { ) -> Result { if let Some(ref state) = params.state { return match state.as_str() { - "closed" => self.pr_close(ctx, wk_name, repo_name, number).await, + "closed" => { + self.pr_close(ctx, wk_name, repo_name, number).await + } "open" => self.pr_reopen(ctx, wk_name, repo_name, number).await, other => Err(AppError::BadRequest(format!( - "invalid state '{}': must be 'open' or 'closed'", other + "invalid state '{}': must be 'open' or 'closed'", + other ))), }; } diff --git a/lib/service/pull_request/review.rs b/lib/service/pull_request/review.rs index 5a94d2b..92ef981 100644 --- a/lib/service/pull_request/review.rs +++ b/lib/service/pull_request/review.rs @@ -207,7 +207,7 @@ impl AppService { }) } - pub(crate) async fn pr_reviews_list( + pub async fn pr_reviews_list( &self, pr_id: uuid::Uuid, ) -> Result, AppError> { diff --git a/lib/service/user/accessibility.rs b/lib/service/user/accessibility.rs index 5d92601..224c478 100644 --- a/lib/service/user/accessibility.rs +++ b/lib/service/user/accessibility.rs @@ -69,7 +69,7 @@ impl AppService { Ok(config) } - pub(crate) async fn user_accessibility_config( + pub async fn user_accessibility_config( &self, user_uid: uuid::Uuid, ) -> Result { diff --git a/lib/service/user/appearance.rs b/lib/service/user/appearance.rs index d33a5b2..669d462 100644 --- a/lib/service/user/appearance.rs +++ b/lib/service/user/appearance.rs @@ -69,7 +69,7 @@ impl AppService { Ok(config) } - pub(crate) async fn user_appearance_config( + pub async fn user_appearance_config( &self, user_uid: uuid::Uuid, ) -> Result { diff --git a/lib/service/user/chpc.rs b/lib/service/user/chpc.rs index 1771fdf..25adca6 100644 --- a/lib/service/user/chpc.rs +++ b/lib/service/user/chpc.rs @@ -43,7 +43,7 @@ impl AppService { .await } - pub(crate) async fn user_contribution_heatmap_for_user( + pub async fn user_contribution_heatmap_for_user( &self, user_uid: uuid::Uuid, username: String, @@ -120,7 +120,7 @@ impl AppService { self.invalidate_user_heatmap_cache(user_uid).await } - pub(crate) async fn invalidate_user_heatmap_cache( + pub async fn invalidate_user_heatmap_cache( &self, user_uid: uuid::Uuid, ) -> Result<(), AppError> { diff --git a/lib/service/user/notification.rs b/lib/service/user/notification.rs index 5cdf8e6..07c3c2f 100644 --- a/lib/service/user/notification.rs +++ b/lib/service/user/notification.rs @@ -177,7 +177,7 @@ impl AppService { Ok(rows.into_iter().map(Into::into).collect()) } - pub(crate) async fn unread_notifications_count( + pub async fn unread_notifications_count( &self, user_uid: uuid::Uuid, ) -> Result { @@ -193,7 +193,7 @@ impl AppService { Ok(row.0.unwrap_or(0)) } - pub(crate) async fn user_notification_config( + pub async fn user_notification_config( &self, user_uid: uuid::Uuid, ) -> Result { diff --git a/lib/service/user/privacy.rs b/lib/service/user/privacy.rs index 86a9564..91515d6 100644 --- a/lib/service/user/privacy.rs +++ b/lib/service/user/privacy.rs @@ -76,7 +76,7 @@ impl AppService { Ok(config) } - pub(crate) async fn user_privacy_config( + pub async fn user_privacy_config( &self, user_uid: uuid::Uuid, ) -> Result { diff --git a/lib/service/user/profile.rs b/lib/service/user/profile.rs index bff7f3c..c340352 100644 --- a/lib/service/user/profile.rs +++ b/lib/service/user/profile.rs @@ -8,12 +8,8 @@ use uuid::Uuid; use crate::{AppService, error::AppError, session_user}; /// Allowed image MIME types for avatars. -const ALLOWED_AVATAR_TYPES: &[&str] = &[ - "image/png", - "image/jpeg", - "image/webp", - "image/gif", -]; +const ALLOWED_AVATAR_TYPES: &[&str] = + &["image/png", "image/jpeg", "image/webp", "image/gif"]; /// Maximum avatar file size: 5 MB. const MAX_AVATAR_SIZE: usize = 5 * 1024 * 1024; @@ -113,10 +109,8 @@ impl AppService { } let ext = extension_from_content_type(content_type); - let key = format!( - "avatars/users/{user_uid}-{}.{ext}", - uuid::Uuid::now_v7() - ); + let key = + format!("avatars/users/{user_uid}-{}.{ext}", uuid::Uuid::now_v7()); let stored = self .storage @@ -145,7 +139,7 @@ impl AppService { }) } - pub(crate) async fn user_profile_config( + pub async fn user_profile_config( &self, user_uid: Uuid, ) -> Result { diff --git a/lib/service/users/summary.rs b/lib/service/users/summary.rs index 98a4895..3ff0383 100644 --- a/lib/service/users/summary.rs +++ b/lib/service/users/summary.rs @@ -35,7 +35,7 @@ impl AppService { Ok(user.avatar_url) } - pub(crate) async fn users_find_active_user_by_username( + pub async fn users_find_active_user_by_username( &self, username: &str, ) -> Result { diff --git a/lib/service/workspace/group.rs b/lib/service/workspace/group.rs index a3f35f4..ba10ac8 100644 --- a/lib/service/workspace/group.rs +++ b/lib/service/workspace/group.rs @@ -208,7 +208,7 @@ impl AppService { .collect()) } - pub(crate) async fn workspace_group_by_name( + pub async fn workspace_group_by_name( &self, wk_id: uuid::Uuid, group_name: &str, diff --git a/lib/service/workspace/types.rs b/lib/service/workspace/types.rs index 22492e7..ac30ecf 100644 --- a/lib/service/workspace/types.rs +++ b/lib/service/workspace/types.rs @@ -36,7 +36,7 @@ pub struct WorkspaceGroupResponse { pub created_at: chrono::DateTime, } -pub(crate) fn normalize_name(name: &str) -> Result { +pub fn normalize_name(name: &str) -> Result { let name = name.trim(); if name.is_empty() { return Err(AppError::BadRequest( @@ -60,7 +60,7 @@ pub(crate) fn normalize_name(name: &str) -> Result { Ok(name.to_string()) } -pub(crate) fn workspace_response( +pub fn workspace_response( wk: WorkspaceModel, owner: bool, admin: bool, @@ -75,7 +75,7 @@ pub(crate) fn workspace_response( } } -pub(crate) fn member_response( +pub fn member_response( user: UserModel, owner: bool, admin: bool, @@ -91,7 +91,7 @@ pub(crate) fn member_response( } } -pub(crate) fn group_response(group: WkGroupModel) -> WorkspaceGroupResponse { +pub fn group_response(group: WkGroupModel) -> WorkspaceGroupResponse { WorkspaceGroupResponse { name: group.name, avatar_url: group.avatar_url, @@ -112,7 +112,7 @@ impl From for WorkspaceModel { } #[derive(db::sqlx::FromRow)] -pub(crate) struct WorkspaceListRow { +pub struct WorkspaceListRow { id: uuid::Uuid, pub name: String, pub description: String, @@ -136,7 +136,7 @@ impl From for WorkspaceMemberResponse { } #[derive(db::sqlx::FromRow)] -pub(crate) struct WorkspaceMemberRow { +pub struct WorkspaceMemberRow { pub username: String, pub display_name: String, pub avatar_url: String, @@ -159,7 +159,7 @@ impl From for WorkspaceMemberResponse { } #[derive(db::sqlx::FromRow)] -pub(crate) struct WorkspaceGroupMemberRow { +pub struct WorkspaceGroupMemberRow { pub username: String, pub display_name: String, pub avatar_url: String, diff --git a/lib/service/workspace/workspace.rs b/lib/service/workspace/workspace.rs index 683ee57..990e6b6 100644 --- a/lib/service/workspace/workspace.rs +++ b/lib/service/workspace/workspace.rs @@ -9,12 +9,8 @@ use super::types::{ }; use crate::{AppService, error::AppError, session_user}; -const ALLOWED_AVATAR_TYPES: &[&str] = &[ - "image/png", - "image/jpeg", - "image/webp", - "image/gif", -]; +const ALLOWED_AVATAR_TYPES: &[&str] = + &["image/png", "image/jpeg", "image/webp", "image/gif"]; const MAX_AVATAR_SIZE: usize = 5 * 1024 * 1024; #[derive(Debug, Clone, Serialize, utoipa::ToSchema)] @@ -248,21 +244,19 @@ impl AppService { AppError::AvatarUploadError(format!("storage error: {e}")) })?; - sqlx::query( - "UPDATE workspace SET avatar_url = $1 WHERE id = $2", - ) - .bind(&stored.url) - .bind(wk.id) - .execute(self.db.writer()) - .await - .map_err(|e| AppError::DatabaseError(e.to_string()))?; + sqlx::query("UPDATE workspace SET avatar_url = $1 WHERE id = $2") + .bind(&stored.url) + .bind(wk.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok(AvatarUploadResponse { avatar_url: stored.url, }) } - pub(crate) async fn workspace_resolve( + pub async fn workspace_resolve( &self, name: &str, ) -> Result { @@ -298,7 +292,7 @@ impl AppService { Err(AppError::NotFound("workspace not found".to_string())) } - pub(crate) async fn workspace_member( + pub async fn workspace_member( &self, wk_id: uuid::Uuid, user_uid: uuid::Uuid, @@ -315,7 +309,7 @@ impl AppService { .ok_or(AppError::PermissionDenied) } - pub(crate) async fn workspace_require_member( + pub async fn workspace_require_member( &self, wk_id: uuid::Uuid, user_uid: uuid::Uuid, @@ -323,7 +317,7 @@ impl AppService { self.workspace_member(wk_id, user_uid).await } - pub(crate) async fn workspace_require_admin( + pub async fn workspace_require_admin( &self, wk_id: uuid::Uuid, user_uid: uuid::Uuid, @@ -336,7 +330,7 @@ impl AppService { } } - pub(crate) async fn workspace_require_owner( + pub async fn workspace_require_owner( &self, wk_id: uuid::Uuid, user_uid: uuid::Uuid, diff --git a/lib/session/config.rs b/lib/session/config.rs index a184d1f..77204e9 100644 --- a/lib/session/config.rs +++ b/lib/session/config.rs @@ -84,11 +84,11 @@ pub enum CookieContentSecurity { Signed, } -pub(crate) const fn default_ttl() -> Duration { +pub const fn default_ttl() -> Duration { Duration::days(1) } -pub(crate) const fn default_ttl_extension_policy() -> TtlExtensionPolicy { +pub const fn default_ttl_extension_policy() -> TtlExtensionPolicy { TtlExtensionPolicy::OnStateChanges } @@ -99,7 +99,7 @@ pub struct SessionMiddlewareBuilder { } impl SessionMiddlewareBuilder { - pub(crate) fn new(store: Store, configuration: Configuration) -> Self { + pub fn new(store: Store, configuration: Configuration) -> Self { Self { storage_backend: store, configuration, @@ -178,31 +178,31 @@ impl SessionMiddlewareBuilder { } #[derive(Clone)] -pub(crate) struct Configuration { - pub(crate) cookie: CookieConfiguration, - pub(crate) session: SessionConfiguration, - pub(crate) ttl_extension_policy: TtlExtensionPolicy, +pub struct Configuration { + pub cookie: CookieConfiguration, + pub session: SessionConfiguration, + pub ttl_extension_policy: TtlExtensionPolicy, } #[derive(Clone)] -pub(crate) struct SessionConfiguration { - pub(crate) state_ttl: Duration, +pub struct SessionConfiguration { + pub state_ttl: Duration, } #[derive(Clone)] -pub(crate) struct CookieConfiguration { - pub(crate) secure: bool, - pub(crate) http_only: bool, - pub(crate) name: String, - pub(crate) same_site: SameSite, - pub(crate) path: String, - pub(crate) domain: Option, - pub(crate) max_age: Option, - pub(crate) content_security: CookieContentSecurity, - pub(crate) key: Key, +pub struct CookieConfiguration { + pub secure: bool, + pub http_only: bool, + pub name: String, + pub same_site: SameSite, + pub path: String, + pub domain: Option, + pub max_age: Option, + pub content_security: CookieContentSecurity, + pub key: Key, } -pub(crate) fn default_configuration(key: Key) -> Configuration { +pub fn default_configuration(key: Key) -> Configuration { Configuration { cookie: CookieConfiguration { secure: true, diff --git a/lib/session/middleware.rs b/lib/session/middleware.rs index 4473700..0a18f6f 100644 --- a/lib/session/middleware.rs +++ b/lib/session/middleware.rs @@ -38,10 +38,7 @@ impl SessionMiddleware { SessionMiddlewareBuilder::new(store, config::default_configuration(key)) } - pub(crate) fn from_parts( - store: Store, - configuration: Configuration, - ) -> Self { + pub fn from_parts(store: Store, configuration: Configuration) -> Self { Self { storage_backend: Rc::new(store), configuration: Rc::new(configuration), diff --git a/lib/session/session.rs b/lib/session/session.rs index 0ec888f..f58067e 100644 --- a/lib/session/session.rs +++ b/lib/session/session.rs @@ -268,7 +268,7 @@ impl Session { } #[allow(clippy::needless_pass_by_ref_mut)] - pub(crate) fn set_session( + pub fn set_session( req: &mut ServiceRequest, data: impl IntoIterator, ) { @@ -278,7 +278,7 @@ impl Session { } #[allow(clippy::needless_pass_by_ref_mut)] - pub(crate) fn get_changes( + pub fn get_changes( res: &mut ServiceResponse, ) -> (SessionStatus, Map) { if let Some(s_impl) = res diff --git a/lib/session/storage/format.rs b/lib/session/storage/format.rs index b6e8df1..c7ba803 100644 --- a/lib/session/storage/format.rs +++ b/lib/session/storage/format.rs @@ -24,7 +24,7 @@ impl Serialize for StoredSessionStateRef<'_> { } } -pub(crate) fn serialize_session_state( +pub fn serialize_session_state( session_state: &SessionState, ) -> Result { let stored = StoredSessionStateRef { @@ -34,7 +34,7 @@ pub(crate) fn serialize_session_state( serde_json::to_string(&stored).map_err(anyhow::Error::new) } -pub(crate) fn deserialize_session_state( +pub fn deserialize_session_state( value: &str, ) -> Result { let value = serde_json::from_str::(value)?; diff --git a/lib/session/storage/interface.rs b/lib/session/storage/interface.rs index d3f6b6f..06ef618 100644 --- a/lib/session/storage/interface.rs +++ b/lib/session/storage/interface.rs @@ -6,7 +6,7 @@ use serde_json::{Map, Value}; use super::SessionKey; -pub(crate) type SessionState = Map; +pub type SessionState = Map; pub trait SessionStore { fn load( diff --git a/lib/socketio/actix.rs b/lib/socketio/actix.rs index 6fc5613..b865ee1 100644 --- a/lib/socketio/actix.rs +++ b/lib/socketio/actix.rs @@ -91,7 +91,11 @@ async fn engine_post( .ok_or_else(|| ErrorBadRequest("missing sid"))?; validate_sid(sid)?; let session = io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))?; - if session.post_active.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed).is_err() { + if session + .post_active + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_err() + { // Another POST is in progress — return error without destroying session return Err(ErrorBadRequest("concurrent polling request")); } @@ -147,7 +151,11 @@ async fn polling_get( let sid = sid.ok_or_else(|| ErrorBadRequest("missing sid"))?; validate_sid(sid)?; let session = io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))?; - if session.get_active.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed).is_err() { + if session + .get_active + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_err() + { // Another GET is in progress — return error without destroying session return Err(ErrorBadRequest("concurrent polling request")); } diff --git a/lib/socketio/engine_packet.rs b/lib/socketio/engine_packet.rs index 013d7d5..dc34327 100644 --- a/lib/socketio/engine_packet.rs +++ b/lib/socketio/engine_packet.rs @@ -6,7 +6,7 @@ use crate::error::{Result, SocketIoError}; const RECORD_SEPARATOR: char = '\x1e'; #[derive(Clone, Debug, PartialEq)] -pub(crate) enum EnginePacket { +pub enum EnginePacket { Open(Value), Close, Ping(Option), @@ -17,12 +17,12 @@ pub(crate) enum EnginePacket { } #[derive(Clone, Debug, PartialEq)] -pub(crate) enum SocketPayload { +pub enum SocketPayload { Text(String), Binary(Vec), } -pub(crate) fn encode_engine_payload( +pub fn encode_engine_payload( packets: &[EnginePacket], polling: bool, ) -> String { @@ -33,9 +33,7 @@ pub(crate) fn encode_engine_payload( .join(&RECORD_SEPARATOR.to_string()) } -pub(crate) fn decode_engine_payload( - payload: &str, -) -> Result> { +pub fn decode_engine_payload(payload: &str) -> Result> { payload .split(RECORD_SEPARATOR) .filter(|item| !item.is_empty()) @@ -43,10 +41,7 @@ pub(crate) fn decode_engine_payload( .collect() } -pub(crate) fn encode_engine_packet( - packet: &EnginePacket, - _polling: bool, -) -> String { +pub fn encode_engine_packet(packet: &EnginePacket, _polling: bool) -> String { match packet { EnginePacket::Open(data) => format!("0{data}"), EnginePacket::Close => "1".to_owned(), @@ -65,7 +60,7 @@ pub(crate) fn encode_engine_packet( } } -pub(crate) fn decode_engine_text_packet(input: &str) -> Result { +pub fn decode_engine_text_packet(input: &str) -> Result { if let Some(encoded) = input.strip_prefix('b') { return Ok(EnginePacket::Message(SocketPayload::Binary( STANDARD.decode(encoded).map_err(|_| { diff --git a/lib/socketio/packet.rs b/lib/socketio/packet.rs index bf58dc8..754b28b 100644 --- a/lib/socketio/packet.rs +++ b/lib/socketio/packet.rs @@ -211,7 +211,7 @@ impl Packet { }) } - pub(crate) fn into_event_payload( + pub fn into_event_payload( self, ack: Option, ) -> Result { diff --git a/lib/socketio/server.rs b/lib/socketio/server.rs index 6d3b913..8f28b9a 100644 --- a/lib/socketio/server.rs +++ b/lib/socketio/server.rs @@ -22,13 +22,13 @@ use crate::{ socket::{AckSender, DisconnectReason, Socket}, }; -pub(crate) type BoxFuture = Pin + Send>>; -pub(crate) type ConnectHandler = Arc BoxFuture + Send + Sync>; -pub(crate) type DisconnectHandler = +pub type BoxFuture = Pin + Send>>; +pub type ConnectHandler = Arc BoxFuture + Send + Sync>; +pub type DisconnectHandler = Arc BoxFuture + Send + Sync>; -pub(crate) type EventHandler = +pub type EventHandler = Arc BoxFuture + Send + Sync>; -pub(crate) type Middleware = Arc< +pub type Middleware = Arc< dyn Fn( Socket, Option, @@ -39,33 +39,33 @@ pub(crate) type Middleware = Arc< #[derive(Clone)] pub struct SocketIo { - pub(crate) inner: Arc, + pub inner: Arc, } #[derive(Clone)] pub struct Namespace { - pub(crate) io: SocketIo, - pub(crate) name: String, + pub io: SocketIo, + pub name: String, } pub struct SocketIoBuilder { - pub(crate) config: SocketIoConfig, - pub(crate) adapter: Arc, + pub config: SocketIoConfig, + pub adapter: Arc, } -pub(crate) struct Inner { - pub(crate) config: SocketIoConfig, - pub(crate) sessions: RwLock>>, - pub(crate) namespaces: RwLock>>, - pub(crate) adapter: Arc, - pub(crate) next_ack_id: AtomicU64, +pub struct Inner { + pub config: SocketIoConfig, + pub sessions: RwLock>>, + pub namespaces: RwLock>>, + pub adapter: Arc, + pub next_ack_id: AtomicU64, } -pub(crate) struct NamespaceState { - pub(crate) connect_handler: RwLock>, - pub(crate) disconnect_handler: RwLock>, - pub(crate) event_handlers: RwLock>, - pub(crate) middleware: RwLock>, +pub struct NamespaceState { + pub connect_handler: RwLock>, + pub disconnect_handler: RwLock>, + pub event_handlers: RwLock>, + pub middleware: RwLock>, } impl Default for SocketIo { @@ -140,11 +140,11 @@ impl SocketIo { .await } - pub(crate) async fn session(&self, sid: &str) -> Option> { + pub async fn session(&self, sid: &str) -> Option> { self.inner.sessions.read().await.get(sid).cloned() } - pub(crate) async fn insert_session(&self, session: Arc) { + pub async fn insert_session(&self, session: Arc) { self.inner .sessions .write() @@ -152,7 +152,7 @@ impl SocketIo { .insert(session.engine_id.clone(), session); } - pub(crate) async fn remove_session( + pub async fn remove_session( &self, session: &Arc, reason: DisconnectReason, @@ -172,7 +172,7 @@ impl SocketIo { } } - pub(crate) async fn handle_socket_payload( + pub async fn handle_socket_payload( &self, session: Arc, payload: SocketPayload, @@ -291,7 +291,7 @@ impl SocketIo { Ok(()) } - pub(crate) async fn disconnect_socket( + pub async fn disconnect_socket( &self, namespace: &str, session: &Arc, @@ -388,7 +388,7 @@ impl SocketIo { Ok(()) } - pub(crate) async fn join( + pub async fn join( &self, namespace: &str, engine_id: &str, @@ -409,7 +409,7 @@ impl SocketIo { .await } - pub(crate) async fn leave( + pub async fn leave( &self, namespace: &str, engine_id: &str, @@ -430,7 +430,7 @@ impl SocketIo { .await } - pub(crate) async fn emit_to_sid( + pub async fn emit_to_sid( &self, namespace: &str, engine_id: &str, @@ -445,7 +445,7 @@ impl SocketIo { .await } - pub(crate) async fn emit_binary_to_sid( + pub async fn emit_binary_to_sid( &self, namespace: &str, engine_id: &str, @@ -460,7 +460,7 @@ impl SocketIo { .await } - pub(crate) async fn emit_to_sid_with_ack( + pub async fn emit_to_sid_with_ack( &self, namespace: &str, engine_id: &str, @@ -499,7 +499,7 @@ impl SocketIo { } } - pub(crate) async fn broadcast_with_opts( + pub async fn broadcast_with_opts( &self, mut opts: BroadcastOptions, event: &str, @@ -542,7 +542,7 @@ impl SocketIo { } } - pub(crate) async fn deliver_remote_packet( + pub async fn deliver_remote_packet( &self, opts: BroadcastOptions, packet: Packet, @@ -580,7 +580,7 @@ impl SocketIo { Ok(()) } - pub(crate) async fn ensure_namespace( + pub async fn ensure_namespace( &self, namespace: &str, ) -> Arc { diff --git a/lib/socketio/session.rs b/lib/socketio/session.rs index 707f2ba..7fa816e 100644 --- a/lib/socketio/session.rs +++ b/lib/socketio/session.rs @@ -16,41 +16,40 @@ use crate::{ }; #[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) enum Transport { +pub enum Transport { Polling, WebSocket, } #[derive(Debug)] -pub(crate) struct SocketState { - pub(crate) sid: String, - pub(crate) rooms: HashSet, - pub(crate) auth: Option, +pub struct SocketState { + pub sid: String, + pub rooms: HashSet, + pub auth: Option, } -pub(crate) struct PendingBinary { - pub(crate) packet: Packet, +pub struct PendingBinary { + pub packet: Packet, } -pub(crate) struct Session { - pub(crate) engine_id: String, - pub(crate) user: StdMutex>, - pub(crate) transport: Mutex, - pub(crate) namespaces: Mutex>, - pub(crate) pending_binary: Mutex>, - pub(crate) ack_waiters: Mutex, - pub(crate) last_pong: Mutex, +pub struct Session { + pub engine_id: String, + pub user: StdMutex>, + pub transport: Mutex, + pub namespaces: Mutex>, + pub pending_binary: Mutex>, + pub ack_waiters: Mutex, + pub last_pong: Mutex, queue: Mutex>, - pub(crate) get_active: Arc, - pub(crate) post_active: Arc, - pub(crate) notify: Notify, + pub get_active: Arc, + pub post_active: Arc, + pub notify: Notify, } -pub(crate) type AckWaiters = - HashMap<(String, u64), oneshot::Sender>>; +pub type AckWaiters = HashMap<(String, u64), oneshot::Sender>>; impl Session { - pub(crate) fn new(user: Option) -> Arc { + pub fn new(user: Option) -> Arc { Arc::new(Self { engine_id: Uuid::new_v4().to_string(), user: StdMutex::new(user), @@ -66,12 +65,12 @@ impl Session { }) } - pub(crate) async fn enqueue(&self, packet: EnginePacket) { + pub async fn enqueue(&self, packet: EnginePacket) { self.queue.lock().await.push_back(packet); self.notify.notify_waiters(); } - pub(crate) async fn enqueue_socket_packet(&self, packet: Packet) { + pub async fn enqueue_socket_packet(&self, packet: Packet) { self.enqueue(EnginePacket::Message(SocketPayload::Text( packet.encode(), ))) @@ -84,7 +83,7 @@ impl Session { } } - pub(crate) async fn drain(&self) -> Vec { + pub async fn drain(&self) -> Vec { self.queue.lock().await.drain(..).collect() } } diff --git a/lib/socketio/socket.rs b/lib/socketio/socket.rs index 5e57be1..c37aac7 100644 --- a/lib/socketio/socket.rs +++ b/lib/socketio/socket.rs @@ -11,10 +11,10 @@ use crate::{ #[derive(Clone)] pub struct Socket { - pub(crate) io: SocketIo, - pub(crate) session: Arc, - pub(crate) namespace: String, - pub(crate) sid: String, + pub io: SocketIo, + pub session: Arc, + pub namespace: String, + pub sid: String, } #[derive(Clone, Debug, Eq, PartialEq)] @@ -42,11 +42,7 @@ impl fmt::Debug for AckSender { } impl AckSender { - pub(crate) fn new( - session: Arc, - namespace: String, - id: u64, - ) -> Self { + pub fn new(session: Arc, namespace: String, id: u64) -> Self { Self { session, namespace, @@ -76,11 +72,16 @@ impl Socket { } pub fn session_user(&self) -> Option { - self.session.user.lock().unwrap_or_else(|e| e.into_inner()).clone() + self.session + .user + .lock() + .unwrap_or_else(|e| e.into_inner()) + .clone() } pub fn set_user(&self, user: uuid::Uuid) { - *self.session.user.lock().unwrap_or_else(|e| e.into_inner()) = Some(user); + *self.session.user.lock().unwrap_or_else(|e| e.into_inner()) = + Some(user); } pub async fn rooms(&self) -> HashSet { diff --git a/lib/storage/lib.rs b/lib/storage/lib.rs index dd89ef6..7dd630f 100644 --- a/lib/storage/lib.rs +++ b/lib/storage/lib.rs @@ -176,7 +176,7 @@ impl ObjectStorage for AppStorage { } } -pub(crate) async fn collect_byte_stream( +pub async fn collect_byte_stream( body: ByteStream, ) -> Result, ByteStreamError> { body.collect().await.map(|data| data.to_vec()) diff --git a/openapi.json b/openapi.json index c787958..90327a2 100644 --- a/openapi.json +++ b/openapi.json @@ -11721,119 +11721,6 @@ } } }, - "/api/v1/ws/rooms/{room_id}/ai": { - "get": { - "tags": [ - "channel" - ], - "operationId": "channel_ai_list", - "parameters": [ - { - "name": "room_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "format": "uuid" - } - } - ], - "responses": { - "200": { - "description": "AI agents in room" - } - } - }, - "post": { - "tags": [ - "channel" - ], - "operationId": "channel_ai_add", - "parameters": [ - { - "name": "room_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "format": "uuid" - } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/AiAddRequest" - } - } - }, - "required": true - }, - "responses": { - "201": { - "description": "AI agent added to room" - } - } - } - }, - "/api/v1/ws/rooms/{room_id}/ai/stop": { - "post": { - "tags": [ - "channel" - ], - "operationId": "channel_ai_stop", - "parameters": [ - { - "name": "room_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "format": "uuid" - } - } - ], - "responses": { - "204": { - "description": "AI agent stopped" - } - } - } - }, - "/api/v1/ws/rooms/{room_id}/ai/{agent_session_id}": { - "delete": { - "tags": [ - "channel" - ], - "operationId": "channel_ai_remove", - "parameters": [ - { - "name": "room_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "format": "uuid" - } - }, - { - "name": "agent_session_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "format": "uuid" - } - } - ], - "responses": { - "200": { - "description": "AI agent removed from room" - } - } - } - }, "/api/v1/ws/rooms/{room_id}/dnd": { "patch": { "tags": [ @@ -12113,6 +12000,18 @@ "minimum": 0 } }, + { + "name": "thread", + "in": "query", + "required": false, + "schema": { + "type": [ + "string", + "null" + ], + "format": "uuid" + } + }, { "name": "room_id", "in": "path", @@ -13269,18 +13168,6 @@ } } }, - "AiAddRequest": { - "type": "object", - "required": [ - "agent_session" - ], - "properties": { - "agent_session": { - "type": "string", - "format": "uuid" - } - } - }, "AiDiscussionResponse": { "type": "object", "required": [ @@ -16903,6 +16790,12 @@ "public" ], "properties": { + "ai_enabled": { + "type": [ + "boolean", + "null" + ] + }, "category": { "type": [ "string", @@ -16925,6 +16818,12 @@ "RoomUpdateRequest": { "type": "object", "properties": { + "ai_enabled": { + "type": [ + "boolean", + "null" + ] + }, "category": { "type": [ "string", diff --git a/src/App.tsx b/src/App.tsx index 47dc276..2a31406 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -3,6 +3,7 @@ import { useEffect } from "react"; import { Navigate, createBrowserRouter, useParams } from "react-router"; import { RouterProvider } from "react-router/dom"; import { getSavedThemeId, getThemeById, applyTheme, defaultThemeId } from "@/lib/theme"; +import RootLayout from "@/app/root-layout"; import AuthLayout from "@/page/auth/layout"; import LoginPage from "@/page/auth/login"; @@ -60,6 +61,9 @@ function App() { }, []); const router = createBrowserRouter([ + { + element: , + children: [ { path: "/", element: , @@ -244,7 +248,9 @@ function App() { element: , }, ], - }, + }, + ], + }, ]); return ( diff --git a/src/app/root-layout.tsx b/src/app/root-layout.tsx new file mode 100644 index 0000000..1e98744 --- /dev/null +++ b/src/app/root-layout.tsx @@ -0,0 +1,10 @@ +import { Outlet } from "react-router"; +import { SettingsModalProvider } from "@/components/settings/SettingsModalContext"; + +export default function RootLayout() { + return ( + + + + ); +} diff --git a/src/client/endpoints.ts b/src/client/endpoints.ts index 9000b40..2224503 100644 --- a/src/client/endpoints.ts +++ b/src/client/endpoints.ts @@ -24,7 +24,6 @@ import type { AgentRunRequest, AgentRunResponse, AgentSessionResponse, - AiAddRequest, AiDiscussionResponse, AiLikeResponse, AiListDiscussionsParams, @@ -34,7 +33,6 @@ import type { AiModelResponse, AiModelVersionResponse, AiProviderResponse, - AppNotificationItem, ApproveWorkspaceJoinApply, AssignIssueUser, AssignPrUser, @@ -654,14 +652,6 @@ const userConfig = ( ); } -const userListNotifications = ( - options?: AxiosRequestConfig - ): Promise> => { - return axiosInstance.get( - `/api/v1/user/notifications`,options - ); - } - const userUpdateAccessibility = ( updateUserAccessibilityConfig: UpdateUserAccessibilityConfig, options?: AxiosRequestConfig ): Promise> => { @@ -2681,42 +2671,6 @@ const channelRoomUpdate = ( ); } -const channelAiList = ( - roomId: string, options?: AxiosRequestConfig - ): Promise> => { - return axiosInstance.get( - `/api/v1/ws/rooms/${roomId}/ai`,options - ); - } - -const channelAiAdd = ( - roomId: string, - aiAddRequest: AiAddRequest, options?: AxiosRequestConfig - ): Promise> => { - return axiosInstance.post( - `/api/v1/ws/rooms/${roomId}/ai`, - aiAddRequest,options - ); - } - -const channelAiStop = ( - roomId: string, options?: AxiosRequestConfig - ): Promise> => { - return axiosInstance.post( - `/api/v1/ws/rooms/${roomId}/ai/stop`, - undefined,options - ); - } - -const channelAiRemove = ( - roomId: string, - agentSessionId: string, options?: AxiosRequestConfig - ): Promise> => { - return axiosInstance.delete( - `/api/v1/ws/rooms/${roomId}/ai/${agentSessionId}`,options - ); - } - const channelDndUpdate = ( roomId: string, dndRequest: DndRequest, options?: AxiosRequestConfig @@ -3014,7 +2968,7 @@ const channelCategoryCreate = ( ); } -return {agentListAllConversations,agentGetConversation,agentDeleteConversation,agentUpdateConversation,agentListMessages,agentSendMessage,agentStreamAgent,agentListSessions,agentCreateSession,agentGetSession,agentDeleteSession,agentUpdateSession,agentListConversations,agentCreateConversation,aiListModels,aiGetModel,aiGetCard,aiListDiscussions,aiListLikes,aiListTags,aiListVersions,aiListProviders,aiGetProvider,authStatus2fa,authDisable2fa,authRegenerateBackupCodes,authEnable2fa,authVerify2fa,authCaptcha,authGetEmail,authEmailChangeRequest,authEmailVerify,authLogin,authLogout,authMe,authRsa,authRegister,authResetPasswordRequest,authResetPasswordVerify,search,userListAccessTokens,userCreateAccessToken,userUpdateAccessToken,userRevokeAccessToken,userUploadAvatar,userListNotifications,userConfig,userUpdateAccessibility,userUpdateAppearance,userUpdateNotification,userUpdatePrivacy,userUpdateProfile,userContributionHeatmap,userInvalidateChpcCache,userListSshKeys,userAddSshKey,userUpdateSshKey,userRevokeSshKey,usersUserAvatar,usersBlockedList,usersBlockUser,usersUserChpc,usersFollowUser,usersFollowers,usersFollowing,usersUserPublic,usersRelationStatus,usersRelationCounts,usersUserSummary,usersUnblockUser,usersUnfollowUser,workspaceCreateWorkspace,workspaceMyJoinApplies,workspaceMyWorkspaces,workspaceGetWorkspace,workspaceUpdateWorkspace,workspaceGetAvatar,workspaceUploadAvatar,workspaceListGroups,workspaceCreateGroup,workspaceUpdateGroup,workspaceDeleteGroup,workspaceListGroupMembers,workspaceAddGroupMember,workspaceRemoveGroupMember,issuesListIssues,issuesCreateIssue,issuesGetIssue,issuesUpdateIssue,issuesDeleteIssue,issuesAssignUser,issuesUnassignUser,issuesCloseIssue,issuesListComments,issuesCreateComment,issuesUpdateComment,issuesDeleteComment,issuesAddCommentReaction,issuesRemoveCommentReaction,issuesListEvents,issuesAddIssueLabel,issuesRemoveIssueLabel,issuesSetIssueMilestone,issuesClearIssueMilestone,issuesBindPullRequest,issuesUnbindPullRequest,issuesAddReaction,issuesRemoveReaction,issuesReopenIssue,issuesBindRepo,issuesUnbindRepo,workspaceJoinStrategy,workspaceUpdateJoinStrategy,workspaceListJoinApplies,workspaceApproveJoin,workspaceApplyJoin,workspaceCancelJoin,issuesListLabels,issuesCreateLabel,issuesUpdateLabel,issuesDeleteLabel,workspaceListMembers,workspaceAddMember,workspaceUpdateMember,workspaceRemoveMember,issuesListMilestones,issuesCreateMilestone,issuesUpdateMilestone,issuesDeleteMilestone,gitListRepos,gitCreateRepo,gitCloneRepo,gitGetRepo,gitUpdateRepo,gitDeleteRepo,gitArchiveRepo,gitCombinedStatus,gitListStatuses,gitCompare,gitGetContents,gitUpdateContents,gitCreateContents,gitDeleteContents,gitListForks,gitCreateFork,gitArchive,gitBlameFile,gitBlobUpload,gitBlobInfo,gitListBranches,gitForkBranch,gitBranchInfo,gitDeleteBranch,gitRenameBranch,gitAheadBehind,gitBranchUpstream,gitListCommits,gitCherryPick,gitCommitHistory,gitCommitWalk,gitCommitInfo,gitTreeEntryByPathFromCommit,gitListContributors,gitDiff,gitDiffBranches,gitGetLanguages,gitGetReadme,gitListRefs,gitStarStatus,gitStarRepo,gitUnstarRepo,gitListTags,gitInitTag,gitTagInfo,gitDeleteTag,gitUpdateTag,gitTreeEntries,gitTreeEntryByPath,gitWatchStatus,gitWatchRepo,gitUnwatchRepo,gitListProtects,gitCreateProtect,gitUpdateProtect,gitDeleteProtect,pullRequestListPrs,pullRequestCreatePr,pullRequestGetPr,pullRequestDeletePr,pullRequestUpdatePr,pullRequestAssignUser,pullRequestUnassignUser,pullRequestListComments,pullRequestCreateComment,pullRequestUpdateComment,pullRequestDeleteComment,pullRequestAddCommentReaction,pullRequestRemoveCommentReaction,pullRequestAddLabel,pullRequestRemoveLabel,pullRequestMergeAnalysis,pullRequestMergePr,pullRequestAddReaction,pullRequestRemoveReaction,pullRequestCreateReviewComment,pullRequestListReviews,pullRequestCreateReview,pullRequestDismissReview,pullRequestUpdateBranch,gitListReleases,gitCreateRelease,gitGetReleaseByTag,gitDeleteReleaseByTag,gitGetRelease,gitDeleteRelease,gitUpdateRelease,gitCreateStatus,gitGetTopics,gitUpdateTopics,gitTransferRepo,gitListWebhooks,gitCreateWebhook,gitUpdateWebhook,gitDeleteWebhook,gitListDeliveries,channelCategoryDelete,channelCategoryUpdate,channelCsrfToken,channelCustomStatusUpdate,channelInviteCreate,channelInviteAccept,channelInviteRevoke,channelRevokeMessage,channelUpdateMessage,channelNotificationMarkAllRead,channelNotificationArchive,channelNotificationMarkRead,channelPing,channelPresenceUpdate,channelRoomCreate,channelRoomGet,channelRoomDelete,channelRoomUpdate,channelAiList,channelAiAdd,channelAiStop,channelAiRemove,channelDndUpdate,channelDraftSave,channelDraftClear,channelAccessGrant,channelAccessRevoke,channelListMessages,channelCreateMessage,channelMessagesAround,channelMissedMessages,channelPinAdd,channelPinRemove,channelReactionAdd,channelReactionRemove,channelReadReceipt,channelScreenShare,channelSubscribe,channelUnsubscribe,channelThreadCreate,channelTyping,channelVoiceDeaf,channelVoiceJoin,channelVoiceLeave,channelVoiceMute,channelSearch,channelThreadArchive,channelThreadResolve,channelGenerateToken,channelUserSummary,channelBanCreate,channelBanRemove,channelCategoryCreate}}; +return {agentListAllConversations,agentGetConversation,agentDeleteConversation,agentUpdateConversation,agentListMessages,agentSendMessage,agentStreamAgent,agentListSessions,agentCreateSession,agentGetSession,agentDeleteSession,agentUpdateSession,agentListConversations,agentCreateConversation,aiListModels,aiGetModel,aiGetCard,aiListDiscussions,aiListLikes,aiListTags,aiListVersions,aiListProviders,aiGetProvider,authStatus2fa,authDisable2fa,authRegenerateBackupCodes,authEnable2fa,authVerify2fa,authCaptcha,authGetEmail,authEmailChangeRequest,authEmailVerify,authLogin,authLogout,authMe,authRsa,authRegister,authResetPasswordRequest,authResetPasswordVerify,search,userListAccessTokens,userCreateAccessToken,userUpdateAccessToken,userRevokeAccessToken,userUploadAvatar,userConfig,userUpdateAccessibility,userUpdateAppearance,userUpdateNotification,userUpdatePrivacy,userUpdateProfile,userContributionHeatmap,userInvalidateChpcCache,userListSshKeys,userAddSshKey,userUpdateSshKey,userRevokeSshKey,usersUserAvatar,usersBlockedList,usersBlockUser,usersUserChpc,usersFollowUser,usersFollowers,usersFollowing,usersUserPublic,usersRelationStatus,usersRelationCounts,usersUserSummary,usersUnblockUser,usersUnfollowUser,workspaceCreateWorkspace,workspaceMyJoinApplies,workspaceMyWorkspaces,workspaceGetWorkspace,workspaceUpdateWorkspace,workspaceGetAvatar,workspaceUploadAvatar,workspaceListGroups,workspaceCreateGroup,workspaceUpdateGroup,workspaceDeleteGroup,workspaceListGroupMembers,workspaceAddGroupMember,workspaceRemoveGroupMember,issuesListIssues,issuesCreateIssue,issuesGetIssue,issuesUpdateIssue,issuesDeleteIssue,issuesAssignUser,issuesUnassignUser,issuesCloseIssue,issuesListComments,issuesCreateComment,issuesUpdateComment,issuesDeleteComment,issuesAddCommentReaction,issuesRemoveCommentReaction,issuesListEvents,issuesAddIssueLabel,issuesRemoveIssueLabel,issuesSetIssueMilestone,issuesClearIssueMilestone,issuesBindPullRequest,issuesUnbindPullRequest,issuesAddReaction,issuesRemoveReaction,issuesReopenIssue,issuesBindRepo,issuesUnbindRepo,workspaceJoinStrategy,workspaceUpdateJoinStrategy,workspaceListJoinApplies,workspaceApproveJoin,workspaceApplyJoin,workspaceCancelJoin,issuesListLabels,issuesCreateLabel,issuesUpdateLabel,issuesDeleteLabel,workspaceListMembers,workspaceAddMember,workspaceUpdateMember,workspaceRemoveMember,issuesListMilestones,issuesCreateMilestone,issuesUpdateMilestone,issuesDeleteMilestone,gitListRepos,gitCreateRepo,gitCloneRepo,gitGetRepo,gitUpdateRepo,gitDeleteRepo,gitArchiveRepo,gitCombinedStatus,gitListStatuses,gitCompare,gitGetContents,gitUpdateContents,gitCreateContents,gitDeleteContents,gitListForks,gitCreateFork,gitArchive,gitBlameFile,gitBlobUpload,gitBlobInfo,gitListBranches,gitForkBranch,gitBranchInfo,gitDeleteBranch,gitRenameBranch,gitAheadBehind,gitBranchUpstream,gitListCommits,gitCherryPick,gitCommitHistory,gitCommitWalk,gitCommitInfo,gitTreeEntryByPathFromCommit,gitListContributors,gitDiff,gitDiffBranches,gitGetLanguages,gitGetReadme,gitListRefs,gitStarStatus,gitStarRepo,gitUnstarRepo,gitListTags,gitInitTag,gitTagInfo,gitDeleteTag,gitUpdateTag,gitTreeEntries,gitTreeEntryByPath,gitWatchStatus,gitWatchRepo,gitUnwatchRepo,gitListProtects,gitCreateProtect,gitUpdateProtect,gitDeleteProtect,pullRequestListPrs,pullRequestCreatePr,pullRequestGetPr,pullRequestDeletePr,pullRequestUpdatePr,pullRequestAssignUser,pullRequestUnassignUser,pullRequestListComments,pullRequestCreateComment,pullRequestUpdateComment,pullRequestDeleteComment,pullRequestAddCommentReaction,pullRequestRemoveCommentReaction,pullRequestAddLabel,pullRequestRemoveLabel,pullRequestMergeAnalysis,pullRequestMergePr,pullRequestAddReaction,pullRequestRemoveReaction,pullRequestCreateReviewComment,pullRequestListReviews,pullRequestCreateReview,pullRequestDismissReview,pullRequestUpdateBranch,gitListReleases,gitCreateRelease,gitGetReleaseByTag,gitDeleteReleaseByTag,gitGetRelease,gitDeleteRelease,gitUpdateRelease,gitCreateStatus,gitGetTopics,gitUpdateTopics,gitTransferRepo,gitListWebhooks,gitCreateWebhook,gitUpdateWebhook,gitDeleteWebhook,gitListDeliveries,channelCategoryDelete,channelCategoryUpdate,channelCsrfToken,channelCustomStatusUpdate,channelInviteCreate,channelInviteAccept,channelInviteRevoke,channelRevokeMessage,channelUpdateMessage,channelNotificationMarkAllRead,channelNotificationArchive,channelNotificationMarkRead,channelPing,channelPresenceUpdate,channelRoomCreate,channelRoomGet,channelRoomDelete,channelRoomUpdate,channelDndUpdate,channelDraftSave,channelDraftClear,channelAccessGrant,channelAccessRevoke,channelListMessages,channelCreateMessage,channelMessagesAround,channelMissedMessages,channelPinAdd,channelPinRemove,channelReactionAdd,channelReactionRemove,channelReadReceipt,channelScreenShare,channelSubscribe,channelUnsubscribe,channelThreadCreate,channelTyping,channelVoiceDeaf,channelVoiceJoin,channelVoiceLeave,channelVoiceMute,channelSearch,channelThreadArchive,channelThreadResolve,channelGenerateToken,channelUserSummary,channelBanCreate,channelBanRemove,channelCategoryCreate}}; export type AgentListAllConversationsResult = AxiosResponse export type AgentGetConversationResult = AxiosResponse export type AgentDeleteConversationResult = AxiosResponse @@ -3257,10 +3211,6 @@ export type ChannelRoomCreateResult = AxiosResponse export type ChannelRoomGetResult = AxiosResponse export type ChannelRoomDeleteResult = AxiosResponse export type ChannelRoomUpdateResult = AxiosResponse -export type ChannelAiListResult = AxiosResponse -export type ChannelAiAddResult = AxiosResponse -export type ChannelAiStopResult = AxiosResponse -export type ChannelAiRemoveResult = AxiosResponse export type ChannelDndUpdateResult = AxiosResponse export type ChannelDraftSaveResult = AxiosResponse export type ChannelDraftClearResult = AxiosResponse diff --git a/src/client/models/aiAddRequest.ts b/src/client/models/aiAddRequest.ts deleted file mode 100644 index 6f5b9e7..0000000 --- a/src/client/models/aiAddRequest.ts +++ /dev/null @@ -1,11 +0,0 @@ -/** - * Generated by orval v8.12.3 🍺 - * Do not edit manually. - * GitDataAI API - * GitDataAI platform REST API - * OpenAPI spec version: 1.0.0 - */ - -export interface AiAddRequest { - agent_session: string; -} diff --git a/src/client/models/appNotificationItem.ts b/src/client/models/appNotificationItem.ts deleted file mode 100644 index 4624cad..0000000 --- a/src/client/models/appNotificationItem.ts +++ /dev/null @@ -1,9 +0,0 @@ -export interface AppNotificationItem { - body: string; - created_at: string; - id: string; - notify_type: string; - /** @nullable */ - read_at?: string | null; - title: string; -} diff --git a/src/client/models/channelMessagesAroundParams.ts b/src/client/models/channelMessagesAroundParams.ts index 819cf3d..32230b0 100644 --- a/src/client/models/channelMessagesAroundParams.ts +++ b/src/client/models/channelMessagesAroundParams.ts @@ -13,4 +13,8 @@ seq: number; * @nullable */ limit?: number | null; +/** + * @nullable + */ +thread?: string | null; }; diff --git a/src/client/models/index.ts b/src/client/models/index.ts index 54d6f0c..8fbabdd 100644 --- a/src/client/models/index.ts +++ b/src/client/models/index.ts @@ -21,7 +21,6 @@ export * from './agentSessionResponse'; export * from './agentStepInfo'; export * from './agentToolCallInfo'; export * from './agentUsageInfo'; -export * from './aiAddRequest'; export * from './aiDiscussionResponse'; export * from './aiLikeResponse'; export * from './aiListDiscussionsParams'; @@ -31,7 +30,6 @@ export * from './aiModelListItem'; export * from './aiModelResponse'; export * from './aiModelVersionResponse'; export * from './aiProviderResponse'; -export * from './appNotificationItem'; export * from './approveWorkspaceJoinApply'; export * from './assignIssueUser'; export * from './assignPrUser'; diff --git a/src/client/models/roomCreateRequest.ts b/src/client/models/roomCreateRequest.ts index 7e5bc7a..916f84a 100644 --- a/src/client/models/roomCreateRequest.ts +++ b/src/client/models/roomCreateRequest.ts @@ -7,6 +7,8 @@ */ export interface RoomCreateRequest { + /** @nullable */ + ai_enabled?: boolean | null; /** @nullable */ category?: string | null; public: boolean; diff --git a/src/client/models/roomUpdateRequest.ts b/src/client/models/roomUpdateRequest.ts index ca3ff9a..abf50ce 100644 --- a/src/client/models/roomUpdateRequest.ts +++ b/src/client/models/roomUpdateRequest.ts @@ -7,6 +7,8 @@ */ export interface RoomUpdateRequest { + /** @nullable */ + ai_enabled?: boolean | null; /** @nullable */ category?: string | null; /** @nullable */ diff --git a/src/components/settings/SettingsModal.tsx b/src/components/settings/SettingsModal.tsx new file mode 100644 index 0000000..1ecf32f --- /dev/null +++ b/src/components/settings/SettingsModal.tsx @@ -0,0 +1,248 @@ +import { useState, useCallback, useEffect } from "react"; +import { useNavigate, useLocation } from "react-router"; +import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog"; +import { + PersonStanding, + Lock, + Paintbrush, + Bell, + Shield, + Eye, + Key, + Terminal, + ExternalLink, + XIcon, +} from "lucide-react"; +import { cn } from "@/lib/utils"; +import SettingsProfilePage from "@/page/settings/profile"; +import SettingsSecurityPage from "@/page/settings/security"; +import SettingsAppearancePage from "@/page/settings/appearance"; +import SettingsNotificationsPage from "@/page/settings/notifications"; +import SettingsPrivacyPage from "@/page/settings/privacy"; +import SettingsAccessibilityPage from "@/page/settings/accessibility"; +import SettingsTokensPage from "@/page/settings/tokens"; +import SettingsSshKeysPage from "@/page/settings/ssh-keys"; + +type SectionKey = + | "profile" + | "security" + | "appearance" + | "notifications" + | "privacy" + | "accessibility" + | "tokens" + | "ssh-keys"; + +const NAV_SECTIONS = [ + { + label: "Account", + items: [ + { key: "profile" as SectionKey, icon: PersonStanding, label: "Profile" }, + { key: "security" as SectionKey, icon: Lock, label: "Security" }, + ], + }, + { + label: "Preferences", + items: [ + { key: "appearance" as SectionKey, icon: Paintbrush, label: "Appearance" }, + { key: "notifications" as SectionKey, icon: Bell, label: "Notifications" }, + { key: "privacy" as SectionKey, icon: Shield, label: "Privacy" }, + { key: "accessibility" as SectionKey, icon: Eye, label: "Accessibility" }, + ], + }, + { + label: "Developer", + items: [ + { key: "tokens" as SectionKey, icon: Key, label: "Access Tokens" }, + { key: "ssh-keys" as SectionKey, icon: Terminal, label: "SSH Keys" }, + ], + }, +]; + +const SECTIONS: Record = { + profile: SettingsProfilePage, + security: SettingsSecurityPage, + appearance: SettingsAppearancePage, + notifications: SettingsNotificationsPage, + privacy: SettingsPrivacyPage, + accessibility: SettingsAccessibilityPage, + tokens: SettingsTokensPage, + "ssh-keys": SettingsSshKeysPage, +}; + +const SETTINGS_RETURN_PATH_KEY = "settings_return_path"; + +const sectionPath: Record = { + profile: "/settings/profile", + security: "/settings/security", + appearance: "/settings/appearance", + notifications: "/settings/notifications", + privacy: "/settings/privacy", + accessibility: "/settings/accessibility", + tokens: "/settings/tokens", + "ssh-keys": "/settings/ssh-keys", +}; + +interface SettingsModalProps { + open: boolean; + onClose: () => void; + initialSection?: string; +} + +export function SettingsModal({ open, onClose, initialSection }: SettingsModalProps) { + const [activeSection, setActiveSectionState] = useState("profile"); + const [isExpanding, setIsExpanding] = useState(false); + const navigate = useNavigate(); + const location = useLocation(); + + // Sync initial section when modal opens + useEffect(() => { + if (open && initialSection && initialSection in SECTIONS) { + setActiveSectionState(initialSection as SectionKey); + } + }, [open, initialSection]); + + const setActiveSection = (section: SectionKey) => { + setActiveSectionState(section); + }; + + const handleOpenAsPage = useCallback(() => { + localStorage.setItem(SETTINGS_RETURN_PATH_KEY, location.pathname); + + const el = document.querySelector('[data-slot="dialog-content"]') as HTMLDivElement | null; + const overlayEl = document.querySelector('[data-slot="dialog-overlay"]') as HTMLDivElement | null; + if (!el) return; + + setIsExpanding(true); + + const duration = 350; + const start = performance.now(); + const startW = 80; + const endW = 100; + const startH = 85; + const endH = 100; + const startR = 12; + const endR = 0; + + function easeOutCubic(t: number): number { + return 1 - Math.pow(1 - t, 3); + } + + if (overlayEl) { + overlayEl.animate( + [{ opacity: 1 }, { opacity: 0 }], + { duration: 300, easing: "ease", fill: "forwards" }, + ); + } + + function frame(now: number) { + const elapsed = now - start; + const progress = Math.min(elapsed / duration, 1); + const eased = easeOutCubic(progress); + + const w = startW + (endW - startW) * eased; + const h = startH + (endH - startH) * eased; + const r = startR + (endR - startR) * eased; + + el.style.setProperty("width", `${w}vw`, "important"); + el.style.setProperty("height", `${h}vh`, "important"); + el.style.setProperty("max-width", `${w}vw`, "important"); + el.style.setProperty("max-height", `${h}vh`, "important"); + el.style.setProperty("border-radius", `${r}px`, "important"); + + if (progress < 1) { + requestAnimationFrame(frame); + } else { + setIsExpanding(false); + setTimeout(() => { + onClose(); + navigate(sectionPath[activeSection]); + }, 0); + } + } + + requestAnimationFrame(frame); + }, [activeSection, onClose, navigate, location.pathname]); + + const ActiveComponent = SECTIONS[activeSection]; + + return ( + { if (!isOpen && !isExpanding) onClose(); }}> + + Settings + +
+ {/* Sidebar */} + + + {/* Content */} +
+ {/* Top bar */} +
+ + +
+ + {/* Scrollable content */} +
+
+ +
+
+
+
+
+
+ ); +} diff --git a/src/components/settings/SettingsModalContext.tsx b/src/components/settings/SettingsModalContext.tsx new file mode 100644 index 0000000..dc61230 --- /dev/null +++ b/src/components/settings/SettingsModalContext.tsx @@ -0,0 +1,43 @@ +import { createContext, useContext, useState, useCallback } from "react"; +import { SettingsModal } from "./SettingsModal"; + +export interface SettingsModalContextType { + showSettingsModal: boolean; + openSettingsModal: (section?: string) => void; + closeSettingsModal: () => void; +} + +export const SettingsModalContext = createContext({ + showSettingsModal: false, + openSettingsModal: () => {}, + closeSettingsModal: () => {}, +}); + +export const useSettingsModal = () => useContext(SettingsModalContext); + +export function SettingsModalProvider({ children }: { children: React.ReactNode }) { + const [showSettingsModal, setShowSettingsModal] = useState(false); + const [initialSection, setInitialSection] = useState(); + + const openSettingsModal = useCallback((section?: string) => { + setInitialSection(section); + setShowSettingsModal(true); + }, []); + + const closeSettingsModal = useCallback(() => { + setShowSettingsModal(false); + }, []); + + return ( + + {children} + + + ); +} diff --git a/src/components/shell/rail.tsx b/src/components/shell/rail.tsx index ad09a47..91c01b3 100644 --- a/src/components/shell/rail.tsx +++ b/src/components/shell/rail.tsx @@ -7,6 +7,7 @@ import { CreateWorkspaceDialog } from "@/page/workspace/create"; import { Button } from "@/components/ui/button"; import { cn } from "@/lib/utils"; import { workspaceColor, workspaceInitial } from "./shared"; +import { useSettingsModal } from "@/components/settings/SettingsModalContext"; function WorkspaceIcon({ active, @@ -49,6 +50,20 @@ function WorkspaceIcon({ ); } +function SettingsRailButton() { + const { openSettingsModal } = useSettingsModal(); + return ( + + ); +} + function RailButton({ label, to, @@ -131,9 +146,7 @@ export function WorkspaceRail() {
- - - +
); diff --git a/src/components/shell/settings-sidebar.tsx b/src/components/shell/settings-sidebar.tsx index 3926474..b72d482 100644 --- a/src/components/shell/settings-sidebar.tsx +++ b/src/components/shell/settings-sidebar.tsx @@ -1,4 +1,4 @@ -import { Link, Outlet, useLocation } from "react-router"; +import { NavLink, Outlet, useLocation, useNavigate } from "react-router"; import { Eye, Key, @@ -8,145 +8,120 @@ import { Shield, Terminal, Bell, - ChevronDown, - Search, + XIcon, } from "lucide-react"; import { cn } from "@/lib/utils"; import NavShell from "./rail"; -import { workspaceColor, workspaceInitial } from "./shared"; -import { useAuth } from "@/context/auth-context"; -import { Button } from "@/components/ui/button"; -function SettingsAvatar() { - const { me } = useAuth(); - const name = me?.display_name || me?.username || "User"; +const NAV_SECTIONS = [ + { + label: "Account", + items: [ + { to: "/settings/profile", end: false, icon: PersonStanding, label: "Profile" }, + { to: "/settings/security", end: false, icon: Lock, label: "Security" }, + ], + }, + { + label: "Preferences", + items: [ + { to: "/settings/appearance", end: false, icon: Paintbrush, label: "Appearance" }, + { to: "/settings/notifications", end: false, icon: Bell, label: "Notifications" }, + { to: "/settings/privacy", end: false, icon: Shield, label: "Privacy" }, + { to: "/settings/accessibility", end: false, icon: Eye, label: "Accessibility" }, + ], + }, + { + label: "Developer", + items: [ + { to: "/settings/tokens", end: false, icon: Key, label: "Access Tokens" }, + { to: "/settings/ssh-keys", end: false, icon: Terminal, label: "SSH Keys" }, + ], + }, +]; - return ( - - {me?.avatar_url ? ( - - ) : ( - workspaceInitial(name) - )} - - ); -} - -function SettingsNavLink({ item, active }: { item: { label: string; to: string; icon: React.ReactNode }; active: boolean }) { - return ( - - {item.icon} - {item.label} - - ); -} - -function SettingsSidebar() { - const location = useLocation(); - - const accountItems = [ - { label: "Profile", to: "/settings/profile", icon: }, - { label: "Security", to: "/settings/security", icon: }, - ]; - - const preferenceItems = [ - { label: "Appearance", to: "/settings/appearance", icon: }, - { label: "Notifications", to: "/settings/notifications", icon: }, - { label: "Privacy", to: "/settings/privacy", icon: }, - { label: "Accessibility", to: "/settings/accessibility", icon: }, - ]; - - const developerItems = [ - { label: "Access Tokens", to: "/settings/tokens", icon: }, - { label: "SSH Keys", to: "/settings/ssh-keys", icon: }, - ]; - - return ( - - ); -} +const SETTINGS_RETURN_PATH_KEY = "settings_return_path"; export function SettingsShell() { + const location = useLocation(); + const navigate = useNavigate(); + + const currentSection = NAV_SECTIONS + .flatMap((s) => s.items) + .find((item) => location.pathname === item.to || location.pathname.startsWith(item.to + "/")); + + const handleClose = () => { + const returnPath = localStorage.getItem(SETTINGS_RETURN_PATH_KEY); + localStorage.removeItem(SETTINGS_RETURN_PATH_KEY); + navigate(returnPath || "/me"); + }; + return ( -
- -
-
-
- +
+ {/* Sidebar */} +
+ + {NAV_SECTIONS.map((section, si) => ( +
+
+ {section.label} +
+ {section.items.map((item) => ( + + cn( + "flex items-center gap-2.5 mx-2 px-3 py-1.5 rounded-lg text-[14px] transition-colors", + isActive + ? "bg-accent text-foreground font-medium" + : "text-muted-foreground hover:bg-accent/50 hover:text-foreground", + ) + } + > + + {item.label} + + ))} +
+ ))} + + + {/* Content */} +
+ {/* Top bar */} +
+
+

+ Settings +

+

+ {currentSection?.label || "Settings"} +

+
+ +
+ + {/* Scrollable content */} +
+
+ +
+
+
); -} \ No newline at end of file +} diff --git a/src/page/me/chat-conversation.tsx b/src/page/me/chat-conversation.tsx index 73a9591..8fc9bba 100644 --- a/src/page/me/chat-conversation.tsx +++ b/src/page/me/chat-conversation.tsx @@ -16,6 +16,7 @@ import { import { ModelSelectorPopover } from "@/page/workspace/workplan/chat/model-selector-popover"; import { MessageBubble } from "@/page/workspace/workplan/chat/message-bubble"; import { StreamingView } from "@/page/workspace/workplan/chat/streaming-view"; +import { CodePreviewProvider } from "@/page/workspace/workplan/chat/code-preview-context"; import type { Message, Conversation, @@ -420,6 +421,7 @@ export default function MeChatConversationPage() { return ( +
{/* Header */}
@@ -529,6 +531,7 @@ export default function MeChatConversationPage() { onModelChange={setModelProvider} />
+
); } diff --git a/src/page/workspace/channel/channel-sidebar.tsx b/src/page/workspace/channel/channel-sidebar.tsx index d38c162..2437f04 100644 --- a/src/page/workspace/channel/channel-sidebar.tsx +++ b/src/page/workspace/channel/channel-sidebar.tsx @@ -17,6 +17,7 @@ export type Room = { topic?: string | null; room_type: string; is_private: boolean; + ai_enabled?: boolean; category?: string | null; workspace_id: string; }; diff --git a/src/page/workspace/channel/index.tsx b/src/page/workspace/channel/index.tsx index 30a3cda..ed24a50 100644 --- a/src/page/workspace/channel/index.tsx +++ b/src/page/workspace/channel/index.tsx @@ -1,5 +1,6 @@ import { useCallback, useState } from "react"; import { useParams } from "react-router"; +import { useQueryClient } from "@tanstack/react-query"; import { api } from "@/client"; import { MessageSquare } from "lucide-react"; import { useChannelState } from "./use-channel-state"; @@ -11,6 +12,7 @@ import MessageView from "./message-view"; export default function ChannelPage() { const { roomId } = useParams(); const { state, actions } = useChannelState(roomId); + const queryClient = useQueryClient(); const [showThreads, setShowThreads] = useState(false); const [activeThreadId, setActiveThreadId] = useState(null); const [activeThreadSeq, setActiveThreadSeq] = useState(0); @@ -109,9 +111,10 @@ export default function ChannelPage() { categories={state.categories} categoryId={state.currentRoom.category ?? null} isPrivate={state.currentRoom.is_private} - onDeleted={() => {}} + aiEnabled={state.currentRoom.ai_enabled} + onDeleted={() => queryClient.invalidateQueries({queryKey: ["channel", "rooms"]})} onOpenChange={setShowRoomSettings} - onUpdated={() => {}} + onUpdated={() => queryClient.invalidateQueries({queryKey: ["channel", "rooms"]})} open={showRoomSettings} roomId={roomId} roomName={state.currentRoom.name} diff --git a/src/page/workspace/channel/room-create-dialog.tsx b/src/page/workspace/channel/room-create-dialog.tsx index f27631d..636dc7a 100644 --- a/src/page/workspace/channel/room-create-dialog.tsx +++ b/src/page/workspace/channel/room-create-dialog.tsx @@ -80,6 +80,7 @@ export default function RoomCreateDialog({ setIsPublic(true); setCategoryId(""); setNewCategoryName(""); + setAiEnabled(false); setOpen(false); onCreated?.(); } catch { diff --git a/src/page/workspace/channel/room-settings-dialog.tsx b/src/page/workspace/channel/room-settings-dialog.tsx index 457315d..df39a6a 100644 --- a/src/page/workspace/channel/room-settings-dialog.tsx +++ b/src/page/workspace/channel/room-settings-dialog.tsx @@ -23,6 +23,7 @@ type Props = { roomName: string; topic?: string | null; isPrivate: boolean; + aiEnabled?: boolean; categoryId?: string | null; categories: Category[]; open: boolean; @@ -36,6 +37,7 @@ export default function RoomSettingsDialog({ roomName, topic, isPrivate, + aiEnabled, categoryId, categories, open, @@ -46,6 +48,7 @@ export default function RoomSettingsDialog({ const [name, setName] = useState(roomName); const [topicText, setTopicText] = useState(topic ?? ""); const [isPublic, setIsPublic] = useState(!isPrivate); + const [aiEnabledState, setAiEnabledState] = useState(aiEnabled ?? false); const [category, setCategory] = useState(categoryId ?? ""); const [saving, setSaving] = useState(false); const [deleting, setDeleting] = useState(false); @@ -56,9 +59,10 @@ export default function RoomSettingsDialog({ setName(roomName); setTopicText(topic ?? ""); setIsPublic(!isPrivate); + setAiEnabledState(aiEnabled ?? false); setCategory(categoryId ?? ""); setConfirmDelete(false); - }, [roomName, topic, isPrivate, categoryId, open]); + }, [roomName, topic, isPrivate, aiEnabled, categoryId, open]); const handleSave = useCallback(async () => { const trimmed = name.trim(); @@ -73,6 +77,10 @@ export default function RoomSettingsDialog({ (category || null) !== (categoryId ?? null) ? category || null : undefined, + ai_enabled: + aiEnabledState !== (aiEnabled ?? false) + ? aiEnabledState + : undefined, }); toast({ title: "Room updated" }); onUpdated?.(); @@ -90,6 +98,8 @@ export default function RoomSettingsDialog({ name, topicText, isPublic, + aiEnabledState, + aiEnabled, category, roomId, roomName, @@ -197,6 +207,22 @@ export default function RoomSettingsDialog({ /> +
+
+ +

+ Enable AI responses in this channel +

+
+ +
+
diff --git a/src/page/workspace/channel/use-channel-state.ts b/src/page/workspace/channel/use-channel-state.ts index 8b93fba..e7ca928 100644 --- a/src/page/workspace/channel/use-channel-state.ts +++ b/src/page/workspace/channel/use-channel-state.ts @@ -510,6 +510,7 @@ export function useChannelState(roomId: string | undefined) { case "room_settings_updated": toast({title: "Room settings updated"}); + queryClient.invalidateQueries({queryKey: ["channel", "rooms"]}); break; case "user_banned": { diff --git a/src/socket/index.ts b/src/socket/index.ts index fd3eabe..eb0cdcf 100644 --- a/src/socket/index.ts +++ b/src/socket/index.ts @@ -19,10 +19,6 @@ export type { WsTokenResponse, } from "./types"; export type { - AgentInfo, - AiAgentJoinedService, - AiAgentLeftService, - AiAgentStatusChangedService, AttachmentUploadedService, BannedService, CategoryCreatedService, @@ -35,9 +31,6 @@ export type { ConversationUnpinnedService, ConversationUnreadUpdatedService, CustomStatusUpdatedService, - DmClosedService, - DmCreatedService, - DmReopenedService, DraftClearedService, DraftSavedService, InviteAcceptedService, @@ -68,8 +61,6 @@ export type { ReactionGroup, ReactionRemovedService, ReadReceiptService, - RoomAiEntry, - RoomAiListService, RoomCreatedService, RoomDeletedService, RoomInfo, diff --git a/src/socket/manager.ts b/src/socket/manager.ts index 5aedcf9..1aae7be 100644 --- a/src/socket/manager.ts +++ b/src/socket/manager.ts @@ -585,10 +585,6 @@ const EVENT_TYPE_MAP: Record = { "custom_status.updated": "custom_status_updated", "draft.saved": "draft_saved", "draft.cleared": "draft_cleared", - "ai.agent_joined": "ai_agent_joined", - "ai.agent_left": "ai_agent_left", - "ai.agent_status_changed": "ai_agent_status_changed", - "ai.stop": "ai_stop", "invite.created": "invite_created", "invite.accepted": "invite_accepted", "notify.created": "notify_created", diff --git a/src/socket/schema.ts b/src/socket/schema.ts index a5f0b98..b065e9e 100644 --- a/src/socket/schema.ts +++ b/src/socket/schema.ts @@ -17,13 +17,6 @@ export type WorkspaceInfo = { avatar_url: string; }; -export type AgentInfo = { - id: string; - name: string; - agent_type: string; - model_name: string | null; -}; - // --- Enums --- export type UserPresenceStatus = "online" | "idle" | "dnd" | "offline"; @@ -415,41 +408,6 @@ export type VoiceChannelLeftService = { left_at: string; }; -// --- AI Service structs --- -export type AiAgentJoinedService = { - room: RoomInfo; - agent: AgentInfo; - joined_at: string; -}; - -export type AiAgentLeftService = { - room: RoomInfo; - agent: AgentInfo; - left_at: string; -}; - -export type RoomAiEntry = { - agent_session: string; - name: string; - agent_kind: string; - model_version: string | null; - enabled: boolean; - auto_reply: boolean; -}; - -export type RoomAiListService = { - room: RoomInfo; - agents: RoomAiEntry[]; -}; - -export type AiAgentStatusChangedService = { - room: RoomInfo; - agent: AgentInfo; - old_status: string; - new_status: string; - changed_at: string; -}; - // --- Search Service structs --- export type SearchMessageHitService = MessageNewService & { highlighted_content: string; @@ -522,26 +480,6 @@ export type ConversationSummary = { last_read_at: string | null; }; -// --- Direct messaging (OpenIM SingleChat + Rocket.Chat DM) --- -export type DmCreatedService = { - room: RoomInfo; - initiator: UserInfo; - recipient: UserInfo; - created_at: string; -}; - -export type DmClosedService = { - room: RoomInfo; - closed_by: UserInfo; - closed_at: string; -}; - -export type DmReopenedService = { - room: RoomInfo; - reopened_by: UserInfo; - reopened_at: string; -}; - // --- Per-message read tracking (OpenIM MarkMsgsAsRead) --- export type MessageReadService = { room: RoomInfo; @@ -665,11 +603,6 @@ export type WsOutEvent = | { type: "attachment_uploaded"; data: AttachmentUploadedService } | { type: "user_banned"; data: BannedService } | { type: "user_unbanned"; data: UnbannedService } - | { type: "ai_agent_joined"; data: AiAgentJoinedService } - | { type: "ai_agent_left"; data: AiAgentLeftService } - | { type: "ai_agent_list"; room: RoomInfo; data: RoomAiListService } - | { type: "ai_agent_status_changed"; data: AiAgentStatusChangedService } - | { type: "voice_channel_joined"; data: VoiceChannelJoinedService } | { type: "voice_channel_left"; data: VoiceChannelLeftService } | { type: "conversation_pinned"; room: RoomInfo; data: ConversationPinnedService } | { type: "conversation_unpinned"; room: RoomInfo; data: ConversationUnpinnedService } @@ -677,10 +610,6 @@ export type WsOutEvent = | { type: "conversation_unmuted"; room: RoomInfo; data: ConversationUnmutedService } | { type: "conversation_unread_updated"; room: RoomInfo; data: ConversationUnreadUpdatedService } | { type: "conversation_list"; data: ConversationSummary[] } - | { type: "dm_created"; room: RoomInfo; data: DmCreatedService } - | { type: "dm_closed"; room: RoomInfo; data: DmClosedService } - | { type: "dm_reopened"; room: RoomInfo; data: DmReopenedService } - | { type: "dm_list"; data: DmCreatedService[] } | { type: "message_read"; room: RoomInfo; data: MessageReadService } | { type: "message_read_batch"; room: RoomInfo; data: MessageReadBatchService } | { type: "message_readers"; data: MessageReadersService }