actor(visibilityref): update function visibility and formatting across modules

This commit is contained in:
zhenyi 2026-05-30 22:54:09 +08:00
parent 9ffc7c9fb3
commit f947c931cd
242 changed files with 4861 additions and 3330 deletions

1
Cargo.lock generated
View File

@ -2021,6 +2021,7 @@ dependencies = [
"db", "db",
"futures", "futures",
"hmac 0.13.0", "hmac 0.13.0",
"lazy_static",
"model", "model",
"redis", "redis",
"serde", "serde",

View File

@ -7,13 +7,16 @@ use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{info, warn}; use tracing::{info, warn};
use super::RigStreamChunk;
use super::config::AgentConfig; use super::config::AgentConfig;
use super::helpers::{build_input_string, check_token_budget, estimate_tokens}; use super::helpers::{build_input_string, check_token_budget, estimate_tokens};
use super::hooks::{HookChain, HookLlmResponse, HookMessage, HookToolDef, ToolCallOutcome, ToolGuardrailDecision}; use super::hooks::{
HookChain, HookLlmResponse, HookMessage, HookToolDef, ToolCallOutcome,
ToolGuardrailDecision,
};
use super::persistence::ActiveAgentRun; use super::persistence::ActiveAgentRun;
use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord}; use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord};
use super::subagent::run_experts; use super::subagent::run_experts;
use super::RigStreamChunk;
use crate::client::AiClient; use crate::client::AiClient;
use crate::error::{AiError, AiResult}; use crate::error::{AiError, AiResult};
@ -48,9 +51,7 @@ impl RigAgent {
tools: Vec<Box<dyn ToolDyn>>, tools: Vec<Box<dyn ToolDyn>>,
) -> AiResult<String> { ) -> AiResult<String> {
let (mut rx, handle) = self.run(request, tools); let (mut rx, handle) = self.run(request, tools);
tokio::spawn(async move { tokio::spawn(async move { while rx.recv().await.is_some() {} });
while rx.recv().await.is_some() {}
});
let result = handle.await.map_err(|_| { let result = handle.await.map_err(|_| {
AiError::Response("agent task panicked".to_string()) AiError::Response("agent task panicked".to_string())
})?; })?;
@ -152,15 +153,24 @@ async fn execute_agent_run(
// ---- SubAgent execution ---- // ---- SubAgent execution ----
let expert_outputs = if !request.experts.is_empty() { let expert_outputs = if !request.experts.is_empty() {
let run = ActiveAgentRun { let run = ActiveAgentRun {
conversation_id: request.run_context.as_ref().and_then(|c| c.conversation_id), conversation_id: request
.run_context
.as_ref()
.and_then(|c| c.conversation_id),
message_id: None, message_id: None,
invocation_id: request.run_context.as_ref().and_then(|c| c.invocation_id), invocation_id: request
.run_context
.as_ref()
.and_then(|c| c.invocation_id),
session_id: request.run_context.as_ref().and_then(|c| c.session_id), session_id: request.run_context.as_ref().and_then(|c| c.session_id),
user_id: request.run_context.as_ref().and_then(|c| c.user_id), user_id: request.run_context.as_ref().and_then(|c| c.user_id),
started_at: std::time::Instant::now(), started_at: std::time::Instant::now(),
current_step: 0, current_step: 0,
}; };
let realtime = request.run_context.as_ref().and_then(|c| c.realtime.as_ref()); let realtime = request
.run_context
.as_ref()
.and_then(|c| c.realtime.as_ref());
// Notify frontend that subagents are starting. // Notify frontend that subagents are starting.
for expert in &request.experts { for expert in &request.experts {
@ -173,7 +183,15 @@ async fn execute_agent_run(
.await; .await;
} }
match run_experts(&ai_client, &agent_config, &request.experts, realtime, &run).await { match run_experts(
&ai_client,
&agent_config,
&request.experts,
realtime,
&run,
)
.await
{
Ok(outputs) => { Ok(outputs) => {
for out in &outputs { for out in &outputs {
let _ = tx let _ = tx
@ -252,7 +270,10 @@ async fn execute_agent_run(
Err(_elapsed) => { Err(_elapsed) => {
let _ = tx let _ = tx
.send(RigStreamChunk::Failed { .send(RigStreamChunk::Failed {
error: format!("agent timed out after {}s", dur.as_secs()), error: format!(
"agent timed out after {}s",
dur.as_secs()
),
}) })
.await; .await;
return Err(AiError::Timeout { return Err(AiError::Timeout {
@ -284,7 +305,11 @@ async fn execute_agent_run(
} }
if let Some(limit) = max_total_tokens if let Some(limit) = max_total_tokens
&& check_token_budget(estimated_input_tokens, accumulated_output_chars, limit) && check_token_budget(
estimated_input_tokens,
accumulated_output_chars,
limit,
)
{ {
let _ = tx let _ = tx
.send(RigStreamChunk::Failed { .send(RigStreamChunk::Failed {
@ -317,7 +342,8 @@ async fn execute_agent_run(
)) => { )) => {
for part in &reasoning.content { for part in &reasoning.content {
if let rig::completion::message::ReasoningContent::Text { if let rig::completion::message::ReasoningContent::Text {
text, .. text,
..
} = part } = part
{ {
accumulated_output_chars += text.chars().count(); accumulated_output_chars += text.chars().count();
@ -334,7 +360,8 @@ async fn execute_agent_run(
} }
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::ReasoningDelta { rig::streaming::StreamedAssistantContent::ReasoningDelta {
reasoning, .. reasoning,
..
}, },
)) => { )) => {
accumulated_output_chars += reasoning.chars().count(); accumulated_output_chars += reasoning.chars().count();
@ -363,7 +390,9 @@ async fn execute_agent_run(
let tool_args: serde_json::Value = let tool_args: serde_json::Value =
serde_json::from_str(&args).unwrap_or_default(); serde_json::from_str(&args).unwrap_or_default();
if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await { if let Ok(Some(decision)) =
hooks.run_pre_tool_call(&tool_name, &tool_args).await
{
match decision { match decision {
ToolGuardrailDecision::Allow => {} ToolGuardrailDecision::Allow => {}
ToolGuardrailDecision::Block { reason } => { ToolGuardrailDecision::Block { reason } => {
@ -390,7 +419,9 @@ async fn execute_agent_run(
.send(RigStreamChunk::ToolCallFinished { .send(RigStreamChunk::ToolCallFinished {
tool_call_id: tool_call.id.clone(), tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(), tool_name: tool_name.clone(),
output: format!("awaiting approval: {message}"), output: format!(
"awaiting approval: {message}"
),
error: None, error: None,
}) })
.await; .await;
@ -399,7 +430,9 @@ async fn execute_agent_run(
name: tool_name.clone(), name: tool_name.clone(),
arguments: tool_args.clone(), arguments: tool_args.clone(),
output: None, output: None,
error: Some(format!("requires approval: {message}")), error: Some(format!(
"requires approval: {message}"
)),
elapsed_ms: None, elapsed_ms: None,
}); });
continue; continue;
@ -424,16 +457,22 @@ async fn execute_agent_run(
}); });
} }
Ok(rig::agent::MultiTurnStreamItem::StreamUserItem( Ok(rig::agent::MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, rig::streaming::StreamedUserContent::ToolResult {
tool_result,
..
},
)) => { )) => {
let content = let content = super::helpers::tool_result_content_to_string(
super::helpers::tool_result_content_to_string(&tool_result.content); &tool_result.content,
);
accumulated_output_chars += content.chars().count(); accumulated_output_chars += content.chars().count();
if let Some(last) = current_step_tool_calls.last_mut() if let Some(last) = current_step_tool_calls.last_mut()
&& last.id == tool_result.id && last.id == tool_result.id
{ {
last.output = Some(serde_json::from_str(&content).unwrap_or_default()); last.output = Some(
serde_json::from_str(&content).unwrap_or_default(),
);
} }
let tool_name = current_step_tool_calls let tool_name = current_step_tool_calls
@ -464,15 +503,21 @@ async fn execute_agent_run(
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => { Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage(); let usage = resp.usage();
if !current_step_tool_calls.is_empty() || !current_step_assistant.is_empty() { if !current_step_tool_calls.is_empty()
|| !current_step_assistant.is_empty()
{
let reasoning = (!current_step_reasoning.is_empty()) let reasoning = (!current_step_reasoning.is_empty())
.then_some(std::mem::take(&mut current_step_reasoning)); .then_some(std::mem::take(&mut current_step_reasoning));
steps.push(AgentStep { steps.push(AgentStep {
index: steps.len(), index: steps.len(),
assistant: (!current_step_assistant.is_empty()) assistant: (!current_step_assistant.is_empty())
.then_some(std::mem::take(&mut current_step_assistant)), .then_some(std::mem::take(
&mut current_step_assistant,
)),
reasoning_content: reasoning, reasoning_content: reasoning,
tool_calls: std::mem::take(&mut current_step_tool_calls), tool_calls: std::mem::take(
&mut current_step_tool_calls,
),
reflection: None, reflection: None,
}); });
} }
@ -533,7 +578,9 @@ async fn execute_agent_run(
} }
} }
Err(AiError::Response("agent stream ended without final response".to_string())) Err(AiError::Response(
"agent stream ended without final response".to_string(),
))
} }
impl Clone for HookChain { impl Clone for HookChain {

View File

@ -65,7 +65,10 @@ impl CompressionStrategy {
self self
} }
pub fn with_custom_instructions(mut self, instructions: impl Into<String>) -> Self { pub fn with_custom_instructions(
mut self,
instructions: impl Into<String>,
) -> Self {
self.custom_instructions = Some(instructions.into()); self.custom_instructions = Some(instructions.into());
self self
} }
@ -91,7 +94,11 @@ pub struct CompactionResult {
} }
impl CompactionResult { impl CompactionResult {
pub fn new(summary: String, messages_compacted: usize, tokens_saved: i64) -> Self { pub fn new(
summary: String,
messages_compacted: usize,
tokens_saved: i64,
) -> Self {
Self { Self {
summary, summary,
messages_compacted, messages_compacted,
@ -115,7 +122,12 @@ pub fn build_compression_prompt(
existing_summary: Option<&str>, existing_summary: Option<&str>,
messages_text: &str, messages_text: &str,
) -> String { ) -> String {
build_compression_prompt_with_options(existing_summary, messages_text, None, 1500) build_compression_prompt_with_options(
existing_summary,
messages_text,
None,
1500,
)
} }
/// Build the compaction prompt with custom instructions and word limit. /// Build the compaction prompt with custom instructions and word limit.

View File

@ -132,7 +132,10 @@ impl AgentConfig {
self self
} }
pub fn with_max_completion_tokens(mut self, max_completion_tokens: Option<u64>) -> Self { pub fn with_max_completion_tokens(
mut self,
max_completion_tokens: Option<u64>,
) -> Self {
self.max_completion_tokens = max_completion_tokens; self.max_completion_tokens = max_completion_tokens;
self self
} }
@ -142,19 +145,31 @@ impl AgentConfig {
self self
} }
pub fn with_toolset_policy(mut self, enabled: Vec<String>, disabled: Vec<String>) -> Self { pub fn with_toolset_policy(
mut self,
enabled: Vec<String>,
disabled: Vec<String>,
) -> Self {
self.enabled_toolsets = enabled; self.enabled_toolsets = enabled;
self.disabled_toolsets = disabled; self.disabled_toolsets = disabled;
self self
} }
pub fn with_tool_policy(mut self, allowed_tools: Vec<String>, denied_tools: Vec<String>) -> Self { pub fn with_tool_policy(
mut self,
allowed_tools: Vec<String>,
denied_tools: Vec<String>,
) -> Self {
self.allowed_tools = allowed_tools; self.allowed_tools = allowed_tools;
self.denied_tools = denied_tools; self.denied_tools = denied_tools;
self self
} }
pub fn with_retry(mut self, max_attempts: usize, base_delay_ms: u64) -> Self { pub fn with_retry(
mut self,
max_attempts: usize,
base_delay_ms: u64,
) -> Self {
self.retry_max_attempts = max_attempts; self.retry_max_attempts = max_attempts;
self.retry_base_delay_ms = base_delay_ms; self.retry_base_delay_ms = base_delay_ms;
self self
@ -165,7 +180,10 @@ impl AgentConfig {
self self
} }
pub fn with_fallback_model(mut self, fallback_model: impl Into<String>) -> Self { pub fn with_fallback_model(
mut self,
fallback_model: impl Into<String>,
) -> Self {
self.fallback_model = Some(fallback_model.into()); self.fallback_model = Some(fallback_model.into());
self self
} }

View File

@ -44,7 +44,8 @@ impl RetryPolicy {
let half = (ms as f64 * 0.25) as u64; let half = (ms as f64 * 0.25) as u64;
let lo = ms.saturating_sub(half); let lo = ms.saturating_sub(half);
let hi = ms.saturating_add(half); let hi = ms.saturating_add(half);
let mix = ((attempt as u64).wrapping_mul(1_103_515_245)) % (hi - lo + 1); let mix =
((attempt as u64).wrapping_mul(1_103_515_245)) % (hi - lo + 1);
lo + mix lo + mix
} else { } else {
ms ms
@ -58,17 +59,26 @@ impl RetryPolicy {
/// ///
/// Inspects both the HTTP status code (when available) and the error message /// Inspects both the HTTP status code (when available) and the error message
/// content to determine the most appropriate category. /// content to determine the most appropriate category.
pub fn classify_error(error: &AiError, http_status: Option<u16>) -> ErrorCategory { pub fn classify_error(
error: &AiError,
http_status: Option<u16>,
) -> ErrorCategory {
// HTTP status-based classification takes precedence // HTTP status-based classification takes precedence
let from_status = match http_status { let from_status = match http_status {
Some(429) => Some(ErrorCategory::Retryable { Some(429) => Some(ErrorCategory::Retryable {
reason: "rate limited (HTTP 429)".to_string(), reason: "rate limited (HTTP 429)".to_string(),
}), }),
Some(401) | Some(403) => Some(ErrorCategory::FallbackModel { Some(401) | Some(403) => Some(ErrorCategory::FallbackModel {
reason: format!("authentication failed (HTTP {})", http_status.unwrap()), reason: format!(
"authentication failed (HTTP {})",
http_status.unwrap()
),
}), }),
Some(502) | Some(503) => Some(ErrorCategory::Overloaded { Some(502) | Some(503) => Some(ErrorCategory::Overloaded {
reason: format!("provider unavailable (HTTP {})", http_status.unwrap()), reason: format!(
"provider unavailable (HTTP {})",
http_status.unwrap()
),
}), }),
Some(504) => Some(ErrorCategory::Timeout), Some(504) => Some(ErrorCategory::Timeout),
Some(413) => Some(ErrorCategory::ContextWindowExceeded { Some(413) => Some(ErrorCategory::ContextWindowExceeded {
@ -90,7 +100,9 @@ pub fn classify_error(error: &AiError, http_status: Option<u16>) -> ErrorCategor
// Message-based classification // Message-based classification
match error { match error {
AiError::Timeout { .. } => ErrorCategory::Timeout, AiError::Timeout { .. } => ErrorCategory::Timeout,
AiError::TokenBudgetExceeded { .. } => ErrorCategory::TokenBudgetExceeded, AiError::TokenBudgetExceeded { .. } => {
ErrorCategory::TokenBudgetExceeded
}
AiError::Api(msg) => classify_api_message(msg), AiError::Api(msg) => classify_api_message(msg),
AiError::Response(msg) => classify_response_message(msg), AiError::Response(msg) => classify_response_message(msg),
AiError::ModelRetriesExhausted { .. } => ErrorCategory::Fatal { AiError::ModelRetriesExhausted { .. } => ErrorCategory::Fatal {
@ -107,7 +119,10 @@ fn classify_api_message(msg: &str) -> ErrorCategory {
let lower = msg.to_lowercase(); let lower = msg.to_lowercase();
// Rate limiting // Rate limiting
if lower.contains("rate") || lower.contains("too many requests") || lower.contains("throttl") { if lower.contains("rate")
|| lower.contains("too many requests")
|| lower.contains("throttl")
{
return ErrorCategory::Retryable { return ErrorCategory::Retryable {
reason: msg.to_string(), reason: msg.to_string(),
}; };
@ -213,15 +228,15 @@ pub fn retry_policy_for(
exponential: false, exponential: false,
switch_to_fallback: false, switch_to_fallback: false,
}, },
ErrorCategory::TokenBudgetExceeded | ErrorCategory::Cancelled | ErrorCategory::Fatal { .. } => { ErrorCategory::TokenBudgetExceeded
RetryPolicy { | ErrorCategory::Cancelled
max_attempts: 0, | ErrorCategory::Fatal { .. } => RetryPolicy {
base_delay: Duration::from_millis(0), max_attempts: 0,
jitter: false, base_delay: Duration::from_millis(0),
exponential: false, jitter: false,
switch_to_fallback: false, exponential: false,
} switch_to_fallback: false,
} },
} }
} }

View File

@ -140,7 +140,9 @@ impl EventSink {
} }
/// Subscribe to events, returns a receiver. /// Subscribe to events, returns a receiver.
pub fn subscribe(&mut self) -> tokio::sync::mpsc::UnboundedReceiver<AgentEvent> { pub fn subscribe(
&mut self,
) -> tokio::sync::mpsc::UnboundedReceiver<AgentEvent> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
self.senders.push(tx); self.senders.push(tx);
rx rx

View File

@ -69,7 +69,9 @@ where
match f().await { match f().await {
Ok(result) => return Ok(result), Ok(result) => return Ok(result),
Err(e) if is_retryable(&e) && attempt + 1 < max_attempts => { Err(e) if is_retryable(&e) && attempt + 1 < max_attempts => {
let delay = Duration::from_millis(base_delay_ms * 2u64.pow(attempt as u32)); let delay = Duration::from_millis(
base_delay_ms * 2u64.pow(attempt as u32),
);
tracing::warn!( tracing::warn!(
error = %e, error = %e,
attempt = attempt + 1, attempt = attempt + 1,
@ -94,12 +96,16 @@ where
fn is_retryable(error: &AiError) -> bool { fn is_retryable(error: &AiError) -> bool {
matches!( matches!(
error, error,
AiError::Api(_) | AiError::Response(_) | AiError::ModelRetriesExhausted { .. } AiError::Api(_)
| AiError::Response(_)
| AiError::ModelRetriesExhausted { .. }
) )
} }
pub fn tool_result_content_to_string( pub fn tool_result_content_to_string(
content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>, content: &rig::one_or_many::OneOrMany<
rig::completion::message::ToolResultContent,
>,
) -> String { ) -> String {
use rig::completion::message::ToolResultContent; use rig::completion::message::ToolResultContent;
content content

View File

@ -51,11 +51,19 @@ pub trait AgentHook: Send + Sync {
Ok(()) Ok(())
} }
async fn on_session_end(&self, _ctx: &AgentRunContext, _success: bool) -> AiResult<()> { async fn on_session_end(
&self,
_ctx: &AgentRunContext,
_success: bool,
) -> AiResult<()> {
Ok(()) Ok(())
} }
async fn pre_llm_call(&self, _messages: &[HookMessage], _tools: &[HookToolDef]) -> AiResult<()> { async fn pre_llm_call(
&self,
_messages: &[HookMessage],
_tools: &[HookToolDef],
) -> AiResult<()> {
Ok(()) Ok(())
} }
@ -93,28 +101,42 @@ impl HookChain {
self.hooks.is_empty() self.hooks.is_empty()
} }
pub async fn run_session_start(&self, ctx: &AgentRunContext) -> AiResult<()> { pub async fn run_session_start(
&self,
ctx: &AgentRunContext,
) -> AiResult<()> {
for hook in &self.hooks { for hook in &self.hooks {
hook.on_session_start(ctx).await?; hook.on_session_start(ctx).await?;
} }
Ok(()) Ok(())
} }
pub async fn run_session_end(&self, ctx: &AgentRunContext, success: bool) -> AiResult<()> { pub async fn run_session_end(
&self,
ctx: &AgentRunContext,
success: bool,
) -> AiResult<()> {
for hook in &self.hooks { for hook in &self.hooks {
hook.on_session_end(ctx, success).await?; hook.on_session_end(ctx, success).await?;
} }
Ok(()) Ok(())
} }
pub async fn run_pre_llm_call(&self, messages: &[HookMessage], tools: &[HookToolDef]) -> AiResult<()> { pub async fn run_pre_llm_call(
&self,
messages: &[HookMessage],
tools: &[HookToolDef],
) -> AiResult<()> {
for hook in &self.hooks { for hook in &self.hooks {
hook.pre_llm_call(messages, tools).await?; hook.pre_llm_call(messages, tools).await?;
} }
Ok(()) Ok(())
} }
pub async fn run_post_llm_call(&self, response: &HookLlmResponse) -> AiResult<()> { pub async fn run_post_llm_call(
&self,
response: &HookLlmResponse,
) -> AiResult<()> {
for hook in &self.hooks { for hook in &self.hooks {
hook.post_llm_call(response).await?; hook.post_llm_call(response).await?;
} }
@ -127,7 +149,9 @@ impl HookChain {
arguments: &Value, arguments: &Value,
) -> AiResult<Option<ToolGuardrailDecision>> { ) -> AiResult<Option<ToolGuardrailDecision>> {
for hook in &self.hooks { for hook in &self.hooks {
if let Some(decision) = hook.pre_tool_call(tool_name, arguments).await? { if let Some(decision) =
hook.pre_tool_call(tool_name, arguments).await?
{
if !matches!(decision, ToolGuardrailDecision::Allow) { if !matches!(decision, ToolGuardrailDecision::Allow) {
return Ok(Some(decision)); return Ok(Some(decision));
} }
@ -136,7 +160,10 @@ impl HookChain {
Ok(None) Ok(None)
} }
pub async fn run_post_tool_call(&self, outcome: &ToolCallOutcome) -> AiResult<()> { pub async fn run_post_tool_call(
&self,
outcome: &ToolCallOutcome,
) -> AiResult<()> {
for hook in &self.hooks { for hook in &self.hooks {
hook.post_tool_call(outcome).await?; hook.post_tool_call(outcome).await?;
} }

View File

@ -11,16 +11,19 @@ use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{info, warn}; use tracing::{info, warn};
use super::RigStreamChunk;
use super::config::AgentConfig; use super::config::AgentConfig;
use super::error_classifier::{ use super::error_classifier::{
classify_error, retry_policy_for, should_switch_to_fallback, classify_error, retry_policy_for, should_switch_to_fallback,
}; };
use super::events::{AgentEvent, EventSink}; use super::events::{AgentEvent, EventSink};
use super::helpers::{build_input_string, estimate_tokens}; use super::helpers::{build_input_string, estimate_tokens};
use super::hooks::{HookChain, HookLlmResponse, HookMessage, ToolCallOutcome, ToolGuardrailDecision}; use super::hooks::{
HookChain, HookLlmResponse, HookMessage, ToolCallOutcome,
ToolGuardrailDecision,
};
use super::iteration_budget::IterationBudget; use super::iteration_budget::IterationBudget;
use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord}; use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord};
use super::RigStreamChunk;
use crate::client::AiClient; use crate::client::AiClient;
use crate::error::{AiError, AiResult}; use crate::error::{AiError, AiResult};
@ -50,13 +53,13 @@ pub type FollowUpFn = Arc<
>; >;
/// Callback to decide whether the agent should stop after a turn. /// Callback to decide whether the agent should stop after a turn.
pub type ShouldStopFn = Arc< pub type ShouldStopFn = Arc<dyn Fn(&TurnContext) -> bool + Send + Sync>;
dyn Fn(&TurnContext) -> bool + Send + Sync,
>;
/// Callback to prepare/modify state before the next turn. /// Callback to prepare/modify state before the next turn.
pub type PrepareNextTurnFn = Arc< pub type PrepareNextTurnFn = Arc<
dyn Fn(&TurnContext) -> Pin<Box<dyn Future<Output = Option<TurnUpdate>> + Send>> dyn Fn(
&TurnContext,
) -> Pin<Box<dyn Future<Output = Option<TurnUpdate>> + Send>>
+ Send + Send
+ Sync, + Sync,
>; >;
@ -144,7 +147,10 @@ pub struct EnhancedAgent {
} }
impl EnhancedAgent { impl EnhancedAgent {
pub fn new(client: AiClient, loop_config: AgentLoopConfig) -> AiResult<Self> { pub fn new(
client: AiClient,
loop_config: AgentLoopConfig,
) -> AiResult<Self> {
loop_config.config.validate()?; loop_config.config.validate()?;
Ok(Self { Ok(Self {
client, client,
@ -270,7 +276,11 @@ async fn run_enhanced_loop(
loop { loop {
// Check cancellation // Check cancellation
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) { if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
let _ = tx.send(RigStreamChunk::Failed { error: "cancelled".to_string() }).await; let _ = tx
.send(RigStreamChunk::Failed {
error: "cancelled".to_string(),
})
.await;
if let Some(sink) = &event_sink { if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ErrorClassified { sink.emit(AgentEvent::ErrorClassified {
category: "cancelled".to_string(), category: "cancelled".to_string(),
@ -279,7 +289,9 @@ async fn run_enhanced_loop(
retry_delay_ms: None, retry_delay_ms: None,
}); });
} }
return Err(AiError::Response("agent run cancelled".to_string())); return Err(AiError::Response(
"agent run cancelled".to_string(),
));
} }
// Inject steering messages if any // Inject steering messages if any
@ -298,10 +310,12 @@ async fn run_enhanced_loop(
if let Some(sink) = &event_sink { if let Some(sink) = &event_sink {
sink.emit(AgentEvent::TurnStart { turn_index }); sink.emit(AgentEvent::TurnStart { turn_index });
} }
let _ = tx.send(RigStreamChunk::TextDelta { let _ = tx
index: 0, .send(RigStreamChunk::TextDelta {
content: String::new(), // placeholder for turn boundary detection index: 0,
}).await; content: String::new(), // placeholder for turn boundary detection
})
.await;
// Run one LLM turn with retry // Run one LLM turn with retry
let turn_result = run_single_turn( let turn_result = run_single_turn(
@ -325,7 +339,9 @@ async fn run_enhanced_loop(
// Collect step // Collect step
let tool_call_count = turn_output.tool_calls.len(); let tool_call_count = turn_output.tool_calls.len();
if !turn_output.tool_calls.is_empty() || !turn_output.assistant_text.is_empty() { if !turn_output.tool_calls.is_empty()
|| !turn_output.assistant_text.is_empty()
{
all_steps.push(AgentStep { all_steps.push(AgentStep {
index: all_steps.len(), index: all_steps.len(),
assistant: (!turn_output.assistant_text.is_empty()) assistant: (!turn_output.assistant_text.is_empty())
@ -340,7 +356,9 @@ async fn run_enhanced_loop(
if let Some(sink) = &event_sink { if let Some(sink) = &event_sink {
sink.emit(AgentEvent::TurnEnd { sink.emit(AgentEvent::TurnEnd {
turn_index, turn_index,
assistant_text: Some(turn_output.assistant_text.clone()), assistant_text: Some(
turn_output.assistant_text.clone(),
),
tool_call_count, tool_call_count,
}); });
} }
@ -357,7 +375,10 @@ async fn run_enhanced_loop(
if let Some(stop_fn) = &should_stop { if let Some(stop_fn) = &should_stop {
if stop_fn(&turn_ctx) { if stop_fn(&turn_ctx) {
info!(turn_index, "agent stopped by should_stop callback"); info!(
turn_index,
"agent stopped by should_stop callback"
);
break; break;
} }
} }
@ -378,7 +399,8 @@ async fn run_enhanced_loop(
if let Some(temp) = update.temperature { if let Some(temp) = update.temperature {
config.temperature = Some(temp); config.temperature = Some(temp);
} }
if let Some(max_tok) = update.max_completion_tokens { if let Some(max_tok) = update.max_completion_tokens
{
config.max_completion_tokens = Some(max_tok); config.max_completion_tokens = Some(max_tok);
} }
} }
@ -397,14 +419,21 @@ async fn run_enhanced_loop(
Err(e) => { Err(e) => {
// Error classification and retry with fallback // Error classification and retry with fallback
let category = classify_error(&e, None); let category = classify_error(&e, None);
let policy = retry_policy_for(&category, config.retry_max_attempts, config.retry_base_delay_ms); let policy = retry_policy_for(
&category,
config.retry_max_attempts,
config.retry_base_delay_ms,
);
if let Some(sink) = &event_sink { if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ErrorClassified { sink.emit(AgentEvent::ErrorClassified {
category: format!("{category:?}"), category: format!("{category:?}"),
message: e.to_string(), message: e.to_string(),
will_retry: policy.switch_to_fallback || policy.max_attempts > 0, will_retry: policy.switch_to_fallback
retry_delay_ms: Some(policy.base_delay.as_millis() as u64), || policy.max_attempts > 0,
retry_delay_ms: Some(
policy.base_delay.as_millis() as u64
),
}); });
} }
@ -443,16 +472,20 @@ async fn run_enhanced_loop(
match retry_result { match retry_result {
Ok(turn_output) => { Ok(turn_output) => {
total_input_tokens += turn_output.input_tokens; total_input_tokens +=
total_output_tokens += turn_output.output_tokens; turn_output.input_tokens;
total_output_tokens +=
turn_output.output_tokens;
let tc_count = turn_output.tool_calls.len(); let tc_count = turn_output.tool_calls.len();
let has_tools = tc_count > 0; let has_tools = tc_count > 0;
let has_text = !turn_output.assistant_text.is_empty(); let has_text =
!turn_output.assistant_text.is_empty();
let assistant = turn_output.assistant_text; let assistant = turn_output.assistant_text;
if has_tools || has_text { if has_tools || has_text {
all_steps.push(AgentStep { all_steps.push(AgentStep {
index: all_steps.len(), index: all_steps.len(),
assistant: has_text.then_some(assistant.clone()), assistant: has_text
.then_some(assistant.clone()),
reasoning_content: None, reasoning_content: None,
tool_calls: turn_output.tool_calls, tool_calls: turn_output.tool_calls,
reflection: None, reflection: None,
@ -472,7 +505,9 @@ async fn run_enhanced_loop(
}) })
.await; .await;
if let Some(ctx) = &request.run_context { if let Some(ctx) = &request.run_context {
let _ = hooks.run_session_end(ctx, false).await; let _ = hooks
.run_session_end(ctx, false)
.await;
} }
return Err(retry_err); return Err(retry_err);
} }
@ -582,7 +617,9 @@ async fn run_single_turn(
tx: &mpsc::Sender<RigStreamChunk>, tx: &mpsc::Sender<RigStreamChunk>,
) -> AiResult<TurnOutput> { ) -> AiResult<TurnOutput> {
if !budget.consume() { if !budget.consume() {
return Err(AiError::Response("iteration budget exhausted".to_string())); return Err(AiError::Response(
"iteration budget exhausted".to_string(),
));
} }
let model = client.completion_model(&config.model); let model = client.completion_model(&config.model);
@ -674,7 +711,11 @@ async fn run_single_turn(
rig::streaming::StreamedAssistantContent::Reasoning(reasoning), rig::streaming::StreamedAssistantContent::Reasoning(reasoning),
)) => { )) => {
for part in &reasoning.content { for part in &reasoning.content {
if let rig::completion::message::ReasoningContent::Text { text, .. } = part { if let rig::completion::message::ReasoningContent::Text {
text,
..
} = part
{
_accumulated_output_chars += text.chars().count(); _accumulated_output_chars += text.chars().count();
if let Some(sink) = &event_sink { if let Some(sink) = &event_sink {
sink.emit(AgentEvent::MessageThinkingDelta { sink.emit(AgentEvent::MessageThinkingDelta {
@ -693,7 +734,10 @@ async fn run_single_turn(
} }
} }
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::ReasoningDelta { reasoning, .. }, rig::streaming::StreamedAssistantContent::ReasoningDelta {
reasoning,
..
},
)) => { )) => {
_accumulated_output_chars += reasoning.chars().count(); _accumulated_output_chars += reasoning.chars().count();
if let Some(sink) = &event_sink { if let Some(sink) = &event_sink {
@ -711,7 +755,10 @@ async fn run_single_turn(
delta_index += 1; delta_index += 1;
} }
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::ToolCall { tool_call, .. }, rig::streaming::StreamedAssistantContent::ToolCall {
tool_call,
..
},
)) => { )) => {
let args = match &tool_call.function.arguments { let args = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(), serde_json::Value::String(s) => s.clone(),
@ -724,7 +771,9 @@ async fn run_single_turn(
serde_json::from_str(&args).unwrap_or_default(); serde_json::from_str(&args).unwrap_or_default();
// Pre-tool-call guardrail hook // Pre-tool-call guardrail hook
if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await { if let Ok(Some(decision)) =
hooks.run_pre_tool_call(&tool_name, &tool_args).await
{
match decision { match decision {
ToolGuardrailDecision::Allow => {} ToolGuardrailDecision::Allow => {}
ToolGuardrailDecision::Block { reason } => { ToolGuardrailDecision::Block { reason } => {
@ -761,7 +810,9 @@ async fn run_single_turn(
name: tool_name.clone(), name: tool_name.clone(),
arguments: tool_args, arguments: tool_args,
output: None, output: None,
error: Some(format!("requires approval: {message}")), error: Some(format!(
"requires approval: {message}"
)),
elapsed_ms: None, elapsed_ms: None,
}); });
continue; continue;
@ -794,10 +845,14 @@ async fn run_single_turn(
}); });
} }
Ok(rig::agent::MultiTurnStreamItem::StreamUserItem( Ok(rig::agent::MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, rig::streaming::StreamedUserContent::ToolResult {
tool_result,
..
},
)) => { )) => {
let content = let content = super::helpers::tool_result_content_to_string(
super::helpers::tool_result_content_to_string(&tool_result.content); &tool_result.content,
);
_accumulated_output_chars += content.chars().count(); _accumulated_output_chars += content.chars().count();
let tool_name = tool_calls let tool_name = tool_calls
@ -808,14 +863,18 @@ async fn run_single_turn(
if let Some(last) = tool_calls.last_mut() if let Some(last) = tool_calls.last_mut()
&& last.id == tool_result.id && last.id == tool_result.id
{ {
last.output = Some(serde_json::from_str(&content).unwrap_or_default()); last.output = Some(
serde_json::from_str(&content).unwrap_or_default(),
);
} }
if let Some(sink) = &event_sink { if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ToolExecutionEnd { sink.emit(AgentEvent::ToolExecutionEnd {
tool_call_id: tool_result.id.clone(), tool_call_id: tool_result.id.clone(),
tool_name: tool_name.clone(), tool_name: tool_name.clone(),
output: Some(serde_json::Value::String(content.clone())), output: Some(serde_json::Value::String(
content.clone(),
)),
error: None, error: None,
elapsed_ms: 0, elapsed_ms: 0,
}); });
@ -872,5 +931,3 @@ async fn run_single_turn(
output_tokens, output_tokens,
}) })
} }

View File

@ -37,13 +37,12 @@ impl RigAgent {
} }
let agent = builder.build(); let agent = builder.build();
let response = agent let response =
.prompt(&ui) agent.prompt(&ui).extended_details().await.map_err(
.extended_details() |e: rig::completion::PromptError| {
.await AiError::Api(e.to_string())
.map_err(|e: rig::completion::PromptError| { },
AiError::Api(e.to_string()) )?;
})?;
Ok(( Ok((
response.output, response.output,

View File

@ -63,8 +63,13 @@ impl SystemPromptBuilder {
} }
/// Add a one-line tool description snippet. /// Add a one-line tool description snippet.
pub fn tool_snippet(mut self, tool_name: impl Into<String>, description: impl Into<String>) -> Self { pub fn tool_snippet(
self.tool_snippets.push((tool_name.into(), description.into())); mut self,
tool_name: impl Into<String>,
description: impl Into<String>,
) -> Self {
self.tool_snippets
.push((tool_name.into(), description.into()));
self self
} }
@ -75,7 +80,11 @@ impl SystemPromptBuilder {
} }
/// Add a project context file (e.g., AGENTS.md content). /// Add a project context file (e.g., AGENTS.md content).
pub fn project_context(mut self, path: impl Into<String>, content: impl Into<String>) -> Self { pub fn project_context(
mut self,
path: impl Into<String>,
content: impl Into<String>,
) -> Self {
self.project_contexts.push((path.into(), content.into())); self.project_contexts.push((path.into(), content.into()));
self self
} }
@ -87,13 +96,20 @@ impl SystemPromptBuilder {
} }
/// Set a variable for {{key}} substitution. /// Set a variable for {{key}} substitution.
pub fn variable(mut self, key: impl Into<String>, value: impl Into<String>) -> Self { pub fn variable(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.variables.insert(key.into(), value.into()); self.variables.insert(key.into(), value.into());
self self
} }
/// Set multiple variables from an iterator. /// Set multiple variables from an iterator.
pub fn variables(mut self, vars: impl IntoIterator<Item = (String, String)>) -> Self { pub fn variables(
mut self,
vars: impl IntoIterator<Item = (String, String)>,
) -> Self {
self.variables.extend(vars); self.variables.extend(vars);
self self
} }
@ -105,7 +121,11 @@ impl SystemPromptBuilder {
} }
/// Add a custom named section to the prompt. /// Add a custom named section to the prompt.
pub fn custom_section(mut self, name: impl Into<String>, content: impl Into<String>) -> Self { pub fn custom_section(
mut self,
name: impl Into<String>,
content: impl Into<String>,
) -> Self {
self.custom_sections.push((name.into(), content.into())); self.custom_sections.push((name.into(), content.into()));
self self
} }
@ -142,7 +162,8 @@ impl SystemPromptBuilder {
// 4. Project context files // 4. Project context files
if !self.project_contexts.is_empty() { if !self.project_contexts.is_empty() {
let mut section = String::from("\n<project_context>\n\n"); let mut section = String::from("\n<project_context>\n\n");
section.push_str("Project-specific instructions and guidelines:\n\n"); section
.push_str("Project-specific instructions and guidelines:\n\n");
for (path, content) in &self.project_contexts { for (path, content) in &self.project_contexts {
section.push_str(&format!("<project_instructions path=\"{path}\">\n{content}\n</project_instructions>\n\n")); section.push_str(&format!("<project_instructions path=\"{path}\">\n{content}\n</project_instructions>\n\n"));
} }

View File

@ -38,7 +38,9 @@ impl AgentRequest {
pub fn validate(&self) -> AiResult<()> { pub fn validate(&self) -> AiResult<()> {
if self.input.trim().is_empty() { if self.input.trim().is_empty() {
return Err(AiError::Config("agent request input is required".to_string())); return Err(AiError::Config(
"agent request input is required".to_string(),
));
} }
if self.input.len() > 1_000_000 { if self.input.len() > 1_000_000 {
return Err(AiError::Config( return Err(AiError::Config(
@ -83,12 +85,18 @@ impl AgentRequest {
self self
} }
pub fn with_prefill_messages(mut self, prefill_messages: Vec<rig::completion::Message>) -> Self { pub fn with_prefill_messages(
mut self,
prefill_messages: Vec<rig::completion::Message>,
) -> Self {
self.prefill_messages = prefill_messages; self.prefill_messages = prefill_messages;
self self
} }
pub fn with_cancellation_token(mut self, cancellation_token: CancellationToken) -> Self { pub fn with_cancellation_token(
mut self,
cancellation_token: CancellationToken,
) -> Self {
self.cancellation_token = Some(cancellation_token); self.cancellation_token = Some(cancellation_token);
self self
} }
@ -119,7 +127,11 @@ pub struct AgentExpert {
} }
impl AgentExpert { impl AgentExpert {
pub fn new(id: impl Into<String>, role: impl Into<String>, task: impl Into<String>) -> Self { pub fn new(
id: impl Into<String>,
role: impl Into<String>,
task: impl Into<String>,
) -> Self {
Self { Self {
id: id.into(), id: id.into(),
role: role.into(), role: role.into(),
@ -131,7 +143,10 @@ impl AgentExpert {
} }
} }
pub fn with_system_prompt(mut self, system_prompt: impl Into<String>) -> Self { pub fn with_system_prompt(
mut self,
system_prompt: impl Into<String>,
) -> Self {
self.system_prompt = Some(system_prompt.into()); self.system_prompt = Some(system_prompt.into());
self self
} }

View File

@ -145,7 +145,10 @@ impl SessionEntry {
} }
/// Create a user message entry. /// Create a user message entry.
pub fn user_message(parent_id: Option<Uuid>, content: impl Into<String>) -> Self { pub fn user_message(
parent_id: Option<Uuid>,
content: impl Into<String>,
) -> Self {
Self::Message { Self::Message {
id: Uuid::new_v4(), id: Uuid::new_v4(),
parent_id, parent_id,
@ -328,7 +331,13 @@ impl Session {
pub fn active_messages(&self) -> Vec<&SessionEntry> { pub fn active_messages(&self) -> Vec<&SessionEntry> {
self.active_branch() self.active_branch()
.into_iter() .into_iter()
.filter(|e| matches!(e, SessionEntry::Message { .. } | SessionEntry::Compaction { .. })) .filter(|e| {
matches!(
e,
SessionEntry::Message { .. }
| SessionEntry::Compaction { .. }
)
})
.collect() .collect()
} }
@ -342,11 +351,8 @@ impl Session {
/// Get all leaf entries (entries with no children). /// Get all leaf entries (entries with no children).
pub fn leaves(&self) -> Vec<&SessionEntry> { pub fn leaves(&self) -> Vec<&SessionEntry> {
let parent_ids: std::collections::HashSet<Uuid> = self let parent_ids: std::collections::HashSet<Uuid> =
.entries self.entries.iter().filter_map(|e| e.parent_id()).collect();
.iter()
.filter_map(|e| e.parent_id())
.collect();
self.entries self.entries
.iter() .iter()
@ -367,7 +373,9 @@ impl Session {
.iter() .iter()
.position(|e| e.id() == fork_entry_id) .position(|e| e.id() == fork_entry_id)
.ok_or_else(|| { .ok_or_else(|| {
AiError::Config(format!("fork entry {fork_entry_id} not found in session")) AiError::Config(format!(
"fork entry {fork_entry_id} not found in session"
))
})?; })?;
let mut new_session = Session::new(); let mut new_session = Session::new();
@ -445,9 +453,11 @@ fn iso_now() -> String {
// Simple ISO 8601 format (UTC) // Simple ISO 8601 format (UTC)
let days = secs / 86400; let days = secs / 86400;
let years = (days * 400) / 146097; let years = (days * 400) / 146097;
let remaining_days = days - (years * 365 + years / 4 - years / 100 + years / 400); let remaining_days =
days - (years * 365 + years / 4 - years / 100 + years / 400);
let month_days = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]; let month_days = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
let is_leap = (years % 4 == 0 && years % 100 != 0) || years % 400 == 0; let is_leap =
(years % 4 == 0 && years % 100 != 0) || years % 400 == 0;
let mut month = 0usize; let mut month = 0usize;
let mut day_acc = remaining_days as i64; let mut day_acc = remaining_days as i64;
for (i, &md) in month_days.iter().enumerate() { for (i, &md) in month_days.iter().enumerate() {
@ -487,7 +497,8 @@ mod tests {
let msg1_id = msg1.id(); let msg1_id = msg1.id();
session.push(msg1); session.push(msg1);
let msg2 = SessionEntry::assistant_message(Some(msg1_id), "Hi there!", None); let msg2 =
SessionEntry::assistant_message(Some(msg1_id), "Hi there!", None);
session.push(msg2); session.push(msg2);
assert_eq!(session.entry_count(), 2); assert_eq!(session.entry_count(), 2);
@ -502,7 +513,8 @@ mod tests {
let msg1_id = msg1.id(); let msg1_id = msg1.id();
session.push(msg1); session.push(msg1);
let msg2 = SessionEntry::assistant_message(Some(msg1_id), "Reply 1", None); let msg2 =
SessionEntry::assistant_message(Some(msg1_id), "Reply 1", None);
let msg2_id = msg2.id(); let msg2_id = msg2.id();
session.push(msg2); session.push(msg2);
@ -524,8 +536,10 @@ mod tests {
session.push(msg1); session.push(msg1);
// Two children branching from root // Two children branching from root
let msg2a = SessionEntry::assistant_message(Some(msg1_id), "Branch A", None); let msg2a =
let msg2b = SessionEntry::assistant_message(Some(msg1_id), "Branch B", None); SessionEntry::assistant_message(Some(msg1_id), "Branch A", None);
let msg2b =
SessionEntry::assistant_message(Some(msg1_id), "Branch B", None);
session.push(msg2a); session.push(msg2a);
session.push(msg2b); session.push(msg2b);

View File

@ -31,13 +31,24 @@ pub async fn run_experts(
} }
Err(error) => { Err(error) => {
warn!(subagent_id = %expert.id, role = %expert.role, error = %error, "subagent failed"); warn!(subagent_id = %expert.id, role = %expert.role, error = %error, "subagent failed");
let _ = publish_subagent_failed(realtime, run, expert, &error.to_string()).await; let _ = publish_subagent_failed(
realtime,
run,
expert,
&error.to_string(),
)
.await;
failed_count += 1; failed_count += 1;
} }
} }
} }
debug!(total = experts.len(), ok = outputs.len(), failed = failed_count, "experts done"); debug!(
total = experts.len(),
ok = outputs.len(),
failed = failed_count,
"experts done"
);
Ok(outputs) Ok(outputs)
} }
@ -53,7 +64,9 @@ async fn run_single(
let rig_client = client.llm_client().clone(); let rig_client = client.llm_client().clone();
let model_name = config.model.clone(); let model_name = config.model.clone();
let temperature = expert.temperature.or(config.temperature); let temperature = expert.temperature.or(config.temperature);
let max_completion_tokens = expert.max_completion_tokens.or(config.max_completion_tokens); let max_completion_tokens = expert
.max_completion_tokens
.or(config.max_completion_tokens);
let retry_attempts = config.retry_max_attempts; let retry_attempts = config.retry_max_attempts;
let retry_delay_ms = config.retry_base_delay_ms; let retry_delay_ms = config.retry_base_delay_ms;
@ -66,10 +79,8 @@ async fn run_single(
let task = build_expert_task(expert); let task = build_expert_task(expert);
let (output, input_tokens_usage, output_tokens_usage) = with_retry( let (output, input_tokens_usage, output_tokens_usage) =
retry_attempts, with_retry(retry_attempts, retry_delay_ms, || {
retry_delay_ms,
|| {
let rig_client = rig_client.clone(); let rig_client = rig_client.clone();
let model_name = model_name.clone(); let model_name = model_name.clone();
let prompt = prompt.clone(); let prompt = prompt.clone();
@ -85,13 +96,12 @@ async fn run_single(
} }
let agent = builder.build(); let agent = builder.build();
let response = agent let response =
.prompt(&task) agent.prompt(&task).extended_details().await.map_err(
.extended_details() |e: rig::completion::PromptError| {
.await AiError::Api(e.to_string())
.map_err(|e: rig::completion::PromptError| { },
AiError::Api(e.to_string()) )?;
})?;
Ok(( Ok((
response.output, response.output,
@ -99,9 +109,8 @@ async fn run_single(
response.usage.output_tokens, response.usage.output_tokens,
)) ))
} }
}, })
) .await?;
.await?;
let input_tokens = input_tokens_usage as i64; let input_tokens = input_tokens_usage as i64;
let output_tokens = if output_tokens_usage > 0 { let output_tokens = if output_tokens_usage > 0 {
@ -150,17 +159,19 @@ async fn publish_subagent_started(
config: &AgentConfig, config: &AgentConfig,
expert: &AgentExpert, expert: &AgentExpert,
) -> AiResult<()> { ) -> AiResult<()> {
AgentRuntime::default().publish( AgentRuntime::default()
realtime, .publish(
&AgentStreamEvent::SubagentStarted { realtime,
conversation_id: run.conversation_id, &AgentStreamEvent::SubagentStarted {
message_id: run.message_id, conversation_id: run.conversation_id,
subagent_id: expert.id.clone(), message_id: run.message_id,
role: expert.role.clone(), subagent_id: expert.id.clone(),
task: expert.task.clone(), role: expert.role.clone(),
model: config.model.clone(), task: expert.task.clone(),
}, model: config.model.clone(),
).await },
)
.await
} }
async fn publish_subagent_completed( async fn publish_subagent_completed(
@ -169,20 +180,22 @@ async fn publish_subagent_completed(
config: &AgentConfig, config: &AgentConfig,
output: &AgentExpertOutput, output: &AgentExpertOutput,
) -> AiResult<()> { ) -> AiResult<()> {
AgentRuntime::default().publish( AgentRuntime::default()
realtime, .publish(
&AgentStreamEvent::SubagentCompleted { realtime,
conversation_id: run.conversation_id, &AgentStreamEvent::SubagentCompleted {
message_id: run.message_id, conversation_id: run.conversation_id,
subagent_id: output.id.clone(), message_id: run.message_id,
role: output.role.clone(), subagent_id: output.id.clone(),
task: output.task.clone(), role: output.role.clone(),
output: output.output.clone(), task: output.task.clone(),
input_tokens: output.input_tokens, output: output.output.clone(),
output_tokens: output.output_tokens, input_tokens: output.input_tokens,
model: config.model.clone(), output_tokens: output.output_tokens,
}, model: config.model.clone(),
).await },
)
.await
} }
async fn publish_subagent_failed( async fn publish_subagent_failed(
@ -191,13 +204,15 @@ async fn publish_subagent_failed(
expert: &AgentExpert, expert: &AgentExpert,
error: &str, error: &str,
) -> AiResult<()> { ) -> AiResult<()> {
AgentRuntime::default().publish( AgentRuntime::default()
realtime, .publish(
&AgentStreamEvent::SubagentFailed { realtime,
conversation_id: run.conversation_id, &AgentStreamEvent::SubagentFailed {
message_id: run.message_id, conversation_id: run.conversation_id,
subagent_id: expert.id.clone(), message_id: run.message_id,
error: error.to_string(), subagent_id: expert.id.clone(),
}, error: error.to_string(),
).await },
)
.await
} }

View File

@ -23,7 +23,10 @@ impl<C> RigTool<C>
where where
C: Clone + Send + Sync + 'static, C: Clone + Send + Sync + 'static,
{ {
pub fn new(tool: Arc<dyn FunctionCall<Context = C>>, context: Arc<Mutex<C>>) -> Self { pub fn new(
tool: Arc<dyn FunctionCall<Context = C>>,
context: Arc<Mutex<C>>,
) -> Self {
let name = tool.name().to_string(); let name = tool.name().to_string();
let description = tool.description().to_string(); let description = tool.description().to_string();
let schema = tool.schema(); let schema = tool.schema();
@ -49,7 +52,8 @@ where
fn definition<'a>( fn definition<'a>(
&'a self, &'a self,
_prompt: String, _prompt: String,
) -> Pin<Box<dyn std::future::Future<Output = RigToolDefinition> + Send + 'a>> { ) -> Pin<Box<dyn std::future::Future<Output = RigToolDefinition> + Send + 'a>>
{
let name = self.name.clone(); let name = self.name.clone();
let description = self.description.clone(); let description = self.description.clone();
let params = self.schema.clone(); let params = self.schema.clone();
@ -67,23 +71,28 @@ where
&'a self, &'a self,
args: String, args: String,
) -> Pin< ) -> Pin<
Box<dyn std::future::Future<Output = Result<String, rig::tool::ToolError>> + Send + 'a>, Box<
dyn std::future::Future<
Output = Result<String, rig::tool::ToolError>,
> + Send
+ 'a,
>,
> { > {
let tool = self.tool.clone(); let tool = self.tool.clone();
let context = self.context.clone(); let context = self.context.clone();
Box::pin(async move { Box::pin(async move {
let args_value: Value = let args_value: Value = serde_json::from_str(&args)
serde_json::from_str(&args).map_err(rig::tool::ToolError::JsonError)?; .map_err(rig::tool::ToolError::JsonError)?;
let mut ctx = context.lock().await; let mut ctx = context.lock().await;
match tool.call(&mut *ctx, args_value).await { match tool.call(&mut *ctx, args_value).await {
Ok(value) => serde_json::to_string(&value) Ok(value) => serde_json::to_string(&value)
.map_err(rig::tool::ToolError::JsonError), .map_err(rig::tool::ToolError::JsonError),
Err(ai_err) => Err(rig::tool::ToolError::ToolCallError(Box::new( Err(ai_err) => Err(rig::tool::ToolError::ToolCallError(
std::io::Error::other(ai_err.to_string()), Box::new(std::io::Error::other(ai_err.to_string())),
))), )),
} }
}) })
} }
@ -112,10 +121,14 @@ where
register: &crate::tool::register::ToolRegister<C>, register: &crate::tool::register::ToolRegister<C>,
context: Arc<Mutex<C>>, context: Arc<Mutex<C>>,
) -> Self { ) -> Self {
let mut tools: Vec<Box<dyn ToolDyn + 'static>> = Vec::with_capacity(register.len()); let mut tools: Vec<Box<dyn ToolDyn + 'static>> =
Vec::with_capacity(register.len());
for tool_arc in &register.tools { for tool_arc in &register.tools {
tools.push(Box::new(RigTool::new(tool_arc.clone(), context.clone()))); tools.push(Box::new(RigTool::new(
tool_arc.clone(),
context.clone(),
)));
} }
Self { Self {

View File

@ -24,7 +24,10 @@ pub struct EndpointConfig {
} }
impl EndpointConfig { impl EndpointConfig {
pub fn new(base_url: impl Into<String>, api_key: impl Into<String>) -> AiResult<Self> { pub fn new(
base_url: impl Into<String>,
api_key: impl Into<String>,
) -> AiResult<Self> {
let config = Self { let config = Self {
base_url: base_url.into(), base_url: base_url.into(),
api_key: api_key.into(), api_key: api_key.into(),
@ -51,7 +54,11 @@ impl EndpointConfig {
.api_key(&self.api_key) .api_key(&self.api_key)
.base_url(self.base_url.trim()) .base_url(self.base_url.trim())
.build() .build()
.map_err(|e| AiError::Config(format!("failed to build rig OpenAI client: {e}"))) .map_err(|e| {
AiError::Config(format!(
"failed to build rig OpenAI client: {e}"
))
})
} }
} }

View File

@ -1,7 +1,10 @@
use rig::client::EmbeddingsClient; use rig::client::EmbeddingsClient;
use rig::embeddings::EmbeddingModel; use rig::embeddings::EmbeddingModel;
use crate::{client::AiClient, error::{AiError, AiResult}}; use crate::{
client::AiClient,
error::{AiError, AiResult},
};
#[derive(Clone)] #[derive(Clone)]
pub struct EmbedClient { pub struct EmbedClient {
@ -23,23 +26,32 @@ impl EmbedClient {
pub async fn embed_text(&self, text: String) -> AiResult<Vec<f32>> { pub async fn embed_text(&self, text: String) -> AiResult<Vec<f32>> {
let model = self.embedding_model(); let model = self.embedding_model();
let mut embeddings = model.embed_texts(vec![text]) let mut embeddings = model
.embed_texts(vec![text])
.await .await
.map_err(|e| AiError::Api(e.to_string()))?; .map_err(|e| AiError::Api(e.to_string()))?;
embeddings.pop() embeddings
.pop()
.map(|e| e.vec.into_iter().map(|v| v as f32).collect()) .map(|e| e.vec.into_iter().map(|v| v as f32).collect())
.ok_or_else(|| AiError::Response("no embedding returned".to_string())) .ok_or_else(|| {
AiError::Response("no embedding returned".to_string())
})
} }
pub async fn embed_texts(&self, texts: Vec<String>) -> AiResult<Vec<Vec<f32>>> { pub async fn embed_texts(
&self,
texts: Vec<String>,
) -> AiResult<Vec<Vec<f32>>> {
if texts.is_empty() { if texts.is_empty() {
return Ok(Vec::new()); return Ok(Vec::new());
} }
let model = self.embedding_model(); let model = self.embedding_model();
let embeddings = model.embed_texts(texts) let embeddings = model
.embed_texts(texts)
.await .await
.map_err(|e| AiError::Api(e.to_string()))?; .map_err(|e| AiError::Api(e.to_string()))?;
Ok(embeddings.into_iter() Ok(embeddings
.into_iter()
.map(|e| e.vec.into_iter().map(|v| v as f32).collect()) .map(|e| e.vec.into_iter().map(|v| v as f32).collect())
.collect()) .collect())
} }
@ -55,11 +67,15 @@ impl EmbedClient {
let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len()); let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
for chunk in texts.chunks(batch_size) { for chunk in texts.chunks(batch_size) {
let model = self.embedding_model(); let model = self.embedding_model();
let chunk_embeddings = model.embed_texts(chunk.to_vec()) let chunk_embeddings = model
.embed_texts(chunk.to_vec())
.await .await
.map_err(|e| AiError::Api(e.to_string()))?; .map_err(|e| AiError::Api(e.to_string()))?;
embeddings.extend(chunk_embeddings.into_iter() embeddings.extend(
.map(|e| e.vec.into_iter().map(|v| v as f32).collect())); chunk_embeddings
.into_iter()
.map(|e| e.vec.into_iter().map(|v| v as f32).collect()),
);
} }
Ok(embeddings) Ok(embeddings)
} }

View File

@ -24,10 +24,7 @@ pub enum AiError {
Response(String), Response(String),
#[error("model retries exhausted after {attempts} attempts: {last_error}")] #[error("model retries exhausted after {attempts} attempts: {last_error}")]
ModelRetriesExhausted { ModelRetriesExhausted { attempts: usize, last_error: String },
attempts: usize,
last_error: String,
},
#[error("agent timeout after {seconds}s")] #[error("agent timeout after {seconds}s")]
Timeout { seconds: u64 }, Timeout { seconds: u64 },

View File

@ -34,10 +34,7 @@ pub trait MemoryProvider: Send + Sync {
) -> AiResult<Vec<MemoryEntry>> { ) -> AiResult<Vec<MemoryEntry>> {
Ok(Vec::new()) Ok(Vec::new())
} }
async fn build_context_block( async fn build_context_block(&self, _session_id: Uuid) -> AiResult<String> {
&self,
_session_id: Uuid,
) -> AiResult<String> {
Ok(String::new()) Ok(String::new())
} }
async fn setup(&self) -> AiResult<()> { async fn setup(&self) -> AiResult<()> {

View File

@ -42,10 +42,7 @@ impl RagClient {
}) })
} }
pub fn connect( pub fn connect(ai_client: &AiClient, config: RagConfig) -> AiResult<Self> {
ai_client: &AiClient,
config: RagConfig,
) -> AiResult<Self> {
config.validate()?; config.validate()?;
let mut builder = let mut builder =
Qdrant::from_url(config.url.trim()).timeout(config.timeout); Qdrant::from_url(config.url.trim()).timeout(config.timeout);
@ -132,10 +129,8 @@ impl RagClient {
validate_session_id(session_id)?; validate_session_id(session_id)?;
validate_documents(&documents)?; validate_documents(&documents)?;
let texts: Vec<String> = documents let texts: Vec<String> =
.iter() documents.iter().map(|d| d.content.clone()).collect();
.map(|d| d.content.clone())
.collect();
let vectors = self let vectors = self
.embedder .embedder
.embed_texts_chunked(texts, self.config.upsert_batch_size) .embed_texts_chunked(texts, self.config.upsert_batch_size)

View File

@ -20,8 +20,8 @@ pub(super) fn point_id(session_id: &str, document_id: &str) -> u64 {
let uuid = Uuid::new_v5(&ns, key.as_bytes()); let uuid = Uuid::new_v5(&ns, key.as_bytes());
let bytes = uuid.as_bytes(); let bytes = uuid.as_bytes();
u64::from_be_bytes([ u64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6],
bytes[4], bytes[5], bytes[6], bytes[7], bytes[7],
]) ])
} }

View File

@ -75,7 +75,9 @@ static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
} }
} }
#[allow(clippy::expect_used)] #[allow(clippy::expect_used)]
builder.build().expect("failed to build reqwest HTTP client — check system TLS configuration") builder.build().expect(
"failed to build reqwest HTTP client — check system TLS configuration",
)
}); });
pub async fn list_models( pub async fn list_models(
config: &EndpointConfig, config: &EndpointConfig,
@ -102,12 +104,14 @@ pub async fn list_models(
AiError::Response(format!("failed to list models: {}", e)) AiError::Response(format!("failed to list models: {}", e))
})?; })?;
let body = resp let body = resp.text().await.map_err(|e| {
.text() AiError::Response(format!("failed to read models body: {}", e))
.await })?;
.map_err(|e| AiError::Response(format!("failed to read models body: {}", e)))?;
if let Ok(parsed) = serde_json::from_str::<ModelsListResponse>(&body) { if let Ok(parsed) = serde_json::from_str::<ModelsListResponse>(&body) {
debug!(count = parsed.data.len(), "parsed models in standard format"); debug!(
count = parsed.data.len(),
"parsed models in standard format"
);
return Ok(parsed.data); return Ok(parsed.data);
} }
if let Ok(parsed) = serde_json::from_str::<Vec<UpstreamModel>>(&body) { if let Ok(parsed) = serde_json::from_str::<Vec<UpstreamModel>>(&body) {

View File

@ -27,7 +27,10 @@ impl Toolset {
self self
} }
pub fn with_tools(mut self, tool_names: impl IntoIterator<Item = impl Into<String>>) -> Self { pub fn with_tools(
mut self,
tool_names: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.tools.extend(tool_names.into_iter().map(Into::into)); self.tools.extend(tool_names.into_iter().map(Into::into));
self self
} }
@ -36,7 +39,8 @@ impl Toolset {
mut self, mut self,
env_vars: impl IntoIterator<Item = impl Into<String>>, env_vars: impl IntoIterator<Item = impl Into<String>>,
) -> Self { ) -> Self {
self.requires_env.extend(env_vars.into_iter().map(Into::into)); self.requires_env
.extend(env_vars.into_iter().map(Into::into));
self self
} }
pub fn is_available(&self) -> bool { pub fn is_available(&self) -> bool {

View File

@ -1,7 +1,8 @@
use actix_web::{HttpResponse, web, web::ServiceConfig}; use actix_web::{HttpResponse, web, web::ServiceConfig};
use service::AppService; use service::AppService;
use service::agent::conversation::{ use service::agent::conversation::{
ConversationResponse, ConversationWithSessionResponse, CreateConversation, MessageResponse, UpdateConversation, ConversationResponse, ConversationWithSessionResponse, CreateConversation,
MessageResponse, UpdateConversation,
}; };
use service::agent::types::{AgentRunRequest, AgentRunResponse}; use service::agent::types::{AgentRunRequest, AgentRunResponse};
use session::Session; use session::Session;
@ -53,8 +54,14 @@ pub async fn list_conversations(
service: web::Data<AppService>, service: web::Data<AppService>,
path: web::Path<Uuid>, path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_json(service.agent_conversation_list(user_id, path.into_inner()).await?) .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_conversation_list(user_id, path.into_inner())
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
@ -70,10 +77,16 @@ pub async fn create_conversation(
path: web::Path<Uuid>, path: web::Path<Uuid>,
body: web::Json<CreateConversation>, body: web::Json<CreateConversation>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
.user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json( ok_json(
service service
.agent_conversation_create(user_id, path.into_inner(), body.into_inner()) .agent_conversation_create(
user_id,
path.into_inner(),
body.into_inner(),
)
.await?, .await?,
) )
} }
@ -94,7 +107,9 @@ pub async fn list_all_conversations(
service: web::Data<AppService>, service: web::Data<AppService>,
query: web::Query<ListAllConversationsQuery>, query: web::Query<ListAllConversationsQuery>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
.user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json( ok_json(
service service
.agent_conversation_list_all(user_id, query.wk.as_deref()) .agent_conversation_list_all(user_id, query.wk.as_deref())
@ -113,8 +128,14 @@ pub async fn get_conversation(
service: web::Data<AppService>, service: web::Data<AppService>,
path: web::Path<Uuid>, path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_json(service.agent_conversation_get(user_id, path.into_inner()).await?) .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_conversation_get(user_id, path.into_inner())
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
@ -130,10 +151,16 @@ pub async fn update_conversation(
path: web::Path<Uuid>, path: web::Path<Uuid>,
body: web::Json<UpdateConversation>, body: web::Json<UpdateConversation>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
.user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json( ok_json(
service service
.agent_conversation_update(user_id, path.into_inner(), body.into_inner()) .agent_conversation_update(
user_id,
path.into_inner(),
body.into_inner(),
)
.await?, .await?,
) )
} }
@ -149,8 +176,12 @@ pub async fn delete_conversation(
service: web::Data<AppService>, service: web::Data<AppService>,
path: web::Path<Uuid>, path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
service.agent_conversation_delete(user_id, path.into_inner()).await?; .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
service
.agent_conversation_delete(user_id, path.into_inner())
.await?;
Ok(HttpResponse::Ok().json(serde_json::json!({ "deleted": true }))) Ok(HttpResponse::Ok().json(serde_json::json!({ "deleted": true })))
} }
@ -165,8 +196,14 @@ pub async fn archive_conversation(
service: web::Data<AppService>, service: web::Data<AppService>,
path: web::Path<Uuid>, path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_json(service.agent_conversation_archive(user_id, path.into_inner()).await?) .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_conversation_archive(user_id, path.into_inner())
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
@ -180,8 +217,14 @@ pub async fn unarchive_conversation(
service: web::Data<AppService>, service: web::Data<AppService>,
path: web::Path<Uuid>, path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_json(service.agent_conversation_unarchive(user_id, path.into_inner()).await?) .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_conversation_unarchive(user_id, path.into_inner())
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
get, path = "/api/v1/agent/conversations/{id}/messages", get, path = "/api/v1/agent/conversations/{id}/messages",
@ -195,10 +238,17 @@ pub async fn list_messages(
path: web::Path<Uuid>, path: web::Path<Uuid>,
query: web::Query<MessageListQuery>, query: web::Query<MessageListQuery>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
.user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json( ok_json(
service service
.agent_message_list(user_id, path.into_inner(), query.limit, query.before) .agent_message_list(
user_id,
path.into_inner(),
query.limit,
query.before,
)
.await?, .await?,
) )
} }
@ -221,7 +271,9 @@ pub async fn send_message(
path: web::Path<Uuid>, path: web::Path<Uuid>,
body: web::Json<AgentRunRequest>, body: web::Json<AgentRunRequest>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
.user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
let conversation_id = path.into_inner(); let conversation_id = path.into_inner();
let mut req = body.into_inner(); let mut req = body.into_inner();
req.conversation_id = Some(conversation_id); req.conversation_id = Some(conversation_id);
@ -240,7 +292,9 @@ pub async fn stream_agent(
path: web::Path<Uuid>, path: web::Path<Uuid>,
body: web::Json<AgentRunRequest>, body: web::Json<AgentRunRequest>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
.user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
let conversation_id = path.into_inner(); let conversation_id = path.into_inner();
let mut req = body.into_inner(); let mut req = body.into_inner();
req.conversation_id = Some(conversation_id); req.conversation_id = Some(conversation_id);
@ -282,7 +336,9 @@ pub async fn fork_conversation(
path: web::Path<Uuid>, path: web::Path<Uuid>,
body: web::Json<ForkConversationRequest>, body: web::Json<ForkConversationRequest>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
.user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json( ok_json(
service service
.agent_conversation_fork( .agent_conversation_fork(

View File

@ -15,8 +15,7 @@ pub fn configure(cfg: &mut ServiceConfig) {
.route(web::post().to(create_session)), .route(web::post().to(create_session)),
) )
.service( .service(
web::resource("/sessions/search") web::resource("/sessions/search").route(web::get().to(search_sessions)),
.route(web::get().to(search_sessions)),
) )
.service( .service(
web::resource("/sessions/{id}") web::resource("/sessions/{id}")
@ -38,7 +37,9 @@ pub async fn list_sessions(
session: Session, session: Session,
service: web::Data<AppService>, service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
.user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(service.agent_session_list(user_id).await?) ok_json(service.agent_session_list(user_id).await?)
} }
#[utoipa::path( #[utoipa::path(
@ -52,8 +53,14 @@ pub async fn create_session(
service: web::Data<AppService>, service: web::Data<AppService>,
body: web::Json<CreateAgentSession>, body: web::Json<CreateAgentSession>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_json(service.agent_session_create(user_id, body.into_inner()).await?) .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_session_create(user_id, body.into_inner())
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
get, path = "/api/v1/agent/sessions/{id}", get, path = "/api/v1/agent/sessions/{id}",
@ -66,8 +73,14 @@ pub async fn get_session(
service: web::Data<AppService>, service: web::Data<AppService>,
path: web::Path<Uuid>, path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_json(service.agent_session_get(user_id, path.into_inner()).await?) .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_session_get(user_id, path.into_inner())
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
patch, path = "/api/v1/agent/sessions/{id}", patch, path = "/api/v1/agent/sessions/{id}",
@ -82,8 +95,14 @@ pub async fn update_session(
path: web::Path<Uuid>, path: web::Path<Uuid>,
body: web::Json<UpdateAgentSession>, body: web::Json<UpdateAgentSession>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_json(service.agent_session_update(user_id, path.into_inner(), body.into_inner()).await?) .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_session_update(user_id, path.into_inner(), body.into_inner())
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
delete, path = "/api/v1/agent/sessions/{id}", delete, path = "/api/v1/agent/sessions/{id}",
@ -96,8 +115,12 @@ pub async fn delete_session(
service: web::Data<AppService>, service: web::Data<AppService>,
path: web::Path<Uuid>, path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
service.agent_session_delete(user_id, path.into_inner()).await?; .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
service
.agent_session_delete(user_id, path.into_inner())
.await?;
Ok(HttpResponse::Ok().json(serde_json::json!({ "deleted": true }))) Ok(HttpResponse::Ok().json(serde_json::json!({ "deleted": true })))
} }
@ -122,7 +145,9 @@ pub async fn search_sessions(
service: web::Data<AppService>, service: web::Data<AppService>,
query: web::Query<SearchQuery>, query: web::Query<SearchQuery>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
.user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json( ok_json(
service service
.agent_session_search(user_id, &query.q, query.limit) .agent_session_search(user_id, &query.q, query.limit)
@ -148,7 +173,9 @@ pub async fn update_session_toolsets(
path: web::Path<Uuid>, path: web::Path<Uuid>,
body: web::Json<UpdateToolsetsRequest>, body: web::Json<UpdateToolsetsRequest>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
.user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json( ok_json(
service service
.agent_session_update_toolsets( .agent_session_update_toolsets(

View File

@ -37,7 +37,9 @@ pub fn configure(cfg: &mut ServiceConfig) {
web::post().to(reset_pass::reset_password_verify), web::post().to(reset_pass::reset_password_verify),
)), )),
) )
.service(web::resource("/public-key").route(web::get().to(rsa::rsa))) .service(
web::resource("/public-key").route(web::get().to(rsa::rsa)),
)
.service( .service(
web::scope("/2fa") web::scope("/2fa")
.service( .service(

View File

@ -1,9 +1,9 @@
pub mod rest; pub mod rest;
pub mod rest_ai;
pub mod rest_interact; pub mod rest_interact;
pub mod rest_member; pub mod rest_member;
pub mod rest_message; pub mod rest_message;
pub mod rest_room; pub mod rest_room;
pub mod rest_user;
pub mod rest_voice; pub mod rest_voice;
pub mod token; pub mod token;
@ -65,8 +65,9 @@ pub fn configure(cfg: &mut ServiceConfig, bus: ChannelBus) {
.route(actix_web::web::post().to(rest_room::access_grant)), .route(actix_web::web::post().to(rest_room::access_grant)),
) )
.service( .service(
actix_web::web::resource("/workspaces/{workspace_id}/members") actix_web::web::resource("/workspaces/{workspace_id}/members").route(
.route(actix_web::web::get().to(rest_member::list_workspace_members)), actix_web::web::get().to(rest_member::list_workspace_members),
),
) )
.service( .service(
actix_web::web::resource("/rooms/{room_id}/members/{user_id}") actix_web::web::resource("/rooms/{room_id}/members/{user_id}")
@ -184,21 +185,8 @@ pub fn configure(cfg: &mut ServiceConfig, bus: ChannelBus) {
.route(actix_web::web::post().to(rest_voice::screen_share)), .route(actix_web::web::post().to(rest_voice::screen_share)),
); );
cfg.service( cfg.service(
actix_web::web::resource("/rooms/{room_id}/ai/stop")
.route(actix_web::web::post().to(rest_ai::ai_stop)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/ai")
.route(actix_web::web::get().to(rest_ai::ai_list))
.route(actix_web::web::post().to(rest_ai::ai_add)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/ai/{agent_session_id}")
.route(actix_web::web::delete().to(rest_ai::ai_remove)),
)
.service(
actix_web::web::resource("/users/summary/{username}") actix_web::web::resource("/users/summary/{username}")
.route(actix_web::web::get().to(rest_ai::user_summary)), .route(actix_web::web::get().to(rest_user::user_summary)),
); );
cfg.service( cfg.service(
actix_web::web::resource("/token") actix_web::web::resource("/token")

View File

@ -6,13 +6,13 @@ use uuid::Uuid;
use crate::error::ApiError; use crate::error::ApiError;
pub(crate) fn extract_user(req: &HttpRequest) -> Result<Uuid, ApiError> { pub fn extract_user(req: &HttpRequest) -> Result<Uuid, ApiError> {
req.get_session() req.get_session()
.user() .user()
.ok_or_else(|| ApiError(service::error::AppError::Unauthorized)) .ok_or_else(|| ApiError(service::error::AppError::Unauthorized))
} }
pub(crate) fn channel_err(e: ChannelError) -> ApiError { pub fn channel_err(e: ChannelError) -> ApiError {
ApiError(match e { ApiError(match e {
ChannelError::Unauthorized | ChannelError::TokenInvalidOrExpired => { ChannelError::Unauthorized | ChannelError::TokenInvalidOrExpired => {
service::error::AppError::Unauthorized service::error::AppError::Unauthorized
@ -61,14 +61,14 @@ pub(crate) fn channel_err(e: ChannelError) -> ApiError {
}) })
} }
pub(crate) fn ok_json(event: Option<WsOutEvent>) -> HttpResponse { pub fn ok_json(event: Option<WsOutEvent>) -> HttpResponse {
match event { match event {
Some(e) => HttpResponse::Ok().json(e), Some(e) => HttpResponse::Ok().json(e),
None => HttpResponse::NoContent().finish(), None => HttpResponse::NoContent().finish(),
} }
} }
pub(crate) fn created_json(event: Option<WsOutEvent>) -> HttpResponse { pub fn created_json(event: Option<WsOutEvent>) -> HttpResponse {
match event { match event {
Some(e) => HttpResponse::Created().json(e), Some(e) => HttpResponse::Created().json(e),
None => HttpResponse::NoContent().finish(), None => HttpResponse::NoContent().finish(),

View File

@ -1,120 +0,0 @@
use actix_web::{HttpRequest, HttpResponse, web};
use channel::ChannelBus;
use channel::http::{WsHandler, WsInMessage};
use serde::Deserialize;
use uuid::Uuid;
use super::rest::{channel_err, created_json, extract_user, ok_json};
use crate::error::ApiError;
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct AiAddRequest {
pub agent_session: Uuid,
}
#[utoipa::path(
get,
path = "/api/v1/ws/rooms/{room_id}/ai",
responses((status = 200, description = "AI agents in room")),
tag = "channel",
)]
pub async fn ai_list(
req: HttpRequest,
room_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::AiList {
room: room_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/ai",
request_body = AiAddRequest,
responses((status = 201, description = "AI agent added to room")),
tag = "channel",
)]
pub async fn ai_add(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<AiAddRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::AiUpsert {
room: room_id.into_inner(),
model: body.agent_session,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(created_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/rooms/{room_id}/ai/{agent_session_id}",
responses((status = 200, description = "AI agent removed from room")),
tag = "channel",
)]
pub async fn ai_remove(
req: HttpRequest,
path: web::Path<(Uuid, Uuid)>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let (room, agent_id) = path.into_inner();
let msg = WsInMessage::AiDelete { room, agent_id };
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/ai/stop",
responses((status = 204, description = "AI agent stopped")),
tag = "channel",
)]
pub async fn ai_stop(
req: HttpRequest,
room_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::AiStop {
room: room_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
get,
path = "/api/v1/ws/users/summary/{username}",
responses((status = 200, description = "User summary")),
tag = "channel",
)]
pub async fn user_summary(
req: HttpRequest,
username: web::Path<String>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::UserSummary {
username: username.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}

View File

@ -360,7 +360,10 @@ pub async fn list_workspace_members(
let _user_id = extract_user(&req)?; let _user_id = extract_user(&req)?;
let workspace = workspace_id.into_inner(); let workspace = workspace_id.into_inner();
let members = bus.list_workspace_members(workspace).await.map_err(channel_err)?; let members = bus
.list_workspace_members(workspace)
.await
.map_err(channel_err)?;
let result: Vec<RoomMember> = members let result: Vec<RoomMember> = members
.into_iter() .into_iter()
.map(|(id, username, display_name, avatar_url)| RoomMember { .map(|(id, username, display_name, avatar_url)| RoomMember {

View File

@ -13,6 +13,7 @@ pub struct RoomCreateRequest {
pub room_name: String, pub room_name: String,
pub public: bool, pub public: bool,
pub category: Option<Uuid>, pub category: Option<Uuid>,
pub ai_enabled: Option<bool>,
} }
#[derive(Debug, Deserialize, utoipa::ToSchema)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
@ -20,6 +21,7 @@ pub struct RoomUpdateRequest {
pub room_name: Option<String>, pub room_name: Option<String>,
pub public: Option<bool>, pub public: Option<bool>,
pub category: Option<Uuid>, pub category: Option<Uuid>,
pub ai_enabled: Option<bool>,
} }
#[derive(Debug, Deserialize, utoipa::ToSchema)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
@ -49,10 +51,9 @@ pub async fn list_rooms(
bus: web::Data<ChannelBus>, bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?; let user_id = extract_user(&req)?;
let rooms = bus.list_user_rooms(user_id) let rooms = bus.list_user_rooms(user_id).await.map_err(channel_err)?;
.await let categories = bus
.map_err(channel_err)?; .list_user_categories(user_id)
let categories = bus.list_user_categories(user_id)
.await .await
.map_err(channel_err)?; .map_err(channel_err)?;
let workspace_id = if let Some(r) = rooms.first() { let workspace_id = if let Some(r) = rooms.first() {
@ -148,6 +149,7 @@ pub async fn room_create(
room_name: body.room_name.clone(), room_name: body.room_name.clone(),
public: body.public, public: body.public,
category: body.category, category: body.category,
ai_enabled: body.ai_enabled,
}; };
let result = WsHandler::handle(&bus, user_id, msg) let result = WsHandler::handle(&bus, user_id, msg)
.await .await
@ -174,6 +176,7 @@ pub async fn room_update(
room_name: body.room_name.clone(), room_name: body.room_name.clone(),
public: body.public, public: body.public,
category: body.category, category: body.category,
ai_enabled: body.ai_enabled,
}; };
let result = WsHandler::handle(&bus, user_id, msg) let result = WsHandler::handle(&bus, user_id, msg)
.await .await

View File

@ -0,0 +1,27 @@
use actix_web::{HttpRequest, HttpResponse, web};
use channel::ChannelBus;
use channel::http::{WsHandler, WsInMessage};
use super::rest::{channel_err, extract_user, ok_json};
use crate::error::ApiError;
#[utoipa::path(
get,
path = "/api/v1/ws/users/summary/{username}",
responses((status = 200, description = "User summary")),
tag = "channel",
)]
pub async fn user_summary(
req: HttpRequest,
username: web::Path<String>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let _user_id = extract_user(&req)?;
let msg = WsInMessage::UserSummary {
username: username.into_inner(),
};
let result = WsHandler::handle(&bus, _user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}

View File

@ -41,7 +41,11 @@ pub async fn archive(
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let WkRepoPath { wk, repo } = path.into_inner(); let WkRepoPath { wk, repo } = path.into_inner();
match query.format.as_str() { match query.format.as_str() {
"zip" => ok_json(service.git_archive_zip(&session, &wk, &repo, None).await?), "zip" => {
_ => ok_json(service.git_archive_tar(&session, &wk, &repo, None).await?), ok_json(service.git_archive_zip(&session, &wk, &repo, None).await?)
}
_ => {
ok_json(service.git_archive_tar(&session, &wk, &repo, None).await?)
}
} }
} }

View File

@ -41,8 +41,13 @@ pub async fn blame_file(
(Some(start), Some(end)) => { (Some(start), Some(end)) => {
let data: dto::BlameFileResponseDto = service let data: dto::BlameFileResponseDto = service
.git_blame_hunk( .git_blame_hunk(
&session, &wk, &repo, query.path.clone(), &session,
query.rev.clone(), start, end, &wk,
&repo,
query.path.clone(),
query.rev.clone(),
start,
end,
) )
.await? .await?
.into(); .into();
@ -51,8 +56,12 @@ pub async fn blame_file(
_ => { _ => {
let data: dto::BlameFileResponseDto = service let data: dto::BlameFileResponseDto = service
.git_blame_file( .git_blame_file(
&session, &wk, &repo, query.path.clone(), &session,
query.rev.clone(), None, &wk,
&repo,
query.path.clone(),
query.rev.clone(),
None,
) )
.await? .await?
.into(); .into();

View File

@ -65,10 +65,8 @@ pub async fn list_branches(
return ok_json(data); return ok_json(data);
} }
if query.default_only { if query.default_only {
let data: dto::BranchHeadResponseDto = service let data: dto::BranchHeadResponseDto =
.git_branch_head(&session, &wk, &repo) service.git_branch_head(&session, &wk, &repo).await?.into();
.await?
.into();
return ok_json(data); return ok_json(data);
} }
let data: dto::BranchListResponseDto = service let data: dto::BranchListResponseDto = service
@ -180,7 +178,13 @@ pub async fn ahead_behind(
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let WkRepoBranchPath { wk, repo, name } = path.into_inner(); let WkRepoBranchPath { wk, repo, name } = path.into_inner();
let data: dto::BranchAheadBehindResponseDto = service let data: dto::BranchAheadBehindResponseDto = service
.git_branch_ahead_behind(&session, &wk, &repo, name, query.remote_branch.clone()) .git_branch_ahead_behind(
&session,
&wk,
&repo,
name,
query.remote_branch.clone(),
)
.await? .await?
.into(); .into();
ok_json(data) ok_json(data)

View File

@ -64,10 +64,8 @@ pub async fn list_commits(
return ok_json(data); return ok_json(data);
} }
if query.refs { if query.refs {
let data: dto::CommitRefsResponseDto = service let data: dto::CommitRefsResponseDto =
.git_commit_refs(&session, &wk, &repo) service.git_commit_refs(&session, &wk, &repo).await?.into();
.await?
.into();
return ok_json(data); return ok_json(data);
} }
if query.summary { if query.summary {
@ -98,7 +96,9 @@ pub async fn commit_history(
let WkRepoPath { wk, repo } = path.into_inner(); let WkRepoPath { wk, repo } = path.into_inner();
let data: dto::CommitHistoryResponseDto = service let data: dto::CommitHistoryResponseDto = service
.git_commit_history( .git_commit_history(
&session, &wk, &repo, &session,
&wk,
&repo,
query.limit.unwrap_or(20), query.limit.unwrap_or(20),
query.skip.unwrap_or(0), query.skip.unwrap_or(0),
query.sort.unwrap_or(0), query.sort.unwrap_or(0),

View File

@ -40,9 +40,13 @@ pub async fn list_statuses(
service: web::Data<AppService>, service: web::Data<AppService>,
path: web::Path<CommitShaPath>, path: web::Path<CommitShaPath>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
ok_json(service.git_commit_status_list_by_name( ok_json(
&session, &path.wk, &path.repo, &path.sha, service
).await?) .git_commit_status_list_by_name(
&session, &path.wk, &path.repo, &path.sha,
)
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
@ -56,9 +60,13 @@ pub async fn combined_status(
service: web::Data<AppService>, service: web::Data<AppService>,
path: web::Path<CommitShaPath>, path: web::Path<CommitShaPath>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
ok_json(service.git_commit_status_combined_by_name( ok_json(
&session, &path.wk, &path.repo, &path.sha, service
).await?) .git_commit_status_combined_by_name(
&session, &path.wk, &path.repo, &path.sha,
)
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
@ -74,8 +82,19 @@ pub async fn create_status(
path: web::Path<CommitShaPath>, path: web::Path<CommitShaPath>,
body: web::Json<CreateCommitStatus>, body: web::Json<CreateCommitStatus>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_created(service.git_commit_status_create_by_name( .user()
&session, user_id, &path.wk, &path.repo, &path.sha, body.into_inner(), .ok_or(ApiError(service::error::AppError::Unauthorized))?;
).await?) ok_created(
service
.git_commit_status_create_by_name(
&session,
user_id,
&path.wk,
&path.repo,
&path.sha,
body.into_inner(),
)
.await?,
)
} }

View File

@ -33,11 +33,15 @@ pub async fn compare(
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let (wk, repo_name, basehead) = path.into_inner(); let (wk, repo_name, basehead) = path.into_inner();
let (base, head) = basehead let (base, head) = basehead.split_once("...").ok_or_else(|| {
.split_once("...") ApiError(service::error::AppError::BadRequest(
.ok_or_else(|| ApiError(service::error::AppError::BadRequest(
"basehead must be in format 'base...head'".to_string(), "basehead must be in format 'base...head'".to_string(),
)))?; ))
})?;
ok_json(service.git_compare(&session, &wk, &repo_name, base, head).await?) ok_json(
service
.git_compare(&session, &wk, &repo_name, base, head)
.await?,
)
} }

View File

@ -34,9 +34,17 @@ pub async fn get_contents(
query: web::Query<ContentQuery>, query: web::Query<ContentQuery>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let (wk, repo_name, file_path) = info.into_inner(); let (wk, repo_name, file_path) = info.into_inner();
ok_json(service.git_contents_get_by_name( ok_json(
&session, &wk, &repo_name, &file_path, query.r#ref.as_deref(), service
).await?) .git_contents_get_by_name(
&session,
&wk,
&repo_name,
&file_path,
query.r#ref.as_deref(),
)
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
@ -53,9 +61,15 @@ pub async fn create_contents(
body: web::Json<CreateContent>, body: web::Json<CreateContent>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let (wk, repo_name, file_path) = info.into_inner(); let (wk, repo_name, file_path) = info.into_inner();
let resp = service.git_contents_create_by_name( let resp = service
&session, &wk, &repo_name, &file_path, body.into_inner(), .git_contents_create_by_name(
).await?; &session,
&wk,
&repo_name,
&file_path,
body.into_inner(),
)
.await?;
Ok(HttpResponse::Created().json(resp)) Ok(HttpResponse::Created().json(resp))
} }
@ -73,9 +87,17 @@ pub async fn update_contents(
body: web::Json<UpdateContent>, body: web::Json<UpdateContent>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let (wk, repo_name, file_path) = info.into_inner(); let (wk, repo_name, file_path) = info.into_inner();
ok_json(service.git_contents_update_by_name( ok_json(
&session, &wk, &repo_name, &file_path, body.into_inner(), service
).await?) .git_contents_update_by_name(
&session,
&wk,
&repo_name,
&file_path,
body.into_inner(),
)
.await?,
)
} }
#[derive(Deserialize, utoipa::IntoParams)] #[derive(Deserialize, utoipa::IntoParams)]
@ -98,8 +120,16 @@ pub async fn delete_contents(
query: web::Query<DeleteContentQuery>, query: web::Query<DeleteContentQuery>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let (wk, repo_name, file_path) = info.into_inner(); let (wk, repo_name, file_path) = info.into_inner();
service.git_contents_delete_by_name( service
&session, &wk, &repo_name, &file_path, &query.message, &query.sha, query.branch.as_deref(), .git_contents_delete_by_name(
).await?; &session,
&wk,
&repo_name,
&file_path,
&query.message,
&query.sha,
query.branch.as_deref(),
)
.await?;
Ok(HttpResponse::NoContent().finish()) Ok(HttpResponse::NoContent().finish())
} }

View File

@ -44,18 +44,34 @@ pub async fn diff(
query: web::Query<DiffQuery>, query: web::Query<DiffQuery>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let WkRepoPath { wk, repo } = path.into_inner(); let WkRepoPath { wk, repo } = path.into_inner();
if let (Some(old_tree), Some(new_tree)) = (&query.old_tree, &query.new_tree) { if let (Some(old_tree), Some(new_tree)) = (&query.old_tree, &query.new_tree)
{
let proto_resp = service let proto_resp = service
.git_diff_tree_to_tree(&session, &wk, &repo, old_tree.clone(), new_tree.clone(), None) .git_diff_tree_to_tree(
&session,
&wk,
&repo,
old_tree.clone(),
new_tree.clone(),
None,
)
.await?; .await?;
let data: dto::DiffResultDto = proto_resp.result.unwrap_or_default().into(); let data: dto::DiffResultDto =
proto_resp.result.unwrap_or_default().into();
return ok_json(data); return ok_json(data);
} }
if let Some(tree_oid) = &query.tree_oid { if let Some(tree_oid) = &query.tree_oid {
let proto_resp = service let proto_resp = service
.git_diff_index_to_tree(&session, &wk, &repo, tree_oid.clone(), None) .git_diff_index_to_tree(
&session,
&wk,
&repo,
tree_oid.clone(),
None,
)
.await?; .await?;
let data: dto::DiffResultDto = proto_resp.result.unwrap_or_default().into(); let data: dto::DiffResultDto =
proto_resp.result.unwrap_or_default().into();
return ok_json(data); return ok_json(data);
} }
let old_oid = query.old_oid.clone().unwrap_or_default(); let old_oid = query.old_oid.clone().unwrap_or_default();
@ -66,21 +82,29 @@ pub async fn diff(
let proto_resp = service let proto_resp = service
.git_diff_stats(&session, &wk, &repo, old_oid, new_oid, None) .git_diff_stats(&session, &wk, &repo, old_oid, new_oid, None)
.await?; .await?;
let data: dto::DiffStatsDto = proto_resp.result.and_then(|r| r.stats).unwrap_or_default().into(); let data: dto::DiffStatsDto = proto_resp
.result
.and_then(|r| r.stats)
.unwrap_or_default()
.into();
ok_json(data) ok_json(data)
} }
"side-by-side" => { "side-by-side" => {
let proto_resp = service let proto_resp = service
.git_diff_patch_side_by_side(&session, &wk, &repo, old_oid, new_oid, None) .git_diff_patch_side_by_side(
&session, &wk, &repo, old_oid, new_oid, None,
)
.await?; .await?;
let data: dto::SideBySideDiffResultDto = proto_resp.result.unwrap_or_default().into(); let data: dto::SideBySideDiffResultDto =
proto_resp.result.unwrap_or_default().into();
ok_json(data) ok_json(data)
} }
_ => { _ => {
let proto_resp = service let proto_resp = service
.git_diff_patch(&session, &wk, &repo, old_oid, new_oid, None) .git_diff_patch(&session, &wk, &repo, old_oid, new_oid, None)
.await?; .await?;
let data: dto::DiffResultDto = proto_resp.result.unwrap_or_default().into(); let data: dto::DiffResultDto =
proto_resp.result.unwrap_or_default().into();
ok_json(data) ok_json(data)
} }
} }

View File

@ -1,7 +1,7 @@
use base64::Engine; use base64::Engine;
use git::rpc::proto as p;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use utoipa::ToSchema; use utoipa::ToSchema;
use git::rpc::proto as p;
fn oid_val(oid: Option<p::ObjectId>) -> String { fn oid_val(oid: Option<p::ObjectId>) -> String {
oid.map(|o| o.value).unwrap_or_default() oid.map(|o| o.value).unwrap_or_default()
@ -430,7 +430,9 @@ impl From<p::BranchSummaryResponse> for BranchSummaryResponseDto {
impl From<p::BranchHeadResponse> for BranchHeadResponseDto { impl From<p::BranchHeadResponse> for BranchHeadResponseDto {
fn from(r: p::BranchHeadResponse) -> Self { fn from(r: p::BranchHeadResponse) -> Self {
BranchHeadResponseDto { head_name: r.head_name } BranchHeadResponseDto {
head_name: r.head_name,
}
} }
} }
@ -531,7 +533,9 @@ impl From<p::CommitRefsResponse> for CommitRefsResponseDto {
impl From<p::CommitPrefixResponse> for CommitPrefixResponseDto { impl From<p::CommitPrefixResponse> for CommitPrefixResponseDto {
fn from(r: p::CommitPrefixResponse) -> Self { fn from(r: p::CommitPrefixResponse) -> Self {
CommitPrefixResponseDto { oid: oid_opt(r.oid) } CommitPrefixResponseDto {
oid: oid_opt(r.oid),
}
} }
} }
@ -543,13 +547,17 @@ impl From<p::CommitExistsResponse> for CommitExistsResponseDto {
impl From<p::CherryPickResponse> for CherryPickResponseDto { impl From<p::CherryPickResponse> for CherryPickResponseDto {
fn from(r: p::CherryPickResponse) -> Self { fn from(r: p::CherryPickResponse) -> Self {
CherryPickResponseDto { oid: oid_opt(r.oid) } CherryPickResponseDto {
oid: oid_opt(r.oid),
}
} }
} }
impl From<p::CherryPickSequenceResponse> for CherryPickResponseDto { impl From<p::CherryPickSequenceResponse> for CherryPickResponseDto {
fn from(r: p::CherryPickSequenceResponse) -> Self { fn from(r: p::CherryPickSequenceResponse) -> Self {
CherryPickResponseDto { oid: oid_opt(r.oid) } CherryPickResponseDto {
oid: oid_opt(r.oid),
}
} }
} }
@ -646,7 +654,9 @@ impl From<p::BlobExistsResponse> for BlobExistsResponseDto {
impl From<p::BlobIsBinaryResponse> for BlobIsBinaryResponseDto { impl From<p::BlobIsBinaryResponse> for BlobIsBinaryResponseDto {
fn from(r: p::BlobIsBinaryResponse) -> Self { fn from(r: p::BlobIsBinaryResponse) -> Self {
BlobIsBinaryResponseDto { is_binary: r.is_binary } BlobIsBinaryResponseDto {
is_binary: r.is_binary,
}
} }
} }
@ -680,31 +690,41 @@ impl From<p::TagListResponse> for TagListResponseDto {
impl From<p::TagInfoResponse> for TagInfoResponseDto { impl From<p::TagInfoResponse> for TagInfoResponseDto {
fn from(r: p::TagInfoResponse) -> Self { fn from(r: p::TagInfoResponse) -> Self {
TagInfoResponseDto { tag: r.tag.map(Into::into) } TagInfoResponseDto {
tag: r.tag.map(Into::into),
}
} }
} }
impl From<p::TagSummary> for TagSummaryDto { impl From<p::TagSummary> for TagSummaryDto {
fn from(s: p::TagSummary) -> Self { fn from(s: p::TagSummary) -> Self {
TagSummaryDto { total_count: s.total_count } TagSummaryDto {
total_count: s.total_count,
}
} }
} }
impl From<p::TagSummaryResponse> for TagSummaryResponseDto { impl From<p::TagSummaryResponse> for TagSummaryResponseDto {
fn from(r: p::TagSummaryResponse) -> Self { fn from(r: p::TagSummaryResponse) -> Self {
TagSummaryResponseDto { summary: r.summary.map(Into::into) } TagSummaryResponseDto {
summary: r.summary.map(Into::into),
}
} }
} }
impl From<p::TagInitResponse> for TagInitResponseDto { impl From<p::TagInitResponse> for TagInitResponseDto {
fn from(r: p::TagInitResponse) -> Self { fn from(r: p::TagInitResponse) -> Self {
TagInitResponseDto { oid: oid_opt(r.oid) } TagInitResponseDto {
oid: oid_opt(r.oid),
}
} }
} }
impl From<p::TagUpdateMessageResponse> for TagUpdateMessageResponseDto { impl From<p::TagUpdateMessageResponse> for TagUpdateMessageResponseDto {
fn from(r: p::TagUpdateMessageResponse) -> Self { fn from(r: p::TagUpdateMessageResponse) -> Self {
TagUpdateMessageResponseDto { oid: oid_opt(r.oid) } TagUpdateMessageResponseDto {
oid: oid_opt(r.oid),
}
} }
} }
@ -843,10 +863,16 @@ impl From<p::DiffResult> for DiffResultDto {
impl From<p::SideBySideChangeType> for SideBySideChangeTypeDto { impl From<p::SideBySideChangeType> for SideBySideChangeTypeDto {
fn from(t: p::SideBySideChangeType) -> Self { fn from(t: p::SideBySideChangeType) -> Self {
match t { match t {
p::SideBySideChangeType::Unchanged => SideBySideChangeTypeDto::Unchanged, p::SideBySideChangeType::Unchanged => {
SideBySideChangeTypeDto::Unchanged
}
p::SideBySideChangeType::Added => SideBySideChangeTypeDto::Added, p::SideBySideChangeType::Added => SideBySideChangeTypeDto::Added,
p::SideBySideChangeType::Removed => SideBySideChangeTypeDto::Removed, p::SideBySideChangeType::Removed => {
p::SideBySideChangeType::Modified => SideBySideChangeTypeDto::Modified, SideBySideChangeTypeDto::Removed
}
p::SideBySideChangeType::Modified => {
SideBySideChangeTypeDto::Modified
}
p::SideBySideChangeType::Empty => SideBySideChangeTypeDto::Empty, p::SideBySideChangeType::Empty => SideBySideChangeTypeDto::Empty,
} }
} }

View File

@ -1,6 +1,9 @@
use actix_web::{HttpResponse, web}; use actix_web::{HttpResponse, web};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use service::{AppService, git::init::{CloneRepo, CreateRepo}}; use service::{
AppService,
git::init::{CloneRepo, CreateRepo},
};
use session::Session; use session::Session;
use crate::error::ApiError; use crate::error::ApiError;

View File

@ -31,8 +31,7 @@ pub fn configure(cfg: &mut ServiceConfig) {
.route(web::get().to(repo::list_repos)), .route(web::get().to(repo::list_repos)),
); );
cfg.service( cfg.service(
web::resource("/clone") web::resource("/clone").route(web::post().to(init::clone_repo)),
.route(web::post().to(init::clone_repo)),
); );
cfg.service( cfg.service(
web::resource("/{repo}") web::resource("/{repo}")
@ -132,8 +131,7 @@ pub fn configure(cfg: &mut ServiceConfig) {
.route(web::get().to(blob::blob_info)), .route(web::get().to(blob::blob_info)),
) )
.service( .service(
web::resource("/blame") web::resource("/blame").route(web::get().to(blame::blame_file)),
.route(web::get().to(blame::blame_file)),
) )
.service( .service(
web::resource("/trees/{oid}") web::resource("/trees/{oid}")
@ -147,10 +145,7 @@ pub fn configure(cfg: &mut ServiceConfig) {
web::resource("/commits/{oid}/tree") web::resource("/commits/{oid}/tree")
.route(web::get().to(tree::tree_entry_by_path_from_commit)), .route(web::get().to(tree::tree_entry_by_path_from_commit)),
) )
.service( .service(web::resource("/diff").route(web::get().to(diff::diff)))
web::resource("/diff")
.route(web::get().to(diff::diff)),
)
.service( .service(
web::resource("/diff/branches") web::resource("/diff/branches")
.route(web::get().to(readme::diff_branches)), .route(web::get().to(readme::diff_branches)),
@ -195,8 +190,7 @@ pub fn configure(cfg: &mut ServiceConfig) {
.route(web::get().to(readme::get_readme)), .route(web::get().to(readme::get_readme)),
) )
.service( .service(
web::resource("/refs") web::resource("/refs").route(web::get().to(refs::list_refs)),
.route(web::get().to(refs::list_refs)),
), ),
); );
cfg.service( cfg.service(

View File

@ -34,8 +34,14 @@ pub async fn list_refs(
query: web::Query<RefQuery>, query: web::Query<RefQuery>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
if let Some(ref_name) = &query.r#ref { if let Some(ref_name) = &query.r#ref {
let r = service.git_ref_get_by_name(&session, &path.wk, &path.repo, ref_name).await?; let r = service
.git_ref_get_by_name(&session, &path.wk, &path.repo, ref_name)
.await?;
return ok_json(vec![r]); return ok_json(vec![r]);
} }
ok_json(service.git_ref_list_by_name(&session, &path.wk, &path.repo).await?) ok_json(
service
.git_ref_list_by_name(&session, &path.wk, &path.repo)
.await?,
)
} }

View File

@ -43,8 +43,14 @@ pub async fn list_releases(
service: web::Data<AppService>, service: web::Data<AppService>,
path: web::Path<WkRepoPath>, path: web::Path<WkRepoPath>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_json(service.git_release_list_by_name(&session, user_id, &path.wk, &path.repo).await?) .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.git_release_list_by_name(&session, user_id, &path.wk, &path.repo)
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
@ -58,8 +64,16 @@ pub async fn get_release(
service: web::Data<AppService>, service: web::Data<AppService>,
path: web::Path<ReleaseIdPath>, path: web::Path<ReleaseIdPath>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_json(service.git_release_get_by_name(&session, user_id, &path.wk, &path.repo, path.id).await?) .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.git_release_get_by_name(
&session, user_id, &path.wk, &path.repo, path.id,
)
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
@ -74,8 +88,16 @@ pub async fn get_release_by_tag(
path: web::Path<(String, String, String)>, path: web::Path<(String, String, String)>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let (wk, repo_name, tag) = path.into_inner(); let (wk, repo_name, tag) = path.into_inner();
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_json(service.git_release_get_by_tag_name(&session, user_id, &wk, &repo_name, &tag).await?) .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.git_release_get_by_tag_name(
&session, user_id, &wk, &repo_name, &tag,
)
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
@ -91,8 +113,20 @@ pub async fn create_release(
path: web::Path<WkRepoPath>, path: web::Path<WkRepoPath>,
body: web::Json<CreateRelease>, body: web::Json<CreateRelease>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_created(service.git_release_create_by_name(&session, user_id, &path.wk, &path.repo, body.into_inner()).await?) .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_created(
service
.git_release_create_by_name(
&session,
user_id,
&path.wk,
&path.repo,
body.into_inner(),
)
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
@ -108,8 +142,21 @@ pub async fn update_release(
path: web::Path<ReleaseIdPath>, path: web::Path<ReleaseIdPath>,
body: web::Json<UpdateRelease>, body: web::Json<UpdateRelease>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
ok_json(service.git_release_update_by_name(&session, user_id, &path.wk, &path.repo, path.id, body.into_inner()).await?) .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.git_release_update_by_name(
&session,
user_id,
&path.wk,
&path.repo,
path.id,
body.into_inner(),
)
.await?,
)
} }
#[utoipa::path( #[utoipa::path(
@ -123,8 +170,14 @@ pub async fn delete_release(
service: web::Data<AppService>, service: web::Data<AppService>,
path: web::Path<ReleaseIdPath>, path: web::Path<ReleaseIdPath>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
service.git_release_delete_by_name(&session, user_id, &path.wk, &path.repo, path.id).await?; .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
service
.git_release_delete_by_name(
&session, user_id, &path.wk, &path.repo, path.id,
)
.await?;
ok_empty() ok_empty()
} }
@ -140,7 +193,13 @@ pub async fn delete_release_by_tag(
path: web::Path<(String, String, String)>, path: web::Path<(String, String, String)>,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let (wk, repo_name, tag) = path.into_inner(); let (wk, repo_name, tag) = path.into_inner();
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; let user_id = session
service.git_release_delete_by_tag_name(&session, user_id, &wk, &repo_name, &tag).await?; .user()
.ok_or(ApiError(service::error::AppError::Unauthorized))?;
service
.git_release_delete_by_tag_name(
&session, user_id, &wk, &repo_name, &tag,
)
.await?;
ok_empty() ok_empty()
} }

View File

@ -54,10 +54,8 @@ pub async fn list_tags(
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let WkRepoPath { wk, repo } = path.into_inner(); let WkRepoPath { wk, repo } = path.into_inner();
if query.summary { if query.summary {
let data: dto::TagSummaryResponseDto = service let data: dto::TagSummaryResponseDto =
.git_tag_summary(&session, &wk, &repo) service.git_tag_summary(&session, &wk, &repo).await?.into();
.await?
.into();
return ok_json(data); return ok_json(data);
} }
let data: dto::TagListResponseDto = service let data: dto::TagListResponseDto = service
@ -117,9 +115,7 @@ pub async fn delete_tag(
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let WkRepoTagPath { wk, repo, name } = path.into_inner(); let WkRepoTagPath { wk, repo, name } = path.into_inner();
let params = git::rpc::proto::TagDeleteParams { name }; let params = git::rpc::proto::TagDeleteParams { name };
let _ = service let _ = service.git_tag_delete(&session, &wk, &repo, params).await?;
.git_tag_delete(&session, &wk, &repo, params)
.await?;
ok_json(serde_json::json!({})) ok_json(serde_json::json!({}))
} }
#[utoipa::path( #[utoipa::path(
@ -146,9 +142,7 @@ pub async fn update_tag(
new_name: new_name.to_string(), new_name: new_name.to_string(),
force: false, force: false,
}; };
let _ = service let _ = service.git_tag_rename(&session, &wk, &repo, params).await?;
.git_tag_rename(&session, &wk, &repo, params)
.await?;
return ok_json(serde_json::json!({})); return ok_json(serde_json::json!({}));
} }
let message = body let message = body

View File

@ -86,7 +86,13 @@ pub async fn tree_entry_by_path(
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let WkRepoTreeSubPath { wk, repo, tree_oid } = path.into_inner(); let WkRepoTreeSubPath { wk, repo, tree_oid } = path.into_inner();
let data: dto::TreeEntryByPathResponseDto = service let data: dto::TreeEntryByPathResponseDto = service
.git_tree_entry_by_path(&session, &wk, &repo, tree_oid, query.path.clone()) .git_tree_entry_by_path(
&session,
&wk,
&repo,
tree_oid,
query.path.clone(),
)
.await? .await?
.into(); .into();
ok_json(data) ok_json(data)
@ -112,7 +118,13 @@ pub async fn tree_entry_by_path_from_commit(
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let WkRepoCommitPath { wk, repo, oid } = path.into_inner(); let WkRepoCommitPath { wk, repo, oid } = path.into_inner();
let data: dto::TreeEntryByPathResponseDto = service let data: dto::TreeEntryByPathResponseDto = service
.git_tree_entry_by_path_from_commit(&session, &wk, &repo, oid, query.path.clone()) .git_tree_entry_by_path_from_commit(
&session,
&wk,
&repo,
oid,
query.path.clone(),
)
.await? .await?
.into(); .into();
ok_json(data) ok_json(data)

View File

@ -31,29 +31,28 @@ pub fn configure(cfg: &mut ServiceConfig, channel_bus: channel::ChannelBus) {
.service( .service(
web::scope("/repos") web::scope("/repos")
.configure(git::configure) .configure(git::configure)
.configure(pull_request::configure) .configure(pull_request::configure),
) )
.service( .service(
web::scope("/issues") web::scope("/issues")
.configure(issues::configure) .configure(issues::configure),
) )
.service( .service(
web::scope("/labels") web::scope("/labels")
.configure(issues::configure_labels) .configure(issues::configure_labels),
) )
.service( .service(
web::scope("/milestones") web::scope("/milestones")
.configure(issues::configure_milestones) .configure(issues::configure_milestones),
) ),
) ),
) )
.service( .service(
web::scope("/ws") web::scope("/ws")
.configure(|cfg| channel::configure(cfg, channel_bus)), .configure(|cfg| channel::configure(cfg, channel_bus)),
) )
.service( .service(
web::resource("/search") web::resource("/search").route(web::get().to(search::search)),
.route(web::get().to(search::search)), ),
)
); );
} }

View File

@ -297,11 +297,7 @@ use utoipa::openapi::security::{
crate::channel::rest_voice::voice_mute, crate::channel::rest_voice::voice_mute,
crate::channel::rest_voice::voice_deaf, crate::channel::rest_voice::voice_deaf,
crate::channel::rest_voice::screen_share, crate::channel::rest_voice::screen_share,
crate::channel::rest_ai::ai_list, crate::channel::rest_user::user_summary,
crate::channel::rest_ai::ai_add,
crate::channel::rest_ai::ai_remove,
crate::channel::rest_ai::ai_stop,
crate::channel::rest_ai::user_summary,
crate::search::search, crate::search::search,
), ),
modifiers(&SecurityAddon) modifiers(&SecurityAddon)

View File

@ -2,8 +2,8 @@ use actix_web::{HttpResponse, web};
use serde::Serialize; use serde::Serialize;
use utoipa::ToSchema; use utoipa::ToSchema;
use crate::error::ApiError;
use crate::channel::ChannelBus; use crate::channel::ChannelBus;
use crate::error::ApiError;
use service::AppService; use service::AppService;
use session::Session; use session::Session;
@ -177,14 +177,9 @@ async fn search_rooms(
user_id: uuid::Uuid, user_id: uuid::Uuid,
q: &str, q: &str,
) -> Result<SearchGroup<RoomHit>, ApiError> { ) -> Result<SearchGroup<RoomHit>, ApiError> {
let rooms = bus let rooms = bus.list_user_rooms(user_id).await.map_err(|e| {
.list_user_rooms(user_id) ApiError(service::error::AppError::InternalServerError(e.to_string()))
.await })?;
.map_err(|e| {
ApiError(service::error::AppError::InternalServerError(
e.to_string(),
))
})?;
let all: Vec<RoomHit> = rooms let all: Vec<RoomHit> = rooms
.into_iter() .into_iter()

View File

@ -2,7 +2,9 @@ use actix_web::{HttpRequest, HttpResponse, web};
use serde::Serialize; use serde::Serialize;
use service::{ use service::{
AppService, AppService,
user::profile::{AvatarUploadResponse, UpdateUserProfileConfig, UserProfileConfig}, user::profile::{
AvatarUploadResponse, UpdateUserProfileConfig, UserProfileConfig,
},
}; };
use session::Session; use session::Session;

View File

@ -106,12 +106,7 @@ pub async fn update_group(
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let GroupPath { wk, group_name } = path.into_inner(); let GroupPath { wk, group_name } = path.into_inner();
let data = service let data = service
.workspace_update_group( .workspace_update_group(&session, &wk, &group_name, params.into_inner())
&session,
&wk,
&group_name,
params.into_inner(),
)
.await?; .await?;
ok_json(data) ok_json(data)
} }

View File

@ -102,12 +102,7 @@ pub async fn update_member(
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let MemberPath { wk, username } = path.into_inner(); let MemberPath { wk, username } = path.into_inner();
let data = service let data = service
.workspace_update_member( .workspace_update_member(&session, &wk, &username, params.into_inner())
&session,
&wk,
&username,
params.into_inner(),
)
.await?; .await?;
ok_json(data) ok_json(data)
} }

View File

@ -6,12 +6,10 @@ pub mod workspace;
use actix_web::{web, web::ServiceConfig}; use actix_web::{web, web::ServiceConfig};
pub fn configure(cfg: &mut ServiceConfig) { pub fn configure(cfg: &mut ServiceConfig) {
cfg.service( cfg.service(
web::resource("") web::resource("").route(web::post().to(workspace::create_workspace)),
.route(web::post().to(workspace::create_workspace)),
); );
cfg.service( cfg.service(
web::resource("/my") web::resource("/my").route(web::get().to(workspace::my_workspaces)),
.route(web::get().to(workspace::my_workspaces)),
); );
cfg.service( cfg.service(
web::resource("/join/my-applies") web::resource("/join/my-applies")
@ -64,12 +62,10 @@ pub fn configure_wk(cfg: &mut ServiceConfig) {
.route(web::put().to(join::update_join_strategy)), .route(web::put().to(join::update_join_strategy)),
); );
cfg.service( cfg.service(
web::resource("/join/apply") web::resource("/join/apply").route(web::post().to(join::apply_join)),
.route(web::post().to(join::apply_join)),
); );
cfg.service( cfg.service(
web::resource("/join/cancel") web::resource("/join/cancel").route(web::post().to(join::cancel_join)),
.route(web::post().to(join::cancel_join)),
); );
cfg.service( cfg.service(
web::resource("/join/applies") web::resource("/join/applies")

2
lib/cache/local.rs vendored
View File

@ -26,7 +26,7 @@ impl Default for LocalCacheConfig {
#[derive(Clone)] #[derive(Clone)]
pub struct MokaCache { pub struct MokaCache {
pub(crate) inner: Cache<Arc<str>, Arc<[u8]>>, pub inner: Cache<Arc<str>, Arc<[u8]>>,
} }
impl MokaCache { impl MokaCache {

View File

@ -35,5 +35,6 @@ tokio = { workspace = true, features = ["sync", "time"] }
tokio-util = { workspace = true } tokio-util = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
uuid = { workspace = true, features = ["serde", "v7"] } uuid = { workspace = true, features = ["serde", "v7"] }
lazy_static = "1.5.0"
[lints] [lints]
workspace = true workspace = true

View File

@ -33,24 +33,31 @@ const ROOM_MESSAGE_EVENT: &str = "room.message";
#[derive(Clone)] #[derive(Clone)]
pub struct ChannelBus { pub struct ChannelBus {
pub(crate) inner: Arc<Inner>, pub inner: Arc<Inner>,
} }
pub(crate) struct Inner { pub struct Inner {
pub(crate) db: AppDatabase, pub db: AppDatabase,
pub(crate) cache: AppCache, pub cache: AppCache,
pub(crate) io: SocketIo, pub io: SocketIo,
pub(crate) config: ChannelBusConfig, pub config: ChannelBusConfig,
pub(crate) online: RwLock<HashMap<Uuid, HashMap<String, Socket>>>, pub online: RwLock<HashMap<Uuid, HashMap<String, Socket>>>,
pub(crate) user_sync_locks: DashMap<Uuid, Arc<Mutex<()>>>, pub user_sync_locks: DashMap<Uuid, Arc<Mutex<()>>>,
pub(crate) typing_states: DashMap<(Uuid, Uuid), (crate::event::UserInfo, crate::event::RoomInfo, tokio_util::sync::CancellationToken)>, pub typing_states: DashMap<
pub(crate) seq: SeqAllocator, (Uuid, Uuid),
pub(crate) dedup: DeduplicationManager, (
pub(crate) metrics: ChannelMetrics, crate::event::UserInfo,
pub(crate) reconnect: ReconnectManager, crate::event::RoomInfo,
pub(crate) rate_limiter: RateLimiter, tokio_util::sync::CancellationToken,
pub(crate) csrf: CsrfProtection, ),
pub(crate) circuit_breaker: CircuitBreaker, >,
pub seq: SeqAllocator,
pub dedup: DeduplicationManager,
pub metrics: ChannelMetrics,
pub reconnect: ReconnectManager,
pub rate_limiter: RateLimiter,
pub csrf: CsrfProtection,
pub circuit_breaker: CircuitBreaker,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -79,7 +86,7 @@ impl ChannelBus {
&self, &self,
room: Uuid, room: Uuid,
) -> ChannelResult<crate::event::RoomInfo> { ) -> ChannelResult<crate::event::RoomInfo> {
let row = db::sqlx::query_as::<_, (String,)>( let row = db::sqlx::query_as::<_, (String,)>(
"SELECT name FROM room WHERE id = $1", "SELECT name FROM room WHERE id = $1",
) )
.bind(room) .bind(room)
@ -132,7 +139,8 @@ impl ChannelBus {
pub async fn lookup_users( pub async fn lookup_users(
&self, &self,
users: &[Uuid], users: &[Uuid],
) -> ChannelResult<std::collections::HashMap<Uuid, crate::event::UserInfo>> { ) -> ChannelResult<std::collections::HashMap<Uuid, crate::event::UserInfo>>
{
if users.is_empty() { if users.is_empty() {
return Ok(std::collections::HashMap::new()); return Ok(std::collections::HashMap::new());
} }
@ -585,7 +593,9 @@ impl ChannelBus {
Err(_) => None, Err(_) => None,
}; };
let event = match sender { let event = match sender {
Some(s) => ChannelEvent::message_created_with_sender(message, s), Some(s) => {
ChannelEvent::message_created_with_sender(message, s)
}
None => ChannelEvent::message_created(message), None => ChannelEvent::message_created(message),
}; };
socket.emit(ROOM_MESSAGE_EVENT, event).await?; socket.emit(ROOM_MESSAGE_EVENT, event).await?;

View File

@ -71,17 +71,15 @@ impl CircuitBreaker {
let slot_reserved = { let slot_reserved = {
let mut state = self.inner.state.lock().await; let mut state = self.inner.state.lock().await;
match state.status { match state.status {
STATUS_OPEN => { STATUS_OPEN => match state.last_failure_time {
match state.last_failure_time { Some(t) if t.elapsed() > self.inner.config.timeout => {
Some(t) if t.elapsed() > self.inner.config.timeout => { state.status = STATUS_HALF_OPEN;
state.status = STATUS_HALF_OPEN; state.half_open_calls = 1;
state.half_open_calls = 1; state.success_count = 0;
state.success_count = 0; true
true
}
_ => false,
} }
} _ => false,
},
STATUS_HALF_OPEN => { STATUS_HALF_OPEN => {
if state.half_open_calls if state.half_open_calls
< self.inner.config.half_open_max_calls < self.inner.config.half_open_max_calls

View File

@ -1,44 +0,0 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::event::{AgentInfo, RoomInfo};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiAgentJoinedService {
pub room: RoomInfo,
pub agent: AgentInfo,
pub joined_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiAgentLeftService {
pub room: RoomInfo,
pub agent: AgentInfo,
pub left_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoomAiEntry {
pub agent_session: Uuid,
pub name: String,
pub agent_kind: String,
pub model_version: Option<Uuid>,
pub enabled: bool,
pub auto_reply: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoomAiListService {
pub room: RoomInfo,
pub agents: Vec<RoomAiEntry>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiAgentStatusChangedService {
pub room: RoomInfo,
pub agent: AgentInfo,
pub old_status: String,
pub new_status: String,
pub changed_at: DateTime<Utc>,
}

View File

@ -1,6 +1,6 @@
use crate::event::{RoomInfo, UserInfo};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use uuid::Uuid; use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};

View File

@ -1,6 +1,6 @@
use crate::event::{UserInfo, WorkspaceInfo};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use uuid::Uuid; use uuid::Uuid;
use crate::event::{UserInfo, WorkspaceInfo};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};

View File

@ -72,21 +72,3 @@ impl WorkspaceInfo {
} }
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentInfo {
pub id: Uuid,
pub name: String,
pub agent_type: String,
pub model_name: Option<String>,
}
impl AgentInfo {
pub fn unknown(id: Uuid) -> Self {
Self {
id,
name: String::new(),
agent_type: String::new(),
model_name: None,
}
}
}

View File

@ -1,66 +0,0 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DmEventType {
Created,
Closed,
Reopened,
}
impl DmEventType {
pub fn as_str(&self) -> &str {
match self {
Self::Created => "dm.created",
Self::Closed => "dm.closed",
Self::Reopened => "dm.reopened",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum DmEvent {
#[serde(rename = "dm.created")]
Created(DmCreatedService),
#[serde(rename = "dm.closed")]
Closed(DmClosedService),
#[serde(rename = "dm.reopened")]
Reopened(DmReopenedService),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DmCreatedService {
pub room: RoomInfo,
pub initiator: UserInfo,
pub recipient: UserInfo,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DmClosedService {
pub room: RoomInfo,
pub closed_by: UserInfo,
pub closed_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DmReopenedService {
pub room: RoomInfo,
pub reopened_by: UserInfo,
pub reopened_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DmCreateClient {
pub recipient: Uuid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DmCloseClient {
pub room: Uuid,
}

View File

@ -1,10 +1,8 @@
pub mod ai;
pub mod attachment; pub mod attachment;
pub mod ban; pub mod ban;
pub mod category; pub mod category;
pub mod common; pub mod common;
pub mod conversation; pub mod conversation;
pub mod dm;
pub mod draft; pub mod draft;
pub mod forward; pub mod forward;
pub mod invite; pub mod invite;
@ -22,7 +20,7 @@ pub mod thread;
pub mod voice; pub mod voice;
pub mod workspace; pub mod workspace;
pub use common::{AgentInfo, RoomInfo, UserInfo, WorkspaceInfo}; pub use common::{RoomInfo, UserInfo, WorkspaceInfo};
use model::room::RoomMessageModel; use model::room::RoomMessageModel;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -37,8 +35,6 @@ pub enum ChannelEventType {
ReactionCreated, ReactionCreated,
ReactionDeleted, ReactionDeleted,
MessageRead, MessageRead,
DmCreated,
DmClosed,
ConversationUpdated, ConversationUpdated,
Custom(String), Custom(String),
} }
@ -52,8 +48,6 @@ impl ChannelEventType {
Self::ReactionCreated => "reaction.created", Self::ReactionCreated => "reaction.created",
Self::ReactionDeleted => "reaction.deleted", Self::ReactionDeleted => "reaction.deleted",
Self::MessageRead => "message.read", Self::MessageRead => "message.read",
Self::DmCreated => "dm.created",
Self::DmClosed => "dm.closed",
Self::ConversationUpdated => "conversation.updated", Self::ConversationUpdated => "conversation.updated",
Self::Custom(value) => value, Self::Custom(value) => value,
} }

View File

@ -1,6 +1,6 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use uuid::Uuid;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo}; use crate::event::{RoomInfo, UserInfo};

View File

@ -1,6 +1,6 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use uuid::Uuid;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo}; use crate::event::{RoomInfo, UserInfo};

View File

@ -13,7 +13,6 @@ pub enum RoomEventType {
TopicUpdated, TopicUpdated,
SettingsUpdated, SettingsUpdated,
Moved, Moved,
AiUpdated,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -27,8 +26,6 @@ pub enum RoomEvent {
Renamed(RoomRenamedService), Renamed(RoomRenamedService),
#[serde(rename = "room.moved")] #[serde(rename = "room.moved")]
Moved(RoomMovedService), Moved(RoomMovedService),
#[serde(rename = "room.ai_updated")]
AiUpdated(RoomAiUpdatedService),
#[serde(rename = "room.topic_updated")] #[serde(rename = "room.topic_updated")]
TopicUpdated(RoomTopicUpdatedService), TopicUpdated(RoomTopicUpdatedService),
#[serde(rename = "room.settings_updated")] #[serde(rename = "room.settings_updated")]
@ -73,18 +70,6 @@ pub struct RoomMovedService {
pub moved_at: DateTime<Utc>, pub moved_at: DateTime<Utc>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoomAiUpdatedService {
pub room: RoomInfo,
pub workspace: WorkspaceInfo,
pub model: Uuid,
pub model_name: String,
pub version: i64,
pub agent_type: String,
pub updated_by: UserInfo,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoomTopicUpdatedService { pub struct RoomTopicUpdatedService {
pub room: RoomInfo, pub room: RoomInfo,
@ -112,6 +97,7 @@ pub struct RoomCreateClient {
pub room_name: String, pub room_name: String,
pub public: bool, pub public: bool,
pub category: Option<Uuid>, pub category: Option<Uuid>,
pub ai_enabled: Option<bool>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -123,6 +109,7 @@ pub struct RoomUpdateClient {
pub slowmode_seconds: Option<i32>, pub slowmode_seconds: Option<i32>,
pub nsfw: Option<bool>, pub nsfw: Option<bool>,
pub default_auto_archive_duration: Option<i32>, pub default_auto_archive_duration: Option<i32>,
pub ai_enabled: Option<bool>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]

View File

@ -1,6 +1,6 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use uuid::Uuid;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, message::MessageNewService}; use crate::event::{RoomInfo, UserInfo, message::MessageNewService};

View File

@ -1,159 +0,0 @@
use chrono::Utc;
use uuid::Uuid;
use crate::event::{AgentInfo, RoomInfo, ai};
use crate::{ChannelBus, ChannelError, ChannelResult};
use super::WsOutEvent;
use super::WsHandler;
impl WsHandler {
pub(super) async fn ai_list(
bus: &ChannelBus,
user_id: Uuid,
room: Uuid,
) -> ChannelResult<Option<WsOutEvent>> {
Self::ensure_room_access(bus, user_id, room).await?;
let rows = db::sqlx::query_as::<_, (Uuid, Option<String>, Option<String>, Option<Uuid>, bool, bool)>(
"SELECT ra.agent_session, s.name, s.agent_kind, s.model_version, ra.enabled, ra.auto_reply \
FROM room_ai ra \
LEFT JOIN agent_session s ON s.id = ra.agent_session AND s.deleted_at IS NULL \
WHERE ra.room = $1",
)
.bind(room)
.fetch_all(bus.inner.db.reader())
.await?;
let agents = rows
.into_iter()
.filter_map(|(agent_session, name, agent_kind, model_version, enabled, auto_reply)| {
name.map(|n| ai::RoomAiEntry {
agent_session,
name: n,
agent_kind: agent_kind.unwrap_or_default(),
model_version,
enabled,
auto_reply,
})
})
.collect();
let ai_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
Ok(Some(WsOutEvent::AiAgentList {
room: ai_room.clone(),
data: ai::RoomAiListService {
room: ai_room,
agents,
},
}))
}
pub(super) async fn ai_upsert(
bus: &ChannelBus,
user_id: Uuid,
room: Uuid,
model: Uuid,
) -> ChannelResult<Option<WsOutEvent>> {
Self::ensure_room_access(bus, user_id, room).await?;
let session = db::sqlx::query_as::<_, model::agent::AgentSessionModel>(
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
system_prompt, temperature, max_output_tokens, enabled, created_by, \
created_at, updated_at, deleted_at \
FROM agent_session WHERE id = $1 AND deleted_at IS NULL",
)
.bind(model)
.fetch_one(bus.inner.db.reader())
.await
.map_err(|e| match e {
db::sqlx::Error::RowNotFound => ChannelError::RoomNotFound,
other => ChannelError::Database(other),
})?;
db::sqlx::query_as::<_, model::room::RoomAiModel>(
"INSERT INTO room_ai (room, agent_session, enabled, auto_reply, created_by, created_at, updated_at) \
VALUES ($1, $2, true, false, $3, now(), now()) \
ON CONFLICT (room, agent_session) DO UPDATE SET enabled = true, updated_at = now() \
RETURNING room, agent_session, enabled, auto_reply, created_by, created_at, updated_at",
)
.bind(room)
.bind(model)
.bind(user_id)
.fetch_one(bus.inner.db.writer())
.await?;
let ai_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
let data = ai::AiAgentJoinedService {
room: ai_room,
agent: AgentInfo {
id: model,
name: session.name.clone(),
agent_type: session.agent_kind.clone(),
model_name: None,
},
joined_at: Utc::now(),
};
bus.publish_room_event(room, "ai.agent_joined", &data)
.await?;
Ok(Some(WsOutEvent::AiAgentJoined { room: data.room.clone(), data }))
}
pub(super) async fn ai_delete(
bus: &ChannelBus,
user_id: Uuid,
room: Uuid,
agent_id: Uuid,
) -> ChannelResult<Option<WsOutEvent>> {
Self::ensure_room_access(bus, user_id, room).await?;
let session = db::sqlx::query_as::<_, model::agent::AgentSessionModel>(
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
system_prompt, temperature, max_output_tokens, enabled, created_by, \
created_at, updated_at, deleted_at \
FROM agent_session WHERE id = $1 AND deleted_at IS NULL",
)
.bind(agent_id)
.fetch_optional(bus.inner.db.reader())
.await?;
let result = db::sqlx::query(
"DELETE FROM room_ai WHERE room = $1 AND agent_session = $2",
)
.bind(room)
.bind(agent_id)
.execute(bus.inner.db.writer())
.await?;
if result.rows_affected() == 0 {
return Err(ChannelError::RoomNotFound);
}
let ai_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
let agent_info = session.map(|s| AgentInfo {
id: s.id,
name: s.name,
agent_type: s.agent_kind,
model_name: None,
}).unwrap_or_else(|| AgentInfo::unknown(agent_id));
let data = ai::AiAgentLeftService {
room: ai_room,
agent: agent_info,
left_at: Utc::now(),
};
bus.publish_room_event(room, "ai.agent_left", &data).await?;
Ok(Some(WsOutEvent::AiAgentLeft { room: data.room.clone(), data }))
}
pub(super) async fn ai_stop(
bus: &ChannelBus,
user_id: Uuid,
room: Uuid,
) -> ChannelResult<Option<WsOutEvent>> {
Self::ensure_room_access(bus, user_id, room).await?;
bus.publish_room_event(
room,
"ai.stop",
&serde_json::json!({"stopped_by": user_id}),
)
.await?;
Ok(None)
}
}

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use crate::event::{UserInfo, WorkspaceInfo, ban}; use crate::event::{UserInfo, WorkspaceInfo, ban};
use crate::{ChannelBus, ChannelResult}; use crate::{ChannelBus, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn ban_create( pub(super) async fn ban_create(
@ -36,9 +36,18 @@ impl WsHandler {
}); });
bus.inner.cache.set(&ban_key, &ban_data).await?; bus.inner.cache.set(&ban_key, &ban_data).await?;
let data = ban::BannedService { let data = ban::BannedService {
workspace: bus.lookup_workspace(workspace).await.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), workspace: bus
user: bus.lookup_user(user).await.unwrap_or_else(|_| UserInfo::unknown(user)), .lookup_workspace(workspace)
banned_by: bus.lookup_user(_user_id).await.unwrap_or_else(|_| UserInfo::unknown(_user_id)), .await
.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)),
user: bus
.lookup_user(user)
.await
.unwrap_or_else(|_| UserInfo::unknown(user)),
banned_by: bus
.lookup_user(_user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(_user_id)),
reason, reason,
expires_at: _expires_at, expires_at: _expires_at,
banned_at: Utc::now(), banned_at: Utc::now(),
@ -64,9 +73,18 @@ impl WsHandler {
let ban_key = format!("ban:{}:{}:{}", workspace, _user_id, user); let ban_key = format!("ban:{}:{}:{}", workspace, _user_id, user);
bus.inner.cache.remove(&ban_key).await?; bus.inner.cache.remove(&ban_key).await?;
let data = ban::UnbannedService { let data = ban::UnbannedService {
workspace: bus.lookup_workspace(workspace).await.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), workspace: bus
user: bus.lookup_user(user).await.unwrap_or_else(|_| UserInfo::unknown(user)), .lookup_workspace(workspace)
unbanned_by: bus.lookup_user(_user_id).await.unwrap_or_else(|_| UserInfo::unknown(_user_id)), .await
.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)),
user: bus
.lookup_user(user)
.await
.unwrap_or_else(|_| UserInfo::unknown(user)),
unbanned_by: bus
.lookup_user(_user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(_user_id)),
unbanned_at: Utc::now(), unbanned_at: Utc::now(),
}; };
bus.workspace_changed(workspace).await?; bus.workspace_changed(workspace).await?;

View File

@ -5,8 +5,8 @@ use crate::event::{UserInfo, WorkspaceInfo, category};
use crate::{ChannelBus, ChannelError, ChannelResult}; use crate::{ChannelBus, ChannelError, ChannelResult};
use super::MAX_CATEGORY_NAME_LEN; use super::MAX_CATEGORY_NAME_LEN;
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn category_create( pub(super) async fn category_create(
@ -17,7 +17,9 @@ impl WsHandler {
position: Option<i32>, position: Option<i32>,
) -> ChannelResult<Option<WsOutEvent>> { ) -> ChannelResult<Option<WsOutEvent>> {
if name.is_empty() || name.len() > MAX_CATEGORY_NAME_LEN { if name.is_empty() || name.len() > MAX_CATEGORY_NAME_LEN {
return Err(ChannelError::Validation("invalid category name".into())); return Err(ChannelError::Validation(
"invalid category name".into(),
));
} }
Self::ensure_workspace_member(bus, user_id, workspace).await?; Self::ensure_workspace_member(bus, user_id, workspace).await?;
let row = db::sqlx::query_as::<_, model::room::RoomCategoryModel>( let row = db::sqlx::query_as::<_, model::room::RoomCategoryModel>(
@ -95,7 +97,10 @@ impl WsHandler {
updated_at: Utc::now(), updated_at: Utc::now(),
}; };
bus.workspace_changed(old.wk).await?; bus.workspace_changed(old.wk).await?;
Ok(Some(WsOutEvent::CategoryUpdated { workspace: data.project.clone(), data })) Ok(Some(WsOutEvent::CategoryUpdated {
workspace: data.project.clone(),
data,
}))
} }
pub(super) async fn category_delete( pub(super) async fn category_delete(
@ -133,6 +138,9 @@ impl WsHandler {
deleted_at: Utc::now(), deleted_at: Utc::now(),
}; };
bus.workspace_changed(row.wk).await?; bus.workspace_changed(row.wk).await?;
Ok(Some(WsOutEvent::CategoryDeleted { workspace: cd_workspace, data })) Ok(Some(WsOutEvent::CategoryDeleted {
workspace: cd_workspace,
data,
}))
} }
} }

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, conversation}; use crate::event::{RoomInfo, UserInfo, conversation};
use crate::{ChannelBus, ChannelResult}; use crate::{ChannelBus, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn conversation_pin( pub(super) async fn conversation_pin(
@ -29,10 +29,14 @@ impl WsHandler {
.execute(bus.inner.db.writer()) .execute(bus.inner.db.writer())
.await?; .await?;
let room_info = let room_info = bus
bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); .lookup_room(room)
let user_info = .await
bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .unwrap_or_else(|_| RoomInfo::unknown(room));
let user_info = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
if pin { if pin {
let data = conversation::ConversationPinnedService { let data = conversation::ConversationPinnedService {
@ -40,7 +44,8 @@ impl WsHandler {
room: room_info.clone(), room: room_info.clone(),
pinned_at: now, pinned_at: now,
}; };
bus.emit_to_user(user_id, "conversation.pinned", &data).await?; bus.emit_to_user(user_id, "conversation.pinned", &data)
.await?;
Ok(Some(WsOutEvent::ConversationPinned { Ok(Some(WsOutEvent::ConversationPinned {
room: room_info, room: room_info,
data, data,
@ -51,7 +56,8 @@ impl WsHandler {
room: room_info.clone(), room: room_info.clone(),
unpinned_at: now, unpinned_at: now,
}; };
bus.emit_to_user(user_id, "conversation.unpinned", &data).await?; bus.emit_to_user(user_id, "conversation.unpinned", &data)
.await?;
Ok(Some(WsOutEvent::ConversationUnpinned { Ok(Some(WsOutEvent::ConversationUnpinned {
room: room_info, room: room_info,
data, data,
@ -80,10 +86,14 @@ impl WsHandler {
.execute(bus.inner.db.writer()) .execute(bus.inner.db.writer())
.await?; .await?;
let room_info = let room_info = bus
bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); .lookup_room(room)
let user_info = .await
bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .unwrap_or_else(|_| RoomInfo::unknown(room));
let user_info = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
if mute { if mute {
let data = conversation::ConversationMutedService { let data = conversation::ConversationMutedService {
@ -91,7 +101,8 @@ impl WsHandler {
room: room_info.clone(), room: room_info.clone(),
muted_at: now, muted_at: now,
}; };
bus.emit_to_user(user_id, "conversation.muted", &data).await?; bus.emit_to_user(user_id, "conversation.muted", &data)
.await?;
Ok(Some(WsOutEvent::ConversationMuted { Ok(Some(WsOutEvent::ConversationMuted {
room: room_info, room: room_info,
data, data,
@ -102,7 +113,8 @@ impl WsHandler {
room: room_info.clone(), room: room_info.clone(),
unmuted_at: now, unmuted_at: now,
}; };
bus.emit_to_user(user_id, "conversation.unmuted", &data).await?; bus.emit_to_user(user_id, "conversation.unmuted", &data)
.await?;
Ok(Some(WsOutEvent::ConversationUnmuted { Ok(Some(WsOutEvent::ConversationUnmuted {
room: room_info, room: room_info,
data, data,
@ -116,7 +128,8 @@ impl WsHandler {
notify_level: String, notify_level: String,
) -> ChannelResult<Option<WsOutEvent>> { ) -> ChannelResult<Option<WsOutEvent>> {
Self::ensure_room_access(bus, user_id, room).await?; Self::ensure_room_access(bus, user_id, room).await?;
let valid = matches!(notify_level.as_str(), "all" | "mentions" | "none"); let valid =
matches!(notify_level.as_str(), "all" | "mentions" | "none");
if !valid { if !valid {
return Err(crate::ChannelError::Internal( return Err(crate::ChannelError::Internal(
"notify_level must be 'all', 'mentions', or 'none'".to_string(), "notify_level must be 'all', 'mentions', or 'none'".to_string(),
@ -147,10 +160,14 @@ impl WsHandler {
.execute(bus.inner.db.writer()) .execute(bus.inner.db.writer())
.await?; .await?;
let room_info = let room_info = bus
bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); .lookup_room(room)
let user_info = .await
bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .unwrap_or_else(|_| RoomInfo::unknown(room));
let user_info = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = conversation::ConversationNotifyLevelChangedService { let data = conversation::ConversationNotifyLevelChangedService {
user: user_info, user: user_info,
@ -208,7 +225,16 @@ impl WsHandler {
let summaries: Vec<conversation::ConversationSummary> = rows let summaries: Vec<conversation::ConversationSummary> = rows
.into_iter() .into_iter()
.map( .map(
|(id, name, room_type, is_pinned, is_muted, notify_level, last_read_seq, max_seq)| { |(
id,
name,
room_type,
is_pinned,
is_muted,
notify_level,
last_read_seq,
max_seq,
)| {
let unread = (max_seq - last_read_seq).max(0); let unread = (max_seq - last_read_seq).max(0);
conversation::ConversationSummary { conversation::ConversationSummary {
room: id, room: id,

View File

@ -1,248 +0,0 @@
use chrono::Utc;
use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, dm};
use crate::{ChannelBus, ChannelResult};
use super::WsOutEvent;
use super::WsHandler;
impl WsHandler {
pub(super) async fn dm_create(
bus: &ChannelBus,
user_id: Uuid,
recipient: Uuid,
) -> ChannelResult<Option<WsOutEvent>> {
if user_id == recipient {
return Err(crate::ChannelError::Internal(
"cannot create DM with yourself".to_string(),
));
}
let recipient_exists: Option<(Uuid,)> = db::sqlx::query_as(
"SELECT id FROM \"user\" WHERE id = $1",
)
.bind(recipient)
.fetch_optional(bus.inner.db.reader())
.await?;
if recipient_exists.is_none() {
return Err(crate::ChannelError::UserNotFound);
}
let (initiator, other) = if user_id < recipient {
(user_id, recipient)
} else {
(recipient, user_id)
};
let existing: Option<(Uuid, Uuid, bool)> = db::sqlx::query_as(
"SELECT room, initiator, is_closed FROM dm_conversation \
WHERE initiator = $1 AND recipient = $2",
)
.bind(initiator)
.bind(other)
.fetch_optional(bus.inner.db.reader())
.await?;
let now = Utc::now();
let (room_id, is_reopen) = if let Some((room, _, is_closed)) = existing {
if is_closed {
db::sqlx::query(
"UPDATE dm_conversation SET is_closed = false, closed_at = NULL, \
updated_at = now() WHERE initiator = $1 AND recipient = $2",
)
.bind(initiator)
.bind(other)
.execute(bus.inner.db.writer())
.await?;
db::sqlx::query(
"UPDATE room SET is_archived = false, updated_at = now() WHERE id = $1",
)
.bind(room)
.execute(bus.inner.db.writer())
.await?;
(room, true)
} else {
(room, false)
}
} else {
let shared_wk: Option<(Uuid,)> = db::sqlx::query_as(
"SELECT wm1.wk FROM wk_member wm1 \
INNER JOIN wk_member wm2 ON wm2.wk = wm1.wk \
WHERE wm1.\"user\" = $1 AND wm1.leave_at IS NULL \
AND wm2.\"user\" = $2 AND wm2.leave_at IS NULL \
LIMIT 1",
)
.bind(user_id)
.bind(recipient)
.fetch_optional(bus.inner.db.reader())
.await?;
let wk = shared_wk.map(|r| r.0).unwrap_or_else(|| {
Uuid::nil()
});
let room_id = Uuid::new_v4();
db::sqlx::query(
"INSERT INTO room (id, wk, name, topic, room_type, position, is_private, \
created_by, created_at, updated_at) \
VALUES ($1, $2, $3, NULL, 'DM', 0, true, $4, now(), now())",
)
.bind(room_id)
.bind(wk)
.bind(format!("dm-{}", &room_id.to_string()[..8]))
.bind(user_id)
.execute(bus.inner.db.writer())
.await?;
db::sqlx::query(
"INSERT INTO dm_conversation (room, initiator, recipient, created_at, updated_at) \
VALUES ($1, $2, $3, now(), now()) \
ON CONFLICT (initiator, recipient) DO NOTHING",
)
.bind(room_id)
.bind(initiator)
.bind(other)
.execute(bus.inner.db.writer())
.await?;
for uid in &[user_id, recipient] {
db::sqlx::query(
"INSERT INTO room_permission_overwrite \
(room, target_type, target_id, allow_mask, deny_mask, created_at) \
VALUES ($1, 'user', $2, 0, 0, now()) \
ON CONFLICT DO NOTHING",
)
.bind(room_id)
.bind(uid)
.execute(bus.inner.db.writer())
.await?;
}
(room_id, false)
};
let _ = crate::rooms::refresh_user_rooms_cache(
&bus.inner.db,
&bus.inner.cache,
&bus.inner.config,
user_id,
)
.await;
let _ = crate::rooms::refresh_user_rooms_cache(
&bus.inner.db,
&bus.inner.cache,
&bus.inner.config,
recipient,
)
.await;
let room_info =
bus.lookup_room(room_id).await.unwrap_or_else(|_| RoomInfo::unknown(room_id));
let initiator_info =
bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id));
let recipient_info =
bus.lookup_user(recipient).await.unwrap_or_else(|_| UserInfo::unknown(recipient));
if is_reopen {
let data = dm::DmReopenedService {
room: room_info.clone(),
reopened_by: initiator_info,
reopened_at: now,
};
bus.emit_to_user(user_id, "dm.reopened", &data).await?;
bus.emit_to_user(recipient, "dm.reopened", &data).await?;
Ok(Some(WsOutEvent::DmReopened {
room: room_info,
data,
}))
} else {
let data = dm::DmCreatedService {
room: room_info.clone(),
initiator: initiator_info,
recipient: recipient_info,
created_at: now,
};
bus.emit_to_user(user_id, "dm.created", &data).await?;
bus.emit_to_user(recipient, "dm.created", &data).await?;
Ok(Some(WsOutEvent::DmCreated {
room: room_info,
data,
}))
}
}
pub(super) async fn dm_close(
bus: &ChannelBus,
user_id: Uuid,
room: Uuid,
) -> ChannelResult<Option<WsOutEvent>> {
let now = Utc::now();
let result = db::sqlx::query(
"UPDATE dm_conversation SET is_closed = true, closed_at = $1, updated_at = $1 \
WHERE room = $2 AND (initiator = $3 OR recipient = $3) AND is_closed = false",
)
.bind(now)
.bind(room)
.bind(user_id)
.execute(bus.inner.db.writer())
.await?;
if result.rows_affected() == 0 {
return Ok(None);
}
let room_info =
bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
let closed_by =
bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = dm::DmClosedService {
room: room_info.clone(),
closed_by,
closed_at: now,
};
bus.publish_room_event(room, "dm.closed", &data).await?;
Ok(Some(WsOutEvent::DmClosed {
room: room_info,
data,
}))
}
pub(super) async fn dm_list(
bus: &ChannelBus,
user_id: Uuid,
) -> ChannelResult<Option<WsOutEvent>> {
let rows = db::sqlx::query_as::<_, (Uuid, Uuid, Uuid, chrono::DateTime<Utc>)>(
"SELECT dc.room, dc.initiator, dc.recipient, dc.created_at \
FROM dm_conversation dc \
INNER JOIN room r ON r.id = dc.room \
WHERE (dc.initiator = $1 OR dc.recipient = $1) \
AND dc.is_closed = false \
AND r.deleted_at IS NULL \
ORDER BY dc.updated_at DESC",
)
.bind(user_id)
.fetch_all(bus.inner.db.reader())
.await?;
let mut results = Vec::with_capacity(rows.len());
for (room_id, initiator_id, recipient_id, created_at) in rows {
let room_info = bus
.lookup_room(room_id)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room_id));
let initiator_info = bus
.lookup_user(initiator_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(initiator_id));
let recipient_info = bus
.lookup_user(recipient_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(recipient_id));
results.push(dm::DmCreatedService {
room: room_info,
initiator: initiator_info,
recipient: recipient_info,
created_at,
});
}
Ok(Some(WsOutEvent::DmList { data: results }))
}
}

View File

@ -4,9 +4,9 @@ use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, draft}; use crate::event::{RoomInfo, UserInfo, draft};
use crate::{ChannelBus, ChannelError, ChannelResult}; use crate::{ChannelBus, ChannelError, ChannelResult};
use super::{MAX_TEXT_LEN}; use super::MAX_TEXT_LEN;
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn draft_save( pub(super) async fn draft_save(
@ -21,15 +21,24 @@ impl WsHandler {
} }
let key = format!("draft:{}:{}", user_id, room); let key = format!("draft:{}:{}", user_id, room);
bus.inner.cache.set(&key, &content).await?; bus.inner.cache.set(&key, &content).await?;
let ds_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let ds_room = bus
let ds_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let ds_user = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = draft::DraftSavedService { let data = draft::DraftSavedService {
user: ds_user, user: ds_user,
room: ds_room, room: ds_room,
content, content,
saved_at: Utc::now(), saved_at: Utc::now(),
}; };
Ok(Some(WsOutEvent::DraftSaved { room: data.room.clone(), data })) Ok(Some(WsOutEvent::DraftSaved {
room: data.room.clone(),
data,
}))
} }
pub(super) async fn draft_clear( pub(super) async fn draft_clear(
@ -40,13 +49,22 @@ impl WsHandler {
Self::ensure_room_access(bus, user_id, room).await?; Self::ensure_room_access(bus, user_id, room).await?;
let key = format!("draft:{}:{}", user_id, room); let key = format!("draft:{}:{}", user_id, room);
bus.inner.cache.remove(&key).await?; bus.inner.cache.remove(&key).await?;
let dc_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let dc_room = bus
let dc_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let dc_user = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = draft::DraftClearedService { let data = draft::DraftClearedService {
user: dc_user, user: dc_user,
room: dc_room, room: dc_room,
cleared_at: Utc::now(), cleared_at: Utc::now(),
}; };
Ok(Some(WsOutEvent::DraftCleared { room: data.room.clone(), data })) Ok(Some(WsOutEvent::DraftCleared {
room: data.room.clone(),
data,
}))
} }
} }

View File

@ -3,8 +3,8 @@ use uuid::Uuid;
use crate::event::{RoomInfo, forward}; use crate::event::{RoomInfo, forward};
use crate::{ChannelBus, ChannelResult}; use crate::{ChannelBus, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn message_forward( pub(super) async fn message_forward(
@ -63,11 +63,8 @@ impl WsHandler {
let fwd_content_type = row.content_type.clone(); let fwd_content_type = row.content_type.clone();
let fwd_created_at = row.created_at; let fwd_created_at = row.created_at;
bus.publish_room_message( bus.publish_room_message(row, Some(bus.lookup_user(user_id).await?))
row, .await?;
Some(bus.lookup_user(user_id).await?),
)
.await?;
let data = forward::MessageForwardedService { let data = forward::MessageForwardedService {
id: fwd_id, id: fwd_id,

View File

@ -59,10 +59,13 @@ impl WsHandler {
.fetch_all(bus.inner.db.reader()) .fetch_all(bus.inner.db.reader())
.await?; .await?;
let user_ids: Vec<Uuid> = rows.iter().map(|(_, _, user)| *user).collect(); let user_ids: Vec<Uuid> =
rows.iter().map(|(_, _, user)| *user).collect();
let users = bus.lookup_users(&user_ids).await.unwrap_or_default(); let users = bus.lookup_users(&user_ids).await.unwrap_or_default();
let mut grouped: HashMap<Uuid, HashMap<String, reaction::ReactionGroup>> = let mut grouped: HashMap<
HashMap::new(); Uuid,
HashMap<String, reaction::ReactionGroup>,
> = HashMap::new();
for (message_id, emoji, reactor) in rows { for (message_id, emoji, reactor) in rows {
let group = grouped let group = grouped

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, WorkspaceInfo, invite}; use crate::event::{RoomInfo, UserInfo, WorkspaceInfo, invite};
use crate::{ChannelBus, ChannelError, ChannelResult}; use crate::{ChannelBus, ChannelError, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn invite_create( pub(super) async fn invite_create(
@ -29,16 +29,29 @@ impl WsHandler {
"expires_at": _expires_at, "expires_at": _expires_at,
}); });
bus.inner.cache.set(&id_key, &meta.to_string()).await?; bus.inner.cache.set(&id_key, &meta.to_string()).await?;
bus.inner.cache.set(&code_key, &invite_id.to_string()).await?; bus.inner
.cache
.set(&code_key, &invite_id.to_string())
.await?;
let inv_room = match _room { let inv_room = match _room {
Some(r) => Some(bus.lookup_room(r).await.unwrap_or_else(|_| RoomInfo::unknown(r))), Some(r) => Some(
bus.lookup_room(r)
.await
.unwrap_or_else(|_| RoomInfo::unknown(r)),
),
None => None, None => None,
}; };
let data = invite::InviteCreatedService { let data = invite::InviteCreatedService {
id: invite_id, id: invite_id,
workspace: bus.lookup_workspace(workspace).await.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), workspace: bus
.lookup_workspace(workspace)
.await
.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)),
room: inv_room, room: inv_room,
inviter: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), inviter: bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id)),
invitee: None, invitee: None,
code, code,
max_uses: _max_uses, max_uses: _max_uses,
@ -54,7 +67,8 @@ impl WsHandler {
code: String, code: String,
) -> ChannelResult<Option<WsOutEvent>> { ) -> ChannelResult<Option<WsOutEvent>> {
let code_key = format!("invite:code:{}", code); let code_key = format!("invite:code:{}", code);
let invite_id_str: Option<String> = bus.inner.cache.get(&code_key).await?; let invite_id_str: Option<String> =
bus.inner.cache.get(&code_key).await?;
let invite_id = invite_id_str let invite_id = invite_id_str
.as_deref() .as_deref()
.and_then(|s| Uuid::parse_str(s).ok()) .and_then(|s| Uuid::parse_str(s).ok())
@ -90,9 +104,15 @@ impl WsHandler {
bus.inner.cache.remove(&id_key).await?; bus.inner.cache.remove(&id_key).await?;
let data = invite::InviteAcceptedService { let data = invite::InviteAcceptedService {
id: Uuid::now_v7(), id: Uuid::now_v7(),
workspace: bus.lookup_workspace(wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(wk)), workspace: bus
.lookup_workspace(wk)
.await
.unwrap_or_else(|_| WorkspaceInfo::unknown(wk)),
room: None, room: None,
user: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), user: bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id)),
accepted_at: Utc::now(), accepted_at: Utc::now(),
}; };
bus.workspace_changed(wk).await?; bus.workspace_changed(wk).await?;

View File

@ -7,9 +7,9 @@ use crate::{
pagination::{MessagePagination, PaginationDirection, PaginationParams}, pagination::{MessagePagination, PaginationDirection, PaginationParams},
}; };
use super::{MAX_MESSAGES_PER_REQUEST, MAX_TEXT_LEN};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
use super::{MAX_MESSAGES_PER_REQUEST, MAX_TEXT_LEN};
impl WsHandler { impl WsHandler {
/// Count non-deleted sibling replies to the same parent message. /// Count non-deleted sibling replies to the same parent message.
@ -124,16 +124,21 @@ impl WsHandler {
// ── Auto-thread logic ────────────────────────────────────────── // ── Auto-thread logic ──────────────────────────────────────────
let mut events: Vec<WsOutEvent> = Vec::new(); let mut events: Vec<WsOutEvent> = Vec::new();
let effective_thread: Option<Uuid> = if let Some(ref parent_id) = in_reply_to { let effective_thread: Option<Uuid> = if let Some(ref parent_id) =
in_reply_to
{
if thread.is_some() { if thread.is_some() {
thread thread
} else { } else {
let existing = Self::find_thread_in_chain(bus, *parent_id).await?; let existing =
Self::find_thread_in_chain(bus, *parent_id).await?;
if let Some(tid) = existing { if let Some(tid) = existing {
Some(tid) Some(tid)
} else { } else {
let sibling_count = Self::count_sibling_replies(bus, *parent_id).await?; let sibling_count =
let (root_id, root_seq, chain_depth) = Self::reply_chain_info(bus, *parent_id).await?; Self::count_sibling_replies(bus, *parent_id).await?;
let (root_id, root_seq, chain_depth) =
Self::reply_chain_info(bus, *parent_id).await?;
let should_create = sibling_count >= 3 || chain_depth >= 5; let should_create = sibling_count >= 3 || chain_depth >= 5;
if should_create { if should_create {
@ -152,11 +157,20 @@ impl WsHandler {
.await?; .await?;
let new_thread_id = thread_row.id; let new_thread_id = thread_row.id;
Self::attach_chain_to_thread(bus, *parent_id, new_thread_id).await?; Self::attach_chain_to_thread(
bus,
*parent_id,
new_thread_id,
)
.await?;
let tc_room = bus.lookup_room(room).await let tc_room = bus
.lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room)); .unwrap_or_else(|_| RoomInfo::unknown(room));
let created_by = bus.lookup_user(user_id).await let created_by = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id)); .unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = thread::ThreadCreatedService { let data = thread::ThreadCreatedService {
id: new_thread_id, id: new_thread_id,
@ -166,7 +180,8 @@ impl WsHandler {
participants: serde_json::Value::Null, participants: serde_json::Value::Null,
created_at: thread_row.created_at, created_at: thread_row.created_at,
}; };
bus.publish_room_event(room, "thread.created", &data).await?; bus.publish_room_event(room, "thread.created", &data)
.await?;
events.push(WsOutEvent::ThreadCreated { events.push(WsOutEvent::ThreadCreated {
room: data.room.clone(), room: data.room.clone(),
data, data,
@ -227,11 +242,11 @@ impl WsHandler {
.fetch_one(bus.inner.db.writer()) .fetch_one(bus.inner.db.writer())
.await?; .await?;
bus.publish_room_message( bus.publish_room_message(row.clone(), Some(sender)).await?;
row.clone(), let msg_room = bus
Some(sender), .lookup_room(room)
).await?; .await
let msg_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); .unwrap_or_else(|_| RoomInfo::unknown(room));
events.push(WsOutEvent::MessageNew { events.push(WsOutEvent::MessageNew {
room: msg_room.clone(), room: msg_room.clone(),
data: message::MessageNewService { data: message::MessageNewService {
@ -254,7 +269,9 @@ impl WsHandler {
}, },
}); });
Ok(events.into_iter().find(|e| matches!(e, WsOutEvent::MessageNew { .. }))) Ok(events
.into_iter()
.find(|e| matches!(e, WsOutEvent::MessageNew { .. })))
} }
pub(super) async fn message_update( pub(super) async fn message_update(
@ -302,8 +319,14 @@ impl WsHandler {
.execute(bus.inner.db.writer()) .execute(bus.inner.db.writer())
.await?; .await?;
let sender = bus.lookup_user(row.author).await.unwrap_or_else(|_| UserInfo::unknown(row.author)); let sender = bus
let room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); .lookup_user(row.author)
.await
.unwrap_or_else(|_| UserInfo::unknown(row.author));
let room = bus
.lookup_room(row.room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(row.room));
let data = message::MessageEditedService { let data = message::MessageEditedService {
id: row.id, id: row.id,
seq: row.seq, seq: row.seq,
@ -353,8 +376,14 @@ impl WsHandler {
.bind(message_id) .bind(message_id)
.fetch_one(bus.inner.db.writer()) .fetch_one(bus.inner.db.writer())
.await?; .await?;
let revoked_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); let revoked_by = bus
let room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); .lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let room = bus
.lookup_room(row.room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(row.room));
let data = message::MessageRevokedService { let data = message::MessageRevokedService {
id: row.id, id: row.id,
seq: row.seq, seq: row.seq,
@ -418,20 +447,28 @@ impl WsHandler {
.await?; .await?;
let mut page_messages = page.messages; let mut page_messages = page.messages;
if before_seq.is_some() || (before_seq.is_none() && after_seq.is_none()) { if before_seq.is_some() || (before_seq.is_none() && after_seq.is_none())
{
page_messages.reverse(); page_messages.reverse();
} }
let message_ids: Vec<Uuid> = page_messages.iter().map(|m| m.id).collect(); let message_ids: Vec<Uuid> =
let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) page_messages.iter().map(|m| m.id).collect();
.await let reactions =
.unwrap_or_default(); Self::reaction_groups_for_messages(bus, user_id, &message_ids)
.await
.unwrap_or_default();
let list_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let list_room = bus
.lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let mut messages: Vec<message::MessageNewService> = let mut messages: Vec<message::MessageNewService> =
Vec::with_capacity(page_messages.len()); Vec::with_capacity(page_messages.len());
for m in page_messages { for m in page_messages {
let sender = bus.lookup_user(m.sender_id).await let sender = bus
.lookup_user(m.sender_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(m.sender_id)); .unwrap_or_else(|_| UserInfo::unknown(m.sender_id));
messages.push(message::MessageNewService { messages.push(message::MessageNewService {
id: m.id, id: m.id,
@ -534,10 +571,14 @@ impl WsHandler {
let author_ids: Vec<Uuid> = rows.iter().map(|r| r.author).collect(); let author_ids: Vec<Uuid> = rows.iter().map(|r| r.author).collect();
let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default();
let message_ids: Vec<Uuid> = rows.iter().map(|r| r.id).collect(); let message_ids: Vec<Uuid> = rows.iter().map(|r| r.id).collect();
let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) let reactions =
Self::reaction_groups_for_messages(bus, user_id, &message_ids)
.await
.unwrap_or_default();
let around_room = bus
.lookup_room(room)
.await .await
.unwrap_or_default(); .unwrap_or_else(|_| RoomInfo::unknown(room));
let around_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
let messages = rows let messages = rows
.into_iter() .into_iter()
.map(|r| { .map(|r| {
@ -561,7 +602,10 @@ impl WsHandler {
thinking_content: None, thinking_content: None,
thinking_is_chunked: None, thinking_is_chunked: None,
send_at: r.created_at, send_at: r.created_at,
reactions: reactions.get(&r.id).cloned().unwrap_or_default(), reactions: reactions
.get(&r.id)
.cloned()
.unwrap_or_default(),
} }
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -591,13 +635,19 @@ impl WsHandler {
.reconnect .reconnect
.get_missed_messages(room, after_seq) .get_missed_messages(room, after_seq)
.await?; .await?;
let author_ids: Vec<Uuid> = messages.iter().map(|m| m.sender_id).collect(); let author_ids: Vec<Uuid> =
let message_ids: Vec<Uuid> = messages.iter().map(|m| m.message_id).collect(); messages.iter().map(|m| m.sender_id).collect();
let message_ids: Vec<Uuid> =
messages.iter().map(|m| m.message_id).collect();
let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default();
let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) let reactions =
Self::reaction_groups_for_messages(bus, user_id, &message_ids)
.await
.unwrap_or_default();
let missed_room = bus
.lookup_room(room)
.await .await
.unwrap_or_default(); .unwrap_or_else(|_| RoomInfo::unknown(room));
let missed_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
let messages = messages let messages = messages
.into_iter() .into_iter()
.take(limit) .take(limit)
@ -622,7 +672,10 @@ impl WsHandler {
thinking_content: None, thinking_content: None,
thinking_is_chunked: None, thinking_is_chunked: None,
send_at: m.send_at, send_at: m.send_at,
reactions: reactions.get(&m.message_id).cloned().unwrap_or_default(), reactions: reactions
.get(&m.message_id)
.cloned()
.unwrap_or_default(),
} }
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, message_read}; use crate::event::{RoomInfo, UserInfo, message_read};
use crate::{ChannelBus, ChannelResult}; use crate::{ChannelBus, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn message_mark_read( pub(super) async fn message_mark_read(
@ -60,10 +60,14 @@ impl WsHandler {
.await?; .await?;
} }
let room_info = let room_info = bus
bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); .lookup_room(room)
let reader_info = .await
bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .unwrap_or_else(|_| RoomInfo::unknown(room));
let reader_info = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = message_read::MessageReadBatchService { let data = message_read::MessageReadBatchService {
room: room_info.clone(), room: room_info.clone(),
@ -72,7 +76,8 @@ impl WsHandler {
reader: reader_info, reader: reader_info,
read_at: now, read_at: now,
}; };
bus.publish_room_event(room, "message.read_batch", &data).await?; bus.publish_room_event(room, "message.read_batch", &data)
.await?;
Ok(Some(WsOutEvent::MessageReadBatch { Ok(Some(WsOutEvent::MessageReadBatch {
room: room_info, room: room_info,
data, data,

View File

@ -12,27 +12,25 @@ pub(crate) const MAX_CATEGORY_NAME_LEN: usize = 50;
mod helpers; mod helpers;
mod subscription;
mod message;
mod room;
mod category;
mod reaction;
mod thread;
mod pin;
mod draft;
mod notification;
mod presence;
mod invite;
mod ban; mod ban;
mod voice; mod category;
mod ai;
mod search;
mod user;
mod conversation; mod conversation;
mod dm; mod draft;
mod forward; mod forward;
mod invite;
mod message;
mod message_read; mod message_read;
mod notification;
mod pin;
mod presence;
mod reaction;
mod room;
mod search;
mod star; mod star;
mod subscription;
mod thread;
mod user;
mod voice;
pub struct WsHandler; pub struct WsHandler;
@ -73,8 +71,14 @@ impl WsHandler {
) )
.await .await
} }
WsInMessage::MessageAround { room, seq, limit, thread } => { WsInMessage::MessageAround {
Self::message_around(bus, user_id, room, seq, limit, thread).await room,
seq,
limit,
thread,
} => {
Self::message_around(bus, user_id, room, seq, limit, thread)
.await
} }
WsInMessage::MessageCreate { WsInMessage::MessageCreate {
room, room,
@ -108,9 +112,11 @@ impl WsHandler {
room_name, room_name,
public, public,
category, category,
ai_enabled,
} => { } => {
Self::room_create( Self::room_create(
bus, user_id, workspace, room_name, public, category, bus, user_id, workspace, room_name, public, category,
ai_enabled,
) )
.await .await
} }
@ -119,7 +125,13 @@ impl WsHandler {
room_name, room_name,
public, public,
category, category,
} => Self::room_update(bus, user_id, room, room_name, public, category).await, ai_enabled,
} => {
Self::room_update(
bus, user_id, room, room_name, public, category, ai_enabled,
)
.await
}
WsInMessage::RoomDelete { room } => { WsInMessage::RoomDelete { room } => {
Self::room_delete(bus, user_id, room).await Self::room_delete(bus, user_id, room).await
} }
@ -127,7 +139,10 @@ impl WsHandler {
workspace, workspace,
name, name,
position, position,
} => Self::category_create(bus, user_id, workspace, name, position).await, } => {
Self::category_create(bus, user_id, workspace, name, position)
.await
}
WsInMessage::CategoryUpdate { id, name, position } => { WsInMessage::CategoryUpdate { id, name, position } => {
Self::category_update(bus, user_id, id, name, position).await Self::category_update(bus, user_id, id, name, position).await
} }
@ -160,7 +175,11 @@ impl WsHandler {
dnd_end_hour, dnd_end_hour,
} => { } => {
Self::dnd_update( Self::dnd_update(
bus, user_id, room, do_not_disturb, dnd_start_hour, bus,
user_id,
room,
do_not_disturb,
dnd_start_hour,
dnd_end_hour, dnd_end_hour,
) )
.await .await
@ -205,7 +224,8 @@ impl WsHandler {
Self::notification_mark_read(bus, user_id, id).await Self::notification_mark_read(bus, user_id, id).await
} }
WsInMessage::NotificationMarkAllRead { workspace_id } => { WsInMessage::NotificationMarkAllRead { workspace_id } => {
Self::notification_mark_all_read(bus, user_id, workspace_id).await Self::notification_mark_all_read(bus, user_id, workspace_id)
.await
} }
WsInMessage::NotificationArchive { id } => { WsInMessage::NotificationArchive { id } => {
Self::notification_archive(bus, user_id, id).await Self::notification_archive(bus, user_id, id).await
@ -218,8 +238,10 @@ impl WsHandler {
text, text,
expires_at, expires_at,
} => { } => {
Self::custom_status_update(bus, user_id, emoji, text, expires_at) Self::custom_status_update(
.await bus, user_id, emoji, text, expires_at,
)
.await
} }
WsInMessage::InviteCreate { WsInMessage::InviteCreate {
workspace, workspace,
@ -227,8 +249,10 @@ impl WsHandler {
max_uses, max_uses,
expires_at, expires_at,
} => { } => {
Self::invite_create(bus, user_id, workspace, room, max_uses, expires_at) Self::invite_create(
.await bus, user_id, workspace, room, max_uses, expires_at,
)
.await
} }
WsInMessage::InviteAccept { code } => { WsInMessage::InviteAccept { code } => {
Self::invite_accept(bus, user_id, code).await Self::invite_accept(bus, user_id, code).await
@ -242,8 +266,10 @@ impl WsHandler {
reason, reason,
expires_at, expires_at,
} => { } => {
Self::ban_create(bus, user_id, workspace, user, reason, expires_at) Self::ban_create(
.await bus, user_id, workspace, user, reason, expires_at,
)
.await
} }
WsInMessage::BanRemove { workspace, user } => { WsInMessage::BanRemove { workspace, user } => {
Self::ban_remove(bus, user_id, workspace, user).await Self::ban_remove(bus, user_id, workspace, user).await
@ -263,18 +289,6 @@ impl WsHandler {
WsInMessage::ScreenShare { room, start } => { WsInMessage::ScreenShare { room, start } => {
Self::screen_share(bus, user_id, room, start).await Self::screen_share(bus, user_id, room, start).await
} }
WsInMessage::AiList { room } => {
Self::ai_list(bus, user_id, room).await
}
WsInMessage::AiUpsert { room, model } => {
Self::ai_upsert(bus, user_id, room, model).await
}
WsInMessage::AiDelete { room, agent_id } => {
Self::ai_delete(bus, user_id, room, agent_id).await
}
WsInMessage::AiStop { room } => {
Self::ai_stop(bus, user_id, room).await
}
WsInMessage::UserSummary { username } => { WsInMessage::UserSummary { username } => {
Self::user_summary(bus, username).await Self::user_summary(bus, username).await
} }
@ -291,29 +305,20 @@ impl WsHandler {
WsInMessage::ConversationMute { room, mute } => { WsInMessage::ConversationMute { room, mute } => {
Self::conversation_mute(bus, user_id, room, mute).await Self::conversation_mute(bus, user_id, room, mute).await
} }
WsInMessage::ConversationNotifyLevel { WsInMessage::ConversationNotifyLevel { room, notify_level } => {
room,
notify_level,
} => {
Self::conversation_notify_level( Self::conversation_notify_level(
bus, user_id, room, notify_level, bus,
user_id,
room,
notify_level,
) )
.await .await
} }
WsInMessage::ConversationList => { WsInMessage::ConversationList => {
Self::conversation_list(bus, user_id).await Self::conversation_list(bus, user_id).await
} }
WsInMessage::DmCreate { recipient } => {
Self::dm_create(bus, user_id, recipient).await WsInMessage::MessageMarkRead { room, message_ids } => {
}
WsInMessage::DmClose { room } => {
Self::dm_close(bus, user_id, room).await
}
WsInMessage::DmList => Self::dm_list(bus, user_id).await,
WsInMessage::MessageMarkRead {
room,
message_ids,
} => {
Self::message_mark_read(bus, user_id, room, message_ids).await Self::message_mark_read(bus, user_id, room, message_ids).await
} }
WsInMessage::MessageGetReaders { message_id } => { WsInMessage::MessageGetReaders { message_id } => {
@ -331,8 +336,13 @@ impl WsHandler {
source_message_id, source_message_id,
target_room, target_room,
} => { } => {
Self::message_forward(bus, user_id, source_message_id, target_room) Self::message_forward(
.await bus,
user_id,
source_message_id,
target_room,
)
.await
} }
} }
} }

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use crate::event::{UserInfo, notify}; use crate::event::{UserInfo, notify};
use crate::{ChannelBus, ChannelError, ChannelResult}; use crate::{ChannelBus, ChannelError, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn notification_mark_read( pub(super) async fn notification_mark_read(
@ -24,7 +24,10 @@ impl WsHandler {
if result.rows_affected() == 0 { if result.rows_affected() == 0 {
return Err(ChannelError::RoomNotFound); return Err(ChannelError::RoomNotFound);
} }
let nr_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); let nr_user = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = notify::NotifyReadService { let data = notify::NotifyReadService {
id, id,
user: nr_user, user: nr_user,

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, pin}; use crate::event::{RoomInfo, UserInfo, pin};
use crate::{ChannelBus, ChannelResult}; use crate::{ChannelBus, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn pin_add( pub(super) async fn pin_add(
@ -35,8 +35,14 @@ impl WsHandler {
.bind(message) .bind(message)
.execute(bus.inner.db.writer()) .execute(bus.inner.db.writer())
.await?; .await?;
let pa_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let pa_room = bus
let pinned_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let pinned_by = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = pin::PinAddedService { let data = pin::PinAddedService {
room: pa_room, room: pa_room,
message, message,
@ -44,7 +50,10 @@ impl WsHandler {
pinned_at: Utc::now(), pinned_at: Utc::now(),
}; };
bus.publish_room_event(room, "pin.added", &data).await?; bus.publish_room_event(room, "pin.added", &data).await?;
Ok(Some(WsOutEvent::PinAdded { room: data.room.clone(), data })) Ok(Some(WsOutEvent::PinAdded {
room: data.room.clone(),
data,
}))
} }
pub(super) async fn pin_remove( pub(super) async fn pin_remove(
@ -69,8 +78,14 @@ impl WsHandler {
.bind(message) .bind(message)
.execute(bus.inner.db.writer()) .execute(bus.inner.db.writer())
.await?; .await?;
let pr_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let pr_room = bus
let removed_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let removed_by = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = pin::PinRemovedService { let data = pin::PinRemovedService {
room: pr_room, room: pr_room,
message, message,
@ -78,6 +93,9 @@ impl WsHandler {
removed_at: Utc::now(), removed_at: Utc::now(),
}; };
bus.publish_room_event(room, "pin.removed", &data).await?; bus.publish_room_event(room, "pin.removed", &data).await?;
Ok(Some(WsOutEvent::PinRemoved { room: data.room.clone(), data })) Ok(Some(WsOutEvent::PinRemoved {
room: data.room.clone(),
data,
}))
} }
} }

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, member, presence}; use crate::event::{RoomInfo, UserInfo, member, presence};
use crate::{ChannelBus, ChannelResult}; use crate::{ChannelBus, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn dnd_update( pub(super) async fn dnd_update(
@ -27,8 +27,14 @@ impl WsHandler {
"dnd_end_hour": end_hour, "dnd_end_hour": end_hour,
}); });
bus.inner.cache.set(&key, &dnd_data).await?; bus.inner.cache.set(&key, &dnd_data).await?;
let dnd_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let dnd_room = bus
let dnd_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let dnd_user = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = member::DndUpdatedService { let data = member::DndUpdatedService {
room: dnd_room, room: dnd_room,
user: dnd_user, user: dnd_user,
@ -36,7 +42,8 @@ impl WsHandler {
dnd_start_hour: start_hour, dnd_start_hour: start_hour,
dnd_end_hour: end_hour, dnd_end_hour: end_hour,
}; };
bus.publish_room_event(room, "member.dnd_updated", &data).await?; bus.publish_room_event(room, "member.dnd_updated", &data)
.await?;
Ok(None) Ok(None)
} }
@ -45,7 +52,10 @@ impl WsHandler {
user_id: Uuid, user_id: Uuid,
status: presence::UserPresenceStatus, status: presence::UserPresenceStatus,
) -> ChannelResult<Option<WsOutEvent>> { ) -> ChannelResult<Option<WsOutEvent>> {
let pc_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); let pc_user = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = presence::PresenceChangedService { let data = presence::PresenceChangedService {
user: pc_user, user: pc_user,
project: None, project: None,
@ -60,7 +70,8 @@ impl WsHandler {
) )
.await?; .await?;
for room in rooms { for room in rooms {
bus.publish_room_event(room, "presence.changed", &data).await?; bus.publish_room_event(room, "presence.changed", &data)
.await?;
} }
Ok(Some(WsOutEvent::PresenceChanged { data })) Ok(Some(WsOutEvent::PresenceChanged { data }))
} }
@ -72,7 +83,10 @@ impl WsHandler {
text: Option<String>, text: Option<String>,
expires_at: Option<chrono::DateTime<chrono::Utc>>, expires_at: Option<chrono::DateTime<chrono::Utc>>,
) -> ChannelResult<Option<WsOutEvent>> { ) -> ChannelResult<Option<WsOutEvent>> {
let cs_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); let cs_user = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = presence::CustomStatusUpdatedService { let data = presence::CustomStatusUpdatedService {
user: cs_user, user: cs_user,
emoji, emoji,
@ -87,7 +101,8 @@ impl WsHandler {
) )
.await?; .await?;
for room in rooms { for room in rooms {
bus.publish_room_event(room, "custom_status.updated", &data).await?; bus.publish_room_event(room, "custom_status.updated", &data)
.await?;
} }
Ok(Some(WsOutEvent::CustomStatusUpdated { data })) Ok(Some(WsOutEvent::CustomStatusUpdated { data }))
} }

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, reaction}; use crate::event::{RoomInfo, UserInfo, reaction};
use crate::{ChannelBus, ChannelError, ChannelResult}; use crate::{ChannelBus, ChannelError, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn reaction_add( pub(super) async fn reaction_add(
@ -39,7 +39,10 @@ impl WsHandler {
.lookup_user(user_id) .lookup_user(user_id)
.await .await
.unwrap_or_else(|_| UserInfo::unknown(user_id)); .unwrap_or_else(|_| UserInfo::unknown(user_id));
let rct_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let rct_room = bus
.lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let data = reaction::ReactionAddedService { let data = reaction::ReactionAddedService {
id: Uuid::now_v7(), id: Uuid::now_v7(),
room: rct_room, room: rct_room,
@ -48,8 +51,12 @@ impl WsHandler {
emoji, emoji,
created_at: Utc::now(), created_at: Utc::now(),
}; };
bus.publish_room_event(room, "reaction.added", &data).await?; bus.publish_room_event(room, "reaction.added", &data)
Ok(Some(WsOutEvent::ReactionAdded { room: data.room.clone(), data })) .await?;
Ok(Some(WsOutEvent::ReactionAdded {
room: data.room.clone(),
data,
}))
} }
pub(super) async fn reaction_remove( pub(super) async fn reaction_remove(
@ -76,7 +83,10 @@ impl WsHandler {
.lookup_user(user_id) .lookup_user(user_id)
.await .await
.unwrap_or_else(|_| UserInfo::unknown(user_id)); .unwrap_or_else(|_| UserInfo::unknown(user_id));
let rct_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let rct_room = bus
.lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let data = reaction::ReactionRemovedService { let data = reaction::ReactionRemovedService {
id: Uuid::now_v7(), id: Uuid::now_v7(),
room: rct_room, room: rct_room,
@ -85,7 +95,11 @@ impl WsHandler {
emoji, emoji,
removed_at: Utc::now(), removed_at: Utc::now(),
}; };
bus.publish_room_event(room, "reaction.removed", &data).await?; bus.publish_room_event(room, "reaction.removed", &data)
Ok(Some(WsOutEvent::ReactionRemoved { room: data.room.clone(), data })) .await?;
Ok(Some(WsOutEvent::ReactionRemoved {
room: data.room.clone(),
data,
}))
} }
} }

View File

@ -4,9 +4,9 @@ use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, WorkspaceInfo, member, rooms}; use crate::event::{RoomInfo, UserInfo, WorkspaceInfo, member, rooms};
use crate::{ChannelBus, ChannelError, ChannelResult}; use crate::{ChannelBus, ChannelError, ChannelResult};
use super::{MAX_ROOM_NAME_LEN}; use super::MAX_ROOM_NAME_LEN;
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn room_get( pub(super) async fn room_get(
@ -17,7 +17,7 @@ impl WsHandler {
Self::ensure_room_access(bus, user_id, room).await?; Self::ensure_room_access(bus, user_id, room).await?;
let row = db::sqlx::query_as::<_, model::room::RoomModel>( let row = db::sqlx::query_as::<_, model::room::RoomModel>(
"SELECT id, wk, parent, name, topic, room_type, position, \ "SELECT id, wk, parent, name, topic, room_type, position, \
is_private, is_archived, created_by, created_at, updated_at, deleted_at \ is_private, is_archived, ai_enabled, created_by, created_at, updated_at, deleted_at \
FROM room WHERE id = $1 AND deleted_at IS NULL", FROM room WHERE id = $1 AND deleted_at IS NULL",
) )
.bind(room) .bind(room)
@ -33,6 +33,7 @@ impl WsHandler {
"room_type": row.room_type, "room_type": row.room_type,
"is_private": row.is_private, "is_private": row.is_private,
"is_archived": row.is_archived, "is_archived": row.is_archived,
"ai_enabled": row.ai_enabled,
"parent": row.parent, "parent": row.parent,
"created_by": row.created_by, "created_by": row.created_by,
"created_at": row.created_at, "created_at": row.created_at,
@ -47,22 +48,25 @@ impl WsHandler {
room_name: String, room_name: String,
public: bool, public: bool,
category: Option<Uuid>, category: Option<Uuid>,
ai_enabled: Option<bool>,
) -> ChannelResult<Option<WsOutEvent>> { ) -> ChannelResult<Option<WsOutEvent>> {
if room_name.is_empty() || room_name.len() > MAX_ROOM_NAME_LEN { if room_name.is_empty() || room_name.len() > MAX_ROOM_NAME_LEN {
return Err(ChannelError::Validation("invalid room name".into())); return Err(ChannelError::Validation("invalid room name".into()));
} }
Self::ensure_workspace_member(bus, user_id, workspace).await?; Self::ensure_workspace_member(bus, user_id, workspace).await?;
let is_private = !public; let is_private = !public;
let ai = ai_enabled.unwrap_or(false);
let row = db::sqlx::query_as::<_, model::room::RoomModel>( let row = db::sqlx::query_as::<_, model::room::RoomModel>(
"INSERT INTO room (wk, parent, name, room_type, is_private, created_by, created_at, updated_at) \ "INSERT INTO room (wk, parent, name, room_type, is_private, ai_enabled, created_by, created_at, updated_at) \
VALUES ($1, $2, $3, 'channel', $4, $5, now(), now()) \ VALUES ($1, $2, $3, 'channel', $4, $5, $6, now(), now()) \
RETURNING id, wk, parent, name, topic, room_type, position, \ RETURNING id, wk, parent, name, topic, room_type, position, \
is_private, is_archived, created_by, created_at, updated_at, deleted_at", is_private, is_archived, ai_enabled, created_by, created_at, updated_at, deleted_at",
) )
.bind(workspace) .bind(workspace)
.bind(category) .bind(category)
.bind(&room_name) .bind(&room_name)
.bind(is_private) .bind(is_private)
.bind(ai)
.bind(user_id) .bind(user_id)
.fetch_one(bus.inner.db.writer()) .fetch_one(bus.inner.db.writer())
.await?; .await?;
@ -77,15 +81,25 @@ impl WsHandler {
.await?; .await?;
let data = rooms::RoomCreatedService { let data = rooms::RoomCreatedService {
room: RoomInfo::from_model(&row), room: RoomInfo::from_model(&row),
workspace: bus.lookup_workspace(workspace).await.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), workspace: bus
.lookup_workspace(workspace)
.await
.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)),
public, public,
category, category,
created_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), created_by: bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id)),
created_at: row.created_at, created_at: row.created_at,
}; };
bus.publish_room_event(row.id, "room.created", &data).await?; bus.publish_room_event(row.id, "room.created", &data)
.await?;
bus.room_changed(row.id).await?; bus.room_changed(row.id).await?;
Ok(Some(WsOutEvent::RoomCreated { room: data.room.clone(), data })) Ok(Some(WsOutEvent::RoomCreated {
room: data.room.clone(),
data,
}))
} }
pub(super) async fn room_update( pub(super) async fn room_update(
@ -95,56 +109,74 @@ impl WsHandler {
room_name: Option<String>, room_name: Option<String>,
public: Option<bool>, public: Option<bool>,
category: Option<Uuid>, category: Option<Uuid>,
ai_enabled: Option<bool>,
) -> ChannelResult<Option<WsOutEvent>> { ) -> ChannelResult<Option<WsOutEvent>> {
Self::ensure_room_access(bus, user_id, room).await?; Self::ensure_room_access(bus, user_id, room).await?;
let old = db::sqlx::query_as::<_, model::room::RoomModel>( let old = db::sqlx::query_as::<_, model::room::RoomModel>(
"SELECT id, wk, parent, name, topic, room_type, position, \ "SELECT id, wk, parent, name, topic, room_type, position, \
is_private, is_archived, created_by, created_at, updated_at, deleted_at \ is_private, is_archived, ai_enabled, created_by, created_at, updated_at, deleted_at \
FROM room WHERE id = $1 AND deleted_at IS NULL", FROM room WHERE id = $1 AND deleted_at IS NULL",
) )
.bind(room) .bind(room)
.fetch_one(bus.inner.db.reader()) .fetch_one(bus.inner.db.reader())
.await?; .await?;
let new_name = room_name.unwrap_or(old.name.clone()); let new_name = room_name.unwrap_or(old.name.clone());
let new_private = let new_private = public.map(|p| !p).unwrap_or(old.is_private);
public.map(|p| !p).unwrap_or(old.is_private);
let new_category = category.or(old.parent); let new_category = category.or(old.parent);
let new_ai = ai_enabled.unwrap_or(old.ai_enabled);
let row = db::sqlx::query_as::<_, model::room::RoomModel>( let row = db::sqlx::query_as::<_, model::room::RoomModel>(
"UPDATE room SET name = $2, is_private = $3, parent = $4, updated_at = now() \ "UPDATE room SET name = $2, is_private = $3, parent = $4, ai_enabled = $5, updated_at = now() \
WHERE id = $1 AND deleted_at IS NULL \ WHERE id = $1 AND deleted_at IS NULL \
RETURNING id, wk, parent, name, topic, room_type, position, \ RETURNING id, wk, parent, name, topic, room_type, position, \
is_private, is_archived, created_by, created_at, updated_at, deleted_at", is_private, is_archived, ai_enabled, created_by, created_at, updated_at, deleted_at",
) )
.bind(room) .bind(room)
.bind(&new_name) .bind(&new_name)
.bind(new_private) .bind(new_private)
.bind(new_category) .bind(new_category)
.bind(new_ai)
.fetch_one(bus.inner.db.writer()) .fetch_one(bus.inner.db.writer())
.await?; .await?;
let mut renamed = false; let mut renamed = false;
if new_name != old.name { if new_name != old.name {
let data = rooms::RoomRenamedService { let data = rooms::RoomRenamedService {
room: RoomInfo::from_model(&row), room: RoomInfo::from_model(&row),
workspace: bus.lookup_workspace(row.wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), workspace: bus
.lookup_workspace(row.wk)
.await
.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)),
old_name: old.name.clone(), old_name: old.name.clone(),
new_name: new_name, new_name: new_name,
renamed_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), renamed_by: bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id)),
renamed_at: Utc::now(), renamed_at: Utc::now(),
}; };
bus.publish_room_event(room, "room.renamed", &data).await?; bus.publish_room_event(room, "room.renamed", &data).await?;
renamed = true; renamed = true;
} }
if new_private != old.is_private || new_category != old.parent { if new_private != old.is_private
|| new_category != old.parent
|| new_ai != old.ai_enabled
{
let data = rooms::RoomSettingsUpdatedService { let data = rooms::RoomSettingsUpdatedService {
room: RoomInfo::from_model(&row), room: RoomInfo::from_model(&row),
workspace: bus.lookup_workspace(row.wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), workspace: bus
.lookup_workspace(row.wk)
.await
.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)),
slowmode_seconds: None, slowmode_seconds: None,
nsfw: false, nsfw: false,
default_auto_archive_duration: None, default_auto_archive_duration: None,
updated_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), updated_by: bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id)),
updated_at: Utc::now(), updated_at: Utc::now(),
}; };
bus.publish_room_event(room, "room.settings_updated", &data).await?; bus.publish_room_event(room, "room.settings_updated", &data)
.await?;
} }
bus.room_changed(room).await?; bus.room_changed(room).await?;
if renamed { if renamed {
@ -152,15 +184,30 @@ impl WsHandler {
room: RoomInfo::from_model(&row), room: RoomInfo::from_model(&row),
data: rooms::RoomRenamedService { data: rooms::RoomRenamedService {
room: RoomInfo::from_model(&row), room: RoomInfo::from_model(&row),
workspace: bus.lookup_workspace(row.wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), workspace: bus
.lookup_workspace(row.wk)
.await
.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)),
old_name: old.name, old_name: old.name,
new_name: row.name, new_name: row.name,
renamed_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), renamed_by: bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id)),
renamed_at: Utc::now(), renamed_at: Utc::now(),
}, },
})); }));
} }
Ok(None) Ok(Some(WsOutEvent::Response {
request_id: Uuid::nil(),
data: serde_json::json!({
"id": row.id,
"name": row.name,
"is_private": row.is_private,
"ai_enabled": row.ai_enabled,
"parent": row.parent,
}),
}))
} }
pub(super) async fn room_delete( pub(super) async fn room_delete(
@ -191,13 +238,22 @@ impl WsHandler {
.await?; .await?;
let data = rooms::RoomDeletedService { let data = rooms::RoomDeletedService {
room: RoomInfo::from_model(&row), room: RoomInfo::from_model(&row),
workspace: bus.lookup_workspace(row.wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), workspace: bus
deleted_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), .lookup_workspace(row.wk)
.await
.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)),
deleted_by: bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id)),
deleted_at: Utc::now(), deleted_at: Utc::now(),
}; };
bus.publish_room_event(room, "room.deleted", &data).await?; bus.publish_room_event(room, "room.deleted", &data).await?;
bus.room_changed(room).await?; bus.room_changed(room).await?;
Ok(Some(WsOutEvent::RoomDeleted { room: data.room.clone(), data })) Ok(Some(WsOutEvent::RoomDeleted {
room: data.room.clone(),
data,
}))
} }
pub(super) async fn access_grant( pub(super) async fn access_grant(
@ -217,8 +273,14 @@ impl WsHandler {
.bind(target_user) .bind(target_user)
.execute(bus.inner.db.writer()) .execute(bus.inner.db.writer())
.await?; .await?;
let mj_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let mj_room = bus
let mj_user = bus.lookup_user(target_user).await.unwrap_or_else(|_| UserInfo::unknown(target_user)); .lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let mj_user = bus
.lookup_user(target_user)
.await
.unwrap_or_else(|_| UserInfo::unknown(target_user));
let data = member::MemberJoinedService { let data = member::MemberJoinedService {
room: mj_room, room: mj_room,
user: mj_user, user: mj_user,
@ -227,7 +289,10 @@ impl WsHandler {
}; };
bus.publish_room_event(room, "member.joined", &data).await?; bus.publish_room_event(room, "member.joined", &data).await?;
bus.room_changed(room).await?; bus.room_changed(room).await?;
Ok(Some(WsOutEvent::MemberJoined { room: data.room.clone(), data })) Ok(Some(WsOutEvent::MemberJoined {
room: data.room.clone(),
data,
}))
} }
pub(super) async fn access_revoke( pub(super) async fn access_revoke(
@ -245,17 +310,30 @@ impl WsHandler {
.bind(target_user) .bind(target_user)
.execute(bus.inner.db.writer()) .execute(bus.inner.db.writer())
.await?; .await?;
let mr_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let mr_room = bus
let mr_target = bus.lookup_user(target_user).await.unwrap_or_else(|_| UserInfo::unknown(target_user)); .lookup_room(room)
let mr_remover = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let mr_target = bus
.lookup_user(target_user)
.await
.unwrap_or_else(|_| UserInfo::unknown(target_user));
let mr_remover = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = member::MemberRemovedService { let data = member::MemberRemovedService {
room: mr_room, room: mr_room,
user: mr_target, user: mr_target,
removed_by: mr_remover, removed_by: mr_remover,
removed_at: Utc::now(), removed_at: Utc::now(),
}; };
bus.publish_room_event(room, "member.removed", &data).await?; bus.publish_room_event(room, "member.removed", &data)
.await?;
bus.room_changed(room).await?; bus.room_changed(room).await?;
Ok(Some(WsOutEvent::MemberRemoved { room: data.room.clone(), data })) Ok(Some(WsOutEvent::MemberRemoved {
room: data.room.clone(),
data,
}))
} }
} }

View File

@ -6,8 +6,8 @@ use crate::{
search::{SearchEngine, SearchQuery}, search::{SearchEngine, SearchQuery},
}; };
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn search( pub(super) async fn search(
@ -36,18 +36,27 @@ impl WsHandler {
}) })
.await?; .await?;
let author_ids: Vec<Uuid> = result.hits.iter().map(|h| h.sender_id).collect(); let author_ids: Vec<Uuid> =
let message_ids: Vec<Uuid> = result.hits.iter().map(|h| h.message_id).collect(); result.hits.iter().map(|h| h.sender_id).collect();
let message_ids: Vec<Uuid> =
result.hits.iter().map(|h| h.message_id).collect();
let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default();
let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) let reactions =
.await Self::reaction_groups_for_messages(bus, user_id, &message_ids)
.unwrap_or_default(); .await
.unwrap_or_default();
let search_room = match room { let search_room = match room {
Some(r) => Some(bus.lookup_room(r).await.unwrap_or_else(|_| RoomInfo::unknown(r))), Some(r) => Some(
bus.lookup_room(r)
.await
.unwrap_or_else(|_| RoomInfo::unknown(r)),
),
None => None, None => None,
}; };
let search_msg_room = search_room.clone().unwrap_or_else(|| RoomInfo::unknown(room.unwrap_or_default())); let search_msg_room = search_room
.clone()
.unwrap_or_else(|| RoomInfo::unknown(room.unwrap_or_default()));
let data = crate::event::search::SearchResultService { let data = crate::event::search::SearchResultService {
q, q,
room: search_room, room: search_room,
@ -60,26 +69,30 @@ impl WsHandler {
.cloned() .cloned()
.unwrap_or_else(|| UserInfo::unknown(h.sender_id)); .unwrap_or_else(|| UserInfo::unknown(h.sender_id));
crate::event::search::SearchMessageHitService { crate::event::search::SearchMessageHitService {
message: crate::event::message::MessageNewService { message: crate::event::message::MessageNewService {
id: h.message_id, id: h.message_id,
seq: 0, seq: 0,
room: search_msg_room.clone(), room: search_msg_room.clone(),
sender_type: "user".to_string(), sender_type: "user".to_string(),
sender, sender,
thread: None, thread: None,
in_reply_to: None, in_reply_to: None,
content: h.content.clone(), content: h.content.clone(),
content_type: "text".to_string(), content_type: "text".to_string(),
pinned: false, pinned: false,
system_type: None, system_type: None,
metadata: serde_json::Value::Null, metadata: serde_json::Value::Null,
thinking_content: None, thinking_content: None,
thinking_is_chunked: None, thinking_is_chunked: None,
send_at: h.send_at, send_at: h.send_at,
reactions: reactions.get(&h.message_id).cloned().unwrap_or_default(), reactions: reactions
}, .get(&h.message_id)
highlighted_content: h.highlighted, .cloned()
}}) .unwrap_or_default(),
},
highlighted_content: h.highlighted,
}
})
.collect(), .collect(),
total: result.total as i64, total: result.total as i64,
took_ms: 0, took_ms: 0,

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, star}; use crate::event::{RoomInfo, UserInfo, star};
use crate::{ChannelBus, ChannelResult}; use crate::{ChannelBus, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn message_star( pub(super) async fn message_star(
@ -18,10 +18,14 @@ impl WsHandler {
Self::ensure_room_access(bus, user_id, room).await?; Self::ensure_room_access(bus, user_id, room).await?;
Self::ensure_message_in_room(bus, room, message).await?; Self::ensure_message_in_room(bus, room, message).await?;
let room_info = let room_info = bus
bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); .lookup_room(room)
let user_info = .await
bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .unwrap_or_else(|_| RoomInfo::unknown(room));
let user_info = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
if do_star { if do_star {
let result = db::sqlx::query( let result = db::sqlx::query(
@ -76,7 +80,8 @@ impl WsHandler {
unstarred_by: user_info, unstarred_by: user_info,
unstarred_at: Utc::now(), unstarred_at: Utc::now(),
}; };
bus.emit_to_user(user_id, "message.unstarred", &data).await?; bus.emit_to_user(user_id, "message.unstarred", &data)
.await?;
Ok(Some(WsOutEvent::MessageUnstarred { Ok(Some(WsOutEvent::MessageUnstarred {
room: room_info, room: room_info,
data, data,
@ -123,7 +128,17 @@ impl WsHandler {
let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default();
let mut entries = Vec::with_capacity(rows.len()); let mut entries = Vec::with_capacity(rows.len());
for (_star_id, msg_id, seq, content, content_type, author_id, starred_at, sent_at) in rows { for (
_star_id,
msg_id,
seq,
content,
content_type,
author_id,
starred_at,
sent_at,
) in rows
{
let msg_room_row: Option<(Uuid,)> = db::sqlx::query_as( let msg_room_row: Option<(Uuid,)> = db::sqlx::query_as(
"SELECT room FROM room_message WHERE id = $1", "SELECT room FROM room_message WHERE id = $1",
) )

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, member}; use crate::event::{RoomInfo, UserInfo, member};
use crate::{ChannelBus, ChannelResult}; use crate::{ChannelBus, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn subscribe( pub(super) async fn subscribe(
@ -36,10 +36,18 @@ impl WsHandler {
let key = (room, user_id); let key = (room, user_id);
if action == "start" { if action == "start" {
let ty_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let ty_room = bus
let ty_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let ty_user = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let already_typing = bus.inner.typing_states.contains_key(&key); let already_typing = bus.inner.typing_states.contains_key(&key);
if let Some((_, (_, _, old_cancel))) = bus.inner.typing_states.remove(&key) { if let Some((_, (_, _, old_cancel))) =
bus.inner.typing_states.remove(&key)
{
old_cancel.cancel(); old_cancel.cancel();
} }
@ -48,13 +56,18 @@ impl WsHandler {
let bus_clone = bus.clone(); let bus_clone = bus.clone();
let user_clone = ty_user.clone(); let user_clone = ty_user.clone();
let room_clone = ty_room.clone(); let room_clone = ty_room.clone();
bus.inner.typing_states.insert(key, (ty_user.clone(), ty_room.clone(), cancel)); bus.inner
.typing_states
.insert(key, (ty_user.clone(), ty_room.clone(), cancel));
tokio::spawn(async move { tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(10)).await; tokio::time::sleep(std::time::Duration::from_secs(10)).await;
if cancel_clone.is_cancelled() { if cancel_clone.is_cancelled() {
return; return;
} }
bus_clone.inner.typing_states.remove(&(room_clone.id, user_clone.id)); bus_clone
.inner
.typing_states
.remove(&(room_clone.id, user_clone.id));
let room_id = room_clone.id; let room_id = room_clone.id;
let stop_data = member::TypingStopService { let stop_data = member::TypingStopService {
room: room_clone, room: room_clone,
@ -62,7 +75,9 @@ impl WsHandler {
sender_type: "user".to_string(), sender_type: "user".to_string(),
stopped_at: Utc::now(), stopped_at: Utc::now(),
}; };
let _ = bus_clone.publish_room_event(room_id, "typing.stop", &stop_data).await; let _ = bus_clone
.publish_room_event(room_id, "typing.stop", &stop_data)
.await;
}); });
if !already_typing { if !already_typing {
let data = member::TypingStartService { let data = member::TypingStartService {
@ -72,16 +87,27 @@ impl WsHandler {
started_at: Utc::now(), started_at: Utc::now(),
}; };
bus.publish_room_event(room, "typing.start", &data).await?; bus.publish_room_event(room, "typing.start", &data).await?;
return Ok(Some(WsOutEvent::TypingStart { room: data.room.clone(), data })); return Ok(Some(WsOutEvent::TypingStart {
room: data.room.clone(),
data,
}));
} }
Ok(None) Ok(None)
} else { } else {
if let Some((_, (_, _, cancel))) = bus.inner.typing_states.remove(&key) { if let Some((_, (_, _, cancel))) =
bus.inner.typing_states.remove(&key)
{
cancel.cancel(); cancel.cancel();
} }
let ty_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let ty_room = bus
let ty_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let ty_user = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = member::TypingStopService { let data = member::TypingStopService {
room: ty_room, room: ty_room,
@ -90,7 +116,10 @@ impl WsHandler {
stopped_at: Utc::now(), stopped_at: Utc::now(),
}; };
bus.publish_room_event(room, "typing.stop", &data).await?; bus.publish_room_event(room, "typing.stop", &data).await?;
Ok(Some(WsOutEvent::TypingStop { room: data.room.clone(), data })) Ok(Some(WsOutEvent::TypingStop {
room: data.room.clone(),
data,
}))
} }
} }
@ -117,8 +146,14 @@ impl WsHandler {
.bind(last_read_seq) .bind(last_read_seq)
.execute(bus.inner.db.writer()) .execute(bus.inner.db.writer())
.await?; .await?;
let rr_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let rr_room = bus
let rr_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let rr_user = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = member::ReadReceiptService { let data = member::ReadReceiptService {
room: rr_room.clone(), room: rr_room.clone(),
user: rr_user, user: rr_user,

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, thread}; use crate::event::{RoomInfo, UserInfo, thread};
use crate::{ChannelBus, ChannelError, ChannelResult}; use crate::{ChannelBus, ChannelError, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
/// Helper struct for thread_list JOIN query result /// Helper struct for thread_list JOIN query result
#[derive(db::sqlx::FromRow)] #[derive(db::sqlx::FromRow)]
@ -48,9 +48,13 @@ impl WsHandler {
let mut items = Vec::new(); let mut items = Vec::new();
for row in rows { for row in rows {
let tc_room = bus.lookup_room(row.room).await let tc_room = bus
.lookup_room(row.room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(row.room)); .unwrap_or_else(|_| RoomInfo::unknown(row.room));
let created_by = bus.lookup_user(row.created_by).await let created_by = bus
.lookup_user(row.created_by)
.await
.unwrap_or_else(|_| UserInfo::unknown(row.created_by)); .unwrap_or_else(|_| UserInfo::unknown(row.created_by));
// Get last message preview // Get last message preview
let preview: Option<(String,)> = db::sqlx::query_as( let preview: Option<(String,)> = db::sqlx::query_as(
@ -110,8 +114,14 @@ impl WsHandler {
.bind(user_id) .bind(user_id)
.fetch_one(bus.inner.db.writer()) .fetch_one(bus.inner.db.writer())
.await?; .await?;
let tc_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let tc_room = bus
let created_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let created_by = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = thread::ThreadCreatedService { let data = thread::ThreadCreatedService {
id: row.id, id: row.id,
room: tc_room, room: tc_room,
@ -120,8 +130,12 @@ impl WsHandler {
participants: serde_json::Value::Null, participants: serde_json::Value::Null,
created_at: row.created_at, created_at: row.created_at,
}; };
bus.publish_room_event(room, "thread.created", &data).await?; bus.publish_room_event(room, "thread.created", &data)
Ok(Some(WsOutEvent::ThreadCreated { room: data.room.clone(), data })) .await?;
Ok(Some(WsOutEvent::ThreadCreated {
room: data.room.clone(),
data,
}))
} }
pub(super) async fn thread_resolve( pub(super) async fn thread_resolve(
@ -129,13 +143,12 @@ impl WsHandler {
user_id: Uuid, user_id: Uuid,
thread_id: Uuid, thread_id: Uuid,
) -> ChannelResult<Option<WsOutEvent>> { ) -> ChannelResult<Option<WsOutEvent>> {
let existing: (Uuid,) = db::sqlx::query_as( let existing: (Uuid,) =
"SELECT room FROM room_thread WHERE id = $1", db::sqlx::query_as("SELECT room FROM room_thread WHERE id = $1")
) .bind(thread_id)
.bind(thread_id) .fetch_optional(bus.inner.db.reader())
.fetch_optional(bus.inner.db.reader()) .await?
.await? .ok_or(ChannelError::RoomNotFound)?;
.ok_or(ChannelError::RoomNotFound)?;
Self::ensure_room_access(bus, user_id, existing.0).await?; Self::ensure_room_access(bus, user_id, existing.0).await?;
let row = db::sqlx::query_as::<_, model::room::RoomThreadModel>( let row = db::sqlx::query_as::<_, model::room::RoomThreadModel>(
"UPDATE room_thread SET locked = true, updated_at = now() \ "UPDATE room_thread SET locked = true, updated_at = now() \
@ -146,16 +159,26 @@ impl WsHandler {
.bind(thread_id) .bind(thread_id)
.fetch_one(bus.inner.db.writer()) .fetch_one(bus.inner.db.writer())
.await?; .await?;
let tr_room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); let tr_room = bus
let resolved_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(row.room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(row.room));
let resolved_by = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = thread::ThreadResolvedService { let data = thread::ThreadResolvedService {
id: row.id, id: row.id,
room: tr_room, room: tr_room,
resolved_by, resolved_by,
resolved_at: Utc::now(), resolved_at: Utc::now(),
}; };
bus.publish_room_event(row.room, "thread.resolved", &data).await?; bus.publish_room_event(row.room, "thread.resolved", &data)
Ok(Some(WsOutEvent::ThreadResolved { room: data.room.clone(), data })) .await?;
Ok(Some(WsOutEvent::ThreadResolved {
room: data.room.clone(),
data,
}))
} }
pub(super) async fn thread_archive( pub(super) async fn thread_archive(
@ -163,13 +186,12 @@ impl WsHandler {
user_id: Uuid, user_id: Uuid,
thread_id: Uuid, thread_id: Uuid,
) -> ChannelResult<Option<WsOutEvent>> { ) -> ChannelResult<Option<WsOutEvent>> {
let existing: (Uuid,) = db::sqlx::query_as( let existing: (Uuid,) =
"SELECT room FROM room_thread WHERE id = $1", db::sqlx::query_as("SELECT room FROM room_thread WHERE id = $1")
) .bind(thread_id)
.bind(thread_id) .fetch_optional(bus.inner.db.reader())
.fetch_optional(bus.inner.db.reader()) .await?
.await? .ok_or(ChannelError::RoomNotFound)?;
.ok_or(ChannelError::RoomNotFound)?;
Self::ensure_room_access(bus, user_id, existing.0).await?; Self::ensure_room_access(bus, user_id, existing.0).await?;
let row = db::sqlx::query_as::<_, model::room::RoomThreadModel>( let row = db::sqlx::query_as::<_, model::room::RoomThreadModel>(
"UPDATE room_thread SET archived = true, archived_at = now(), updated_at = now() \ "UPDATE room_thread SET archived = true, archived_at = now(), updated_at = now() \
@ -180,15 +202,25 @@ impl WsHandler {
.bind(thread_id) .bind(thread_id)
.fetch_one(bus.inner.db.writer()) .fetch_one(bus.inner.db.writer())
.await?; .await?;
let ta_room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); let ta_room = bus
let archived_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(row.room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(row.room));
let archived_by = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = thread::ThreadArchivedService { let data = thread::ThreadArchivedService {
id: row.id, id: row.id,
room: ta_room, room: ta_room,
archived_by, archived_by,
archived_at: Utc::now(), archived_at: Utc::now(),
}; };
bus.publish_room_event(row.room, "thread.archived", &data).await?; bus.publish_room_event(row.room, "thread.archived", &data)
Ok(Some(WsOutEvent::ThreadArchived { room: data.room.clone(), data })) .await?;
Ok(Some(WsOutEvent::ThreadArchived {
room: data.room.clone(),
data,
}))
} }
} }

View File

@ -2,8 +2,8 @@ use uuid::Uuid;
use crate::{ChannelBus, ChannelError, ChannelResult}; use crate::{ChannelBus, ChannelError, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn user_summary( pub(super) async fn user_summary(

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, voice}; use crate::event::{RoomInfo, UserInfo, voice};
use crate::{ChannelBus, ChannelResult}; use crate::{ChannelBus, ChannelResult};
use super::WsOutEvent;
use super::WsHandler; use super::WsHandler;
use super::WsOutEvent;
impl WsHandler { impl WsHandler {
pub(super) async fn voice_join( pub(super) async fn voice_join(
@ -14,8 +14,14 @@ impl WsHandler {
room: Uuid, room: Uuid,
) -> ChannelResult<Option<WsOutEvent>> { ) -> ChannelResult<Option<WsOutEvent>> {
Self::ensure_room_access(bus, user_id, room).await?; Self::ensure_room_access(bus, user_id, room).await?;
let vj_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let vj_room = bus
let vj_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let vj_user = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = voice::VoiceChannelJoinedService { let data = voice::VoiceChannelJoinedService {
room: vj_room, room: vj_room,
workspace: None, workspace: None,
@ -25,8 +31,12 @@ impl WsHandler {
video: false, video: false,
joined_at: Utc::now(), joined_at: Utc::now(),
}; };
bus.publish_room_event(room, "voice.channel_joined", &data).await?; bus.publish_room_event(room, "voice.channel_joined", &data)
Ok(Some(WsOutEvent::VoiceChannelJoined { room: data.room.clone(), data })) .await?;
Ok(Some(WsOutEvent::VoiceChannelJoined {
room: data.room.clone(),
data,
}))
} }
pub(super) async fn voice_leave( pub(super) async fn voice_leave(
@ -35,16 +45,26 @@ impl WsHandler {
room: Uuid, room: Uuid,
) -> ChannelResult<Option<WsOutEvent>> { ) -> ChannelResult<Option<WsOutEvent>> {
Self::ensure_room_access(bus, user_id, room).await?; Self::ensure_room_access(bus, user_id, room).await?;
let vl_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let vl_room = bus
let vl_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); .lookup_room(room)
.await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let vl_user = bus
.lookup_user(user_id)
.await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = voice::VoiceChannelLeftService { let data = voice::VoiceChannelLeftService {
room: vl_room, room: vl_room,
workspace: None, workspace: None,
user: vl_user, user: vl_user,
left_at: Utc::now(), left_at: Utc::now(),
}; };
bus.publish_room_event(room, "voice.channel_left", &data).await?; bus.publish_room_event(room, "voice.channel_left", &data)
Ok(Some(WsOutEvent::VoiceChannelLeft { room: data.room.clone(), data })) .await?;
Ok(Some(WsOutEvent::VoiceChannelLeft {
room: data.room.clone(),
data,
}))
} }
pub(super) async fn voice_mute( pub(super) async fn voice_mute(

View File

@ -2,10 +2,9 @@ use serde::Serialize;
use uuid::Uuid; use uuid::Uuid;
use crate::event::{ use crate::event::{
RoomInfo, WorkspaceInfo, RoomInfo, WorkspaceInfo, attachment, ban, category, conversation, draft,
ai, attachment, ban, category, conversation, dm, draft, forward, invite, forward, invite, member, message, message_read, notify, pin, presence,
member, message, message_read, notify, pin, presence, reaction, rooms, reaction, rooms, search, star, thread, voice, workspace,
search, star, thread, voice, workspace,
}; };
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
@ -188,22 +187,6 @@ pub enum WsOutEvent {
UserUnbanned { UserUnbanned {
data: ban::UnbannedService, data: ban::UnbannedService,
}, },
AiAgentJoined {
room: RoomInfo,
data: ai::AiAgentJoinedService,
},
AiAgentLeft {
room: RoomInfo,
data: ai::AiAgentLeftService,
},
AiAgentList {
room: RoomInfo,
data: ai::RoomAiListService,
},
AiAgentStatusChanged {
room: RoomInfo,
data: ai::AiAgentStatusChangedService,
},
VoiceChannelJoined { VoiceChannelJoined {
room: RoomInfo, room: RoomInfo,
data: voice::VoiceChannelJoinedService, data: voice::VoiceChannelJoinedService,
@ -235,21 +218,6 @@ pub enum WsOutEvent {
ConversationList { ConversationList {
data: Vec<conversation::ConversationSummary>, data: Vec<conversation::ConversationSummary>,
}, },
DmCreated {
room: RoomInfo,
data: dm::DmCreatedService,
},
DmClosed {
room: RoomInfo,
data: dm::DmClosedService,
},
DmReopened {
room: RoomInfo,
data: dm::DmReopenedService,
},
DmList {
data: Vec<dm::DmCreatedService>,
},
MessageRead { MessageRead {
room: RoomInfo, room: RoomInfo,
data: message_read::MessageReadService, data: message_read::MessageReadService,

View File

@ -60,12 +60,14 @@ pub enum WsInMessage {
room_name: String, room_name: String,
public: bool, public: bool,
category: Option<Uuid>, category: Option<Uuid>,
ai_enabled: Option<bool>,
}, },
RoomUpdate { RoomUpdate {
room: Uuid, room: Uuid,
room_name: Option<String>, room_name: Option<String>,
public: Option<bool>, public: Option<bool>,
category: Option<Uuid>, category: Option<Uuid>,
ai_enabled: Option<bool>,
}, },
RoomDelete { RoomDelete {
room: Uuid, room: Uuid,
@ -212,20 +214,6 @@ pub enum WsInMessage {
room: Uuid, room: Uuid,
start: bool, start: bool,
}, },
AiList {
room: Uuid,
},
AiUpsert {
room: Uuid,
model: Uuid,
},
AiDelete {
room: Uuid,
agent_id: Uuid,
},
AiStop {
room: Uuid,
},
UserSummary { UserSummary {
username: String, username: String,
}, },
@ -242,13 +230,6 @@ pub enum WsInMessage {
notify_level: String, notify_level: String,
}, },
ConversationList, ConversationList,
DmCreate {
recipient: Uuid,
},
DmClose {
room: Uuid,
},
DmList,
MessageMarkRead { MessageMarkRead {
room: Uuid, room: Uuid,
message_ids: Vec<Uuid>, message_ids: Vec<Uuid>,
@ -312,14 +293,9 @@ impl WsInMessage {
VoiceMute, VoiceMute,
VoiceDeaf, VoiceDeaf,
ScreenShare, ScreenShare,
AiList,
AiUpsert,
AiDelete,
AiStop,
ConversationPin, ConversationPin,
ConversationMute, ConversationMute,
ConversationNotifyLevel, ConversationNotifyLevel,
DmClose,
MessageMarkRead, MessageMarkRead,
MessageStar, MessageStar,
) )

View File

@ -47,11 +47,7 @@ async fn handle_inbound(bus: &ChannelBus, socket: &Socket, data: EventPayload) {
let parsed = payload; let parsed = payload;
let text = serde_json::to_string(payload).unwrap_or_default(); let text = serde_json::to_string(payload).unwrap_or_default();
if parsed if parsed.get("type").and_then(|t| t.as_str()) == Some("ping") {
.get("type")
.and_then(|t| t.as_str())
== Some("ping")
{
let pong = WsOutEvent::Pong { let pong = WsOutEvent::Pong {
protocol_version: super::types::WS_PROTOCOL_VERSION, protocol_version: super::types::WS_PROTOCOL_VERSION,
}; };
@ -115,7 +111,8 @@ async fn handle_inbound(bus: &ChannelBus, socket: &Socket, data: EventPayload) {
code: 400, code: 400,
error: "parse_error".to_string(), error: "parse_error".to_string(),
message: e.to_string(), message: e.to_string(),
}).unwrap_or_default(), })
.unwrap_or_default(),
}; };
send_event(socket, &err_resp).await; send_event(socket, &err_resp).await;
} else { } else {
@ -126,7 +123,8 @@ async fn handle_inbound(bus: &ChannelBus, socket: &Socket, data: EventPayload) {
error: "parse_error".to_string(), error: "parse_error".to_string(),
message: e.to_string(), message: e.to_string(),
}, },
).await; )
.await;
} }
} }
} }

View File

@ -17,6 +17,7 @@ mod security;
mod seq; mod seq;
mod token; mod token;
use crate::event::UserInfo;
pub use ack::{AckRequest, AckResponse, AckStatus, AckTracker, MessageAck}; pub use ack::{AckRequest, AckResponse, AckStatus, AckTracker, MessageAck};
pub use bus::ChannelBus; pub use bus::ChannelBus;
pub use cdn::{CdnManager, CdnStoredFile}; pub use cdn::{CdnManager, CdnStoredFile};
@ -37,3 +38,14 @@ pub use seq::SeqAllocator;
pub use token::{ pub use token::{
ChannelAccessToken, ChannelTokenApply, ChannelTokenContext, TOKEN_TTL_SECS, ChannelAccessToken, ChannelTokenApply, ChannelTokenContext, TOKEN_TTL_SECS,
}; };
use uuid::Uuid;
lazy_static::lazy_static! {
pub static ref REDPADA: UserInfo = UserInfo {
id: Uuid::nil(),
username: "RedPanda".to_string(),
display_name: "RedPanda".to_string(),
avatar_url: "".to_string(),
};
}

View File

@ -4,8 +4,8 @@ use uuid::Uuid;
use model::room::RoomMessageModel; use model::room::RoomMessageModel;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::rooms::RM_COLUMNS;
use crate::ChannelResult; use crate::ChannelResult;
use crate::rooms::RM_COLUMNS;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientState { pub struct ClientState {

View File

@ -6,8 +6,7 @@ use uuid::Uuid;
use crate::{ChannelBusConfig, ChannelResult}; use crate::{ChannelBusConfig, ChannelResult};
pub(crate) const RM_COLUMNS: &str = pub(crate) const RM_COLUMNS: &str = "id, room, seq, thread, parent, author, content, content_type, pinned, \
"id, room, seq, thread, parent, author, content, content_type, pinned, \
system_type, metadata, edited_at, created_at, updated_at, deleted_at"; system_type, metadata, edited_at, created_at, updated_at, deleted_at";
pub(crate) fn room_socket_name(room: Uuid) -> String { pub(crate) fn room_socket_name(room: Uuid) -> String {
@ -24,6 +23,7 @@ pub struct RoomListItem {
pub topic: Option<String>, pub topic: Option<String>,
pub room_type: String, pub room_type: String,
pub is_private: bool, pub is_private: bool,
pub ai_enabled: bool,
pub category: Option<Uuid>, pub category: Option<Uuid>,
pub workspace_id: Uuid, pub workspace_id: Uuid,
} }
@ -44,8 +44,20 @@ pub async fn user_rooms_for_api(
return Ok(Vec::new()); return Ok(Vec::new());
} }
let rows = sqlx::query_as::<_, (Uuid, String, Option<String>, String, bool, Option<Uuid>, Uuid)>( let rows = sqlx::query_as::<
"SELECT id, name, topic, room_type, is_private, parent, wk \ _,
(
Uuid,
String,
Option<String>,
String,
bool,
bool,
Option<Uuid>,
Uuid,
),
>(
"SELECT id, name, topic, room_type, is_private, ai_enabled, parent, wk \
FROM room \ FROM room \
WHERE id = ANY($1) AND deleted_at IS NULL AND is_archived = false \ WHERE id = ANY($1) AND deleted_at IS NULL AND is_archived = false \
ORDER BY name", ORDER BY name",
@ -56,15 +68,27 @@ pub async fn user_rooms_for_api(
Ok(rows Ok(rows
.into_iter() .into_iter()
.map(|(id, name, topic, room_type, is_private, category, workspace_id)| RoomListItem { .map(
id, |(
name, id,
topic, name,
room_type, topic,
is_private, room_type,
category, is_private,
workspace_id, ai_enabled,
}) category,
workspace_id,
)| RoomListItem {
id,
name,
topic,
room_type,
is_private,
ai_enabled,
category,
workspace_id,
},
)
.collect()) .collect())
} }
pub async fn user_categories_for_api( pub async fn user_categories_for_api(
@ -179,14 +203,14 @@ pub(crate) async fn catchup_messages(
room: Uuid, room: Uuid,
after_seq: i64, after_seq: i64,
) -> ChannelResult<Vec<RoomMessageModel>> { ) -> ChannelResult<Vec<RoomMessageModel>> {
let rows = sqlx::query_as::<_, RoomMessageModel>( let rows = sqlx::query_as::<_, RoomMessageModel>(db::sqlx::AssertSqlSafe(
db::sqlx::AssertSqlSafe(format!( format!(
"SELECT {RM_COLUMNS} FROM room_message \ "SELECT {RM_COLUMNS} FROM room_message \
WHERE room = $1 AND seq > $2 AND deleted_at IS NULL \ WHERE room = $1 AND seq > $2 AND deleted_at IS NULL \
ORDER BY seq ASC \ ORDER BY seq ASC \
LIMIT $3" LIMIT $3"
)), ),
) ))
.bind(room) .bind(room)
.bind(after_seq) .bind(after_seq)
.bind(config.catchup_limit) .bind(config.catchup_limit)

View File

@ -142,7 +142,7 @@ return 0
} }
} }
pub(crate) fn require_cluster( pub fn require_cluster(
cache: &cache::AppCache, cache: &cache::AppCache,
) -> ChannelResult<&cache::ClusterCache> { ) -> ChannelResult<&cache::ClusterCache> {
cache cache

View File

@ -128,7 +128,12 @@ impl SeqAllocator {
} }
if state if state
.next .next
.compare_exchange_weak(current, current + 1, Ordering::AcqRel, Ordering::Acquire) .compare_exchange_weak(
current,
current + 1,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok() .is_ok()
{ {
return Some(current); return Some(current);

View File

@ -184,7 +184,9 @@ impl ChannelBus {
.arg(&session_key) .arg(&session_key)
.query_async(&mut conn) .query_async(&mut conn)
.await .await
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; .map_err(|e| {
ChannelError::Cache(cache::CacheError::Redis(e))
})?;
let device_id = hash_data let device_id = hash_data
.get("device_id") .get("device_id")
@ -252,9 +254,8 @@ impl ChannelBus {
created_at, created_at,
}; };
let new_token_bytes = new_payload.encode(&signing_key)?; let new_token_bytes = new_payload.encode(&signing_key)?;
let new_access_token = let new_access_token = base64::engine::general_purpose::URL_SAFE_NO_PAD
base64::engine::general_purpose::URL_SAFE_NO_PAD .encode(&new_token_bytes);
.encode(&new_token_bytes);
let new_session_key = let new_session_key =
self.session_hash_key(&payload.user_id, created_at); self.session_hash_key(&payload.user_id, created_at);

Some files were not shown because too many files have changed in this diff Show More