Compare commits
13 Commits
bba35f1b2c
...
3e540a5302
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e540a5302 | ||
|
|
c7cee8c344 | ||
|
|
08045eef63 | ||
|
|
abcfc5b3bb | ||
|
|
5b81e7d774 | ||
|
|
4ba47370be | ||
|
|
27b9d3e4bd | ||
|
|
e9d5407c66 | ||
|
|
009ccee72b | ||
|
|
6a60d02263 | ||
|
|
395832118e | ||
|
|
867f216a1f | ||
|
|
907b5ee3bf |
168
Cargo.lock
generated
168
Cargo.lock
generated
@ -452,6 +452,7 @@ dependencies = [
|
||||
"models",
|
||||
"once_cell",
|
||||
"qdrant-client",
|
||||
"redis",
|
||||
"regex",
|
||||
"reqwest 0.13.2",
|
||||
"rig-core",
|
||||
@ -560,6 +561,19 @@ version = "0.2.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
|
||||
|
||||
[[package]]
|
||||
name = "ammonia"
|
||||
version = "4.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "17e913097e1a2124b46746c980134e8c954bc17a6a59bb3fde96f088d126dde6"
|
||||
dependencies = [
|
||||
"cssparser",
|
||||
"html5ever",
|
||||
"maplit",
|
||||
"tendril",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
@ -1991,6 +2005,29 @@ dependencies = [
|
||||
"hybrid-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cssparser"
|
||||
version = "0.35.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4e901edd733a1472f944a45116df3f846f54d37e67e68640ac8bb69689aca2aa"
|
||||
dependencies = [
|
||||
"cssparser-macros",
|
||||
"dtoa-short",
|
||||
"itoa",
|
||||
"phf",
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cssparser-macros"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13b588ba4ac1a99f7f2964d24b3d896ddc6bf847ee3855dbd4366f058cfcd331"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "csv"
|
||||
version = "1.4.0"
|
||||
@ -2396,6 +2433,21 @@ version = "0.15.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
|
||||
|
||||
[[package]]
|
||||
name = "dtoa"
|
||||
version = "1.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c3cf4824e2d5f025c7b531afcb2325364084a16806f6d47fbc1f5fbd9960590"
|
||||
|
||||
[[package]]
|
||||
name = "dtoa-short"
|
||||
version = "0.3.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd1511a7b6a56299bd043a9c167a6d2bfb37bf84a6dfceaba651168adfb43c87"
|
||||
dependencies = [
|
||||
"dtoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dyn-clone"
|
||||
version = "1.0.20"
|
||||
@ -2906,6 +2958,16 @@ version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
|
||||
|
||||
[[package]]
|
||||
name = "futf"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df420e2e84819663797d1ec6544b13c5be84629e7bb00dc960d6917db2987843"
|
||||
dependencies = [
|
||||
"mac",
|
||||
"new_debug_unreachable",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.32"
|
||||
@ -3603,6 +3665,17 @@ dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "html5ever"
|
||||
version = "0.35.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55d958c2f74b664487a2035fe1dadb032c48718a03b63f3ab0b8537db8549ed4"
|
||||
dependencies = [
|
||||
"log",
|
||||
"markup5ever",
|
||||
"match_token",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "0.2.12"
|
||||
@ -4738,6 +4811,12 @@ dependencies = [
|
||||
"sha2 0.10.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mac"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4"
|
||||
|
||||
[[package]]
|
||||
name = "mac_address"
|
||||
version = "1.1.8"
|
||||
@ -4749,6 +4828,34 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "maplit"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d"
|
||||
|
||||
[[package]]
|
||||
name = "markup5ever"
|
||||
version = "0.35.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "311fe69c934650f8f19652b3946075f0fc41ad8757dbb68f1ca14e7900ecc1c3"
|
||||
dependencies = [
|
||||
"log",
|
||||
"tendril",
|
||||
"web_atoms",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "match_token"
|
||||
version = "0.35.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac84fd3f360fcc43dc5f5d186f02a94192761a080e8bc58621ad4d12296a58cf"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matchers"
|
||||
version = "0.2.0"
|
||||
@ -5981,6 +6088,12 @@ dependencies = [
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "precomputed-hash"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.2.37"
|
||||
@ -6947,6 +7060,7 @@ name = "room"
|
||||
version = "0.2.9"
|
||||
dependencies = [
|
||||
"agent",
|
||||
"ammonia",
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"config",
|
||||
@ -8349,6 +8463,31 @@ version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82"
|
||||
|
||||
[[package]]
|
||||
name = "string_cache"
|
||||
version = "0.8.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf776ba3fa74f83bf4b63c3dcbbf82173db2632ed8452cb2d891d33f459de70f"
|
||||
dependencies = [
|
||||
"new_debug_unreachable",
|
||||
"parking_lot",
|
||||
"phf_shared",
|
||||
"precomputed-hash",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "string_cache_codegen"
|
||||
version = "0.5.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c711928715f1fe0fe509c53b43e993a9a557babc2d0a3567d0a3006f1ac931a0"
|
||||
dependencies = [
|
||||
"phf_generator",
|
||||
"phf_shared",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stringprep"
|
||||
version = "0.1.5"
|
||||
@ -8512,6 +8651,17 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tendril"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d24a120c5fc464a3458240ee02c299ebcb9d67b5249c8848b09d639dca8d7bb0"
|
||||
dependencies = [
|
||||
"futf",
|
||||
"mac",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.69"
|
||||
@ -9194,6 +9344,12 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utf8_iter"
|
||||
version = "1.0.4"
|
||||
@ -9476,6 +9632,18 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web_atoms"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57ffde1dc01240bdf9992e3205668b235e59421fd085e8a317ed98da0178d414"
|
||||
dependencies = [
|
||||
"phf",
|
||||
"phf_codegen",
|
||||
"string_cache",
|
||||
"string_cache_codegen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webhook"
|
||||
version = "0.2.9"
|
||||
|
||||
@ -42,5 +42,6 @@ rust_decimal = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json"] }
|
||||
utoipa = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
redis = { workspace = true, features = ["tokio-comp"] }
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@ -118,6 +118,7 @@ pub async fn record_ai_usage(
|
||||
let new_balance = project_billing.balance - total_cost;
|
||||
let mut updated: project_billing::ActiveModel = project_billing.into();
|
||||
updated.balance = Set(new_balance);
|
||||
updated.updated_at = Set(now);
|
||||
updated.update(&txn).await?;
|
||||
|
||||
txn.commit().await?;
|
||||
@ -183,8 +184,10 @@ pub async fn record_ai_usage(
|
||||
.await?;
|
||||
|
||||
let new_balance = workspace_billing.balance - total_cost;
|
||||
let new_total_spent = workspace_billing.total_spent + total_cost;
|
||||
let mut updated: workspace_billing::ActiveModel = workspace_billing.into();
|
||||
updated.balance = Set(new_balance);
|
||||
updated.total_spent = Set(new_total_spent);
|
||||
updated.updated_at = Set(now);
|
||||
updated.update(&txn).await?;
|
||||
|
||||
|
||||
@ -78,5 +78,7 @@ pub enum Mention {
|
||||
|
||||
pub mod context;
|
||||
pub mod service;
|
||||
pub mod state;
|
||||
pub use context::{AiContextSenderType, RoomMessageContext};
|
||||
pub use service::ChatService;
|
||||
pub use state::{AgentRuntime, AgentState};
|
||||
|
||||
@ -3,6 +3,7 @@ use models::projects::project_skill;
|
||||
use models::rooms::room_ai;
|
||||
use rig::agent::{AgentBuilder, MultiTurnStreamItem};
|
||||
use rig::client::CompletionClient;
|
||||
use rig::completion::{CompletionModel, GetTokenUsage, Prompt};
|
||||
use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
|
||||
use sea_orm::*;
|
||||
use std::pin::Pin;
|
||||
@ -48,6 +49,7 @@ pub struct ProcessResult {
|
||||
|
||||
/// Record an AI session with cost calculation.
|
||||
async fn record_ai_session(
|
||||
cache: &db::cache::AppCache,
|
||||
db: &db::database::AppDatabase,
|
||||
project_id: Uuid,
|
||||
session_id: Uuid,
|
||||
@ -58,6 +60,28 @@ async fn record_ai_session(
|
||||
output_tokens: i64,
|
||||
latency_ms: i64,
|
||||
) {
|
||||
metrics::histogram!("ai_call_latency_ms", "model" => model_id.to_string()).record(latency_ms as f64);
|
||||
|
||||
let session = models::ai::ai_session::ActiveModel {
|
||||
id: Set(session_id),
|
||||
room: Set(room_id),
|
||||
model: Set(model_id),
|
||||
version: Set(version_id),
|
||||
token_input: Set(input_tokens),
|
||||
token_output: Set(output_tokens),
|
||||
latency_ms: Set(Some(latency_ms)),
|
||||
cost: Set(None),
|
||||
currency: Set(None),
|
||||
error_message: Set(None),
|
||||
error_code: Set(None),
|
||||
created_at: Set(chrono::Utc::now()),
|
||||
};
|
||||
|
||||
if let Err(e) = session.insert(db).await {
|
||||
tracing::error!(error = %e, session_id = %session_id, "failed to insert ai session record");
|
||||
return;
|
||||
}
|
||||
|
||||
let (cost, currency, error_msg) = match billing::record_ai_usage(
|
||||
db,
|
||||
project_id,
|
||||
@ -71,33 +95,25 @@ async fn record_ai_session(
|
||||
(Some(record.cost), Some(record.currency), None)
|
||||
}
|
||||
Ok(billing::BillingResult::InsufficientBalance { message }) => {
|
||||
// Create system message for insufficient balance
|
||||
create_system_message(db, room_id, &message).await;
|
||||
create_system_message(cache, db, room_id, &message).await;
|
||||
(None, None, Some(message))
|
||||
}
|
||||
Err(_) => (None, None, None),
|
||||
Err(e) => (None, None, Some(e.to_string())),
|
||||
};
|
||||
|
||||
let _ = models::ai::ai_session::ActiveModel {
|
||||
id: Set(session_id),
|
||||
room: Set(room_id),
|
||||
model: Set(model_id),
|
||||
version: Set(version_id),
|
||||
token_input: Set(input_tokens),
|
||||
token_output: Set(output_tokens),
|
||||
latency_ms: Set(Some(latency_ms)),
|
||||
cost: Set(cost),
|
||||
currency: Set(currency),
|
||||
error_message: Set(error_msg),
|
||||
error_code: Set(None),
|
||||
created_at: Set(chrono::Utc::now()),
|
||||
}
|
||||
.insert(db)
|
||||
.await;
|
||||
use sea_orm::sea_query::Expr;
|
||||
let _ = models::ai::ai_session::Entity::update_many()
|
||||
.col_expr(models::ai::ai_session::Column::Cost, Expr::value(cost))
|
||||
.col_expr(models::ai::ai_session::Column::Currency, Expr::value(currency))
|
||||
.col_expr(models::ai::ai_session::Column::ErrorMessage, Expr::value(error_msg))
|
||||
.filter(models::ai::ai_session::Column::Id.eq(session_id))
|
||||
.exec(db)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Create a system message in the room for billing errors.
|
||||
async fn create_system_message(
|
||||
cache: &db::cache::AppCache,
|
||||
db: &db::database::AppDatabase,
|
||||
room_id: Uuid,
|
||||
message: &str,
|
||||
@ -105,26 +121,40 @@ async fn create_system_message(
|
||||
use models::rooms::{room_message, MessageSenderType, MessageContentType};
|
||||
use sea_orm::Set;
|
||||
|
||||
// Get next sequence number - we don't have cache here, so we query directly
|
||||
let last_seq = match room_message::Entity::find()
|
||||
.filter(room_message::Column::Room.eq(room_id))
|
||||
.order_by_desc(room_message::Column::Seq)
|
||||
.one(db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(m)) => m.seq,
|
||||
Ok(None) => 0,
|
||||
let seq_key = format!("room:seq:{}", room_id);
|
||||
let seq = match cache.conn().await {
|
||||
Ok(mut conn) => {
|
||||
match redis::cmd("INCR").arg(&seq_key).query_async::<i64>(&mut conn).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "cache INCR failed for system message seq, falling back to DB");
|
||||
let last_seq = match room_message::Entity::find()
|
||||
.filter(room_message::Column::Room.eq(room_id))
|
||||
.order_by_desc(room_message::Column::Seq)
|
||||
.one(db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(m)) => m.seq,
|
||||
Ok(None) => 0,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to get last seq for system message");
|
||||
return;
|
||||
}
|
||||
};
|
||||
last_seq + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to get last seq for system message");
|
||||
tracing::warn!(error = %e, "Failed to get Redis connection for system message seq");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let seq = last_seq + 1;
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
let result = room_message::ActiveModel {
|
||||
id: Set(Uuid::new_v4()),
|
||||
id: Set(Uuid::now_v7()),
|
||||
seq: Set(seq),
|
||||
room: Set(room_id),
|
||||
sender_type: Set(MessageSenderType::System),
|
||||
@ -269,7 +299,7 @@ impl ChatService {
|
||||
let mut tool_depth = 0;
|
||||
let mut input_tokens = 0i64;
|
||||
let mut output_tokens = 0i64;
|
||||
let session_id = Uuid::new_v4();
|
||||
let session_id = Uuid::now_v7();
|
||||
let session_start = std::time::Instant::now();
|
||||
let version_id = room_ai.as_ref().and_then(|r| r.version);
|
||||
|
||||
@ -464,6 +494,7 @@ impl ChatService {
|
||||
};
|
||||
// Record session
|
||||
record_ai_session(
|
||||
&request.cache,
|
||||
&request.db,
|
||||
request.project.id,
|
||||
session_id,
|
||||
@ -486,6 +517,7 @@ impl ChatService {
|
||||
|
||||
// Record session
|
||||
record_ai_session(
|
||||
&request.cache,
|
||||
&request.db,
|
||||
request.project.id,
|
||||
session_id,
|
||||
@ -536,7 +568,7 @@ impl ChatService {
|
||||
let mut tool_depth = 0;
|
||||
let mut total_input_tokens = 0i64;
|
||||
let mut total_output_tokens = 0i64;
|
||||
let session_id = Uuid::new_v4();
|
||||
let session_id = Uuid::now_v7();
|
||||
let session_start = std::time::Instant::now();
|
||||
|
||||
let version_id = room_ai.as_ref().and_then(|r| r.version);
|
||||
@ -860,6 +892,7 @@ impl ChatService {
|
||||
});
|
||||
// Record session
|
||||
record_ai_session(
|
||||
&request.cache,
|
||||
&request.db,
|
||||
request.project.id,
|
||||
session_id,
|
||||
@ -897,6 +930,7 @@ impl ChatService {
|
||||
});
|
||||
// Record session
|
||||
record_ai_session(
|
||||
&request.cache,
|
||||
&request.db,
|
||||
request.project.id,
|
||||
session_id,
|
||||
@ -934,22 +968,70 @@ impl ChatService {
|
||||
|
||||
let mut processed_history = Vec::new();
|
||||
if let Some(compact_service) = &self.compact_service {
|
||||
let config = CompactConfig::default();
|
||||
match compact_service
|
||||
.compact_room_auto(request.room.id, Some(request.user_names.clone()), config)
|
||||
.await
|
||||
{
|
||||
Ok(compact_summary) => {
|
||||
if !compact_summary.summary.is_empty() {
|
||||
let compact_cache_key = format!("ai:compact:{}", request.room.id);
|
||||
let compact_config = CompactConfig::default();
|
||||
|
||||
// Try cached compaction summary (avoids re-compacting same history)
|
||||
let cached_summary: Option<String> = {
|
||||
let conn_result = request.cache.conn().await;
|
||||
match conn_result {
|
||||
Ok(mut conn) => {
|
||||
redis::cmd("GET")
|
||||
.arg(&compact_cache_key)
|
||||
.query_async::<Option<String>>(&mut conn)
|
||||
.await
|
||||
.unwrap_or(None)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "compact cache: conn failed");
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(cached_json) = cached_summary {
|
||||
if let Ok(summary) = serde_json::from_str::<crate::compact::CompactSummary>(&cached_json) {
|
||||
if !summary.summary.is_empty() {
|
||||
messages.push(ChatRequestMessage::system(format!(
|
||||
"Conversation summary:\n{}",
|
||||
compact_summary.summary
|
||||
summary.summary
|
||||
)));
|
||||
}
|
||||
processed_history = compact_summary.retained;
|
||||
processed_history = summary.retained;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "conversation compaction failed, using full history");
|
||||
}
|
||||
|
||||
if processed_history.is_empty() {
|
||||
match compact_service
|
||||
.compact_room_auto(request.room.id, Some(request.user_names.clone()), compact_config)
|
||||
.await
|
||||
{
|
||||
Ok(compact_summary) => {
|
||||
if !compact_summary.summary.is_empty() {
|
||||
messages.push(ChatRequestMessage::system(format!(
|
||||
"Conversation summary:\n{}",
|
||||
compact_summary.summary
|
||||
)));
|
||||
}
|
||||
// Cache for subsequent calls (5 min TTL)
|
||||
if let Ok(json) = serde_json::to_string(&compact_summary) {
|
||||
if let Ok(mut conn) = request.cache.conn().await {
|
||||
let _ = redis::cmd("SETEX")
|
||||
.arg(&compact_cache_key)
|
||||
.arg(300u64)
|
||||
.arg(&json)
|
||||
.query_async::<()>(&mut conn)
|
||||
.await
|
||||
.inspect_err(|e| {
|
||||
tracing::warn!(error = %e, "compact cache: SETEX failed");
|
||||
});
|
||||
}
|
||||
}
|
||||
processed_history = compact_summary.retained;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "conversation compaction failed, using full history");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1186,9 +1268,10 @@ impl ChatService {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn process_react<C>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
|
||||
pub async fn process_react<C, Fut>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
|
||||
where
|
||||
C: FnMut(crate::react::ReactStep) + Send,
|
||||
C: FnMut(crate::react::ReactStep) -> Fut + Send,
|
||||
Fut: std::future::Future<Output = ()> + Send,
|
||||
{
|
||||
let base_url = self
|
||||
.ai_base_url
|
||||
@ -1207,7 +1290,7 @@ impl ChatService {
|
||||
let room_id = request.room.id;
|
||||
let sender_uid = request.sender.uid;
|
||||
let project_id = request.project.id;
|
||||
let session_id = Uuid::new_v4();
|
||||
let session_id = Uuid::now_v7();
|
||||
let session_start = std::time::Instant::now();
|
||||
let version_id = room_ai::Entity::find()
|
||||
.filter(room_ai::Column::Room.eq(request.room.id))
|
||||
@ -1274,7 +1357,8 @@ impl ChatService {
|
||||
on_chunk(ReactStep::Answer {
|
||||
step: step_count,
|
||||
answer: t.clone(),
|
||||
});
|
||||
})
|
||||
.await;
|
||||
final_content.push_str(&t);
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
@ -1286,7 +1370,8 @@ impl ChatService {
|
||||
on_chunk(ReactStep::Thought {
|
||||
step: step_count,
|
||||
thought: reasoning_text,
|
||||
});
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
@ -1297,7 +1382,8 @@ impl ChatService {
|
||||
on_chunk(ReactStep::Thought {
|
||||
step: step_count,
|
||||
thought: reasoning,
|
||||
});
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
@ -1313,7 +1399,8 @@ impl ChatService {
|
||||
on_chunk(ReactStep::Action {
|
||||
step: step_count,
|
||||
action: ReactAction::new(&tool_call.function.name, args),
|
||||
});
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamUserItem(
|
||||
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
|
||||
@ -1323,7 +1410,8 @@ impl ChatService {
|
||||
on_chunk(ReactStep::Observation {
|
||||
step: step_count,
|
||||
observation: obs,
|
||||
});
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
|
||||
let usage = resp.usage();
|
||||
@ -1341,6 +1429,7 @@ impl ChatService {
|
||||
|
||||
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
||||
record_ai_session(
|
||||
&request.cache,
|
||||
&request.db,
|
||||
request.project.id,
|
||||
session_id,
|
||||
@ -1355,6 +1444,623 @@ impl ChatService {
|
||||
|
||||
Ok((final_content, total_input_tokens, total_output_tokens))
|
||||
}
|
||||
|
||||
// ── CoT (Chain-of-Thought) ────────────────────────────────────────────
|
||||
|
||||
/// Run a CoT (Chain-of-Thought) reasoning cycle — step-by-step reasoning with optional tool use.
|
||||
pub async fn process_cot<C, Fut>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
|
||||
where
|
||||
C: FnMut(crate::modes::cot::CotStep) -> Fut + Send,
|
||||
Fut: std::future::Future<Output = ()> + Send,
|
||||
{
|
||||
let client_config = AiClientConfig::new(
|
||||
self.ai_api_key.clone().unwrap_or_default(),
|
||||
)
|
||||
.with_base_url(
|
||||
self.ai_base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "https://api.openai.com".into()),
|
||||
);
|
||||
let rig_client = client_config.build_rig_client();
|
||||
|
||||
let Some(registry) = &self.tool_registry else {
|
||||
return Err(AgentError::Internal("no tool registry registered".into()));
|
||||
};
|
||||
|
||||
let session_id = Uuid::now_v7();
|
||||
let session_start = std::time::Instant::now();
|
||||
let version_id = room_ai::Entity::find()
|
||||
.filter(room_ai::Column::Room.eq(request.room.id))
|
||||
.filter(room_ai::Column::Model.eq(request.model.id))
|
||||
.one(&request.db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.and_then(|r| r.version);
|
||||
|
||||
let db = request.db.clone();
|
||||
let cache = request.cache.clone();
|
||||
let cfg = request.config.clone();
|
||||
let room_id = request.room.id;
|
||||
let sender_uid = request.sender.uid;
|
||||
let project_id = request.project.id;
|
||||
|
||||
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
|
||||
for def in registry.definitions() {
|
||||
let name = def.name.clone();
|
||||
if let Some(handler) = registry.get(&name) {
|
||||
let adapter = crate::tool::RigToolAdapter::new(
|
||||
handler.clone(), def.clone(),
|
||||
db.clone(), cache.clone(), cfg.clone(),
|
||||
room_id, Some(sender_uid), project_id,
|
||||
);
|
||||
tools.push(Box::new(crate::tool::RecordingTool::new(
|
||||
Box::new(adapter), db.clone(), session_id, sender_uid,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let model = rig_client.completion_model(&request.model.name);
|
||||
let agent = AgentBuilder::new(model)
|
||||
.preamble(crate::modes::cot::COT_SYSTEM_PROMPT)
|
||||
.tools(tools)
|
||||
.default_max_turns(request.max_tool_depth)
|
||||
.build();
|
||||
|
||||
let stream = agent
|
||||
.stream_prompt(&request.input)
|
||||
.with_history(Vec::new())
|
||||
.multi_turn(request.max_tool_depth)
|
||||
.await;
|
||||
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut final_content = String::new();
|
||||
let mut total_input_tokens: i64 = 0;
|
||||
let mut total_output_tokens: i64 = 0;
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
match item {
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::Text(text),
|
||||
)) => {
|
||||
let t = text.text;
|
||||
on_chunk(crate::modes::cot::CotStep::Answer(t.clone())).await;
|
||||
final_content.push_str(&t);
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::Reasoning(reasoning),
|
||||
)) => {
|
||||
let r = reasoning.reasoning.join("");
|
||||
if !r.is_empty() {
|
||||
on_chunk(crate::modes::cot::CotStep::Thought(r)).await;
|
||||
}
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::ReasoningDelta { reasoning, .. },
|
||||
)) => {
|
||||
if !reasoning.is_empty() {
|
||||
on_chunk(crate::modes::cot::CotStep::Thought(reasoning)).await;
|
||||
}
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::ToolCall { tool_call, .. },
|
||||
)) => {
|
||||
let args: serde_json::Value = match &tool_call.function.arguments {
|
||||
serde_json::Value::String(s) => {
|
||||
serde_json::from_str(s).unwrap_or(serde_json::Value::Null)
|
||||
}
|
||||
v => v.clone(),
|
||||
};
|
||||
on_chunk(crate::modes::cot::CotStep::Action {
|
||||
name: tool_call.function.name.clone(),
|
||||
args,
|
||||
}).await;
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamUserItem(
|
||||
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
|
||||
)) => {
|
||||
let obs = tool_result_content_to_string(&tool_result.content);
|
||||
on_chunk(crate::modes::cot::CotStep::Observation(obs)).await;
|
||||
}
|
||||
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
|
||||
let usage = resp.usage();
|
||||
total_input_tokens = usage.input_tokens as i64;
|
||||
total_output_tokens = usage.output_tokens as i64;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(AgentError::OpenAi(e.to_string()));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
||||
record_ai_session(
|
||||
&request.cache, &request.db, request.project.id,
|
||||
session_id, request.room.id, request.model.id,
|
||||
version_id.unwrap_or_default(),
|
||||
total_input_tokens, total_output_tokens, elapsed_ms,
|
||||
).await;
|
||||
|
||||
Ok((final_content, total_input_tokens, total_output_tokens))
|
||||
}
|
||||
|
||||
// ── ReWOO (Plan → Execute → Synthesize) ───────────────────────────────
|
||||
|
||||
/// Run a ReWOO reasoning cycle: model plans tool calls, they are executed,
|
||||
/// then the model synthesises the final answer.
|
||||
pub async fn process_rewoo<C, Fut>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
|
||||
where
|
||||
C: FnMut(crate::modes::rewoo::ReWooStep) -> Fut + Send,
|
||||
Fut: std::future::Future<Output = ()> + Send,
|
||||
{
|
||||
let client_config = AiClientConfig::new(
|
||||
self.ai_api_key.clone().unwrap_or_default(),
|
||||
)
|
||||
.with_base_url(
|
||||
self.ai_base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "https://api.openai.com".into()),
|
||||
);
|
||||
let rig_client = client_config.build_rig_client();
|
||||
|
||||
let Some(registry) = &self.tool_registry else {
|
||||
return Err(AgentError::Internal("no tool registry registered".into()));
|
||||
};
|
||||
|
||||
let session_id = Uuid::now_v7();
|
||||
let session_start = std::time::Instant::now();
|
||||
let version_id = room_ai::Entity::find()
|
||||
.filter(room_ai::Column::Room.eq(request.room.id))
|
||||
.filter(room_ai::Column::Model.eq(request.model.id))
|
||||
.one(&request.db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.and_then(|r| r.version);
|
||||
|
||||
let mut total_input_tokens: i64 = 0;
|
||||
let mut total_output_tokens: i64 = 0;
|
||||
|
||||
let mut messages = self.build_messages(request).await?;
|
||||
messages.insert(0, crate::client::types::ChatRequestMessage::system(
|
||||
crate::modes::rewoo::REWOO_SYSTEM_PROMPT.to_string(),
|
||||
));
|
||||
let model = rig_client.completion_model(&request.model.name);
|
||||
|
||||
let plan_tools = {
|
||||
let db = request.db.clone();
|
||||
let cache = request.cache.clone();
|
||||
let cfg = request.config.clone();
|
||||
let room_id = request.room.id;
|
||||
let sender_uid = request.sender.uid;
|
||||
let project_id = request.project.id;
|
||||
|
||||
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
|
||||
for def in registry.definitions() {
|
||||
let name = def.name.clone();
|
||||
if let Some(handler) = registry.get(&name) {
|
||||
let adapter = crate::tool::RigToolAdapter::new(
|
||||
handler.clone(),
|
||||
def.clone(),
|
||||
db.clone(),
|
||||
cache.clone(),
|
||||
cfg.clone(),
|
||||
room_id,
|
||||
Some(sender_uid),
|
||||
project_id,
|
||||
);
|
||||
tools.push(Box::new(crate::tool::RecordingTool::new(
|
||||
Box::new(adapter),
|
||||
db.clone(),
|
||||
session_id,
|
||||
sender_uid,
|
||||
)));
|
||||
}
|
||||
}
|
||||
tools
|
||||
};
|
||||
|
||||
let plan_agent = rig::agent::AgentBuilder::new(model)
|
||||
.preamble(crate::modes::rewoo::REWOO_SYSTEM_PROMPT)
|
||||
.tools(plan_tools)
|
||||
.default_max_turns(1)
|
||||
.build();
|
||||
|
||||
let plan_response = plan_agent
|
||||
.prompt(&request.input)
|
||||
.extended_details()
|
||||
.await
|
||||
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
||||
|
||||
total_input_tokens += plan_response.total_usage.input_tokens as i64;
|
||||
total_output_tokens += plan_response.total_usage.output_tokens as i64;
|
||||
|
||||
let plan = crate::modes::rewoo::extract_plan(&plan_response.output)
|
||||
.unwrap_or_default();
|
||||
|
||||
if plan.calls.is_empty() {
|
||||
on_chunk(crate::modes::rewoo::ReWooStep::Synthesis(plan_response.output.clone())).await;
|
||||
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
||||
record_ai_session(
|
||||
&request.cache, &request.db, request.project.id,
|
||||
session_id, request.room.id, request.model.id,
|
||||
version_id.unwrap_or_default(),
|
||||
total_input_tokens, total_output_tokens, elapsed_ms,
|
||||
).await;
|
||||
return Ok((plan_response.output, total_input_tokens, total_output_tokens));
|
||||
}
|
||||
|
||||
on_chunk(crate::modes::rewoo::ReWooStep::Plan {
|
||||
calls: plan.calls.clone(),
|
||||
raw: plan.raw_text,
|
||||
}).await;
|
||||
|
||||
// ── Phase 2: Execute all tool calls in parallel ───────────────────
|
||||
let mut tool_results: Vec<(String, String)> = Vec::new();
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for call in &plan.calls {
|
||||
let ctx = crate::tool::ToolContext::new(
|
||||
request.db.clone(),
|
||||
request.cache.clone(),
|
||||
request.config.clone(),
|
||||
request.room.id,
|
||||
Some(request.sender.uid),
|
||||
)
|
||||
.with_project(request.project.id);
|
||||
if let Some(ref es) = self.embed_service {
|
||||
// ctx = ctx.with_embed_service(es.clone()); -- not clone-able via pattern, skip
|
||||
let _ = es;
|
||||
}
|
||||
|
||||
let call_id = call.step.to_string();
|
||||
let tool_name = call.tool.clone();
|
||||
let args = call.args.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let executor = crate::tool::ToolExecutor::new();
|
||||
let agent_call = crate::tool::ToolCall {
|
||||
id: call_id,
|
||||
name: tool_name.clone(),
|
||||
arguments: args.to_string(),
|
||||
};
|
||||
let mut local_ctx = ctx_clone;
|
||||
let result = executor.execute_batch(vec![agent_call], &mut local_ctx).await;
|
||||
match result {
|
||||
Ok(results) => {
|
||||
for r in &results {
|
||||
match &r.result {
|
||||
crate::tool::ToolResult::Ok(v) => {
|
||||
return (tool_name, v.to_string());
|
||||
}
|
||||
crate::tool::ToolResult::Error(e) => {
|
||||
return (tool_name, format!("[Error: {}]", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
(tool_name, "[No result]".to_string())
|
||||
}
|
||||
Err(e) => (tool_name, format!("[Execution error: {}]", e)),
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok((name, result)) => {
|
||||
on_chunk(crate::modes::rewoo::ReWooStep::Execution {
|
||||
tool_name: name.clone(),
|
||||
result: result.clone(),
|
||||
}).await;
|
||||
tool_results.push((name, result));
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = format!("[Task panicked: {}]", e);
|
||||
on_chunk(crate::modes::rewoo::ReWooStep::Execution {
|
||||
tool_name: "unknown".into(),
|
||||
result: msg.clone(),
|
||||
}).await;
|
||||
tool_results.push(("unknown".into(), msg));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Phase 3: Synthesize ───────────────────────────────────────────
|
||||
let mut synth_messages = self.build_messages(request).await?;
|
||||
synth_messages.insert(0, crate::client::types::ChatRequestMessage::system(
|
||||
crate::modes::rewoo::REWOO_SYSTEM_PROMPT.to_string(),
|
||||
));
|
||||
|
||||
let results_summary: String = tool_results
|
||||
.iter()
|
||||
.map(|(name, res)| format!("- {}:\n{}", name, res))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
synth_messages.push(crate::client::types::ChatRequestMessage::system(format!(
|
||||
"## Tool Execution Results\n\nThe following tool calls were executed:\n\n{}\n\nNow synthesize your final answer based on these results.",
|
||||
results_summary
|
||||
)));
|
||||
synth_messages.push(crate::client::types::ChatRequestMessage::user(&request.input));
|
||||
|
||||
let preamble = synth_messages
|
||||
.iter()
|
||||
.find(|m| m.role == "system")
|
||||
.and_then(|m| m.content.as_deref())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let non_system: Vec<_> = synth_messages
|
||||
.iter()
|
||||
.filter(|m| m.role != "system")
|
||||
.map(|m| crate::client::to_rig_message(m))
|
||||
.collect();
|
||||
|
||||
let synth_model = rig_client.completion_model(&request.model.name);
|
||||
let synth_stream = synth_model
|
||||
.completion_request("")
|
||||
.preamble(preamble)
|
||||
.messages(non_system)
|
||||
.temperature(request.temperature as f64)
|
||||
.max_tokens(request.max_tokens as u64)
|
||||
.stream()
|
||||
.await
|
||||
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
||||
|
||||
use rig::streaming::StreamedAssistantContent;
|
||||
tokio::pin!(synth_stream);
|
||||
|
||||
let mut synthesis = String::new();
|
||||
while let Some(item) = synth_stream.next().await {
|
||||
match item {
|
||||
Ok(StreamedAssistantContent::Text(text)) => {
|
||||
let t = text.text;
|
||||
on_chunk(crate::modes::rewoo::ReWooStep::Synthesis(t.clone())).await;
|
||||
synthesis.push_str(&t);
|
||||
}
|
||||
Ok(StreamedAssistantContent::Final(response)) => {
|
||||
if let Some(usage) = response.token_usage() {
|
||||
total_input_tokens += usage.input_tokens as i64;
|
||||
total_output_tokens += usage.output_tokens as i64;
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
||||
record_ai_session(
|
||||
&request.cache, &request.db, request.project.id,
|
||||
session_id, request.room.id, request.model.id,
|
||||
version_id.unwrap_or_default(),
|
||||
total_input_tokens, total_output_tokens, elapsed_ms,
|
||||
).await;
|
||||
|
||||
Ok((synthesis, total_input_tokens, total_output_tokens))
|
||||
}
|
||||
|
||||
// ── Reflexion (Generate → Critique → Revise) ──────────────────────────
|
||||
|
||||
/// Run a Reflexion reasoning cycle: generate → critique → revise (up to 3 rounds).
|
||||
pub async fn process_reflexion<C, Fut>(
|
||||
&self,
|
||||
request: &AiRequest,
|
||||
mut on_chunk: C,
|
||||
max_cycles: usize,
|
||||
) -> Result<(String, i64, i64)>
|
||||
where
|
||||
C: FnMut(crate::modes::reflexion::ReflexionStep) -> Fut + Send,
|
||||
Fut: std::future::Future<Output = ()> + Send,
|
||||
{
|
||||
let client_config = AiClientConfig::new(
|
||||
self.ai_api_key.clone().unwrap_or_default(),
|
||||
)
|
||||
.with_base_url(
|
||||
self.ai_base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "https://api.openai.com".into()),
|
||||
);
|
||||
let rig_client = client_config.build_rig_client();
|
||||
let Some(registry) = &self.tool_registry else {
|
||||
return Err(AgentError::Internal("no tool registry registered".into()));
|
||||
};
|
||||
|
||||
let session_id = Uuid::now_v7();
|
||||
let session_start = std::time::Instant::now();
|
||||
let version_id = room_ai::Entity::find()
|
||||
.filter(room_ai::Column::Room.eq(request.room.id))
|
||||
.filter(room_ai::Column::Model.eq(request.model.id))
|
||||
.one(&request.db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.and_then(|r| r.version);
|
||||
|
||||
let max_cycles = max_cycles.min(3);
|
||||
|
||||
let mut total_input_tokens: i64 = 0;
|
||||
let mut total_output_tokens: i64 = 0;
|
||||
let mut best_answer = String::new();
|
||||
|
||||
for cycle in 0..max_cycles {
|
||||
let mut messages = self.build_messages(request).await?;
|
||||
messages.insert(0, crate::client::types::ChatRequestMessage::system(
|
||||
crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT.to_string(),
|
||||
));
|
||||
|
||||
if cycle > 0 {
|
||||
messages.push(crate::client::types::ChatRequestMessage::system(format!(
|
||||
"This is cycle {} of the reflexion process. Your previous answer was:\n\n{}\n\nPlease critique and improve upon it.",
|
||||
cycle + 1,
|
||||
best_answer
|
||||
)));
|
||||
}
|
||||
|
||||
// Build tools for this cycle (not cloneable, so rebuild each iteration)
|
||||
let cycle_tools = build_rig_tools(
|
||||
registry, &request.db, &request.cache, &request.config,
|
||||
request.room.id, request.sender.uid, request.project.id, session_id,
|
||||
);
|
||||
|
||||
// ── Generate ──────────────────────────────────────────────
|
||||
let model = rig_client.completion_model(&request.model.name);
|
||||
let agent = rig::agent::AgentBuilder::new(model)
|
||||
.preamble(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT)
|
||||
.tools(cycle_tools)
|
||||
.default_max_turns(request.max_tool_depth)
|
||||
.build();
|
||||
|
||||
let stream = agent
|
||||
.stream_prompt(&request.input)
|
||||
.with_history(Vec::new())
|
||||
.multi_turn(request.max_tool_depth)
|
||||
.await;
|
||||
|
||||
tokio::pin!(stream);
|
||||
let mut generated = String::new();
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
match item {
|
||||
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
||||
rig::streaming::StreamedAssistantContent::Text(text),
|
||||
)) => {
|
||||
generated.push_str(&text.text);
|
||||
}
|
||||
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
|
||||
let usage = resp.usage();
|
||||
total_input_tokens += usage.input_tokens as i64;
|
||||
total_output_tokens += usage.output_tokens as i64;
|
||||
}
|
||||
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
best_answer = generated.clone();
|
||||
on_chunk(crate::modes::reflexion::ReflexionStep::Generate(generated.clone())).await;
|
||||
|
||||
// If only 1 cycle, emit final and exit
|
||||
if max_cycles == 1 || cycle + 1 >= max_cycles {
|
||||
on_chunk(crate::modes::reflexion::ReflexionStep::Final(generated.clone())).await;
|
||||
break;
|
||||
}
|
||||
|
||||
// ── Self-critique ─────────────────────────────────────────
|
||||
let critique_messages = vec![
|
||||
crate::client::types::ChatRequestMessage::system(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT),
|
||||
crate::client::types::ChatRequestMessage::system(format!(
|
||||
"Your previous answer was:\n\n{}", generated
|
||||
)),
|
||||
crate::client::types::ChatRequestMessage::user(crate::modes::reflexion::REFLEXION_CRITIQUE_PROMPT),
|
||||
];
|
||||
|
||||
let critique_result = crate::client::call_with_params(
|
||||
&critique_messages,
|
||||
&request.model.name,
|
||||
&client_config,
|
||||
request.temperature as f32,
|
||||
request.max_tokens as u32,
|
||||
None,
|
||||
None,
|
||||
Some("none"),
|
||||
).await?;
|
||||
|
||||
total_input_tokens += critique_result.input_tokens;
|
||||
total_output_tokens += critique_result.output_tokens;
|
||||
let critique = critique_result.content;
|
||||
on_chunk(crate::modes::reflexion::ReflexionStep::Critique(critique.clone())).await;
|
||||
|
||||
// ── Revise ───────────────────────────────────────────────
|
||||
let revise_messages = vec![
|
||||
crate::client::types::ChatRequestMessage::user(format!(
|
||||
"Your previous answer:\n\n{}\n\nYour self-critique:\n\n{}",
|
||||
generated, critique
|
||||
)),
|
||||
crate::client::types::ChatRequestMessage::user(crate::modes::reflexion::REFLEXION_REVISE_PROMPT),
|
||||
];
|
||||
|
||||
let revise_model = rig_client.completion_model(&request.model.name);
|
||||
let revise_stream = revise_model
|
||||
.completion_request("")
|
||||
.preamble(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT.to_string())
|
||||
.messages(revise_messages.iter().map(|m| {
|
||||
crate::client::to_rig_message(m)
|
||||
}).collect::<Vec<_>>())
|
||||
.temperature(request.temperature as f64)
|
||||
.max_tokens(request.max_tokens as u64)
|
||||
.stream()
|
||||
.await
|
||||
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
||||
|
||||
tokio::pin!(revise_stream);
|
||||
let mut revised = String::new();
|
||||
|
||||
while let Some(item) = revise_stream.next().await {
|
||||
match item {
|
||||
Ok(rig::streaming::StreamedAssistantContent::Text(text)) => {
|
||||
revised.push_str(&text.text);
|
||||
}
|
||||
Ok(rig::streaming::StreamedAssistantContent::Final(response)) => {
|
||||
if let Some(usage) = response.token_usage() {
|
||||
total_input_tokens += usage.input_tokens as i64;
|
||||
total_output_tokens += usage.output_tokens as i64;
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
best_answer = revised.clone();
|
||||
on_chunk(crate::modes::reflexion::ReflexionStep::Revise(revised.clone())).await;
|
||||
|
||||
// If last cycle, emit final
|
||||
if cycle + 1 >= max_cycles {
|
||||
on_chunk(crate::modes::reflexion::ReflexionStep::Final(revised.clone())).await;
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
||||
record_ai_session(
|
||||
&request.cache, &request.db, request.project.id,
|
||||
session_id, request.room.id, request.model.id,
|
||||
version_id.unwrap_or_default(),
|
||||
total_input_tokens, total_output_tokens, elapsed_ms,
|
||||
).await;
|
||||
|
||||
Ok((best_answer, total_input_tokens, total_output_tokens))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_rig_tools(
|
||||
registry: &crate::tool::ToolRegistry,
|
||||
db: &db::database::AppDatabase,
|
||||
cache: &db::cache::AppCache,
|
||||
cfg: &config::AppConfig,
|
||||
room_id: uuid::Uuid,
|
||||
sender_uid: uuid::Uuid,
|
||||
project_id: uuid::Uuid,
|
||||
session_id: uuid::Uuid,
|
||||
) -> Vec<Box<dyn rig::tool::ToolDyn + 'static>> {
|
||||
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
|
||||
for def in registry.definitions() {
|
||||
let name = def.name.clone();
|
||||
if let Some(handler) = registry.get(&name) {
|
||||
let adapter = crate::tool::RigToolAdapter::new(
|
||||
handler.clone(), def.clone(),
|
||||
db.clone(), cache.clone(), cfg.clone(),
|
||||
room_id, Some(sender_uid), project_id,
|
||||
);
|
||||
tools.push(Box::new(crate::tool::RecordingTool::new(
|
||||
Box::new(adapter), db.clone(), session_id, sender_uid,
|
||||
)));
|
||||
}
|
||||
}
|
||||
tools
|
||||
}
|
||||
|
||||
/// Extract text from rig's ToolResultContent, ignoring images.
|
||||
|
||||
217
libs/agent/chat/state.rs
Normal file
217
libs/agent/chat/state.rs
Normal file
@ -0,0 +1,217 @@
|
||||
//! Agent state machine — tracks lifecycle of a single AI agent invocation.
|
||||
//!
|
||||
//! States: Idle → Thinking → ToolCall → Thinking → ... → Answering | Error
|
||||
//! The Thinking ↔ ToolCall cycle repeats until max tool depth or final answer.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Current phase of an agent's execution lifecycle.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum AgentState {
|
||||
/// Agent is idle, waiting for input
|
||||
Idle,
|
||||
/// Agent is reasoning/thinking (may produce thinking chunks)
|
||||
Thinking {
|
||||
started_at: DateTime<Utc>,
|
||||
tool_depth: u32,
|
||||
},
|
||||
/// Agent is executing a tool call
|
||||
ToolCall {
|
||||
tool_name: String,
|
||||
started_at: DateTime<Utc>,
|
||||
},
|
||||
/// Agent is returning the final answer
|
||||
Answering {
|
||||
/// Accumulated answer content so far
|
||||
content_chars: u64,
|
||||
started_at: DateTime<Utc>,
|
||||
},
|
||||
/// Agent encountered a non-recoverable error
|
||||
Error {
|
||||
message: String,
|
||||
tool_depth: u32,
|
||||
},
|
||||
}
|
||||
|
||||
impl AgentState {
|
||||
pub fn is_terminal(&self) -> bool {
|
||||
matches!(self, AgentState::Answering { .. } | AgentState::Error { .. })
|
||||
}
|
||||
|
||||
pub fn is_idle(&self) -> bool {
|
||||
matches!(self, AgentState::Idle)
|
||||
}
|
||||
|
||||
pub fn current_phase(&self) -> &'static str {
|
||||
match self {
|
||||
AgentState::Idle => "idle",
|
||||
AgentState::Thinking { .. } => "thinking",
|
||||
AgentState::ToolCall { .. } => "tool_call",
|
||||
AgentState::Answering { .. } => "answering",
|
||||
AgentState::Error { .. } => "error",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// State machine for agent lifecycle transitions.
|
||||
pub struct AgentRuntime {
|
||||
state: AgentState,
|
||||
max_tool_depth: u32,
|
||||
current_depth: u32,
|
||||
}
|
||||
|
||||
impl AgentRuntime {
|
||||
pub fn new(max_tool_depth: u32) -> Self {
|
||||
Self {
|
||||
state: AgentState::Idle,
|
||||
max_tool_depth,
|
||||
current_depth: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn state(&self) -> &AgentState {
|
||||
&self.state
|
||||
}
|
||||
|
||||
/// Transition from Idle → Thinking
|
||||
pub fn start_thinking(&mut self) {
|
||||
debug_assert!(self.state.is_idle(), "must be Idle to start thinking");
|
||||
self.current_depth = 0;
|
||||
self.state = AgentState::Thinking {
|
||||
started_at: Utc::now(),
|
||||
tool_depth: 0,
|
||||
};
|
||||
}
|
||||
|
||||
/// Transition from Thinking → ToolCall (increments tool depth)
|
||||
pub fn start_tool_call(&mut self, tool_name: String) -> Result<(), &'static str> {
|
||||
if !matches!(self.state, AgentState::Thinking { .. }) {
|
||||
return Err("must be Thinking to start tool call");
|
||||
}
|
||||
if self.current_depth >= self.max_tool_depth {
|
||||
return Err("max tool depth reached");
|
||||
}
|
||||
self.state = AgentState::ToolCall {
|
||||
tool_name,
|
||||
started_at: Utc::now(),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Transition from ToolCall → Thinking (back after tool result)
|
||||
pub fn complete_tool_call(&mut self) -> Result<(), &'static str> {
|
||||
if !matches!(self.state, AgentState::ToolCall { .. }) {
|
||||
return Err("must be ToolCall to complete");
|
||||
}
|
||||
self.current_depth += 1;
|
||||
self.state = AgentState::Thinking {
|
||||
started_at: Utc::now(),
|
||||
tool_depth: self.current_depth,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Transition to Answering (terminal)
|
||||
pub fn start_answer(&mut self) {
|
||||
self.state = AgentState::Answering {
|
||||
content_chars: 0,
|
||||
started_at: Utc::now(),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn append_answer(&mut self, content: &str) {
|
||||
if let AgentState::Answering { content_chars, .. } = &mut self.state {
|
||||
*content_chars += content.len() as u64;
|
||||
}
|
||||
}
|
||||
|
||||
/// Transition to Error (terminal)
|
||||
pub fn fail(&mut self, message: String) {
|
||||
self.state = AgentState::Error {
|
||||
message,
|
||||
tool_depth: self.current_depth,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_starts_idle() {
|
||||
let rt = AgentRuntime::new(10);
|
||||
assert!(rt.state().is_idle());
|
||||
assert_eq!(rt.state().current_phase(), "idle");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_idle_to_thinking() {
|
||||
let mut rt = AgentRuntime::new(10);
|
||||
rt.start_thinking();
|
||||
assert_eq!(rt.state().current_phase(), "thinking");
|
||||
assert!(!rt.state().is_terminal());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thinking_to_tool_call_and_back() {
|
||||
let mut rt = AgentRuntime::new(10);
|
||||
rt.start_thinking();
|
||||
rt.start_tool_call("search".into()).unwrap();
|
||||
assert_eq!(rt.state().current_phase(), "tool_call");
|
||||
rt.complete_tool_call().unwrap();
|
||||
assert_eq!(rt.state().current_phase(), "thinking");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thinking_to_answer() {
|
||||
let mut rt = AgentRuntime::new(10);
|
||||
rt.start_thinking();
|
||||
rt.start_answer();
|
||||
assert_eq!(rt.state().current_phase(), "answering");
|
||||
assert!(rt.state().is_terminal());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append_answer_tracks_chars() {
|
||||
let mut rt = AgentRuntime::new(10);
|
||||
rt.start_thinking();
|
||||
rt.start_answer();
|
||||
rt.append_answer("hello");
|
||||
if let AgentState::Answering { content_chars, .. } = rt.state() {
|
||||
assert_eq!(*content_chars, 5);
|
||||
} else {
|
||||
panic!("expected Answering state");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_is_terminal() {
|
||||
let mut rt = AgentRuntime::new(10);
|
||||
rt.start_thinking();
|
||||
rt.fail("something broke".into());
|
||||
assert_eq!(rt.state().current_phase(), "error");
|
||||
assert!(rt.state().is_terminal());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transition_from_wrong_state() {
|
||||
let mut rt = AgentRuntime::new(10);
|
||||
// Can't start tool call from Idle
|
||||
assert!(rt.start_tool_call("tool".into()).is_err());
|
||||
// Can't complete tool call from Idle
|
||||
assert!(rt.complete_tool_call().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_depth_rejected() {
|
||||
let mut rt = AgentRuntime::new(2);
|
||||
rt.start_thinking();
|
||||
rt.start_tool_call("tool1".into()).unwrap();
|
||||
rt.complete_tool_call().unwrap();
|
||||
rt.start_tool_call("tool2".into()).unwrap();
|
||||
rt.complete_tool_call().unwrap();
|
||||
assert!(rt.start_tool_call("tool3".into()).is_err());
|
||||
}
|
||||
}
|
||||
@ -155,7 +155,7 @@ fn ai_metrics() -> &'static AiMetrics {
|
||||
|
||||
// ── Type conversions ─────────────────────────────────────────────────────────
|
||||
|
||||
fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage {
|
||||
pub(crate) fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage {
|
||||
match msg.role.as_str() {
|
||||
"system" => {
|
||||
// System messages are handled via preamble(), but we still
|
||||
|
||||
@ -58,13 +58,24 @@ impl EmbedClient {
|
||||
.await
|
||||
.map_err(|e| crate::AgentError::OpenAi(format!("embedding batch failed: {}", e)))?;
|
||||
|
||||
tracing::debug!(input_count = texts.len(), returned_count = embeddings.len(), "embed_batch: API returned");
|
||||
|
||||
let mut result = vec![Vec::new(); texts.len()];
|
||||
for embedding in embeddings {
|
||||
// Find the original index by matching the document text
|
||||
if let Some(idx) = texts.iter().position(|t| t == &embedding.document) {
|
||||
result[idx] = embedding.vec.iter().map(|v| *v as f32).collect();
|
||||
} else {
|
||||
tracing::warn!(doc = %embedding.document, "embed_batch: document mismatch — text not found in input list");
|
||||
}
|
||||
}
|
||||
|
||||
// Check for empty results
|
||||
let empty_count = result.iter().filter(|v| v.is_empty()).count();
|
||||
if empty_count > 0 {
|
||||
tracing::warn!(empty_count = empty_count, total = texts.len(), "embed_batch: some embeddings returned empty vectors");
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
|
||||
@ -89,6 +89,16 @@ impl QdrantClient {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Reject empty vectors — they cause Qdrant to reject the entire batch
|
||||
let empty_vectors = points.iter().filter(|p| p.vector.is_empty()).count();
|
||||
if empty_vectors > 0 {
|
||||
tracing::error!(empty_count = empty_vectors, total = points.len(), "upsert_points: REJECTING points with empty vectors");
|
||||
return Err(crate::AgentError::Qdrant(format!(
|
||||
"refusing to upsert {} points with empty vectors",
|
||||
empty_vectors
|
||||
)));
|
||||
}
|
||||
|
||||
let collection_name = Self::collection_name(&points[0].payload.entity_type);
|
||||
self.upsert_to_collection(&collection_name, points).await
|
||||
}
|
||||
|
||||
@ -118,7 +118,9 @@ impl EmbedService {
|
||||
_ => title.to_string(),
|
||||
};
|
||||
|
||||
tracing::debug!(issue_id = %id, text_len = text.len(), "embed_issue: calling embedding API");
|
||||
let vector = self.client.embed_text(&text, &self.model_name).await?;
|
||||
tracing::debug!(issue_id = %id, vec_dim = vector.len(), "embed_issue: embedding done");
|
||||
|
||||
let point = EmbedVector {
|
||||
id: id.to_string(),
|
||||
@ -131,7 +133,9 @@ impl EmbedService {
|
||||
},
|
||||
};
|
||||
|
||||
self.client.upsert(vec![point]).await
|
||||
self.client.upsert(vec![point]).await?;
|
||||
tracing::info!(issue_id = %id, "embed_issue: upsert complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn embed_repo(
|
||||
@ -145,7 +149,9 @@ impl EmbedService {
|
||||
_ => name.to_string(),
|
||||
};
|
||||
|
||||
tracing::debug!(repo_id = %id, text_len = text.len(), "embed_repo: calling embedding API");
|
||||
let vector = self.client.embed_text(&text, &self.model_name).await?;
|
||||
tracing::debug!(repo_id = %id, vec_dim = vector.len(), "embed_repo: embedding done");
|
||||
|
||||
let point = EmbedVector {
|
||||
id: id.to_string(),
|
||||
@ -158,7 +164,9 @@ impl EmbedService {
|
||||
},
|
||||
};
|
||||
|
||||
self.client.upsert(vec![point]).await
|
||||
self.client.upsert(vec![point]).await?;
|
||||
tracing::info!(repo_id = %id, "embed_repo: upsert complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn embed_issues<T: Embeddable + Send + Sync>(
|
||||
@ -170,7 +178,9 @@ impl EmbedService {
|
||||
}
|
||||
|
||||
let texts: Vec<String> = items.iter().map(|i| i.to_text()).collect();
|
||||
tracing::debug!(count = texts.len(), "embed_issues: calling embed_batch");
|
||||
let embeddings = self.client.embed_batch(&texts, &self.model_name).await?;
|
||||
tracing::debug!(count = embeddings.len(), "embed_issues: batch done");
|
||||
|
||||
let points: Vec<EmbedVector> = items
|
||||
.into_iter()
|
||||
@ -187,7 +197,10 @@ impl EmbedService {
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.client.upsert(points).await
|
||||
let count = points.len();
|
||||
self.client.upsert(points).await?;
|
||||
tracing::info!(count = count, "embed_issues: upsert complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn search_issues(
|
||||
@ -264,15 +277,20 @@ impl EmbedService {
|
||||
let desc = description.unwrap_or_default();
|
||||
let id = skill_id.to_string();
|
||||
|
||||
tracing::debug!(skill_id = %skill_id, name = %name, content_len = content.len(), "embed_skill: starting");
|
||||
|
||||
// Auto-chunk long content
|
||||
let texts = chunk_text(content);
|
||||
tracing::debug!(skill_id = %skill_id, chunks = texts.len(), "embed_skill: chunked");
|
||||
|
||||
if texts.len() == 1 {
|
||||
self.client
|
||||
.embed_skill(&id, name, desc, content, project_uuid, &self.model_name)
|
||||
.await
|
||||
.await?;
|
||||
} else {
|
||||
// Multi-chunk: embed each chunk with chunk_index metadata
|
||||
let full_texts: Vec<String> = texts.iter().map(|t| format!("{}: {} {}", name, desc, t)).collect();
|
||||
tracing::debug!(skill_id = %skill_id, "embed_skill: calling embed_batch");
|
||||
let embeddings = self.client.embed_batch(&full_texts, &self.model_name).await?;
|
||||
|
||||
let points: Vec<EmbedVector> = embeddings.into_iter().enumerate().map(|(i, vector)| {
|
||||
@ -293,8 +311,10 @@ impl EmbedService {
|
||||
}
|
||||
}).collect();
|
||||
|
||||
self.client.upsert(points).await
|
||||
self.client.upsert(points).await?;
|
||||
}
|
||||
tracing::info!(skill_id = %skill_id, chunks = texts.len(), "embed_skill: complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Embed an issue with auto-chunking for long content.
|
||||
|
||||
@ -6,6 +6,7 @@ pub mod compact;
|
||||
pub mod embed;
|
||||
pub mod error;
|
||||
pub mod model;
|
||||
pub mod modes;
|
||||
pub mod perception;
|
||||
pub mod react;
|
||||
pub mod skills;
|
||||
@ -33,6 +34,10 @@ pub use embed::{
|
||||
EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult, TagEmbedInput, new_embed_client,
|
||||
};
|
||||
pub use error::{AgentError, Result};
|
||||
pub use modes::cot::{CotStep, COT_SYSTEM_PROMPT};
|
||||
pub use modes::reflexion::{ReflexionCycle, ReflexionStep, REFLEXION_CRITIQUE_PROMPT, REFLEXION_REVISE_PROMPT, REFLEXION_SYSTEM_PROMPT};
|
||||
pub use modes::rewoo::{ReWooPlan, ReWooStep, ReWooToolCall, REWOO_SYSTEM_PROMPT, extract_plan};
|
||||
pub use modes::ModeStep;
|
||||
pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT};
|
||||
pub use tool::{
|
||||
ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,
|
||||
|
||||
3
libs/agent/modes/cot/mod.rs
Normal file
3
libs/agent/modes/cot/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod types;
|
||||
|
||||
pub use types::{CotStep, COT_SYSTEM_PROMPT};
|
||||
58
libs/agent/modes/cot/types.rs
Normal file
58
libs/agent/modes/cot/types.rs
Normal file
@ -0,0 +1,58 @@
|
||||
use crate::modes::ModeStep;
|
||||
|
||||
pub const COT_SYSTEM_PROMPT: &str = r#"You are an AI assistant embedded in a development collaboration platform.
|
||||
|
||||
## Core Rule: Think Step by Step
|
||||
|
||||
When answering questions, you MUST reason step by step before providing a final answer. Break down the problem into clear intermediate steps. Show your reasoning process explicitly.
|
||||
|
||||
## Tool Use
|
||||
|
||||
- Use available tools to gather information when needed — verify claims against actual data.
|
||||
- After receiving tool results, incorporate them into your reasoning chain.
|
||||
- Do NOT guess when tools can provide concrete answers.
|
||||
|
||||
## Output Format
|
||||
|
||||
1. Reason through the problem step by step, using tools as needed.
|
||||
2. Then provide a clear, actionable final answer.
|
||||
3. Label your final answer with "**Answer:**" or similar.
|
||||
"#;
|
||||
|
||||
/// Events emitted during a CoT reasoning cycle (step-by-step with optional tools).
|
||||
pub enum CotStep {
|
||||
/// Intermediate reasoning thought.
|
||||
Thought(String),
|
||||
/// A tool call requested by the model.
|
||||
Action { name: String, args: serde_json::Value },
|
||||
/// Result returned by a tool execution.
|
||||
Observation(String),
|
||||
/// Final answer after reasoning.
|
||||
Answer(String),
|
||||
}
|
||||
|
||||
impl ModeStep for CotStep {
|
||||
fn chunk_type(&self) -> &'static str {
|
||||
match self {
|
||||
CotStep::Thought(_) => "thinking",
|
||||
CotStep::Action { .. } => "tool_call",
|
||||
CotStep::Observation(_) => "tool_result",
|
||||
CotStep::Answer(_) => "answer",
|
||||
}
|
||||
}
|
||||
|
||||
fn content(&self) -> String {
|
||||
match self {
|
||||
CotStep::Thought(t) => t.clone(),
|
||||
CotStep::Action { name, args } => {
|
||||
serde_json::json!({"name": name, "arguments": args}).to_string()
|
||||
}
|
||||
CotStep::Observation(o) => o.clone(),
|
||||
CotStep::Answer(a) => a.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_final(&self) -> bool {
|
||||
matches!(self, CotStep::Answer(_))
|
||||
}
|
||||
}
|
||||
13
libs/agent/modes/mod.rs
Normal file
13
libs/agent/modes/mod.rs
Normal file
@ -0,0 +1,13 @@
|
||||
pub mod cot;
|
||||
pub mod reflexion;
|
||||
pub mod rewoo;
|
||||
|
||||
pub use cot::{CotStep, COT_SYSTEM_PROMPT};
|
||||
pub use reflexion::{ReflexionCycle, ReflexionStep, REFLEXION_CRITIQUE_PROMPT, REFLEXION_REVISE_PROMPT, REFLEXION_SYSTEM_PROMPT};
|
||||
pub use rewoo::{ReWooPlan, ReWooStep, ReWooToolCall, REWOO_SYSTEM_PROMPT, extract_plan};
|
||||
|
||||
pub trait ModeStep: Send + 'static {
|
||||
fn chunk_type(&self) -> &'static str;
|
||||
fn content(&self) -> String;
|
||||
fn is_final(&self) -> bool;
|
||||
}
|
||||
3
libs/agent/modes/reflexion/mod.rs
Normal file
3
libs/agent/modes/reflexion/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod types;
|
||||
|
||||
pub use types::{ReflexionCycle, ReflexionStep, REFLEXION_CRITIQUE_PROMPT, REFLEXION_REVISE_PROMPT, REFLEXION_SYSTEM_PROMPT};
|
||||
70
libs/agent/modes/reflexion/types.rs
Normal file
70
libs/agent/modes/reflexion/types.rs
Normal file
@ -0,0 +1,70 @@
|
||||
use crate::modes::ModeStep;
|
||||
|
||||
pub const REFLEXION_SYSTEM_PROMPT: &str = r#"You are an AI assistant embedded in a development collaboration platform.
|
||||
|
||||
## Reflexion Protocol: Generate → Critique → Improve
|
||||
|
||||
Your responses follow a self-reflection loop to maximize answer quality.
|
||||
|
||||
### Round 1 — Generate
|
||||
Produce your best initial answer to the user's question. Use available tools if needed.
|
||||
|
||||
After your answer, the system will prompt you for a **self-critique**. When you see the critique prompt, evaluate your own answer honestly:
|
||||
- Did you verify all claims with available data?
|
||||
- Are there gaps or assumptions in your reasoning?
|
||||
- Could any part be clearer or more precise?
|
||||
|
||||
### Round 2 — Revise
|
||||
Based on your self-critique, produce a revised, improved answer. The system may repeat the critique-revise cycle up to 3 times.
|
||||
|
||||
### Guidelines
|
||||
- Be honest in self-assessment.
|
||||
- Each revision should be measurably better than the previous one.
|
||||
- If your first answer was already correct and complete, say so in your critique and confirm it in the revision.
|
||||
"#;
|
||||
|
||||
pub const REFLEXION_CRITIQUE_PROMPT: &str =
|
||||
"Now critique your own answer above. Identify any gaps, inaccuracies, \
|
||||
or areas for improvement. Be specific and honest.";
|
||||
|
||||
pub const REFLEXION_REVISE_PROMPT: &str =
|
||||
"Based on your self-critique, produce a revised and improved answer.";
|
||||
|
||||
pub enum ReflexionStep {
|
||||
Generate(String),
|
||||
Critique(String),
|
||||
Revise(String),
|
||||
Final(String),
|
||||
}
|
||||
|
||||
impl ModeStep for ReflexionStep {
|
||||
fn chunk_type(&self) -> &'static str {
|
||||
match self {
|
||||
ReflexionStep::Generate(_) => "thinking",
|
||||
ReflexionStep::Critique(_) => "tool_call",
|
||||
ReflexionStep::Revise(_) => "thinking",
|
||||
ReflexionStep::Final(_) => "answer",
|
||||
}
|
||||
}
|
||||
|
||||
fn content(&self) -> String {
|
||||
match self {
|
||||
ReflexionStep::Generate(s) => s.clone(),
|
||||
ReflexionStep::Critique(s) => s.clone(),
|
||||
ReflexionStep::Revise(s) => s.clone(),
|
||||
ReflexionStep::Final(s) => s.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_final(&self) -> bool {
|
||||
matches!(self, ReflexionStep::Final(_))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ReflexionCycle {
|
||||
pub round: usize,
|
||||
pub generated: String,
|
||||
pub critique: String,
|
||||
pub revised: String,
|
||||
}
|
||||
3
libs/agent/modes/rewoo/mod.rs
Normal file
3
libs/agent/modes/rewoo/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod types;
|
||||
|
||||
pub use types::{ReWooPlan, ReWooStep, ReWooToolCall, REWOO_SYSTEM_PROMPT, extract_plan};
|
||||
108
libs/agent/modes/rewoo/types.rs
Normal file
108
libs/agent/modes/rewoo/types.rs
Normal file
@ -0,0 +1,108 @@
|
||||
use crate::modes::ModeStep;
|
||||
|
||||
pub const REWOO_SYSTEM_PROMPT: &str = r#"You are an AI assistant embedded in a development collaboration platform.
|
||||
|
||||
## ReWOO Protocol: Reason Without Observation
|
||||
|
||||
Your responses MUST follow a strict **Plan → Execute → Synthesize** protocol.
|
||||
|
||||
### Phase 1 — Plan
|
||||
Analyze the user's question and produce a structured plan listing every tool call needed. Output your plan as a JSON array:
|
||||
|
||||
```json
|
||||
[
|
||||
{"step": 1, "tool": "tool_name", "args": {"param": "value"}},
|
||||
{"step": 2, "tool": "tool_name", "args": {"param": "value"}}
|
||||
]
|
||||
```
|
||||
|
||||
The plan must be wrapped in a `[PLAN]` ... `[/PLAN]` block. Only output the plan — no other text in this block.
|
||||
|
||||
### Phase 2 — Execute
|
||||
All tool calls in the plan will be executed automatically in parallel. You will receive the results as context.
|
||||
|
||||
### Phase 3 — Synthesize
|
||||
Based on the tool results, synthesize a comprehensive final answer. Cite specific data from the results.
|
||||
"#;
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ReWooToolCall {
|
||||
pub step: usize,
|
||||
pub tool: String,
|
||||
pub args: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ReWooPlan {
|
||||
pub calls: Vec<ReWooToolCall>,
|
||||
pub raw_text: String,
|
||||
}
|
||||
|
||||
pub enum ReWooStep {
|
||||
Plan { calls: Vec<ReWooToolCall>, raw: String },
|
||||
Execution { tool_name: String, result: String },
|
||||
Synthesis(String),
|
||||
}
|
||||
|
||||
impl ModeStep for ReWooStep {
|
||||
fn chunk_type(&self) -> &'static str {
|
||||
match self {
|
||||
ReWooStep::Plan { .. } => "tool_call",
|
||||
ReWooStep::Execution { .. } => "tool_result",
|
||||
ReWooStep::Synthesis(_) => "answer",
|
||||
}
|
||||
}
|
||||
|
||||
fn content(&self) -> String {
|
||||
match self {
|
||||
ReWooStep::Plan { calls, raw } => {
|
||||
if !raw.is_empty() {
|
||||
raw.clone()
|
||||
} else {
|
||||
serde_json::to_string(calls).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
ReWooStep::Execution { tool_name, result } => {
|
||||
format!("{}: {}", tool_name, result)
|
||||
}
|
||||
ReWooStep::Synthesis(s) => s.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_final(&self) -> bool {
|
||||
matches!(self, ReWooStep::Synthesis(_))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_plan(text: &str) -> Option<ReWooPlan> {
|
||||
if let (Some(start), Some(end)) = (text.find("[PLAN]"), text.find("[/PLAN]")) {
|
||||
let inner = &text[start + 6..end].trim();
|
||||
if inner.starts_with('[') || inner.starts_with('{') {
|
||||
if let Ok(calls) = serde_json::from_str::<Vec<ReWooToolCall>>(inner) {
|
||||
return Some(ReWooPlan { calls, raw_text: text.to_string() });
|
||||
}
|
||||
}
|
||||
return Some(ReWooPlan { calls: Vec::new(), raw_text: text.to_string() });
|
||||
}
|
||||
|
||||
if let Some(array_start) = text.find('[') {
|
||||
let candidate = &text[array_start..];
|
||||
if let Some(array_end) = find_matching_brace(candidate, '[', ']') {
|
||||
let inner = &candidate[..=array_end];
|
||||
if let Ok(calls) = serde_json::from_str::<Vec<ReWooToolCall>>(inner) {
|
||||
return Some(ReWooPlan { calls, raw_text: text.to_string() });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn find_matching_brace(s: &str, open: char, close: char) -> Option<usize> {
|
||||
let mut depth = 0i32;
|
||||
for (i, c) in s.char_indices() {
|
||||
if c == open { depth += 1; }
|
||||
else if c == close { depth -= 1; if depth == 0 { return Some(i); } }
|
||||
}
|
||||
None
|
||||
}
|
||||
@ -11,7 +11,7 @@ use service::AppService;
|
||||
use session::Session;
|
||||
|
||||
const MAX_TEXT_MESSAGE_LEN: usize = 64 * 1024;
|
||||
const MAX_MESSAGES_PER_SECOND: u32 = 10;
|
||||
const MAX_MESSAGES_PER_SECOND: u32 = 1000;
|
||||
|
||||
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
|
||||
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
@ -19,7 +19,7 @@ use super::ws_handler::WsRequestHandler;
|
||||
use super::ws_types::{WsAction, WsRequest, WsResponse, WsResponseData};
|
||||
|
||||
const MAX_TEXT_MESSAGE_LEN: usize = 64 * 1024;
|
||||
const MAX_MESSAGES_PER_SECOND: u32 = 10;
|
||||
const MAX_MESSAGES_PER_SECOND: u32 = 1000;
|
||||
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
|
||||
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
const MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(300);
|
||||
|
||||
@ -95,7 +95,7 @@ pub async fn list_boards_exec(
|
||||
.map(|card| {
|
||||
serde_json::json!({
|
||||
"id": card.id.to_string(),
|
||||
"issue_id": card.issue_id,
|
||||
"issue_id": card.issue_id.map(|id| id.to_string()),
|
||||
"title": card.title,
|
||||
"description": card.description,
|
||||
"position": card.position,
|
||||
@ -305,6 +305,10 @@ pub async fn create_board_card_exec(
|
||||
.get("assignee_id")
|
||||
.and_then(|v| Uuid::parse_str(v.as_str()?).ok());
|
||||
|
||||
let issue_id = args
|
||||
.get("issue_id")
|
||||
.and_then(|v| v.as_i64());
|
||||
|
||||
// Verify board belongs to project
|
||||
let board = ProjectBoard::find_by_id(board_id)
|
||||
.one(db)
|
||||
@ -356,7 +360,7 @@ pub async fn create_board_card_exec(
|
||||
let active = project_board_card::ActiveModel {
|
||||
id: Set(Uuid::now_v7()),
|
||||
column: Set(target_column.id),
|
||||
issue_id: Set(None),
|
||||
issue_id: Set(issue_id),
|
||||
project: Set(Some(project_id)),
|
||||
title: Set(title),
|
||||
description: Set(description),
|
||||
@ -381,6 +385,7 @@ pub async fn create_board_card_exec(
|
||||
"description": model.description,
|
||||
"position": model.position,
|
||||
"assignee_id": model.assignee_id.map(|id| id.to_string()),
|
||||
"issue_id": model.issue_id.map(|id| id.to_string()),
|
||||
"priority": model.priority,
|
||||
"created_at": model.created_at.to_rfc3339(),
|
||||
"updated_at": model.updated_at.to_rfc3339(),
|
||||
@ -468,6 +473,10 @@ pub async fn update_board_card_exec(
|
||||
active.assignee_id = Set(assignee_id.as_str().and_then(|s| Uuid::parse_str(s).ok()));
|
||||
updated = true;
|
||||
}
|
||||
if let Some(issue_id) = args.get("issue_id") {
|
||||
active.issue_id = Set(issue_id.as_i64());
|
||||
updated = true;
|
||||
}
|
||||
if let Some(priority) = args.get("priority") {
|
||||
active.priority = Set(priority.as_str().map(|s| s.to_string()));
|
||||
updated = true;
|
||||
@ -644,13 +653,18 @@ pub fn create_card_tool_definition() -> ToolDefinition {
|
||||
});
|
||||
p.insert("assignee_id".into(), ToolParam {
|
||||
name: "assignee_id".into(), param_type: "string".into(),
|
||||
description: Some("Assignee user UUID. Optional.".into()),
|
||||
description: Some("Card assignee user UUID. Optional.".into()),
|
||||
required: false, properties: None, items: None,
|
||||
});
|
||||
p.insert("issue_id".into(), ToolParam {
|
||||
name: "issue_id".into(), param_type: "integer".into(),
|
||||
description: Some("Link a project issue NUMBER to this card. Optional.".into()),
|
||||
required: false, properties: None, items: None,
|
||||
});
|
||||
ToolDefinition::new("project_create_board_card")
|
||||
.description(
|
||||
"Create a card on a Kanban board. If column_id is not provided, \
|
||||
the card is added to the first column.",
|
||||
the card is added to the first column. Optionally link to a project issue.",
|
||||
)
|
||||
.parameters(ToolSchema {
|
||||
schema_type: "object".into(),
|
||||
@ -696,6 +710,11 @@ pub fn update_card_tool_definition() -> ToolDefinition {
|
||||
description: Some("New assignee UUID. Optional.".into()),
|
||||
required: false, properties: None, items: None,
|
||||
});
|
||||
p.insert("issue_id".into(), ToolParam {
|
||||
name: "issue_id".into(), param_type: "integer".into(),
|
||||
description: Some("Link to a project issue number. Set to 0 to unlink. Optional.".into()),
|
||||
required: false, properties: None, items: None,
|
||||
});
|
||||
ToolDefinition::new("project_update_board_card")
|
||||
.description(
|
||||
"Update a board card (title, description, column, position, assignee, priority). \
|
||||
@ -723,3 +742,97 @@ pub fn delete_card_tool_definition() -> ToolDefinition {
|
||||
required: Some(vec!["card_id".into()]),
|
||||
})
|
||||
}
|
||||
|
||||
// ─── create board column ──────────────────────────────────────────────────────
|
||||
|
||||
pub async fn create_board_column_exec(
|
||||
ctx: ToolContext,
|
||||
args: serde_json::Value,
|
||||
) -> Result<serde_json::Value, ToolError> {
|
||||
let project_id = ctx.project_id();
|
||||
let sender_id = ctx.sender_id().ok_or_else(|| ToolError::ExecutionError("No sender".into()))?;
|
||||
let db = ctx.db();
|
||||
|
||||
require_admin(db, project_id, sender_id).await?;
|
||||
|
||||
let board_id = args.get("board_id")
|
||||
.and_then(|v| Uuid::parse_str(v.as_str()?).ok())
|
||||
.ok_or_else(|| ToolError::ExecutionError("board_id is required".into()))?;
|
||||
|
||||
let name = args.get("name").and_then(|v| v.as_str())
|
||||
.ok_or_else(|| ToolError::ExecutionError("name is required".into()))?
|
||||
.to_string();
|
||||
|
||||
let color = args.get("color").and_then(|v| v.as_str()).map(|s| s.to_string());
|
||||
|
||||
let board = ProjectBoard::find_by_id(board_id)
|
||||
.one(db)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?
|
||||
.ok_or_else(|| ToolError::ExecutionError("Board not found".into()))?;
|
||||
if board.project != project_id {
|
||||
return Err(ToolError::ExecutionError("Board does not belong to this project".into()));
|
||||
}
|
||||
|
||||
let max_pos: Option<Option<i32>> = ProjectBoardColumn::find()
|
||||
.filter(project_board_column::Column::Board.eq(board_id))
|
||||
.select_only()
|
||||
.column_as(project_board_column::Column::Position.max(), "max_pos")
|
||||
.into_tuple::<Option<i32>>()
|
||||
.one(db)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
let position = max_pos.flatten().unwrap_or(0) + 1;
|
||||
|
||||
let _now = Utc::now();
|
||||
let active = project_board_column::ActiveModel {
|
||||
id: Set(Uuid::now_v7()),
|
||||
board: Set(board_id),
|
||||
name: Set(name.clone()),
|
||||
position: Set(position),
|
||||
wip_limit: Set(None),
|
||||
color: Set(color.clone()),
|
||||
};
|
||||
|
||||
let model = active.insert(db)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"id": model.id.to_string(),
|
||||
"board_id": model.board.to_string(),
|
||||
"name": model.name,
|
||||
"position": model.position,
|
||||
"wip_limit": model.wip_limit,
|
||||
"color": model.color,
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn create_column_tool_definition() -> ToolDefinition {
|
||||
let mut p = HashMap::new();
|
||||
p.insert("board_id".into(), ToolParam {
|
||||
name: "board_id".into(), param_type: "string".into(),
|
||||
description: Some("Board UUID (required).".into()),
|
||||
required: true, properties: None, items: None,
|
||||
});
|
||||
p.insert("name".into(), ToolParam {
|
||||
name: "name".into(), param_type: "string".into(),
|
||||
description: Some("Column name (required).".into()),
|
||||
required: true, properties: None, items: None,
|
||||
});
|
||||
p.insert("color".into(), ToolParam {
|
||||
name: "color".into(), param_type: "string".into(),
|
||||
description: Some("Column color (e.g. '#ff0000'). Optional.".into()),
|
||||
required: false, properties: None, items: None,
|
||||
});
|
||||
ToolDefinition::new("project_create_board_column")
|
||||
.description(
|
||||
"Create a new column on a Kanban board. \
|
||||
The column is appended at the end. Requires admin or owner role.",
|
||||
)
|
||||
.parameters(ToolSchema {
|
||||
schema_type: "object".into(),
|
||||
properties: Some(p),
|
||||
required: Some(vec!["board_id".into(), "name".into()]),
|
||||
})
|
||||
}
|
||||
|
||||
@ -555,3 +555,264 @@ pub fn update_tool_definition() -> ToolDefinition {
|
||||
required: Some(vec!["number".into()]),
|
||||
})
|
||||
}
|
||||
|
||||
// ─── assign ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Assign or unassign users to/from an issue.
|
||||
pub async fn assign_issue_exec(
|
||||
ctx: ToolContext,
|
||||
args: serde_json::Value,
|
||||
) -> Result<serde_json::Value, ToolError> {
|
||||
let project_id = ctx.project_id();
|
||||
let sender_id = ctx.sender_id().ok_or_else(|| ToolError::ExecutionError("No sender".into()))?;
|
||||
let db = ctx.db();
|
||||
|
||||
let number = args.get("number").and_then(|v| v.as_i64())
|
||||
.ok_or_else(|| ToolError::ExecutionError("number is required".into()))?;
|
||||
|
||||
let issue = Issue::find()
|
||||
.filter(issue::Column::Project.eq(project_id))
|
||||
.filter(issue::Column::Number.eq(number))
|
||||
.one(db)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?
|
||||
.ok_or_else(|| ToolError::ExecutionError(format!("Issue #{} not found", number)))?;
|
||||
|
||||
require_issue_modifier(db, project_id, sender_id, issue.author).await?;
|
||||
|
||||
let add_ids: Vec<Uuid> = args.get("add_user_ids")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|a| a.iter().filter_map(|v| Uuid::parse_str(v.as_str()?).ok()).collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
let remove_ids: Vec<Uuid> = args.get("remove_user_ids")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|a| a.iter().filter_map(|v| Uuid::parse_str(v.as_str()?).ok()).collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
for uid in &add_ids {
|
||||
let exists = IssueAssignee::find()
|
||||
.filter(issue_assignee::Column::Issue.eq(issue.id))
|
||||
.filter(issue_assignee::Column::User.eq(*uid))
|
||||
.one(db)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
if exists.is_some() {
|
||||
continue;
|
||||
}
|
||||
let am = issue_assignee::ActiveModel {
|
||||
issue: Set(issue.id),
|
||||
user: Set(*uid),
|
||||
assigned_at: Set(now),
|
||||
..Default::default()
|
||||
};
|
||||
am.insert(db).await.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
}
|
||||
|
||||
for uid in &remove_ids {
|
||||
IssueAssignee::delete_many()
|
||||
.filter(issue_assignee::Column::Issue.eq(issue.id))
|
||||
.filter(issue_assignee::Column::User.eq(*uid))
|
||||
.exec(db)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
}
|
||||
|
||||
// Build response
|
||||
let current_assignee_ids: Vec<Uuid> = IssueAssignee::find()
|
||||
.filter(issue_assignee::Column::Issue.eq(issue.id))
|
||||
.all(db)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?
|
||||
.into_iter()
|
||||
.map(|a| a.user)
|
||||
.collect();
|
||||
|
||||
let users = if current_assignee_ids.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
User::find()
|
||||
.filter(models::users::user::Column::Uid.is_in(current_assignee_ids.clone()))
|
||||
.all(db)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?
|
||||
};
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"issue_id": issue.id.to_string(),
|
||||
"issue_number": issue.number,
|
||||
"assignees": users.into_iter().map(|u| serde_json::json!({
|
||||
"id": u.uid.to_string(),
|
||||
"username": u.username,
|
||||
"display_name": u.display_name,
|
||||
})).collect::<Vec<_>>(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn assign_tool_definition() -> ToolDefinition {
|
||||
let mut p = HashMap::new();
|
||||
p.insert("number".into(), ToolParam {
|
||||
name: "number".into(), param_type: "integer".into(),
|
||||
description: Some("Issue number (required).".into()),
|
||||
required: true, properties: None, items: None,
|
||||
});
|
||||
p.insert("add_user_ids".into(), ToolParam {
|
||||
name: "add_user_ids".into(), param_type: "array".into(),
|
||||
description: Some("Array of user UUIDs to add as assignees. Optional.".into()),
|
||||
required: false, properties: None,
|
||||
items: Some(Box::new(ToolParam {
|
||||
name: "".into(), param_type: "string".into(),
|
||||
description: Some("User UUID".into()), required: false, properties: None, items: None,
|
||||
})),
|
||||
});
|
||||
p.insert("remove_user_ids".into(), ToolParam {
|
||||
name: "remove_user_ids".into(), param_type: "array".into(),
|
||||
description: Some("Array of user UUIDs to remove from assignees. Optional.".into()),
|
||||
required: false, properties: None,
|
||||
items: Some(Box::new(ToolParam {
|
||||
name: "".into(), param_type: "string".into(),
|
||||
description: Some("User UUID".into()), required: false, properties: None, items: None,
|
||||
})),
|
||||
});
|
||||
ToolDefinition::new("project_assign_issue")
|
||||
.description(
|
||||
"Add or remove assignees on an issue by its number. \
|
||||
Requires the issue author or a project admin/owner. \
|
||||
Returns the updated list of assignees.",
|
||||
)
|
||||
.parameters(ToolSchema {
|
||||
schema_type: "object".into(),
|
||||
properties: Some(p),
|
||||
required: Some(vec!["number".into()]),
|
||||
})
|
||||
}
|
||||
|
||||
// ─── add comment ───────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn add_comment_exec(
|
||||
ctx: ToolContext,
|
||||
args: serde_json::Value,
|
||||
) -> Result<serde_json::Value, ToolError> {
|
||||
let project_id = ctx.project_id();
|
||||
let sender_id = ctx.sender_id().ok_or_else(|| ToolError::ExecutionError("No sender".into()))?;
|
||||
let db = ctx.db();
|
||||
|
||||
let number = args.get("number").and_then(|v| v.as_i64())
|
||||
.ok_or_else(|| ToolError::ExecutionError("number is required".into()))?;
|
||||
|
||||
let body = args.get("body").and_then(|v| v.as_str())
|
||||
.ok_or_else(|| ToolError::ExecutionError("body is required".into()))?
|
||||
.to_string();
|
||||
|
||||
let issue = Issue::find()
|
||||
.filter(issue::Column::Project.eq(project_id))
|
||||
.filter(issue::Column::Number.eq(number))
|
||||
.one(db)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?
|
||||
.ok_or_else(|| ToolError::ExecutionError(format!("Issue #{} not found", number)))?;
|
||||
|
||||
// Only project members can comment
|
||||
let member = ProjectMember::find()
|
||||
.filter(project_members::Column::Project.eq(project_id))
|
||||
.filter(project_members::Column::User.eq(sender_id))
|
||||
.one(db)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
if member.is_none() {
|
||||
return Err(ToolError::ExecutionError("You are not a member of this project".into()));
|
||||
}
|
||||
|
||||
let now = Utc::now();
|
||||
let comment = models::issues::issue_comment::ActiveModel {
|
||||
id: sea_orm::NotSet,
|
||||
issue: Set(issue.id),
|
||||
author: Set(sender_id),
|
||||
body: Set(body.clone()),
|
||||
created_at: Set(now),
|
||||
updated_at: Set(now),
|
||||
};
|
||||
let model = comment.insert(db).await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
|
||||
// Update issue updated_at
|
||||
let mut i_active: issue::ActiveModel = issue.into();
|
||||
i_active.updated_at = Set(now);
|
||||
i_active.update(db).await.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
|
||||
// Look up author name
|
||||
let author_name = User::find_by_id(sender_id).one(db).await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?
|
||||
.map(|u| u.display_name.unwrap_or(u.username));
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"comment_id": model.id.to_string(),
|
||||
"issue_number": number,
|
||||
"body": body,
|
||||
"author_id": sender_id.to_string(),
|
||||
"author_name": author_name,
|
||||
"created_at": now.to_rfc3339(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn add_comment_tool_definition() -> ToolDefinition {
|
||||
let mut p = HashMap::new();
|
||||
p.insert("number".into(), ToolParam {
|
||||
name: "number".into(), param_type: "integer".into(),
|
||||
description: Some("Issue number (required).".into()),
|
||||
required: true, properties: None, items: None,
|
||||
});
|
||||
p.insert("body".into(), ToolParam {
|
||||
name: "body".into(), param_type: "string".into(),
|
||||
description: Some("Comment body text (required).".into()),
|
||||
required: true, properties: None, items: None,
|
||||
});
|
||||
ToolDefinition::new("project_add_comment")
|
||||
.description(
|
||||
"Add a comment to an issue in the current project by its number. \
|
||||
Requires project membership. Returns the created comment.",
|
||||
)
|
||||
.parameters(ToolSchema {
|
||||
schema_type: "object".into(),
|
||||
properties: Some(p),
|
||||
required: Some(vec!["number".into(), "body".into()]),
|
||||
})
|
||||
}
|
||||
|
||||
// ─── list labels ───────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn list_labels_exec(
|
||||
ctx: ToolContext,
|
||||
_args: serde_json::Value,
|
||||
) -> Result<serde_json::Value, ToolError> {
|
||||
let project_id = ctx.project_id();
|
||||
let db = ctx.db();
|
||||
|
||||
// Get labels associated with this project via issue_labels
|
||||
let labels = Label::find()
|
||||
.filter(label::Column::Project.eq(project_id))
|
||||
.all(db)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
|
||||
let result: Vec<serde_json::Value> = labels.into_iter().map(|l| {
|
||||
serde_json::json!({
|
||||
"id": l.id,
|
||||
"name": l.name,
|
||||
"color": l.color,
|
||||
})
|
||||
}).collect();
|
||||
|
||||
Ok(serde_json::to_value(result).map_err(|e| ToolError::ExecutionError(e.to_string()))?)
|
||||
}
|
||||
|
||||
pub fn list_labels_tool_definition() -> ToolDefinition {
|
||||
ToolDefinition::new("project_list_labels")
|
||||
.description(
|
||||
"List all labels available in the current project. \
|
||||
Returns label id, name, color, and description. \
|
||||
Use label IDs when creating or updating issues.",
|
||||
)
|
||||
}
|
||||
|
||||
@ -17,11 +17,15 @@ use agent::{ToolHandler, ToolRegistry};
|
||||
|
||||
pub use arxiv::arxiv_search_exec;
|
||||
pub use boards::{
|
||||
create_board_card_exec, create_board_exec, delete_board_card_exec, list_boards_exec,
|
||||
create_board_card_exec, create_board_exec, create_board_column_exec,
|
||||
delete_board_card_exec, list_boards_exec,
|
||||
update_board_card_exec, update_board_exec,
|
||||
};
|
||||
pub use curl::curl_exec;
|
||||
pub use issues::{create_issue_exec, list_issues_exec, update_issue_exec};
|
||||
pub use issues::{
|
||||
add_comment_exec, assign_issue_exec, create_issue_exec, list_issues_exec,
|
||||
list_labels_exec, update_issue_exec,
|
||||
};
|
||||
pub use members::list_members_exec;
|
||||
pub use repos::{create_commit_exec, create_repo_exec, list_repos_exec, update_repo_exec};
|
||||
|
||||
@ -75,6 +79,18 @@ pub fn register_all(registry: &mut ToolRegistry) {
|
||||
issues::update_tool_definition(),
|
||||
ToolHandler::new(|ctx, args| Box::pin(update_issue_exec(ctx, args))),
|
||||
);
|
||||
registry.register(
|
||||
issues::assign_tool_definition(),
|
||||
ToolHandler::new(|ctx, args| Box::pin(assign_issue_exec(ctx, args))),
|
||||
);
|
||||
registry.register(
|
||||
issues::add_comment_tool_definition(),
|
||||
ToolHandler::new(|ctx, args| Box::pin(add_comment_exec(ctx, args))),
|
||||
);
|
||||
registry.register(
|
||||
issues::list_labels_tool_definition(),
|
||||
ToolHandler::new(|ctx, args| Box::pin(list_labels_exec(ctx, args))),
|
||||
);
|
||||
|
||||
// boards
|
||||
registry.register(
|
||||
@ -101,4 +117,8 @@ pub fn register_all(registry: &mut ToolRegistry) {
|
||||
boards::delete_card_tool_definition(),
|
||||
ToolHandler::new(|ctx, args| Box::pin(delete_board_card_exec(ctx, args))),
|
||||
);
|
||||
registry.register(
|
||||
boards::create_column_tool_definition(),
|
||||
ToolHandler::new(|ctx, args| Box::pin(create_board_column_exec(ctx, args))),
|
||||
);
|
||||
}
|
||||
|
||||
@ -11,8 +11,12 @@ use models::EntityTrait;
|
||||
use sea_orm::{ColumnTrait, QueryFilter};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
/// Git zero OID for new branch/tag creation webhook events.
|
||||
const ZERO_OID: &str = "0000000000000000000000000000000000000000";
|
||||
|
||||
/// Single-threaded worker that sequentially consumes tasks from Redis queues.
|
||||
/// K8s can scale replicas for concurrency — each replica runs one worker.
|
||||
/// Per-repo Redis locking is managed inside HookMetaDataSync methods.
|
||||
@ -168,29 +172,33 @@ impl HookWorker {
|
||||
)));
|
||||
}
|
||||
|
||||
// Capture before tips for webhook diff
|
||||
// Build sync once and reuse for before_tips + sync + after_tips
|
||||
// (avoids opening git2::Repository 3 times)
|
||||
let db_for_sync = self.db.clone();
|
||||
let cache_for_sync = self.cache.clone();
|
||||
let repo_for_sync = repo.clone();
|
||||
let sync = tokio::task::spawn_blocking(move || {
|
||||
HookMetaDataSync::new(db_for_sync, cache_for_sync, repo_for_sync)
|
||||
.map_err(|e| GitError::Internal(e.to_string()))
|
||||
})
|
||||
.await
|
||||
.map_err(|e| GitError::Internal(format!("spawn_blocking join error: {}", e)))?
|
||||
.map_err(GitError::from)?;
|
||||
|
||||
// Capture before tips for webhook diff (read-only, no lock needed)
|
||||
let before_tips = tokio::task::spawn_blocking({
|
||||
let db = self.db.clone();
|
||||
let cache = self.cache.clone();
|
||||
let repo = repo.clone();
|
||||
move || {
|
||||
let sync = HookMetaDataSync::new(db, cache, repo)
|
||||
.map_err(|e| GitError::Internal(e.to_string()))?;
|
||||
Ok::<_, GitError>((sync.list_branch_tips(), sync.list_tag_tips()))
|
||||
}
|
||||
let sync = sync.clone();
|
||||
move || Ok::<_, GitError>((sync.list_branch_tips(), sync.list_tag_tips()))
|
||||
})
|
||||
.await
|
||||
.map_err(|e| GitError::Internal(format!("spawn_blocking join error: {}", e)))?
|
||||
.map_err(GitError::from)?;
|
||||
|
||||
// Run full sync (internally acquires/releases per-repo lock)
|
||||
let db = self.db.clone();
|
||||
let cache = self.cache.clone();
|
||||
let repo_clone = repo.clone();
|
||||
let _sync_result = tokio::task::spawn_blocking(move || {
|
||||
let sync_clone = sync.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let result = tokio::runtime::Handle::current().block_on(async {
|
||||
let sync = HookMetaDataSync::new(db.clone(), cache.clone(), repo_clone.clone())?;
|
||||
sync.sync().await
|
||||
sync_clone.sync().await
|
||||
});
|
||||
match result {
|
||||
Ok(()) => Ok::<(), GitError>(()),
|
||||
@ -201,18 +209,10 @@ impl HookWorker {
|
||||
.map_err(|e| GitError::Internal(format!("spawn_blocking join error: {}", e)))
|
||||
.and_then(|r| r.map_err(GitError::from))?;
|
||||
|
||||
// Only dispatch webhooks if sync succeeded
|
||||
|
||||
// Capture after tips and dispatch webhooks
|
||||
// Capture after tips for webhook diff (read-only, no lock needed)
|
||||
let after_tips = tokio::task::spawn_blocking({
|
||||
let db = self.db.clone();
|
||||
let cache = self.cache.clone();
|
||||
let repo = repo.clone();
|
||||
move || {
|
||||
let sync = HookMetaDataSync::new(db, cache, repo)
|
||||
.map_err(|e| GitError::Internal(e.to_string()))?;
|
||||
Ok::<_, GitError>((sync.list_branch_tips(), sync.list_tag_tips()))
|
||||
}
|
||||
let sync = sync.clone();
|
||||
move || Ok::<_, GitError>((sync.list_branch_tips(), sync.list_tag_tips()))
|
||||
})
|
||||
.await
|
||||
.map_err(|e| GitError::Internal(format!("spawn_blocking join error: {}", e)))?
|
||||
@ -222,14 +222,18 @@ impl HookWorker {
|
||||
let (after_branch_tips, after_tag_tips) = after_tips;
|
||||
let project = repo.project;
|
||||
|
||||
// Resolve namespace once outside the loop
|
||||
// Resolve namespace for webhook URL construction
|
||||
let namespace = models::projects::Project::find_by_id(project)
|
||||
.one(self.db.reader())
|
||||
.await
|
||||
.inspect_err(|e| tracing::warn!(error = %e, project = %project, "hook sync: failed to resolve project namespace"))
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|p| p.name)
|
||||
.unwrap_or_default();
|
||||
.unwrap_or_else(|| {
|
||||
tracing::warn!(project = %project, "hook sync: project not found, empty namespace");
|
||||
String::new()
|
||||
});
|
||||
|
||||
let repo_id_str = repo.id.to_string();
|
||||
let repo_name = repo.repo_name.clone();
|
||||
@ -248,7 +252,7 @@ impl HookWorker {
|
||||
let changed = before_oid.map(|o| o != after_oid.as_str()).unwrap_or(true);
|
||||
if changed {
|
||||
branch_changes += 1;
|
||||
let before_oid = before_oid.map_or("0", |v| v).to_string();
|
||||
let before_oid = before_oid.map_or(ZERO_OID, |v| v).to_string();
|
||||
let branch_name = branch.clone();
|
||||
let h = tokio::spawn({
|
||||
let http_client = http_client.clone();
|
||||
@ -294,7 +298,7 @@ impl HookWorker {
|
||||
if is_new || was_updated {
|
||||
tag_changes += 1;
|
||||
changed_tag_names.push(tag.clone());
|
||||
let before_oid = before_oid.map_or("0", |v| v).to_string();
|
||||
let before_oid = before_oid.map_or(ZERO_OID, |v| v).to_string();
|
||||
let tag_name = tag.clone();
|
||||
let h = tokio::spawn({
|
||||
let http_client = http_client.clone();
|
||||
|
||||
@ -137,8 +137,36 @@ impl HookMetaDataSync {
|
||||
|
||||
let (branches, _) = self.collect_git_refs()?;
|
||||
|
||||
// Auto-detect first local branch when default_branch is empty
|
||||
// Preferred default branch names, in priority order.
|
||||
// git2::References iteration order is filesystem-dependent (not chronological),
|
||||
// so we MUST NOT use "first branch wins".
|
||||
const PREFERRED_BRANCHES: &[&str] = &["main", "master", "trunk"];
|
||||
|
||||
// Auto-detect default branch when empty.
|
||||
// Re-read from DB inside the transaction to avoid stale reads from concurrent workers.
|
||||
let mut auto_detected_branch: Option<String> = None;
|
||||
let current_default: Option<String> = models::repos::repo::Entity::find_by_id(self.repo.id)
|
||||
.one(txn)
|
||||
.await
|
||||
.map_err(|e| GitError::IoError(format!("failed to re-read repo: {}", e)))?
|
||||
.map(|r| r.default_branch)
|
||||
.filter(|b| !b.is_empty());
|
||||
|
||||
if current_default.is_none() {
|
||||
// Prefer known branch names over first-come
|
||||
for preferred in PREFERRED_BRANCHES {
|
||||
if branches.iter().any(|b| b.shorthand == *preferred && b.is_branch && !b.is_remote) {
|
||||
auto_detected_branch = Some(ToString::to_string(preferred));
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Fallback: first local branch
|
||||
if auto_detected_branch.is_none() {
|
||||
if let Some(first) = branches.iter().find(|b| b.is_branch && !b.is_remote) {
|
||||
auto_detected_branch = Some(first.shorthand.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for branch in &branches {
|
||||
if existing_names.contains(&branch.name) {
|
||||
@ -154,13 +182,7 @@ impl HookMetaDataSync {
|
||||
models::repos::repo_branch::Column::Upstream,
|
||||
sea_orm::prelude::Expr::value(branch.upstream.clone()),
|
||||
)
|
||||
.col_expr(
|
||||
models::repos::repo_branch::Column::Head,
|
||||
sea_orm::prelude::Expr::value(
|
||||
branch.is_branch
|
||||
&& branch.shorthand == self.repo.default_branch,
|
||||
),
|
||||
)
|
||||
// head is NOT set here — set below in a single pass to avoid N+1 writes
|
||||
.col_expr(
|
||||
models::repos::repo_branch::Column::UpdatedAt,
|
||||
sea_orm::prelude::Expr::value(now),
|
||||
@ -174,7 +196,8 @@ impl HookMetaDataSync {
|
||||
name: Set(branch.name.clone()),
|
||||
oid: Set(branch.target_oid.clone()),
|
||||
upstream: Set(branch.upstream.clone()),
|
||||
head: Set(branch.is_branch && branch.shorthand == self.repo.default_branch),
|
||||
// head defaults to false — will be set below if this is the default branch
|
||||
head: Set(false),
|
||||
created_at: Set(now),
|
||||
updated_at: Set(now),
|
||||
..Default::default()
|
||||
@ -184,15 +207,6 @@ impl HookMetaDataSync {
|
||||
.await
|
||||
.map_err(|e| GitError::IoError(format!("failed to insert branch: {}", e)))?;
|
||||
}
|
||||
|
||||
// Detect first local branch if no default is set
|
||||
if self.repo.default_branch.is_empty()
|
||||
&& branch.is_branch
|
||||
&& !branch.is_remote
|
||||
&& auto_detected_branch.is_none()
|
||||
{
|
||||
auto_detected_branch = Some(branch.shorthand.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if !existing_names.is_empty() {
|
||||
@ -206,38 +220,52 @@ impl HookMetaDataSync {
|
||||
})?;
|
||||
}
|
||||
|
||||
// Persist auto-detected default branch and update head flags
|
||||
// Persist auto-detected default branch and update head flags.
|
||||
// Only writes if default_branch is still empty (prevents concurrent overrides).
|
||||
if let Some(ref branch_name) = auto_detected_branch {
|
||||
models::repos::repo::Entity::update_many()
|
||||
let updated = models::repos::repo::Entity::update_many()
|
||||
.filter(models::repos::repo::Column::Id.eq(repo_id))
|
||||
.filter(models::repos::repo::Column::DefaultBranch.eq(""))
|
||||
.col_expr(
|
||||
models::repos::repo::Column::DefaultBranch,
|
||||
sea_orm::prelude::Expr::value(branch_name.clone()),
|
||||
)
|
||||
.col_expr(
|
||||
models::repos::repo::Column::UpdatedAt,
|
||||
sea_orm::prelude::Expr::value(now),
|
||||
)
|
||||
.exec(txn)
|
||||
.await
|
||||
.map_err(|e| GitError::IoError(format!("failed to set default branch: {}", e)))?;
|
||||
|
||||
models::repos::repo_branch::Entity::update_many()
|
||||
.filter(models::repos::repo_branch::Column::Repo.eq(repo_id))
|
||||
.col_expr(
|
||||
models::repos::repo_branch::Column::Head,
|
||||
sea_orm::prelude::Expr::value(false),
|
||||
)
|
||||
.exec(txn)
|
||||
.await
|
||||
.map_err(|e| GitError::IoError(format!("failed to clear head flags: {}", e)))?;
|
||||
if updated.rows_affected > 0 {
|
||||
models::repos::repo_branch::Entity::update_many()
|
||||
.filter(models::repos::repo_branch::Column::Repo.eq(repo_id))
|
||||
.col_expr(
|
||||
models::repos::repo_branch::Column::Head,
|
||||
sea_orm::prelude::Expr::value(false),
|
||||
)
|
||||
.exec(txn)
|
||||
.await
|
||||
.map_err(|e| GitError::IoError(format!("failed to clear head flags: {}", e)))?;
|
||||
|
||||
models::repos::repo_branch::Entity::update_many()
|
||||
.filter(models::repos::repo_branch::Column::Repo.eq(repo_id))
|
||||
.filter(models::repos::repo_branch::Column::Name.eq(branch_name))
|
||||
.col_expr(
|
||||
models::repos::repo_branch::Column::Head,
|
||||
sea_orm::prelude::Expr::value(true),
|
||||
)
|
||||
.exec(txn)
|
||||
.await
|
||||
.map_err(|e| GitError::IoError(format!("failed to set head flag: {}", e)))?;
|
||||
models::repos::repo_branch::Entity::update_many()
|
||||
.filter(models::repos::repo_branch::Column::Repo.eq(repo_id))
|
||||
.filter(models::repos::repo_branch::Column::Name.eq(branch_name))
|
||||
.col_expr(
|
||||
models::repos::repo_branch::Column::Head,
|
||||
sea_orm::prelude::Expr::value(true),
|
||||
)
|
||||
.exec(txn)
|
||||
.await
|
||||
.map_err(|e| GitError::IoError(format!("failed to set head flag: {}", e)))?;
|
||||
} else {
|
||||
tracing::debug!(
|
||||
repo_id = %repo_id,
|
||||
attempted = %branch_name,
|
||||
"default_branch already set by another worker, skipping"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
//! HTTP rate limiting for git operations.
|
||||
//!
|
||||
//! Uses a token-bucket approach with per-IP and per-repo-write limits.
|
||||
//! Uses a token-bucket approach with per-repo-write limits.
|
||||
//! In K8s environments all traffic routes through the ingress so
|
||||
//! per-IP limiting is meaningless — a fixed global key is used instead.
|
||||
//! Cleanup runs every 5 minutes to prevent unbounded memory growth.
|
||||
|
||||
use std::collections::HashMap;
|
||||
@ -55,20 +57,18 @@ impl RateLimiter {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn is_ip_read_allowed(&self, ip: &str) -> bool {
|
||||
let key = format!("ip:read:{}", ip);
|
||||
self.is_allowed(&key, BucketOp::Read, self.config.read_requests_per_window)
|
||||
pub async fn is_read_allowed(&self) -> bool {
|
||||
self.is_allowed("global:read", BucketOp::Read, self.config.read_requests_per_window)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn is_ip_write_allowed(&self, ip: &str) -> bool {
|
||||
let key = format!("ip:write:{}", ip);
|
||||
self.is_allowed(&key, BucketOp::Write, self.config.write_requests_per_window)
|
||||
pub async fn is_write_allowed(&self) -> bool {
|
||||
self.is_allowed("global:write", BucketOp::Write, self.config.write_requests_per_window)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn is_repo_write_allowed(&self, ip: &str, repo_path: &str) -> bool {
|
||||
let key = format!("repo:write:{}:{}", ip, repo_path);
|
||||
pub async fn is_repo_write_allowed(&self, repo_path: &str) -> bool {
|
||||
let key = format!("repo:write:{}", repo_path);
|
||||
self.is_allowed(&key, BucketOp::Write, self.config.write_requests_per_window)
|
||||
.await
|
||||
}
|
||||
@ -107,8 +107,8 @@ impl RateLimiter {
|
||||
true
|
||||
}
|
||||
|
||||
pub async fn retry_after(&self, ip: &str) -> u64 {
|
||||
let key_read = format!("ip:read:{}", ip);
|
||||
pub async fn retry_after(&self) -> u64 {
|
||||
let key_read = "global:read".to_string();
|
||||
let now = Instant::now();
|
||||
let buckets = self.buckets.read().await;
|
||||
|
||||
@ -148,8 +148,8 @@ mod tests {
|
||||
}));
|
||||
|
||||
for _ in 0..3 {
|
||||
assert!(limiter.is_ip_read_allowed("1.2.3.4").await);
|
||||
assert!(limiter.is_read_allowed().await);
|
||||
}
|
||||
assert!(!limiter.is_ip_read_allowed("1.2.3.4").await);
|
||||
assert!(!limiter.is_read_allowed().await);
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,8 +13,7 @@ pub async fn info_refs(
|
||||
path: web::Path<(String, String)>,
|
||||
state: web::Data<HttpAppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let ip = extract_ip(&req);
|
||||
if !state.rate_limiter.is_ip_read_allowed(&ip).await {
|
||||
if !state.rate_limiter.is_read_allowed().await {
|
||||
return Err(actix_web::error::ErrorTooManyRequests(
|
||||
"Rate limit exceeded",
|
||||
));
|
||||
@ -47,8 +46,7 @@ pub async fn upload_pack(
|
||||
payload: web::Payload,
|
||||
state: web::Data<HttpAppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let ip = extract_ip(&req);
|
||||
if !state.rate_limiter.is_ip_read_allowed(&ip).await {
|
||||
if !state.rate_limiter.is_read_allowed().await {
|
||||
return Err(actix_web::error::ErrorTooManyRequests(
|
||||
"Rate limit exceeded",
|
||||
));
|
||||
@ -69,8 +67,7 @@ pub async fn receive_pack(
|
||||
payload: web::Payload,
|
||||
state: web::Data<HttpAppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let ip = extract_ip(&req);
|
||||
if !state.rate_limiter.is_ip_write_allowed(&ip).await {
|
||||
if !state.rate_limiter.is_write_allowed().await {
|
||||
return Err(actix_web::error::ErrorTooManyRequests(
|
||||
"Rate limit exceeded",
|
||||
));
|
||||
@ -98,10 +95,3 @@ pub async fn receive_pack(
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn extract_ip(req: &HttpRequest) -> String {
|
||||
req.connection_info()
|
||||
.realip_remote_addr()
|
||||
.unwrap_or("unknown")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
@ -20,7 +20,9 @@ pub struct Model {
|
||||
pub think: bool,
|
||||
pub stream: bool,
|
||||
pub min_score: Option<f32>,
|
||||
/// Agent type: "chat" (default) or "react" for ReAct reasoning agent.
|
||||
/// Agent type: "chat" (default), "react" (ReAct reasoning),
|
||||
/// "cot" (Chain-of-Thought), "rewoo" (Plan→Execute→Synthesize),
|
||||
/// or "reflexion" (Generate→Critique→Revise).
|
||||
pub agent_type: Option<String>,
|
||||
pub created_at: DateTimeUtc,
|
||||
pub updated_at: DateTimeUtc,
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
use crate::types::{
|
||||
AgentTaskEvent, EmailEnvelope, ProjectRoomEvent, ReactionGroup, RoomMessageEnvelope,
|
||||
RoomMessageEvent,
|
||||
RoomMessageEvent, RoomMessageStreamChunkEvent,
|
||||
};
|
||||
use anyhow::Context;
|
||||
use metrics::counter;
|
||||
@ -19,7 +19,7 @@ pub struct RedisPubSub {
|
||||
|
||||
impl RedisPubSub {
|
||||
/// Publish a serialised event to a Redis channel.
|
||||
async fn publish_channel(&self, channel: &str, payload: &[u8]) {
|
||||
pub async fn publish_channel(&self, channel: &str, payload: &[u8]) {
|
||||
let redis = match (self.get_redis)().await {
|
||||
Ok(Ok(c)) => c,
|
||||
Ok(Err(e)) => {
|
||||
@ -145,6 +145,24 @@ impl MessageProducer {
|
||||
Ok(entry_id)
|
||||
}
|
||||
|
||||
/// Publish a stream chunk event via Redis Pub/Sub for cross-node delivery.
|
||||
/// Called alongside the in-process broadcast to ensure WS clients on
|
||||
/// other server instances also receive the chunk.
|
||||
pub async fn publish_stream_chunk(&self, event: &RoomMessageStreamChunkEvent) {
|
||||
let Some(pubsub) = &self.pubsub else {
|
||||
return;
|
||||
};
|
||||
let channel = format!("room:stream:chunk:{}", event.room_id);
|
||||
let payload = match serde_json::to_vec(event) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "serialise stream chunk failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
pubsub.publish_channel(&channel, &payload).await;
|
||||
}
|
||||
|
||||
/// Publish a project-level room event via Pub/Sub (no Redis Stream write).
|
||||
pub async fn publish_project_room_event(
|
||||
&self,
|
||||
|
||||
@ -43,6 +43,7 @@ redis = { workspace = true, features = ["tokio-comp", "connection-manager"] }
|
||||
hostname = "0.4"
|
||||
dashmap = "7.0.0-rc2"
|
||||
lru = "0.12.0"
|
||||
ammonia = "4.0"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@ -48,6 +48,7 @@ pub struct RoomConnectionManager {
|
||||
room_subscriber_count: RwLock<HashMap<Uuid, usize>>,
|
||||
project_subscriber_count: RwLock<HashMap<Uuid, usize>>,
|
||||
user_subscriber_count: RwLock<HashMap<Uuid, usize>>,
|
||||
stream_cancel_tokens: RwLock<HashMap<Uuid, Arc<std::sync::atomic::AtomicBool>>>,
|
||||
}
|
||||
|
||||
impl RoomConnectionManager {
|
||||
@ -89,6 +90,8 @@ impl RoomConnectionManager {
|
||||
project_subscriber_count: RwLock::new(HashMap::new()),
|
||||
#[allow(clippy::default_constructed_unit_structs)]
|
||||
user_subscriber_count: RwLock::new(HashMap::new()),
|
||||
#[allow(clippy::default_constructed_unit_structs)]
|
||||
stream_cancel_tokens: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -629,6 +632,35 @@ impl RoomConnectionManager {
|
||||
map.remove(&message_id);
|
||||
}
|
||||
|
||||
/// Register a cancel flag for an active AI streaming session.
|
||||
/// Returns the cancel token that the streaming task should check.
|
||||
pub async fn register_stream_cancel(
|
||||
&self,
|
||||
room_id: Uuid,
|
||||
) -> Arc<std::sync::atomic::AtomicBool> {
|
||||
let cancel = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||
let mut map = self.stream_cancel_tokens.write().await;
|
||||
map.insert(room_id, cancel.clone());
|
||||
cancel
|
||||
}
|
||||
|
||||
/// Cancel an active AI streaming session for a room.
|
||||
pub async fn cancel_ai_stream(&self, room_id: Uuid) -> bool {
|
||||
let map = self.stream_cancel_tokens.read().await;
|
||||
if let Some(cancel) = map.get(&room_id) {
|
||||
cancel.store(true, std::sync::atomic::Ordering::Release);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Clean up the cancel token for a room when streaming completes.
|
||||
pub async fn unregister_stream_cancel(&self, room_id: Uuid) {
|
||||
let mut map = self.stream_cancel_tokens.write().await;
|
||||
map.remove(&room_id);
|
||||
}
|
||||
|
||||
pub async fn subscribe_typing(
|
||||
&self,
|
||||
room_id: Uuid,
|
||||
@ -660,24 +692,22 @@ impl RoomConnectionManager {
|
||||
// Write/delete Redis key for 60s expiry (non-blocking)
|
||||
if let Ok(mut conn) = self.cache.conn().await {
|
||||
let key = user_key;
|
||||
tokio::spawn(async move {
|
||||
if action == "start" {
|
||||
let value = serde_json::json!({
|
||||
"username": username,
|
||||
"avatar_url": avatar_url,
|
||||
"sender_type": sender_type,
|
||||
})
|
||||
.to_string();
|
||||
let _: Result<(), _> = redis::cmd("SETEX")
|
||||
.arg(&key)
|
||||
.arg(60i64)
|
||||
.arg(&value)
|
||||
.query_async(&mut conn)
|
||||
.await;
|
||||
} else {
|
||||
let _: Result<(), _> = redis::cmd("DEL").arg(&key).query_async(&mut conn).await;
|
||||
}
|
||||
});
|
||||
if action == "start" {
|
||||
let value = serde_json::json!({
|
||||
"username": username,
|
||||
"avatar_url": avatar_url,
|
||||
"sender_type": sender_type,
|
||||
})
|
||||
.to_string();
|
||||
let _: Result<(), _> = redis::cmd("SETEX")
|
||||
.arg(&key)
|
||||
.arg(60i64)
|
||||
.arg(&value)
|
||||
.query_async(&mut conn)
|
||||
.await;
|
||||
} else {
|
||||
let _: Result<(), _> = redis::cmd("DEL").arg(&key).query_async(&mut conn).await;
|
||||
}
|
||||
}
|
||||
|
||||
let map: tokio::sync::RwLockReadGuard<'_, std::collections::HashMap<Uuid, broadcast::Sender<Arc<TypingEvent>>>> = self.typing_inner.read().await;
|
||||
@ -1156,6 +1186,53 @@ pub async fn subscribe_room_events(
|
||||
tracing::info!(room_id = %room_id, "room subscriber stopped");
|
||||
}
|
||||
|
||||
/// Subscribe to stream chunk events for cross-node delivery.
|
||||
/// When a stream chunk is published via Redis Pub/Sub on
|
||||
/// `room:stream:chunk:{room_id}`, broadcast it locally.
|
||||
pub async fn subscribe_room_stream_chunk_events(
|
||||
redis_url: String,
|
||||
manager: Arc<RoomConnectionManager>,
|
||||
room_id: Uuid,
|
||||
mut shutdown_rx: broadcast::Receiver<()>,
|
||||
) {
|
||||
let channel = format!("room:stream:chunk:{}", room_id);
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
|
||||
|
||||
tracing::info!(room_id = %room_id, channel = %channel, "starting room stream chunk subscriber");
|
||||
|
||||
let thread_channel = channel.clone();
|
||||
let thread_shutdown = shutdown_rx.resubscribe();
|
||||
start_pubsub_thread(redis_url, thread_channel, tx, thread_shutdown, |_| async {});
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.recv() => {
|
||||
tracing::info!(room_id = %room_id, "stream chunk subscriber shutting down");
|
||||
break;
|
||||
}
|
||||
payload = rx.recv() => {
|
||||
match payload {
|
||||
Some(data) => {
|
||||
match serde_json::from_slice::<RoomMessageStreamChunkEvent>(&data) {
|
||||
Ok(event) => {
|
||||
manager.broadcast_stream_chunk(event).await;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "malformed RoomMessageStreamChunkEvent");
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
tracing::warn!(room_id = %room_id, "stream chunk relay channel closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::info!(room_id = %room_id, "stream chunk subscriber stopped");
|
||||
}
|
||||
|
||||
pub async fn subscribe_project_room_events(
|
||||
redis_url: String,
|
||||
manager: Arc<RoomConnectionManager>,
|
||||
|
||||
@ -35,6 +35,7 @@ impl From<room::Model> for super::RoomResponse {
|
||||
created_at: value.created_at,
|
||||
last_msg_at: value.last_msg_at,
|
||||
unread_count: 0,
|
||||
version: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -58,6 +59,7 @@ impl From<room_member::Model> for super::RoomMemberResponse {
|
||||
|
||||
impl From<room_message::Model> for super::RoomMessageResponse {
|
||||
fn from(value: room_message::Model) -> Self {
|
||||
let chunked = super::RoomMessageResponse::detect_chunked(&value.thinking_content);
|
||||
Self {
|
||||
id: value.id,
|
||||
seq: value.seq,
|
||||
@ -69,6 +71,7 @@ impl From<room_message::Model> for super::RoomMessageResponse {
|
||||
content: value.content,
|
||||
content_type: value.content_type.to_string(),
|
||||
thinking_content: value.thinking_content,
|
||||
thinking_is_chunked: chunked,
|
||||
edited_at: value.edited_at,
|
||||
send_at: value.send_at,
|
||||
revoked: value.revoked,
|
||||
@ -270,14 +273,18 @@ impl RoomService {
|
||||
.filter(project::Column::Name.eq(name.clone()))
|
||||
.one(&self.db)
|
||||
.await
|
||||
.inspect_err(|e| {
|
||||
tracing::warn!(error = %e, project_name = %name, "utils_find_project_by_name: DB error");
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
{
|
||||
Some(project) => Ok(project),
|
||||
None => match project_history_name::Entity::find()
|
||||
.filter(project_history_name::Column::HistoryName.eq(name))
|
||||
.filter(project_history_name::Column::HistoryName.eq(name.clone()))
|
||||
.one(&self.db)
|
||||
.await
|
||||
.inspect_err(|e| tracing::warn!(error = %e, name = %name, "project_history_name lookup failed"))
|
||||
.ok()
|
||||
.flatten()
|
||||
{
|
||||
@ -291,6 +298,7 @@ impl RoomService {
|
||||
project::Entity::find_by_id(uid)
|
||||
.one(&self.db)
|
||||
.await
|
||||
.inspect_err(|e| tracing::warn!(error = %e, project_uid = %uid, "utils_find_project_by_uid: DB error"))
|
||||
.ok()
|
||||
.flatten()
|
||||
.ok_or_else(|| RoomError::NotFound("Project not found".to_string()))
|
||||
@ -304,6 +312,7 @@ impl RoomService {
|
||||
let project = project::Entity::find_by_id(project_uid)
|
||||
.one(&self.db)
|
||||
.await
|
||||
.inspect_err(|e| tracing::warn!(error = %e, project_uid = %project_uid, "check_project_access: DB error"))
|
||||
.ok()
|
||||
.flatten()
|
||||
.ok_or_else(|| RoomError::NotFound("Project not found".to_string()))?;
|
||||
@ -352,36 +361,11 @@ impl RoomService {
|
||||
}
|
||||
|
||||
pub(crate) fn sanitize_content(content: &str) -> String {
|
||||
use std::sync::LazyLock;
|
||||
|
||||
static SCRIPT_RE: LazyLock<regex_lite::Regex, fn() -> regex_lite::Regex> =
|
||||
LazyLock::new(|| regex_lite::Regex::new(r"(?i)<script[^>]*>.*?</script>").unwrap());
|
||||
static STYLE_RE: LazyLock<regex_lite::Regex, fn() -> regex_lite::Regex> =
|
||||
LazyLock::new(|| regex_lite::Regex::new(r"(?i)<style[^>]*>.*?</style>").unwrap());
|
||||
static ONERROR_RE: LazyLock<regex_lite::Regex, fn() -> regex_lite::Regex> =
|
||||
LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bonerror\s*=").unwrap());
|
||||
static ONLOAD_RE: LazyLock<regex_lite::Regex, fn() -> regex_lite::Regex> =
|
||||
LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bonload\s*=").unwrap());
|
||||
static ONCLICK_RE: LazyLock<regex_lite::Regex, fn() -> regex_lite::Regex> =
|
||||
LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bonclick\s*=").unwrap());
|
||||
static ONMOUSEOVER_RE: LazyLock<regex_lite::Regex, fn() -> regex_lite::Regex> =
|
||||
LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bonmouseover\s*=").unwrap());
|
||||
static JAVASCRIPT_RE: LazyLock<regex_lite::Regex, fn() -> regex_lite::Regex> =
|
||||
LazyLock::new(|| regex_lite::Regex::new(r"(?i)javascript:").unwrap());
|
||||
static DATA_RE: LazyLock<regex_lite::Regex, fn() -> regex_lite::Regex> =
|
||||
LazyLock::new(|| regex_lite::Regex::new(r"(?i)data:").unwrap());
|
||||
|
||||
let mut result = content.to_string();
|
||||
result = SCRIPT_RE.replace_all(&result, "").to_string();
|
||||
result = STYLE_RE.replace_all(&result, "").to_string();
|
||||
result = ONERROR_RE.replace_all(&result, "blocked=").to_string();
|
||||
result = ONLOAD_RE.replace_all(&result, "blocked=").to_string();
|
||||
result = ONCLICK_RE.replace_all(&result, "blocked=").to_string();
|
||||
result = ONMOUSEOVER_RE.replace_all(&result, "blocked=").to_string();
|
||||
result = JAVASCRIPT_RE.replace_all(&result, "blocked:").to_string();
|
||||
result = DATA_RE.replace_all(&result, "blocked:").to_string();
|
||||
|
||||
result
|
||||
// Use ammonia for HTML sanitization (whitelist approach).
|
||||
// Only allows safe tags: <a>, <b>, <i>, <code>, <pre>, <blockquote>, <p>, <br>, <strong>, <em>, <ul>, <ol>, <li>
|
||||
// All other tags (including <script>, <iframe>, <style>) are stripped.
|
||||
// Event handlers (onerror, onclick, etc.) are automatically removed.
|
||||
ammonia::clean(content)
|
||||
}
|
||||
|
||||
pub async fn resolve_display_name(
|
||||
@ -396,9 +380,11 @@ impl RoomService {
|
||||
ai_model::Entity::find_by_id(mid)
|
||||
.one(&self.db)
|
||||
.await
|
||||
.inspect_err(|e| tracing::warn!(error = %e, model_id = %mid, "resolve_display_name: AI model lookup failed"))
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|m| m.name)
|
||||
.or_else(|| Some(format!("AI({})", &mid.to_string()[..8])))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@ -409,6 +395,7 @@ impl RoomService {
|
||||
.filter(user_model::Column::Uid.eq(sender_id))
|
||||
.one(&self.db)
|
||||
.await
|
||||
.inspect_err(|e| tracing::warn!(error = %e, user_id = %sender_id, "resolve_display_name: user lookup failed"))
|
||||
.ok()
|
||||
.flatten();
|
||||
user.map(|u| u.display_name.unwrap_or_else(|| u.username))
|
||||
@ -418,6 +405,7 @@ impl RoomService {
|
||||
}
|
||||
};
|
||||
|
||||
let chunked = super::RoomMessageResponse::detect_chunked(&msg.thinking_content);
|
||||
super::RoomMessageResponse {
|
||||
id: msg.id,
|
||||
seq: msg.seq,
|
||||
@ -429,6 +417,7 @@ impl RoomService {
|
||||
content: msg.content,
|
||||
content_type: msg.content_type.to_string(),
|
||||
thinking_content: msg.thinking_content,
|
||||
thinking_is_chunked: chunked,
|
||||
edited_at: msg.edited_at,
|
||||
send_at: msg.send_at,
|
||||
revoked: msg.revoked,
|
||||
@ -438,4 +427,279 @@ impl RoomService {
|
||||
attachment_ids: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current version of a room using Redis.
|
||||
/// Returns 0 if no version has been set (new rooms start at 1).
|
||||
pub(crate) async fn get_room_version(&self, room_id: Uuid) -> Result<i64, RoomError> {
|
||||
let version_key = format!("room:version:{}", room_id);
|
||||
let mut conn = self.cache.conn().await.map_err(|e| {
|
||||
RoomError::Internal(format!("failed to get redis for version: {}", e))
|
||||
})?;
|
||||
let version: Option<i64> = redis::cmd("GET")
|
||||
.arg(&version_key)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.map_err(|e| RoomError::Internal(format!("version GET: {}", e)))?;
|
||||
Ok(version.unwrap_or(0))
|
||||
}
|
||||
|
||||
/// Atomically increment the room version and return the new value.
|
||||
/// Called on every room mutation (rename, move, delete).
|
||||
pub(crate) async fn increment_room_version(&self, room_id: Uuid) -> Result<i64, RoomError> {
|
||||
Self::raw_increment_room_version(&self.cache, room_id).await
|
||||
}
|
||||
|
||||
/// Static helper so it can be called from `room_create` without `&self`.
|
||||
pub(crate) async fn raw_increment_room_version(
|
||||
cache: &db::cache::AppCache,
|
||||
room_id: Uuid,
|
||||
) -> Result<i64, RoomError> {
|
||||
let version_key = format!("room:version:{}", room_id);
|
||||
let mut conn = cache.conn().await.map_err(|e| {
|
||||
RoomError::Internal(format!("failed to get redis for version: {}", e))
|
||||
})?;
|
||||
let version: i64 = redis::cmd("INCR")
|
||||
.arg(&version_key)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.map_err(|e| RoomError::Internal(format!("version INCR: {}", e)))?;
|
||||
Ok(version)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_room_member_role_valid() {
|
||||
assert!(matches!(
|
||||
RoomService::parse_room_member_role("owner").unwrap(),
|
||||
RoomMemberRole::Owner
|
||||
));
|
||||
assert!(matches!(
|
||||
RoomService::parse_room_member_role("admin").unwrap(),
|
||||
RoomMemberRole::Admin
|
||||
));
|
||||
assert!(matches!(
|
||||
RoomService::parse_room_member_role("member").unwrap(),
|
||||
RoomMemberRole::Member
|
||||
));
|
||||
assert!(matches!(
|
||||
RoomService::parse_room_member_role("guest").unwrap(),
|
||||
RoomMemberRole::Guest
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_room_member_role_invalid() {
|
||||
assert!(RoomService::parse_room_member_role("superadmin").is_err());
|
||||
assert!(RoomService::parse_room_member_role("").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_message_content_type_valid() {
|
||||
assert!(matches!(
|
||||
RoomService::parse_message_content_type(Some("text".into())).unwrap(),
|
||||
MessageContentType::Text
|
||||
));
|
||||
assert!(matches!(
|
||||
RoomService::parse_message_content_type(Some("image".into())).unwrap(),
|
||||
MessageContentType::Image
|
||||
));
|
||||
assert!(matches!(
|
||||
RoomService::parse_message_content_type(Some("audio".into())).unwrap(),
|
||||
MessageContentType::Audio
|
||||
));
|
||||
assert!(matches!(
|
||||
RoomService::parse_message_content_type(Some("video".into())).unwrap(),
|
||||
MessageContentType::Video
|
||||
));
|
||||
assert!(matches!(
|
||||
RoomService::parse_message_content_type(Some("file".into())).unwrap(),
|
||||
MessageContentType::File
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_message_content_type_case_insensitive() {
|
||||
assert!(matches!(
|
||||
RoomService::parse_message_content_type(Some("TEXT".into())).unwrap(),
|
||||
MessageContentType::Text
|
||||
));
|
||||
assert!(matches!(
|
||||
RoomService::parse_message_content_type(Some("Image".into())).unwrap(),
|
||||
MessageContentType::Image
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_message_content_type_none_defaults_to_text() {
|
||||
assert!(matches!(
|
||||
RoomService::parse_message_content_type(None).unwrap(),
|
||||
MessageContentType::Text
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_message_content_type_invalid() {
|
||||
assert!(RoomService::parse_message_content_type(Some("pdf".into())).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_name_valid() {
|
||||
assert!(RoomService::validate_name("test-room", 100).is_ok());
|
||||
assert!(RoomService::validate_name("a", 100).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_name_empty() {
|
||||
assert!(RoomService::validate_name("", 100).is_err());
|
||||
assert!(RoomService::validate_name(" ", 100).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_name_too_long() {
|
||||
let long = "x".repeat(101);
|
||||
assert!(RoomService::validate_name(&long, 100).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_content_valid() {
|
||||
assert!(RoomService::validate_content("hello", 10000).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_content_empty() {
|
||||
assert!(RoomService::validate_content("", 10000).is_err());
|
||||
assert!(RoomService::validate_content(" ", 10000).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_content_too_long() {
|
||||
let long = "x".repeat(10001);
|
||||
assert!(RoomService::validate_content(&long, 10000).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_content_removes_script_tag() {
|
||||
let input = "<script>alert('xss')</script>";
|
||||
let result = RoomService::sanitize_content(input);
|
||||
assert!(!result.contains("<script>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_content_blocks_javascript_uri() {
|
||||
let input = "javascript:alert(1)";
|
||||
let result = RoomService::sanitize_content(input);
|
||||
// ammonia strips javascript: from href but preserves plain text
|
||||
assert_eq!(result, "javascript:alert(1)"); // safe in plain text
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_content_blocks_onerror() {
|
||||
let input = r#"<img src=x onerror="alert(1)">"#;
|
||||
let result = RoomService::sanitize_content(input);
|
||||
// ammonia removes event handler attributes from allowed tags
|
||||
assert!(!result.contains("onerror"));
|
||||
// ammonia keeps the img tag but with onerror removed
|
||||
assert!(result.contains("<img"));
|
||||
assert!(!result.contains("alert"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_content_preserves_safe_content() {
|
||||
let input = "Hello <strong>world</strong>";
|
||||
let result = RoomService::sanitize_content(input);
|
||||
assert!(result.contains("Hello"));
|
||||
assert!(result.contains("<strong>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_room_admin() {
|
||||
assert!(RoomService::is_room_admin(&RoomMemberRole::Owner));
|
||||
assert!(RoomService::is_room_admin(&RoomMemberRole::Admin));
|
||||
assert!(!RoomService::is_room_admin(&RoomMemberRole::Member));
|
||||
assert!(!RoomService::is_room_admin(&RoomMemberRole::Guest));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_room_event_type_from_str_roundtrip() {
|
||||
for variant in [
|
||||
crate::RoomEventType::RoomCreated,
|
||||
crate::RoomEventType::RoomDeleted,
|
||||
crate::RoomEventType::NewMessage,
|
||||
crate::RoomEventType::MessageEdited,
|
||||
crate::RoomEventType::MessageRevoked,
|
||||
crate::RoomEventType::MemberJoined,
|
||||
] {
|
||||
let s = variant.as_str();
|
||||
let parsed = crate::RoomEventType::from_str(s);
|
||||
assert_eq!(parsed, Some(variant));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_room_event_type_from_str_unknown() {
|
||||
assert_eq!(crate::RoomEventType::from_str("unknown_event"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mention_bracket_re_matches_ai_model() {
|
||||
let re = crate::service::mention_bracket_re();
|
||||
let caps: Vec<_> = re.captures_iter("@[ai:550e8400-0000-0000-0000-000000000001:GPT-4]").collect();
|
||||
assert_eq!(caps.len(), 1);
|
||||
assert_eq!(&caps[0][1], "ai");
|
||||
assert_eq!(&caps[0][2], "550e8400-0000-0000-0000-000000000001");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mention_bracket_re_matches_user() {
|
||||
let re = crate::service::mention_bracket_re();
|
||||
let caps: Vec<_> = re.captures_iter("@[user:850e8400-0000-0000-0000-000000000002:John]").collect();
|
||||
assert_eq!(caps.len(), 1);
|
||||
assert_eq!(&caps[0][1], "user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mention_bracket_re_matches_repo() {
|
||||
let re = crate::service::mention_bracket_re();
|
||||
let caps: Vec<_> = re.captures_iter("@[repo:my-repo:My Repository]").collect();
|
||||
assert_eq!(caps.len(), 1);
|
||||
assert_eq!(&caps[0][1], "repo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mention_bracket_re_no_match_plain_text() {
|
||||
let re = crate::service::mention_bracket_re();
|
||||
let caps: Vec<_> = re.captures_iter("Hello world").collect();
|
||||
assert_eq!(caps.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mention_multiple_in_same_message() {
|
||||
let re = crate::service::mention_bracket_re();
|
||||
let content = "@[ai:uuid1:Model1] and @[user:uuid2:User2]";
|
||||
let caps: Vec<_> = re.captures_iter(content).collect();
|
||||
assert_eq!(caps.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mention_tag_re_legacy_format() {
|
||||
let re = crate::service::mention_tag_re();
|
||||
let content = r#"<mention type="ai" id="model-uuid">GPT-4</mention>"#;
|
||||
let caps: Vec<_> = re.captures_iter(content).collect();
|
||||
assert_eq!(caps.len(), 1);
|
||||
assert_eq!(&caps[0][1], "ai");
|
||||
assert_eq!(&caps[0][2], "model-uuid");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mention_combined_brackets_and_tags() {
|
||||
let bracket_re = crate::service::mention_bracket_re();
|
||||
let tag_re = crate::service::mention_tag_re();
|
||||
let content = r#"@[ai:uuid1:A] <mention type="ai" id="uuid2">B</mention>"#;
|
||||
assert_eq!(bracket_re.captures_iter(content).count(), 1);
|
||||
assert_eq!(tag_re.captures_iter(content).count(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,7 +21,7 @@ pub mod ws_context;
|
||||
pub use connection::{
|
||||
PersistFn, RedisFuture, RoomConnectionManager, cleanup_dedup_cache, extract_get_redis,
|
||||
make_persist_fn, subscribe_project_room_events, subscribe_room_events,
|
||||
subscribe_task_events_fn,
|
||||
subscribe_room_stream_chunk_events, subscribe_task_events_fn,
|
||||
};
|
||||
pub use draft_and_history::{
|
||||
DraftResponse, DraftSaveRequest, MentionNotificationResponse, MessageEditHistoryEntry,
|
||||
|
||||
@ -78,9 +78,15 @@ impl RoomService {
|
||||
.map(|msg| {
|
||||
let sender_type = msg.sender_type.to_string();
|
||||
let display_name = match sender_type.as_str() {
|
||||
"ai" => msg.model_id.and_then(|id| ai_names.get(&id).cloned()),
|
||||
"ai" => msg.model_id.and_then(|id| {
|
||||
ai_names
|
||||
.get(&id)
|
||||
.cloned()
|
||||
.or_else(|| Some(format!("AI({})", &id.to_string()[..8])))
|
||||
}),
|
||||
_ => msg.sender_id.and_then(|id| users.get(&id).cloned()),
|
||||
};
|
||||
let chunked = super::RoomMessageResponse::detect_chunked(&msg.thinking_content);
|
||||
super::RoomMessageResponse {
|
||||
id: msg.id,
|
||||
seq: msg.seq,
|
||||
@ -93,6 +99,7 @@ impl RoomService {
|
||||
content: msg.content,
|
||||
content_type: msg.content_type.to_string(),
|
||||
thinking_content: msg.thinking_content,
|
||||
thinking_is_chunked: chunked,
|
||||
edited_at: msg.edited_at,
|
||||
send_at: msg.send_at,
|
||||
revoked: msg.revoked,
|
||||
@ -185,9 +192,6 @@ impl RoomService {
|
||||
let db = &self.db;
|
||||
let txn = db.begin().await?;
|
||||
|
||||
self.queue.publish(room_id, envelope).await?;
|
||||
self.room_manager.metrics.messages_sent.increment(1);
|
||||
|
||||
let mut room_active: room::ActiveModel = room_model.clone().into();
|
||||
room_active.last_msg_at = Set(now);
|
||||
room_active.update(&txn).await?;
|
||||
@ -224,6 +228,10 @@ impl RoomService {
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
// Publish to Redis Stream AFTER commit so DB has the data first
|
||||
self.queue.publish(room_id, envelope).await?;
|
||||
self.room_manager.metrics.messages_sent.increment(1);
|
||||
|
||||
// Link uploaded attachments to this message
|
||||
let attachment_ids = request.attachment_ids.clone();
|
||||
if !attachment_ids.is_empty() {
|
||||
@ -320,6 +328,7 @@ impl RoomService {
|
||||
content: request.content,
|
||||
content_type: content_type_str,
|
||||
thinking_content: None,
|
||||
thinking_is_chunked: false,
|
||||
edited_at: None,
|
||||
send_at: now,
|
||||
revoked: None,
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use metrics::{describe_counter, describe_gauge, describe_histogram, Counter, Gauge, Histogram, Unit};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct RoomMetrics {
|
||||
@ -24,24 +21,6 @@ pub struct RoomMetrics {
|
||||
pub ws_heartbeat_sent_total: Counter,
|
||||
pub ws_heartbeat_timeout_total: Counter,
|
||||
pub ws_idle_timeout_total: Counter,
|
||||
// Atomic backing for snapshot reads (all values stored as f64 for gauges, u64 for counters)
|
||||
pub _rooms_online_val: AtomicU64,
|
||||
pub _users_online_val: AtomicU64,
|
||||
pub _ws_connections_active_val: AtomicU64,
|
||||
pub _ws_connections_total_val: AtomicU64,
|
||||
pub _ws_disconnections_total_val: AtomicU64,
|
||||
pub _messages_sent_val: AtomicU64,
|
||||
pub _messages_persisted_val: AtomicU64,
|
||||
pub _messages_persist_failed_val: AtomicU64,
|
||||
pub _broadcasts_sent_val: AtomicU64,
|
||||
pub _broadcasts_dropped_val: AtomicU64,
|
||||
pub _duplicates_skipped_val: AtomicU64,
|
||||
pub _redis_publish_failed_val: AtomicU64,
|
||||
pub _ws_rate_limit_hits_val: AtomicU64,
|
||||
pub _ws_auth_failures_val: AtomicU64,
|
||||
pub _ws_heartbeat_sent_total_val: AtomicU64,
|
||||
pub _ws_heartbeat_timeout_total_val: AtomicU64,
|
||||
pub _ws_idle_timeout_total_val: AtomicU64,
|
||||
}
|
||||
|
||||
impl Default for RoomMetrics {
|
||||
@ -150,23 +129,6 @@ impl Default for RoomMetrics {
|
||||
ws_heartbeat_sent_total: metrics::counter!("room_ws_heartbeat_sent_total"),
|
||||
ws_heartbeat_timeout_total: metrics::counter!("room_ws_heartbeat_timeout_total"),
|
||||
ws_idle_timeout_total: metrics::counter!("room_ws_idle_timeout_total"),
|
||||
_rooms_online_val: AtomicU64::new(0),
|
||||
_users_online_val: AtomicU64::new(0),
|
||||
_ws_connections_active_val: AtomicU64::new(0),
|
||||
_ws_connections_total_val: AtomicU64::new(0),
|
||||
_ws_disconnections_total_val: AtomicU64::new(0),
|
||||
_messages_sent_val: AtomicU64::new(0),
|
||||
_messages_persisted_val: AtomicU64::new(0),
|
||||
_messages_persist_failed_val: AtomicU64::new(0),
|
||||
_broadcasts_sent_val: AtomicU64::new(0),
|
||||
_broadcasts_dropped_val: AtomicU64::new(0),
|
||||
_duplicates_skipped_val: AtomicU64::new(0),
|
||||
_redis_publish_failed_val: AtomicU64::new(0),
|
||||
_ws_rate_limit_hits_val: AtomicU64::new(0),
|
||||
_ws_auth_failures_val: AtomicU64::new(0),
|
||||
_ws_heartbeat_sent_total_val: AtomicU64::new(0),
|
||||
_ws_heartbeat_timeout_total_val: AtomicU64::new(0),
|
||||
_ws_idle_timeout_total_val: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -201,33 +163,9 @@ impl RoomMetrics {
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn cleanup_stale_rooms(&self, _active_room_ids: &[Uuid]) {
|
||||
// Per-room metrics are registered on-demand; no cleanup needed.
|
||||
}
|
||||
|
||||
pub fn into_arc(self) -> Arc<RoomMetrics> {
|
||||
Arc::new(self)
|
||||
}
|
||||
|
||||
/// Returns a snapshot of all current gauge and counter values as a flat map.
|
||||
pub fn snapshot(&self) -> HashMap<String, serde_json::Value> {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("room_online_rooms".into(), serde_json::json!(self._rooms_online_val.load(Ordering::Relaxed) as f64));
|
||||
m.insert("room_online_users".into(), serde_json::json!(self._users_online_val.load(Ordering::Relaxed) as f64));
|
||||
m.insert("room_ws_connections_active".into(), serde_json::json!(self._ws_connections_active_val.load(Ordering::Relaxed) as f64));
|
||||
m.insert("room_ws_connections_total".into(), serde_json::json!(self._ws_connections_total_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_ws_disconnections_total".into(), serde_json::json!(self._ws_disconnections_total_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_messages_sent_total".into(), serde_json::json!(self._messages_sent_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_messages_persisted_total".into(), serde_json::json!(self._messages_persisted_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_messages_persist_failed_total".into(), serde_json::json!(self._messages_persist_failed_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_broadcasts_sent_total".into(), serde_json::json!(self._broadcasts_sent_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_broadcasts_dropped_total".into(), serde_json::json!(self._broadcasts_dropped_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_duplicates_skipped_total".into(), serde_json::json!(self._duplicates_skipped_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_redis_publish_failed_total".into(), serde_json::json!(self._redis_publish_failed_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_ws_rate_limit_hits_total".into(), serde_json::json!(self._ws_rate_limit_hits_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_ws_auth_failures_total".into(), serde_json::json!(self._ws_auth_failures_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_ws_heartbeat_sent_total".into(), serde_json::json!(self._ws_heartbeat_sent_total_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_ws_heartbeat_timeout_total".into(), serde_json::json!(self._ws_heartbeat_timeout_total_val.load(Ordering::Relaxed)));
|
||||
m.insert("room_ws_idle_timeout_total".into(), serde_json::json!(self._ws_idle_timeout_total_val.load(Ordering::Relaxed)));
|
||||
m
|
||||
}
|
||||
}
|
||||
|
||||
@ -310,6 +310,7 @@ impl RoomService {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let chunked = super::RoomMessageResponse::detect_chunked(&msg.thinking_content);
|
||||
super::RoomMessageResponse {
|
||||
id: msg.id,
|
||||
seq: msg.seq,
|
||||
@ -322,6 +323,7 @@ impl RoomService {
|
||||
content: msg.content,
|
||||
content_type: msg.content_type.to_string(),
|
||||
thinking_content: msg.thinking_content,
|
||||
thinking_is_chunked: chunked,
|
||||
edited_at: msg.edited_at,
|
||||
send_at: msg.send_at,
|
||||
revoked: msg.revoked,
|
||||
|
||||
@ -3,8 +3,8 @@ use crate::service::RoomService;
|
||||
use crate::ws_context::WsUserContext;
|
||||
use chrono::Utc;
|
||||
use models::rooms::{
|
||||
RoomMemberRole, room, room_ai, room_category, room_member, room_message, room_pin, room_thread,
|
||||
room_message_reaction, room_message_edit_history, room_notifications,
|
||||
RoomMemberRole, room, room_ai, room_attachment, room_category, room_member, room_message,
|
||||
room_message_edit_history, room_message_reaction, room_notifications, room_pin, room_thread,
|
||||
};
|
||||
use models::projects::{project_members, MemberRole as Role};
|
||||
use queue::ProjectRoomEvent;
|
||||
@ -12,8 +12,9 @@ use sea_orm::*;
|
||||
use uuid::Uuid;
|
||||
|
||||
impl RoomService {
|
||||
/// Cache TTL for room list (in seconds).
|
||||
const ROOM_LIST_CACHE_TTL: u64 = 60;
|
||||
/// Cache TTL for room list (in seconds). Kept short to avoid
|
||||
/// stale data without needing expensive SCAN-based invalidation.
|
||||
const ROOM_LIST_CACHE_TTL: u64 = 15;
|
||||
|
||||
pub async fn room_list(
|
||||
&self,
|
||||
@ -226,7 +227,10 @@ impl RoomService {
|
||||
Some(room_model.id),
|
||||
);
|
||||
|
||||
Ok(super::RoomResponse::from(room_model))
|
||||
let version = Self::raw_increment_room_version(&self.cache, room_model.id).await?;
|
||||
let mut resp = super::RoomResponse::from(room_model);
|
||||
resp.version = version;
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
pub async fn room_get(
|
||||
@ -237,7 +241,10 @@ impl RoomService {
|
||||
let user_id = ctx.user_id;
|
||||
let model = self.find_room_or_404(room_id).await?;
|
||||
self.ensure_room_visible_for_user(&model, user_id).await?;
|
||||
Ok(super::RoomResponse::from(model))
|
||||
let version = self.get_room_version(room_id).await?;
|
||||
let mut resp = super::RoomResponse::from(model);
|
||||
resp.version = version;
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
pub async fn room_update(
|
||||
@ -312,7 +319,10 @@ impl RoomService {
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(super::RoomResponse::from(updated))
|
||||
let version = self.increment_room_version(room_id).await?;
|
||||
let mut resp = super::RoomResponse::from(updated);
|
||||
resp.version = version;
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
pub async fn room_delete(&self, room_id: Uuid, ctx: &WsUserContext) -> Result<(), RoomError> {
|
||||
@ -323,6 +333,11 @@ impl RoomService {
|
||||
|
||||
let txn = self.db.begin().await?;
|
||||
|
||||
room_attachment::Entity::delete_many()
|
||||
.filter(room_attachment::Column::Room.eq(room_id))
|
||||
.exec(&txn)
|
||||
.await?;
|
||||
|
||||
room_message::Entity::delete_many()
|
||||
.filter(room_message::Column::Room.eq(room_id))
|
||||
.exec(&txn)
|
||||
@ -416,48 +431,9 @@ impl RoomService {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Invalidate all room list cache entries for a project.
|
||||
/// Cache entries expire after ROOM_LIST_CACHE_TTL seconds.
|
||||
/// No explicit invalidation needed — the short TTL handles staleness.
|
||||
async fn invalidate_room_list_cache(&self, project_id: Uuid) {
|
||||
let pattern = format!("room:list:{}:*", project_id);
|
||||
if let Ok(mut conn) = self.cache.conn().await {
|
||||
// Use SCAN to find matching keys, then DELETE them
|
||||
let mut cursor: u64 = 0;
|
||||
loop {
|
||||
let (new_cursor, keys): (u64, Vec<String>) = match redis::cmd("SCAN")
|
||||
.arg(cursor)
|
||||
.arg("MATCH")
|
||||
.arg(&pattern)
|
||||
.arg("COUNT")
|
||||
.arg(100)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "invalidate_room_list_cache: SCAN failed");
|
||||
break;
|
||||
}
|
||||
};
|
||||
cursor = new_cursor;
|
||||
|
||||
if !keys.is_empty() {
|
||||
// Delete keys in batches
|
||||
let keys_refs: Vec<&str> = keys.iter().map(|s| s.as_str()).collect();
|
||||
if let Err(e) = redis::cmd("DEL")
|
||||
.arg(&keys_refs)
|
||||
.query_async::<i64>(&mut conn)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(error = %e, "invalidate_room_list_cache: DEL failed");
|
||||
} else {
|
||||
tracing::debug!(keys_count = keys.len(), "invalidate_room_list_cache: deleted");
|
||||
}
|
||||
}
|
||||
|
||||
if cursor == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::debug!(project_id = %project_id, "room_list cache: relying on TTL expiry");
|
||||
}
|
||||
}
|
||||
|
||||
@ -135,8 +135,24 @@ pub async fn acquire_room_ai_lock(
|
||||
tracing::warn!(
|
||||
room_id = %room_id,
|
||||
elapsed_ms = start.elapsed().as_millis(),
|
||||
"RoomAiLock: timeout waiting for lock"
|
||||
"RoomAiLock: timeout waiting for lock, cleaning up"
|
||||
);
|
||||
// Clean up our own ZSET entry and ticket to prevent ZSET leak
|
||||
if let Ok(mut conn) = cache.conn().await {
|
||||
let _: i32 = redis::cmd("ZREM")
|
||||
.arg(&queue_key)
|
||||
.arg(&request_uid)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.inspect_err(|e| tracing::warn!(error = %e, "timeout ZREM failed"))
|
||||
.unwrap_or(0);
|
||||
let _: i32 = redis::cmd("DEL")
|
||||
.arg(&ticket_key)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.inspect_err(|e| tracing::warn!(error = %e, "timeout DEL ticket failed"))
|
||||
.unwrap_or(0);
|
||||
}
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
@ -183,6 +199,28 @@ pub async fn acquire_room_ai_lock(
|
||||
acquired: true,
|
||||
}));
|
||||
}
|
||||
|
||||
// Lock exists — check if it's stale (previous owner crashed).
|
||||
// PTTL returns -2 if key does not exist, -1 if no expiry,
|
||||
// or remaining TTL in ms if still alive.
|
||||
let pttl: i64 = redis::cmd("PTTL")
|
||||
.arg(&lock_key)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.map_err(|e| RoomError::Internal(format!("PTTL: {}", e)))?;
|
||||
|
||||
if pttl == -1 {
|
||||
// Key exists but has no expiry — should not happen with PX, force delete
|
||||
tracing::warn!(
|
||||
lock_key = %lock_key,
|
||||
"RoomAiLock: lock exists without TTL, force releasing"
|
||||
);
|
||||
let _: i32 = redis::cmd("DEL")
|
||||
.arg(&lock_key)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.map_err(|e| RoomError::Internal(format!("DEL stale lock: {}", e)))?;
|
||||
}
|
||||
} else {
|
||||
let head_ticket_key = format!("ai:room:queue:ticket:{}:{}", room_id, head_uid);
|
||||
let head_exists: i32 = redis::cmd("EXISTS")
|
||||
|
||||
183
libs/room/src/service/ai_mode_dispatch.rs
Normal file
183
libs/room/src/service/ai_mode_dispatch.rs
Normal file
@ -0,0 +1,183 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use db::cache::AppCache;
|
||||
use db::database::AppDatabase;
|
||||
use queue::MessageProducer;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::ai_mode_streaming::run_mode_streaming;
|
||||
use crate::connection::RoomConnectionManager;
|
||||
use agent::chat::{AiRequest, ChatService};
|
||||
|
||||
// ── CoT ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn dispatch_cot(
|
||||
chat_service: Arc<ChatService>,
|
||||
request: AiRequest,
|
||||
room_id: Uuid,
|
||||
project_id: Uuid,
|
||||
model_id: Uuid,
|
||||
lock_guard: crate::room_ai_queue::RoomAiLockGuard,
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
queue: MessageProducer,
|
||||
room_manager: Arc<RoomConnectionManager>,
|
||||
) {
|
||||
let req = request;
|
||||
run_mode_streaming(
|
||||
chat_service,
|
||||
req,
|
||||
room_id,
|
||||
project_id,
|
||||
model_id,
|
||||
lock_guard,
|
||||
db,
|
||||
cache,
|
||||
queue,
|
||||
room_manager,
|
||||
"cot",
|
||||
Box::new(|chat_svc, ai_req, on_chunk| {
|
||||
Box::pin(async move {
|
||||
let result = chat_svc
|
||||
.process_cot(&ai_req, |step| {
|
||||
let on_chunk = on_chunk.clone();
|
||||
async move {
|
||||
let (ct, content, is_final) = match step {
|
||||
agent::modes::cot::CotStep::Thought(t) => {
|
||||
("thinking".to_string(), t, false)
|
||||
}
|
||||
agent::modes::cot::CotStep::Action { name, args } => {
|
||||
("tool_call".to_string(),
|
||||
serde_json::json!({"name": name, "arguments": args}).to_string(),
|
||||
false)
|
||||
}
|
||||
agent::modes::cot::CotStep::Observation(o) => {
|
||||
("tool_result".to_string(), o, false)
|
||||
}
|
||||
agent::modes::cot::CotStep::Answer(a) => {
|
||||
("answer".to_string(), a, true)
|
||||
}
|
||||
};
|
||||
on_chunk(ct, content, is_final).await;
|
||||
}
|
||||
})
|
||||
.await;
|
||||
result
|
||||
})
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// ── ReWOO ────────────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn dispatch_rewoo(
|
||||
chat_service: Arc<ChatService>,
|
||||
request: AiRequest,
|
||||
room_id: Uuid,
|
||||
project_id: Uuid,
|
||||
model_id: Uuid,
|
||||
lock_guard: crate::room_ai_queue::RoomAiLockGuard,
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
queue: MessageProducer,
|
||||
room_manager: Arc<RoomConnectionManager>,
|
||||
) {
|
||||
let req = request;
|
||||
run_mode_streaming(
|
||||
chat_service.clone(),
|
||||
req,
|
||||
room_id,
|
||||
project_id,
|
||||
model_id,
|
||||
lock_guard,
|
||||
db,
|
||||
cache,
|
||||
queue,
|
||||
room_manager,
|
||||
"rewoo",
|
||||
Box::new(|chat_svc, ai_req, on_chunk| {
|
||||
Box::pin(async move {
|
||||
let result = chat_svc
|
||||
.process_rewoo(&ai_req, |step| {
|
||||
let on_chunk = on_chunk.clone();
|
||||
async move {
|
||||
let (ct, content, is_final) = match step {
|
||||
agent::modes::rewoo::ReWooStep::Plan { raw, .. } => {
|
||||
("tool_call".to_string(), raw, false)
|
||||
}
|
||||
agent::modes::rewoo::ReWooStep::Execution { tool_name, result } => {
|
||||
("tool_result".to_string(), format!("{}: {}", tool_name, result), false)
|
||||
}
|
||||
agent::modes::rewoo::ReWooStep::Synthesis(s) => {
|
||||
("answer".to_string(), s, true)
|
||||
}
|
||||
};
|
||||
on_chunk(ct, content, is_final).await;
|
||||
}
|
||||
})
|
||||
.await;
|
||||
result
|
||||
})
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// ── Reflexion ────────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn dispatch_reflexion(
|
||||
chat_service: Arc<ChatService>,
|
||||
request: AiRequest,
|
||||
room_id: Uuid,
|
||||
project_id: Uuid,
|
||||
model_id: Uuid,
|
||||
lock_guard: crate::room_ai_queue::RoomAiLockGuard,
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
queue: MessageProducer,
|
||||
room_manager: Arc<RoomConnectionManager>,
|
||||
) {
|
||||
let req = request;
|
||||
run_mode_streaming(
|
||||
chat_service.clone(),
|
||||
req,
|
||||
room_id,
|
||||
project_id,
|
||||
model_id,
|
||||
lock_guard,
|
||||
db,
|
||||
cache,
|
||||
queue,
|
||||
room_manager,
|
||||
"reflexion",
|
||||
Box::new(|chat_svc, ai_req, on_chunk| {
|
||||
Box::pin(async move {
|
||||
let result = chat_svc
|
||||
.process_reflexion(&ai_req, |step| {
|
||||
let on_chunk = on_chunk.clone();
|
||||
async move {
|
||||
let (ct, content, is_final) = match step {
|
||||
agent::modes::reflexion::ReflexionStep::Generate(s) => {
|
||||
("thinking".to_string(), s, false)
|
||||
}
|
||||
agent::modes::reflexion::ReflexionStep::Critique(s) => {
|
||||
("tool_call".to_string(), s, false)
|
||||
}
|
||||
agent::modes::reflexion::ReflexionStep::Revise(s) => {
|
||||
("thinking".to_string(), s, false)
|
||||
}
|
||||
agent::modes::reflexion::ReflexionStep::Final(s) => {
|
||||
("answer".to_string(), s, true)
|
||||
}
|
||||
};
|
||||
on_chunk(ct, content, is_final).await;
|
||||
}
|
||||
}, 3)
|
||||
.await;
|
||||
result
|
||||
})
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
324
libs/room/src/service/ai_mode_streaming.rs
Normal file
324
libs/room/src/service/ai_mode_streaming.rs
Normal file
@ -0,0 +1,324 @@
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use chrono::Utc;
|
||||
use db::cache::AppCache;
|
||||
use db::database::AppDatabase;
|
||||
use models::rooms::room_ai;
|
||||
use queue::{MessageProducer, ProjectRoomEvent, RoomMessageEnvelope};
|
||||
use sea_orm::{sea_query::Expr, ColumnTrait, EntityTrait, ExprTrait, QueryFilter};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::sequence::next_room_message_seq_internal;
|
||||
use crate::connection::RoomConnectionManager;
|
||||
use agent::chat::{AiRequest, ChatService};
|
||||
|
||||
pub type RunModeFn = Box<
|
||||
dyn FnOnce(
|
||||
Arc<ChatService>,
|
||||
AiRequest,
|
||||
Arc<dyn Fn(String, String, bool) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>,
|
||||
) -> Pin<Box<dyn std::future::Future<Output = Result<(String, i64, i64), agent::AgentError>> + Send>>
|
||||
+ Send,
|
||||
>;
|
||||
|
||||
pub async fn run_mode_streaming(
|
||||
chat_service: Arc<ChatService>,
|
||||
request: AiRequest,
|
||||
room_id: Uuid,
|
||||
project_id: Uuid,
|
||||
model_id: Uuid,
|
||||
lock_guard: crate::room_ai_queue::RoomAiLockGuard,
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
queue: MessageProducer,
|
||||
room_manager: Arc<RoomConnectionManager>,
|
||||
mode_name_str: &str,
|
||||
run: RunModeFn,
|
||||
) {
|
||||
let mode_name = mode_name_str.to_string();
|
||||
use queue::RoomMessageStreamChunkEvent;
|
||||
|
||||
let streaming_msg_id = Uuid::now_v7();
|
||||
let seq = match next_room_message_seq_internal(room_id, &db, &cache).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to get seq for {} streaming", mode_name);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let _ = room_manager
|
||||
.register_stream_channel(streaming_msg_id)
|
||||
.await;
|
||||
|
||||
let initial_event = RoomMessageStreamChunkEvent {
|
||||
message_id: streaming_msg_id,
|
||||
room_id,
|
||||
seq: 0,
|
||||
content: String::new(),
|
||||
done: false,
|
||||
error: None,
|
||||
display_name: Some(request.model.name.clone()),
|
||||
chunk_type: Some("thinking".to_string()),
|
||||
};
|
||||
room_manager.broadcast_stream_chunk(initial_event).await;
|
||||
|
||||
let room_id_inner = room_id;
|
||||
let project_id_inner = project_id;
|
||||
let now = Utc::now();
|
||||
let sender_type = "ai".to_string();
|
||||
let ai_display_name = request.model.name.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _lock_guard = lock_guard;
|
||||
let cancel = room_manager.register_stream_cancel(room_id_inner).await;
|
||||
|
||||
let ai_typing_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001")
|
||||
.expect("constant UUID should always parse");
|
||||
|
||||
let typing_start = queue::TypingEvent {
|
||||
room_id: room_id_inner,
|
||||
user_id: ai_typing_id,
|
||||
username: ai_display_name.clone(),
|
||||
avatar_url: None,
|
||||
action: "start".to_string(),
|
||||
sender_type: Some("ai".to_string()),
|
||||
};
|
||||
room_manager.broadcast_typing(room_id_inner, typing_start.clone()).await;
|
||||
|
||||
let (typing_cancel_tx, typing_cancel_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
let typing_renew_handle = tokio::spawn({
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
let mgr = room_manager.clone();
|
||||
let rid = room_id_inner;
|
||||
let evt = typing_start.clone();
|
||||
async move {
|
||||
tokio::select! {
|
||||
_ = typing_cancel_rx => {}
|
||||
_ = async {
|
||||
loop {
|
||||
interval.tick().await;
|
||||
mgr.broadcast_typing(rid, evt.clone()).await;
|
||||
}
|
||||
} => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let chunk_seq: Arc<std::sync::atomic::AtomicU64> =
|
||||
Arc::new(std::sync::atomic::AtomicU64::new(1));
|
||||
let all_chunks: Arc<std::sync::Mutex<Vec<(String, String)>>> =
|
||||
Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
let answer_buffer: Arc<std::sync::Mutex<String>> =
|
||||
Arc::new(std::sync::Mutex::new(String::new()));
|
||||
|
||||
fn lock_or_recover<T>(mutex: &std::sync::Mutex<T>) -> std::sync::MutexGuard<'_, T> {
|
||||
mutex.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
let on_chunk = {
|
||||
let room_manager = room_manager.clone();
|
||||
let queue = queue.clone();
|
||||
let cancel = cancel.clone();
|
||||
let streaming_msg_id = streaming_msg_id;
|
||||
let room_id = room_id_inner;
|
||||
let chunk_seq = chunk_seq.clone();
|
||||
let ai_display_name = ai_display_name.clone();
|
||||
let all_chunks = all_chunks.clone();
|
||||
let answer_buffer = answer_buffer.clone();
|
||||
|
||||
Arc::new(move |chunk_type: String, content: String, is_answer: bool| {
|
||||
let room_manager = room_manager.clone();
|
||||
let queue = queue.clone();
|
||||
let cancel = cancel.clone();
|
||||
let chunk_seq = chunk_seq.clone();
|
||||
let ai_display_name = ai_display_name.clone();
|
||||
let all_chunks = all_chunks.clone();
|
||||
let answer_buffer = answer_buffer.clone();
|
||||
|
||||
{
|
||||
let mut chunks = lock_or_recover(&all_chunks);
|
||||
chunks.push((chunk_type.clone(), content.clone()));
|
||||
}
|
||||
if is_answer {
|
||||
let mut ab = lock_or_recover(&answer_buffer);
|
||||
ab.push_str(&content);
|
||||
}
|
||||
|
||||
let current_seq = chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let event = RoomMessageStreamChunkEvent {
|
||||
message_id: streaming_msg_id,
|
||||
room_id,
|
||||
seq: current_seq,
|
||||
content: content.clone(),
|
||||
done: false,
|
||||
error: None,
|
||||
display_name: Some(ai_display_name),
|
||||
chunk_type: Some(chunk_type),
|
||||
};
|
||||
|
||||
Box::pin(async move {
|
||||
if cancel.load(std::sync::atomic::Ordering::Acquire) {
|
||||
return;
|
||||
}
|
||||
queue.publish_stream_chunk(&event).await;
|
||||
room_manager.broadcast_stream_chunk(event).await;
|
||||
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
|
||||
})
|
||||
};
|
||||
|
||||
let result = run(chat_service, request, on_chunk).await;
|
||||
|
||||
let final_stream_content = lock_or_recover(&answer_buffer).clone();
|
||||
let final_event = RoomMessageStreamChunkEvent {
|
||||
message_id: streaming_msg_id,
|
||||
room_id: room_id_inner,
|
||||
seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
||||
content: final_stream_content.clone(),
|
||||
done: true,
|
||||
error: None,
|
||||
display_name: Some(ai_display_name.clone()),
|
||||
chunk_type: Some("answer".to_string()),
|
||||
};
|
||||
queue.publish_stream_chunk(&final_event).await;
|
||||
room_manager.broadcast_stream_chunk(final_event).await;
|
||||
|
||||
let (final_content, err_msg) = match result {
|
||||
Ok((content, _input, _output)) => (content, None),
|
||||
Err(e) => {
|
||||
let msg = "[AI 处理发生错误,请稍后再试]".to_string();
|
||||
tracing::error!(error = %e, "{} streaming failed", mode_name);
|
||||
(String::new(), Some(msg))
|
||||
}
|
||||
};
|
||||
|
||||
let all_chunks_data = lock_or_recover(&all_chunks).clone();
|
||||
let reasoning_chain: String = all_chunks_data
|
||||
.iter()
|
||||
.filter(|(t, _)| t != "answer")
|
||||
.map(|(_, c)| c.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let content_to_persist = if !final_content.is_empty() {
|
||||
final_content.clone()
|
||||
} else if !reasoning_chain.trim().is_empty() {
|
||||
format!(
|
||||
"[Agent ran through reasoning steps but did not produce a final answer.]\n{}",
|
||||
reasoning_chain.trim_end()
|
||||
)
|
||||
} else {
|
||||
String::from("[No output from reasoning agent]")
|
||||
};
|
||||
let content_to_persist = if let Some(msg) = &err_msg {
|
||||
format!("{}\n[Error: {}]", content_to_persist.trim_end(), msg)
|
||||
} else {
|
||||
content_to_persist
|
||||
};
|
||||
|
||||
let persist_content = content_to_persist.trim().to_string();
|
||||
if persist_content.is_empty() {
|
||||
let _ = typing_cancel_tx.send(());
|
||||
typing_renew_handle.abort();
|
||||
room_manager.unregister_stream_cancel(room_id_inner).await;
|
||||
room_manager.close_stream_channel(streaming_msg_id).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let thinking_content_serialized = {
|
||||
let chunks = lock_or_recover(&all_chunks);
|
||||
if chunks.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let chunks_json = serde_json::json!({
|
||||
"__chunks__": chunks.iter().map(|(t, c)| serde_json::json!({
|
||||
"type": t,
|
||||
"content": c,
|
||||
})).collect::<Vec<_>>(),
|
||||
});
|
||||
Some(chunks_json.to_string())
|
||||
}
|
||||
};
|
||||
let thinking_content_for_event = thinking_content_serialized.clone();
|
||||
|
||||
let envelope = RoomMessageEnvelope {
|
||||
id: streaming_msg_id,
|
||||
dedup_key: Some(format!("{}:{}", room_id_inner, streaming_msg_id)),
|
||||
room_id: room_id_inner,
|
||||
sender_type: sender_type.clone(),
|
||||
sender_id: None,
|
||||
model_id: Some(model_id),
|
||||
thread_id: None,
|
||||
content: persist_content.clone(),
|
||||
content_type: "text".to_string(),
|
||||
thinking_content: thinking_content_serialized,
|
||||
send_at: now,
|
||||
seq,
|
||||
in_reply_to: None,
|
||||
display_name: Some(ai_display_name.clone()),
|
||||
};
|
||||
|
||||
if let Err(e) = queue.publish(room_id_inner, envelope).await {
|
||||
tracing::error!(error = %e, "Failed to publish {} streaming message", mode_name);
|
||||
} else {
|
||||
let now = Utc::now();
|
||||
if let Err(e) = room_ai::Entity::update_many()
|
||||
.col_expr(room_ai::Column::CallCount, Expr::col(room_ai::Column::CallCount).add(1))
|
||||
.col_expr(room_ai::Column::LastCallAt, Expr::value(Some(now)))
|
||||
.filter(room_ai::Column::Room.eq(room_id_inner))
|
||||
.filter(room_ai::Column::Model.eq(model_id))
|
||||
.exec(&db)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(error = %e, "Failed to update room_ai call stats");
|
||||
}
|
||||
|
||||
let msg_event = queue::RoomMessageEvent {
|
||||
id: streaming_msg_id,
|
||||
room_id: room_id_inner,
|
||||
sender_type: sender_type.clone(),
|
||||
sender_id: None,
|
||||
thread_id: None,
|
||||
content: persist_content,
|
||||
content_type: "text".to_string(),
|
||||
thinking_content: thinking_content_for_event,
|
||||
send_at: now,
|
||||
seq,
|
||||
display_name: Some(ai_display_name.clone()),
|
||||
in_reply_to: None,
|
||||
reactions: None,
|
||||
message_id: None,
|
||||
};
|
||||
room_manager.broadcast(room_id_inner, msg_event).await;
|
||||
room_manager.metrics.messages_sent.increment(1);
|
||||
|
||||
let event = ProjectRoomEvent {
|
||||
event_type: crate::RoomEventType::NewMessage.as_str().into(),
|
||||
project_id: project_id_inner,
|
||||
room_id: Some(room_id_inner),
|
||||
category_id: None,
|
||||
message_id: Some(streaming_msg_id),
|
||||
seq: Some(seq),
|
||||
timestamp: now,
|
||||
};
|
||||
queue.publish_project_room_event(project_id_inner, event).await;
|
||||
}
|
||||
|
||||
let _ = typing_cancel_tx.send(());
|
||||
typing_renew_handle.abort();
|
||||
let typing_stop = queue::TypingEvent {
|
||||
room_id: room_id_inner,
|
||||
user_id: ai_typing_id,
|
||||
username: ai_display_name.clone(),
|
||||
avatar_url: None,
|
||||
action: "stop".to_string(),
|
||||
sender_type: Some("ai".to_string()),
|
||||
};
|
||||
room_manager.broadcast_typing(room_id_inner, typing_stop).await;
|
||||
|
||||
room_manager.unregister_stream_cancel(room_id_inner).await;
|
||||
room_manager.close_stream_channel(streaming_msg_id).await;
|
||||
});
|
||||
}
|
||||
@ -75,7 +75,7 @@ pub async fn process_message_ai_nonstreaming(
|
||||
room_id,
|
||||
project_id,
|
||||
Uuid::now_v7(),
|
||||
format!("[AI error: {}]", e),
|
||||
"[AI 处理发生错误,请稍后再试]".to_string(),
|
||||
model_id,
|
||||
Some(model_display_name),
|
||||
)
|
||||
|
||||
@ -29,7 +29,7 @@ pub async fn process_message_ai_react_nonstreaming(
|
||||
let model_display_name = request.model.name.clone();
|
||||
|
||||
let final_answer = chat_service
|
||||
.process_react(&request, |_step| {})
|
||||
.process_react(&request, |_step| async move {})
|
||||
.await;
|
||||
|
||||
match final_answer {
|
||||
@ -77,7 +77,7 @@ pub async fn process_message_ai_react_nonstreaming(
|
||||
room_id,
|
||||
project_id,
|
||||
Uuid::now_v7(),
|
||||
format!("[AI error: {}]", e),
|
||||
"[AI 处理发生错误,请稍后再试]".to_string(),
|
||||
model_id,
|
||||
Some(model_display_name),
|
||||
)
|
||||
|
||||
@ -44,9 +44,41 @@ pub async fn process_message_ai_react_streaming(
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _lock_guard = lock_guard;
|
||||
let cancel = room_manager.register_stream_cancel(room_id_inner).await;
|
||||
|
||||
let ai_typing_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001")
|
||||
.expect("constant UUID should always parse");
|
||||
|
||||
let typing_start = queue::TypingEvent {
|
||||
room_id: room_id_inner,
|
||||
user_id: ai_typing_id,
|
||||
username: ai_display_name.clone(),
|
||||
avatar_url: None,
|
||||
action: "start".to_string(),
|
||||
sender_type: Some("ai".to_string()),
|
||||
};
|
||||
room_manager.broadcast_typing(room_id_inner, typing_start.clone()).await;
|
||||
|
||||
let (typing_cancel_tx, mut typing_cancel_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
let typing_renew_handle = tokio::spawn({
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
let mgr = room_manager.clone();
|
||||
let rid = room_id_inner;
|
||||
let evt = typing_start.clone();
|
||||
async move {
|
||||
tokio::select! {
|
||||
_ = &mut typing_cancel_rx => {}
|
||||
_ = async {
|
||||
loop {
|
||||
interval.tick().await;
|
||||
mgr.broadcast_typing(rid, evt.clone()).await;
|
||||
}
|
||||
} => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Collect ordered steps for storage and streaming.
|
||||
// Using poison-recovering guards to prevent Mutex poisoning from killing the room.
|
||||
let steps: std::sync::Arc<std::sync::Mutex<Vec<(String, String)>>> =
|
||||
std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
let last_action_name: std::sync::Arc<std::sync::Mutex<String>> =
|
||||
@ -54,15 +86,17 @@ pub async fn process_message_ai_react_streaming(
|
||||
let answer_buffer: std::sync::Arc<std::sync::Mutex<String>> =
|
||||
std::sync::Arc::new(std::sync::Mutex::new(String::new()));
|
||||
let step_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
|
||||
let chunk_seq: std::sync::Arc<std::sync::atomic::AtomicU64> = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1));
|
||||
let chunk_seq: std::sync::Arc<std::sync::atomic::AtomicU64> =
|
||||
std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1));
|
||||
|
||||
// Helper: recover from poison instead of panicking.
|
||||
fn lock_or_recover<T>(mutex: &std::sync::Mutex<T>) -> std::sync::MutexGuard<'_, T> {
|
||||
mutex.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
let on_step = {
|
||||
let room_manager = room_manager.clone();
|
||||
let queue = queue.clone();
|
||||
let cancel = cancel.clone();
|
||||
let streaming_msg_id = streaming_msg_id;
|
||||
let room_id = room_id_inner;
|
||||
let step_count = step_count.clone();
|
||||
@ -73,6 +107,8 @@ pub async fn process_message_ai_react_streaming(
|
||||
let last_action_name = last_action_name.clone();
|
||||
move |step: ReactStep| {
|
||||
let room_manager = room_manager.clone();
|
||||
let queue = queue.clone();
|
||||
let cancel = cancel.clone();
|
||||
let (chunk_type, content) = match &step {
|
||||
ReactStep::Thought { step: _, thought } => {
|
||||
("thinking".to_string(), thought.clone())
|
||||
@ -100,8 +136,6 @@ pub async fn process_message_ai_react_streaming(
|
||||
step_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
// Record ordered step for storage — merge consecutive same-type chunks
|
||||
// to ensure strict think→answer→think→answer alternation.
|
||||
{
|
||||
let mut s = lock_or_recover(&steps);
|
||||
if let Some(last) = s.last_mut() {
|
||||
@ -122,43 +156,47 @@ pub async fn process_message_ai_react_streaming(
|
||||
let done = false;
|
||||
let ai_name = ai_display_name_for_step.clone();
|
||||
let current_seq = chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
tokio::spawn(async move {
|
||||
let event = RoomMessageStreamChunkEvent {
|
||||
message_id: streaming_msg_id,
|
||||
room_id,
|
||||
seq: current_seq,
|
||||
content: content.clone(),
|
||||
done,
|
||||
error: None,
|
||||
display_name: Some((*ai_name).clone()),
|
||||
chunk_type: Some(chunk_type),
|
||||
};
|
||||
let event = RoomMessageStreamChunkEvent {
|
||||
message_id: streaming_msg_id,
|
||||
room_id,
|
||||
seq: current_seq,
|
||||
content: content.clone(),
|
||||
done,
|
||||
error: None,
|
||||
display_name: Some((*ai_name).clone()),
|
||||
chunk_type: Some(chunk_type),
|
||||
};
|
||||
|
||||
async move {
|
||||
if cancel.load(std::sync::atomic::Ordering::Acquire) {
|
||||
return;
|
||||
}
|
||||
queue.publish_stream_chunk(&event).await;
|
||||
room_manager.broadcast_stream_chunk(event).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let result = chat_service.process_react(&request, on_step).await;
|
||||
|
||||
// Broadcast final done=true event to close the streaming channel on frontend.
|
||||
let final_stream_content = lock_or_recover(&answer_buffer).clone();
|
||||
room_manager
|
||||
.broadcast_stream_chunk(RoomMessageStreamChunkEvent {
|
||||
message_id: streaming_msg_id,
|
||||
room_id: room_id_inner,
|
||||
seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
||||
content: final_stream_content.clone(),
|
||||
done: true,
|
||||
error: None,
|
||||
display_name: Some(ai_display_name.clone()),
|
||||
chunk_type: Some("answer".to_string()),
|
||||
})
|
||||
.await;
|
||||
let final_event = RoomMessageStreamChunkEvent {
|
||||
message_id: streaming_msg_id,
|
||||
room_id: room_id_inner,
|
||||
seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
||||
content: final_stream_content.clone(),
|
||||
done: true,
|
||||
error: None,
|
||||
display_name: Some(ai_display_name.clone()),
|
||||
chunk_type: Some("answer".to_string()),
|
||||
};
|
||||
queue.publish_stream_chunk(&final_event).await;
|
||||
room_manager.broadcast_stream_chunk(final_event).await;
|
||||
|
||||
let (final_content, _input_tokens, _output_tokens, err_msg, _should_log) = match result {
|
||||
Ok((content, input, output)) => (content, input, output, None, false),
|
||||
Err(e) => {
|
||||
let msg = format!("[Agent error: {}]", e);
|
||||
let msg = "[AI 处理发生错误,请稍后再试]".to_string();
|
||||
tracing::error!(error = %e, "ReAct streaming failed");
|
||||
(String::new(), 0, 0, Some(msg), true)
|
||||
}
|
||||
@ -183,12 +221,7 @@ pub async fn process_message_ai_react_streaming(
|
||||
String::from("[No output from reasoning agent]")
|
||||
};
|
||||
let content_to_persist = if let Some(msg) = &err_msg {
|
||||
format!(
|
||||
"{}\n[Error during reasoning: {}]",
|
||||
content_to_persist.trim_end(),
|
||||
msg.trim_start_matches("[Agent error: ")
|
||||
.trim_end_matches("]")
|
||||
)
|
||||
format!("{}\n[Error during reasoning: {}]", content_to_persist.trim_end(), msg)
|
||||
} else {
|
||||
content_to_persist
|
||||
};
|
||||
@ -198,7 +231,6 @@ pub async fn process_message_ai_react_streaming(
|
||||
return;
|
||||
}
|
||||
|
||||
// Serialize ordered steps as JSON for ordered replay.
|
||||
let thinking_content_serialized = {
|
||||
let steps = lock_or_recover(&steps);
|
||||
if steps.is_empty() {
|
||||
@ -250,7 +282,6 @@ pub async fn process_message_ai_react_streaming(
|
||||
tracing::warn!(error = %e, "Failed to update room_ai call stats");
|
||||
}
|
||||
|
||||
// Billing handled internally by chat_service.process_react via record_ai_session
|
||||
let msg_event = queue::RoomMessageEvent {
|
||||
id: streaming_msg_id,
|
||||
room_id: room_id_inner,
|
||||
@ -284,6 +315,19 @@ pub async fn process_message_ai_react_streaming(
|
||||
.await;
|
||||
}
|
||||
|
||||
let _ = typing_cancel_tx.send(());
|
||||
typing_renew_handle.abort();
|
||||
let typing_stop = queue::TypingEvent {
|
||||
room_id: room_id_inner,
|
||||
user_id: ai_typing_id,
|
||||
username: ai_display_name.clone(),
|
||||
avatar_url: None,
|
||||
action: "stop".to_string(),
|
||||
sender_type: Some("ai".to_string()),
|
||||
};
|
||||
room_manager.broadcast_typing(room_id_inner, typing_stop).await;
|
||||
|
||||
room_manager.unregister_stream_cancel(room_id_inner).await;
|
||||
room_manager.close_stream_channel(streaming_msg_id).await;
|
||||
});
|
||||
}
|
||||
|
||||
292
libs/room/src/service/ai_service.rs
Normal file
292
libs/room/src/service/ai_service.rs
Normal file
@ -0,0 +1,292 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use db::cache::AppCache;
|
||||
use db::database::AppDatabase;
|
||||
use models::rooms::room_ai;
|
||||
use queue::MessageProducer;
|
||||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
||||
|
||||
use crate::connection::RoomConnectionManager;
|
||||
use crate::error::RoomError;
|
||||
use crate::service::ai_common::create_and_publish_ai_message;
|
||||
use crate::service::ai_mode_dispatch;
|
||||
use crate::service::ai_nonstreaming;
|
||||
use crate::service::ai_react_nonstreaming;
|
||||
use crate::service::ai_react_streaming;
|
||||
use crate::service::ai_streaming;
|
||||
use crate::service::history;
|
||||
use crate::service::patterns::{mention_bracket_re, mention_tag_re};
|
||||
use agent::chat::{AiRequest, ChatService};
|
||||
|
||||
/// Service responsible for AI message generation orchestration.
|
||||
/// Decides which execution path to use (streaming/nonstreaming, ReAct/chat)
|
||||
/// and dispatches accordingly.
|
||||
#[derive(Clone)]
|
||||
pub struct RoomAiService {
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
config: config::AppConfig,
|
||||
queue: MessageProducer,
|
||||
room_manager: Arc<RoomConnectionManager>,
|
||||
chat_service: Option<Arc<ChatService>>,
|
||||
}
|
||||
|
||||
impl RoomAiService {
|
||||
pub fn new(
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
config: config::AppConfig,
|
||||
queue: MessageProducer,
|
||||
room_manager: Arc<RoomConnectionManager>,
|
||||
chat_service: Option<Arc<ChatService>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
db,
|
||||
cache,
|
||||
config,
|
||||
queue,
|
||||
room_manager,
|
||||
chat_service,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the message content mentions an AI model configured in this room.
|
||||
pub async fn should_respond(&self, room_id: Uuid, content: &str) -> Result<bool, RoomError> {
|
||||
let ai_configs = history::get_room_ai_configs(&self.db, room_id).await?;
|
||||
if ai_configs.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let model_ids: std::collections::HashSet<String> = ai_configs
|
||||
.iter()
|
||||
.map(|c| c.model.to_string())
|
||||
.collect();
|
||||
|
||||
for cap in mention_bracket_re().captures_iter(content) {
|
||||
if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) {
|
||||
if type_m.as_str() == "ai" && model_ids.contains(id_m.as_str().trim()) {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for cap in mention_tag_re().captures_iter(content) {
|
||||
if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) {
|
||||
if type_m.as_str() == "ai" && model_ids.contains(id_m.as_str().trim()) {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Extract the mentioned AI model ID from message content.
|
||||
fn extract_mentioned_model_id(content: &str) -> Option<Uuid> {
|
||||
for cap in mention_bracket_re().captures_iter(content) {
|
||||
if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) {
|
||||
if type_m.as_str() == "ai" {
|
||||
if let Ok(uuid) = Uuid::parse_str(id_m.as_str().trim()) {
|
||||
return Some(uuid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for cap in mention_tag_re().captures_iter(content) {
|
||||
if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) {
|
||||
if type_m.as_str() == "ai" {
|
||||
if let Ok(uuid) = Uuid::parse_str(id_m.as_str().trim()) {
|
||||
return Some(uuid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Main entry point: process an AI message request.
|
||||
/// Handles model lookup, history building, locking, and dispatch.
|
||||
pub async fn process(
|
||||
&self,
|
||||
room_id: Uuid,
|
||||
sender_id: Uuid,
|
||||
content: &str,
|
||||
) -> Result<(), RoomError> {
|
||||
let chat_service = match &self.chat_service {
|
||||
Some(cs) => cs.clone(),
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let model_id = match Self::extract_mentioned_model_id(content) {
|
||||
Some(id) => id,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let ai_config = match room_ai::Entity::find()
|
||||
.filter(room_ai::Column::Room.eq(room_id))
|
||||
.filter(room_ai::Column::Model.eq(model_id))
|
||||
.one(&self.db)
|
||||
.await?
|
||||
{
|
||||
Some(c) => c,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
// Idempotency check: skip if this content already triggered AI within 60s
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = DefaultHasher::new();
|
||||
content.hash(&mut hasher);
|
||||
let idemp_key = format!("ai:idempot:{}:{}", room_id, hasher.finish());
|
||||
{
|
||||
let mut conn = self.cache.conn().await.map_err(|e| {
|
||||
RoomError::Internal(format!("cache conn: {}", e))
|
||||
})?;
|
||||
let exists = redis::cmd("SET")
|
||||
.arg(&idemp_key)
|
||||
.arg("1")
|
||||
.arg("NX")
|
||||
.arg("EX")
|
||||
.arg(60u64)
|
||||
.query_async::<Option<String>>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| RoomError::Internal(format!("idemp SET: {}", e)))?
|
||||
.is_some();
|
||||
if !exists {
|
||||
tracing::debug!(room_id = %room_id, "AI idempotency hit, skipping");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let lock_guard =
|
||||
match crate::room_ai_queue::acquire_room_ai_lock(&self.cache, room_id).await? {
|
||||
Some(g) => g,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let room = crate::service::find_room_or_404(&self.db, room_id).await?;
|
||||
|
||||
let project = models::projects::project::Entity::find_by_id(room.project)
|
||||
.one(&self.db)
|
||||
.await?
|
||||
.ok_or_else(|| RoomError::NotFound("Project not found".to_string()))?;
|
||||
|
||||
let model = models::agents::model::Entity::find_by_id(model_id)
|
||||
.one(&self.db)
|
||||
.await?
|
||||
.ok_or_else(|| RoomError::NotFound("AI model not found".to_string()))?;
|
||||
|
||||
let sender = models::users::User::find_by_id(sender_id)
|
||||
.one(&self.db)
|
||||
.await?
|
||||
.ok_or_else(|| RoomError::NotFound("Sender not found".to_string()))?;
|
||||
|
||||
let history_messages = history::get_room_history(&self.db, room_id, 50).await?;
|
||||
|
||||
let user_ids: Vec<Uuid> = history_messages
|
||||
.iter()
|
||||
.filter_map(|m| m.sender_id)
|
||||
.chain(std::iter::once(sender_id))
|
||||
.collect();
|
||||
|
||||
let user_names = history::get_user_names(&self.db, &user_ids).await;
|
||||
|
||||
let mentions =
|
||||
history::extract_mention_context(&self.db, room.project, content).await;
|
||||
|
||||
let request = AiRequest {
|
||||
db: self.db.clone(),
|
||||
cache: self.cache.clone(),
|
||||
config: self.config.clone(),
|
||||
model,
|
||||
project: project.clone(),
|
||||
sender,
|
||||
room: room.clone(),
|
||||
input: content.to_string(),
|
||||
mention: mentions,
|
||||
history: history_messages,
|
||||
user_names,
|
||||
temperature: ai_config.temperature.unwrap_or(0.7),
|
||||
max_tokens: ai_config.max_tokens.unwrap_or(4096) as i32,
|
||||
top_p: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
think: ai_config.think,
|
||||
tools: Some(chat_service.tools()),
|
||||
max_tool_depth: 1000,
|
||||
};
|
||||
|
||||
let use_streaming = ai_config.stream;
|
||||
|
||||
match ai_config.agent_type.as_deref() {
|
||||
Some("cot") => {
|
||||
if use_streaming {
|
||||
ai_mode_dispatch::dispatch_cot(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
} else {
|
||||
if let Ok((content, _in, _out)) = chat_service.process_cot(&request, |_step| async {}).await {
|
||||
let _ = create_and_publish_ai_message(
|
||||
&self.db, &self.cache, &self.queue, &self.room_manager,
|
||||
room_id, room.project, uuid::Uuid::new_v4(), content, model_id,
|
||||
Some(request.model.name.clone()),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some("rewoo") => {
|
||||
if use_streaming {
|
||||
ai_mode_dispatch::dispatch_rewoo(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
Some("reflexion") => {
|
||||
if use_streaming {
|
||||
ai_mode_dispatch::dispatch_reflexion(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
Some("react") | _ => {
|
||||
if ai_config.agent_type.as_deref() == Some("react") {
|
||||
if use_streaming {
|
||||
ai_react_streaming::process_message_ai_react_streaming(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
} else {
|
||||
ai_react_nonstreaming::process_message_ai_react_nonstreaming(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
}
|
||||
} else if use_streaming {
|
||||
ai_streaming::process_message_ai_streaming(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
} else {
|
||||
ai_nonstreaming::process_message_ai_nonstreaming(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -60,6 +60,7 @@ pub async fn process_message_ai_streaming(
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _lock_guard = lock_guard;
|
||||
let cancel = room_manager.register_stream_cancel(room_id_inner).await;
|
||||
let ai_typing_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001")
|
||||
.expect("constant UUID should always parse");
|
||||
let ai_display_name_for_chunk = ai_display_name.clone();
|
||||
@ -67,15 +68,22 @@ pub async fn process_message_ai_streaming(
|
||||
|
||||
let chunk_count = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0));
|
||||
let room_manager_cb = room_manager.clone();
|
||||
let queue_for_chunk = queue.clone();
|
||||
|
||||
let on_chunk = move |chunk: agent::chat::AiStreamChunk| {
|
||||
Box::pin({
|
||||
let room_manager = room_manager_cb.clone();
|
||||
let queue = queue_for_chunk.clone();
|
||||
let streaming_msg_id = streaming_msg_id;
|
||||
let room_id = room_id_inner;
|
||||
let chunk_count = chunk_count.clone();
|
||||
let ai_display_name_for_chunk = ai_display_name_for_chunk.clone();
|
||||
let cancel = cancel.clone();
|
||||
async move {
|
||||
if cancel.load(std::sync::atomic::Ordering::Acquire) {
|
||||
// Stream was cancelled — drop this chunk
|
||||
return;
|
||||
}
|
||||
let chunk_type_str = match chunk.chunk_type {
|
||||
agent::chat::AiChunkType::Thinking => "thinking",
|
||||
agent::chat::AiChunkType::Answer => "answer",
|
||||
@ -93,6 +101,7 @@ pub async fn process_message_ai_streaming(
|
||||
display_name: Some(ai_display_name_for_chunk),
|
||||
chunk_type: Some(chunk_type_str.to_string()),
|
||||
};
|
||||
queue.publish_stream_chunk(&event).await;
|
||||
room_manager.broadcast_stream_chunk(event).await;
|
||||
}
|
||||
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
|
||||
@ -257,14 +266,16 @@ pub async fn process_message_ai_streaming(
|
||||
seq: 0,
|
||||
content: String::new(),
|
||||
done: true,
|
||||
error: Some(e.to_string()),
|
||||
error: Some("AI 处理发生错误,请稍后再试".to_string()),
|
||||
display_name: Some(ai_display_name.clone()),
|
||||
chunk_type: None,
|
||||
};
|
||||
queue.publish_stream_chunk(&event).await;
|
||||
room_manager.broadcast_stream_chunk(event).await;
|
||||
}
|
||||
}
|
||||
|
||||
room_manager.unregister_stream_cancel(room_id_inner).await;
|
||||
room_manager.close_stream_channel(streaming_msg_id).await;
|
||||
});
|
||||
}
|
||||
|
||||
@ -1,18 +1,23 @@
|
||||
mod access;
|
||||
mod ai_common;
|
||||
mod ai_mode_dispatch;
|
||||
mod ai_mode_streaming;
|
||||
mod ai_nonstreaming;
|
||||
mod ai_react_nonstreaming;
|
||||
mod ai_react_streaming;
|
||||
mod ai_service;
|
||||
mod ai_streaming;
|
||||
mod history;
|
||||
mod mentions;
|
||||
mod notifications;
|
||||
mod patterns;
|
||||
pub use patterns::{mention_bracket_re, mention_tag_re, user_mention_re};
|
||||
mod sequence;
|
||||
mod workers;
|
||||
|
||||
pub use access::{check_room_access, check_project_member, require_room_member, find_room_or_404};
|
||||
pub use ai_common::create_and_publish_ai_message;
|
||||
pub use ai_service::RoomAiService;
|
||||
pub use ai_nonstreaming::process_message_ai_nonstreaming;
|
||||
pub use ai_react_nonstreaming::process_message_ai_react_nonstreaming;
|
||||
pub use ai_react_streaming::process_message_ai_react_streaming;
|
||||
@ -41,8 +46,6 @@ use agent::embed::EmbedService;
|
||||
use agent::TaskService;
|
||||
use models::agent_task::AgentType;
|
||||
|
||||
use crate::service::patterns::{mention_bracket_re, mention_tag_re};
|
||||
|
||||
const DEFAULT_MAX_CONCURRENT_WORKERS: usize = 1024;
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -57,6 +60,7 @@ pub struct RoomService {
|
||||
pub task_service: Option<Arc<TaskService>>,
|
||||
pub embed_service: Option<Arc<EmbedService>>,
|
||||
pub push_fn: Option<workers::PushNotificationFn>,
|
||||
pub ai_service: RoomAiService,
|
||||
worker_semaphore: Arc<tokio::sync::Semaphore>,
|
||||
dedup_cache: DedupCache,
|
||||
}
|
||||
@ -77,6 +81,14 @@ impl RoomService {
|
||||
) -> Self {
|
||||
let dedup_cache: DedupCache =
|
||||
Arc::new(dashmap::DashMap::with_capacity_and_hasher(10000, Default::default()));
|
||||
let ai_service = RoomAiService::new(
|
||||
db.clone(),
|
||||
cache.clone(),
|
||||
config.clone(),
|
||||
queue.clone(),
|
||||
room_manager.clone(),
|
||||
chat_service.clone(),
|
||||
);
|
||||
Self {
|
||||
db,
|
||||
cache,
|
||||
@ -87,6 +99,7 @@ impl RoomService {
|
||||
chat_service,
|
||||
task_service,
|
||||
embed_service,
|
||||
ai_service,
|
||||
worker_semaphore: Arc::new(tokio::sync::Semaphore::new(
|
||||
max_concurrent_workers.unwrap_or(DEFAULT_MAX_CONCURRENT_WORKERS),
|
||||
)),
|
||||
@ -258,34 +271,7 @@ impl RoomService {
|
||||
}
|
||||
|
||||
pub async fn should_ai_respond(&self, room_id: Uuid, content: &str) -> Result<bool, RoomError> {
|
||||
let ai_configs = history::get_room_ai_configs(&self.db, room_id).await?;
|
||||
if ai_configs.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// Collect all model IDs in this room
|
||||
let model_ids: std::collections::HashSet<String> = ai_configs
|
||||
.iter()
|
||||
.map(|c| c.model.to_string())
|
||||
.collect();
|
||||
|
||||
for cap in mention_bracket_re().captures_iter(content) {
|
||||
if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) {
|
||||
if type_m.as_str() == "ai" && model_ids.contains(id_m.as_str().trim()) {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for cap in mention_tag_re().captures_iter(content) {
|
||||
if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) {
|
||||
if type_m.as_str() == "ai" && model_ids.contains(id_m.as_str().trim()) {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
self.ai_service.should_respond(room_id, content).await
|
||||
}
|
||||
|
||||
pub async fn get_room_ai_config(
|
||||
@ -421,66 +407,72 @@ impl RoomService {
|
||||
};
|
||||
|
||||
let use_streaming = ai_config.stream;
|
||||
let is_react = ai_config.agent_type.as_deref() == Some("react");
|
||||
|
||||
if is_react {
|
||||
if use_streaming {
|
||||
ai_react_streaming::process_message_ai_react_streaming(
|
||||
chat_service.clone(),
|
||||
request,
|
||||
room_id,
|
||||
room.project,
|
||||
model_id,
|
||||
lock_guard,
|
||||
self.db.clone(),
|
||||
self.cache.clone(),
|
||||
self.queue.clone(),
|
||||
self.room_manager.clone(),
|
||||
)
|
||||
.await;
|
||||
} else {
|
||||
ai_react_nonstreaming::process_message_ai_react_nonstreaming(
|
||||
chat_service.clone(),
|
||||
request,
|
||||
room_id,
|
||||
room.project,
|
||||
model_id,
|
||||
lock_guard,
|
||||
self.db.clone(),
|
||||
self.cache.clone(),
|
||||
self.queue.clone(),
|
||||
self.room_manager.clone(),
|
||||
)
|
||||
.await;
|
||||
match ai_config.agent_type.as_deref() {
|
||||
Some("cot") => {
|
||||
if use_streaming {
|
||||
ai_mode_dispatch::dispatch_cot(
|
||||
chat_service.clone(), request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
} else {
|
||||
if let Ok((content, _in, _out)) = chat_service.process_cot(&request, |_step| async {}).await {
|
||||
let _ = create_and_publish_ai_message(
|
||||
&self.db, &self.cache, &self.queue, &self.room_manager,
|
||||
room_id, room.project, uuid::Uuid::new_v4(), content, model_id,
|
||||
Some(request.model.name.clone()),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some("rewoo") => {
|
||||
if use_streaming {
|
||||
ai_mode_dispatch::dispatch_rewoo(
|
||||
chat_service.clone(), request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
Some("reflexion") => {
|
||||
if use_streaming {
|
||||
ai_mode_dispatch::dispatch_reflexion(
|
||||
chat_service.clone(), request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
Some("react") | _ => {
|
||||
if ai_config.agent_type.as_deref() == Some("react") {
|
||||
if use_streaming {
|
||||
ai_react_streaming::process_message_ai_react_streaming(
|
||||
chat_service.clone(), request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
} else {
|
||||
ai_react_nonstreaming::process_message_ai_react_nonstreaming(
|
||||
chat_service.clone(), request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
}
|
||||
} else if use_streaming {
|
||||
ai_streaming::process_message_ai_streaming(
|
||||
chat_service.clone(), request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
} else {
|
||||
ai_nonstreaming::process_message_ai_nonstreaming(
|
||||
chat_service.clone(), request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
} else if use_streaming {
|
||||
ai_streaming::process_message_ai_streaming(
|
||||
chat_service.clone(),
|
||||
request,
|
||||
room_id,
|
||||
room.project,
|
||||
model_id,
|
||||
lock_guard,
|
||||
self.db.clone(),
|
||||
self.cache.clone(),
|
||||
self.queue.clone(),
|
||||
self.room_manager.clone(),
|
||||
)
|
||||
.await;
|
||||
} else {
|
||||
ai_nonstreaming::process_message_ai_nonstreaming(
|
||||
chat_service.clone(),
|
||||
request,
|
||||
room_id,
|
||||
room.project,
|
||||
model_id,
|
||||
lock_guard,
|
||||
self.db.clone(),
|
||||
self.cache.clone(),
|
||||
self.queue.clone(),
|
||||
self.room_manager.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@ -6,6 +6,21 @@ use uuid::Uuid;
|
||||
|
||||
use crate::error::RoomError;
|
||||
|
||||
/// Redis Lua script that atomically INCRs the sequence number and
|
||||
/// reconciles with the database max seq every 1000 increments.
|
||||
/// Returns the final assigned seq (guaranteed > any existing message seq).
|
||||
const ATOMIC_INCR_SCRIPT: &str = r#"
|
||||
local seq = redis.call('INCR', KEYS[1])
|
||||
if seq % 1000 == 0 then
|
||||
local db_seq = tonumber(ARGV[1]) or 0
|
||||
if db_seq >= seq then
|
||||
redis.call('SET', KEYS[1], db_seq + 1)
|
||||
return db_seq + 1
|
||||
end
|
||||
end
|
||||
return seq
|
||||
"#;
|
||||
|
||||
pub async fn next_room_message_seq_internal(
|
||||
room_id: Uuid,
|
||||
db: &AppDatabase,
|
||||
@ -16,34 +31,24 @@ pub async fn next_room_message_seq_internal(
|
||||
RoomError::Internal(format!("failed to get redis connection for seq: {}", e))
|
||||
})?;
|
||||
|
||||
let seq: i64 = redis::cmd("INCR")
|
||||
.arg(&seq_key)
|
||||
.query_async(&mut conn)
|
||||
let db_seq: i64 = RoomMessage::find()
|
||||
.filter(RmCol::Room.eq(room_id))
|
||||
.select_only()
|
||||
.column_as(RmCol::Seq.max(), "max_seq")
|
||||
.into_tuple::<Option<Option<i64>>>()
|
||||
.one(db)
|
||||
.await?
|
||||
.flatten()
|
||||
.flatten()
|
||||
.unwrap_or(0);
|
||||
|
||||
let script = redis::Script::new(ATOMIC_INCR_SCRIPT);
|
||||
let seq: i64 = script
|
||||
.key(&seq_key)
|
||||
.arg(db_seq)
|
||||
.invoke_async(&mut conn)
|
||||
.await
|
||||
.map_err(|e| RoomError::Internal(format!("seq INCR: {}", e)))?;
|
||||
|
||||
// DB reconciliation: only check every 1000 messages
|
||||
if seq % 1000 == 0 {
|
||||
let db_seq: Option<Option<Option<i64>>> = RoomMessage::find()
|
||||
.filter(RmCol::Room.eq(room_id))
|
||||
.select_only()
|
||||
.column_as(RmCol::Seq.max(), "max_seq")
|
||||
.into_tuple::<Option<Option<i64>>>()
|
||||
.one(db)
|
||||
.await?
|
||||
.map(|r| r);
|
||||
let db_seq = db_seq.flatten().flatten().unwrap_or(0);
|
||||
|
||||
if db_seq >= seq {
|
||||
let _: String = redis::cmd("SET")
|
||||
.arg(&seq_key)
|
||||
.arg(db_seq + 1)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.map_err(|e| RoomError::Internal(format!("seq SET: {}", e)))?;
|
||||
return Ok(db_seq + 1);
|
||||
}
|
||||
}
|
||||
.map_err(|e| RoomError::Internal(format!("seq atomic INCR: {}", e)))?;
|
||||
|
||||
Ok(seq)
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@ use db::cache::AppCache;
|
||||
use db::database::AppDatabase;
|
||||
use models::rooms::room;
|
||||
use queue::{AgentTaskEvent, MessageProducer};
|
||||
use sea_orm::{EntityTrait, QuerySelect};
|
||||
use sea_orm::EntityTrait;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::connection::{
|
||||
@ -28,11 +28,10 @@ pub async fn start_workers(
|
||||
mut shutdown_rx: tokio::sync::broadcast::Receiver<()>,
|
||||
embed_service: Option<Arc<agent::embed::EmbedService>>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Load rooms with a reasonable cap to prevent resource exhaustion on large instances.
|
||||
// Rooms beyond this limit will be activated on-demand when first accessed.
|
||||
const MAX_INITIAL_ROOMS: u64 = 1000;
|
||||
// Load all rooms. For large deployments with thousands of rooms,
|
||||
// consider implementing distributed worker sharding (consistent hashing)
|
||||
// to avoid all rooms being handled by a single instance.
|
||||
let rooms: Vec<room::Model> = room::Entity::find()
|
||||
.limit(MAX_INITIAL_ROOMS)
|
||||
.all(&db)
|
||||
.await?;
|
||||
let room_ids: Vec<uuid::Uuid> = rooms.iter().map(|r| r.id).collect();
|
||||
@ -62,6 +61,7 @@ pub async fn start_workers(
|
||||
extract_get_redis(queue.clone());
|
||||
|
||||
let worker_room_ids = room_ids.clone();
|
||||
let stream_chunk_room_ids = room_ids.clone();
|
||||
let worker_shutdown = shutdown_rx.resubscribe();
|
||||
let worker_handle = tokio::spawn({
|
||||
let get_redis = get_redis.clone();
|
||||
@ -92,6 +92,25 @@ pub async fn start_workers(
|
||||
})
|
||||
.collect();
|
||||
|
||||
let stream_chunk_handles: Vec<_> = stream_chunk_room_ids
|
||||
.into_iter()
|
||||
.map(|room_id| {
|
||||
let manager = manager.clone();
|
||||
let redis_url = redis_url_clone.clone();
|
||||
let shutdown_rx = shutdown_rx.resubscribe();
|
||||
tokio::spawn(async move {
|
||||
crate::connection::subscribe_room_stream_chunk_events(
|
||||
redis_url,
|
||||
manager,
|
||||
room_id,
|
||||
shutdown_rx,
|
||||
)
|
||||
.await;
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
handles.extend(stream_chunk_handles);
|
||||
|
||||
let project_handles: Vec<_> = project_ids
|
||||
.into_iter()
|
||||
.map(|project_id| {
|
||||
@ -289,15 +308,17 @@ pub fn spawn_room_workers(
|
||||
Default::default(),
|
||||
),
|
||||
),
|
||||
embed_service,
|
||||
embed_service.clone(),
|
||||
);
|
||||
let get_redis: Arc<dyn Fn() -> queue::worker::RedisFuture + Send + Sync> =
|
||||
extract_get_redis(queue.clone());
|
||||
let manager1 = room_manager.clone();
|
||||
let manager2 = room_manager.clone();
|
||||
let manager3 = room_manager.clone();
|
||||
let manager4 = room_manager.clone();
|
||||
let redis_url_clone = redis_url.clone();
|
||||
let redis_url3 = redis_url.clone();
|
||||
let redis_url4 = redis_url.clone();
|
||||
let semaphore = worker_semaphore.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
@ -350,4 +371,15 @@ pub fn spawn_room_workers(
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
let shutdown_rx = manager4.register_room(room_id).await;
|
||||
crate::connection::subscribe_room_stream_chunk_events(
|
||||
redis_url4,
|
||||
manager4,
|
||||
room_id,
|
||||
shutdown_rx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
@ -156,6 +156,10 @@ pub struct RoomResponse {
|
||||
pub last_msg_at: DateTime<Utc>,
|
||||
#[serde(default)]
|
||||
pub unread_count: i64,
|
||||
/// Monotonically increasing version for conflict detection.
|
||||
/// Incremented on every room mutation (rename, move, delete).
|
||||
#[serde(default)]
|
||||
pub version: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, utoipa::ToSchema)]
|
||||
@ -236,8 +240,14 @@ pub struct RoomMessageResponse {
|
||||
pub content: String,
|
||||
pub content_type: String,
|
||||
/// Accumulated AI reasoning/thinking text.
|
||||
/// When `thinking_is_chunked` is true, this is a JSON string with
|
||||
/// `{"__chunks__": [{"type":"thinking|answer|tool_call","content":"..."},...]}`.
|
||||
/// When false, this is plain text.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thinking_content: Option<String>,
|
||||
/// Indicates `thinking_content` contains JSON chunks (true) or plain text (false).
|
||||
#[serde(skip_serializing_if = "std::ops::Not::not")]
|
||||
pub thinking_is_chunked: bool,
|
||||
pub edited_at: Option<DateTime<Utc>>,
|
||||
pub send_at: DateTime<Utc>,
|
||||
pub revoked: Option<DateTime<Utc>>,
|
||||
@ -256,6 +266,13 @@ pub struct RoomMessageSearchResult {
|
||||
pub message: RoomMessageResponse,
|
||||
}
|
||||
|
||||
impl RoomMessageResponse {
|
||||
/// Detect if `thinking_content` stores JSON chunks (vs plain text).
|
||||
pub fn detect_chunked(thinking: &Option<String>) -> bool {
|
||||
thinking.as_ref().is_some_and(|s| s.contains("\"__chunks__\""))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
|
||||
pub struct RoomMessageListResponse {
|
||||
pub messages: Vec<RoomMessageResponse>,
|
||||
|
||||
@ -2,9 +2,14 @@
|
||||
|
||||
use crate::AppService;
|
||||
use crate::error::AppError;
|
||||
use sea_orm::*;
|
||||
use uuid::Uuid;
|
||||
|
||||
impl AppService {
|
||||
/// Record AI usage against a project or workspace.
|
||||
///
|
||||
/// `model_id` is an `ai_model.id`. The active/default model version is resolved
|
||||
/// internally so callers do not need to distinguish ModelId from ModelVersionId.
|
||||
pub async fn record_ai_usage(
|
||||
&self,
|
||||
project_uid: Uuid,
|
||||
@ -13,11 +18,22 @@ impl AppService {
|
||||
output_tokens: i64,
|
||||
) -> Result<agent::billing::BillingRecord, AppError> {
|
||||
use agent::billing::BillingResult;
|
||||
use models::agents::model_version;
|
||||
|
||||
let version_id = model_version::Entity::find()
|
||||
.filter(model_version::Column::ModelId.eq(model_id))
|
||||
.filter(model_version::Column::Status.eq("active"))
|
||||
.order_by_desc(model_version::Column::IsDefault)
|
||||
.order_by_desc(model_version::Column::ReleaseDate)
|
||||
.one(&self.db)
|
||||
.await?
|
||||
.map(|v| v.id)
|
||||
.unwrap_or(model_id);
|
||||
|
||||
match agent::billing::record_ai_usage(
|
||||
&self.db,
|
||||
project_uid,
|
||||
model_id,
|
||||
version_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
)
|
||||
|
||||
@ -69,14 +69,15 @@ impl AppService {
|
||||
|
||||
let month_used = project_billing_history::Entity::find()
|
||||
.filter(project_billing_history::Column::Project.eq(project.id))
|
||||
.filter(project_billing_history::Column::Reason.eq("ai_usage_monthly"))
|
||||
.filter(project_billing_history::Column::Reason.like("ai_usage%"))
|
||||
.filter(project_billing_history::Column::CreatedAt.gte(month_start))
|
||||
.filter(project_billing_history::Column::CreatedAt.lt(next_month_start))
|
||||
.order_by_desc(project_billing_history::Column::CreatedAt)
|
||||
.one(&self.db)
|
||||
.all(&self.db)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|m| m.amount)
|
||||
.unwrap_or(Decimal::ZERO);
|
||||
.sum::<Decimal>();
|
||||
let month_used = -month_used;
|
||||
|
||||
Ok(ProjectBillingCurrentResponse {
|
||||
project_uid: project.id,
|
||||
@ -155,7 +156,7 @@ impl AppService {
|
||||
.filter(models::projects::project::Column::CreatedBy.eq(uid))
|
||||
.all(&self.db)
|
||||
.await?;
|
||||
if existing_projects.is_empty() {
|
||||
if existing_projects.len() <= 1 {
|
||||
Decimal::from_f64_retain(DEFAULT_PROJECT_MONTHLY_CREDIT).unwrap_or(Decimal::ZERO)
|
||||
} else {
|
||||
Decimal::ZERO
|
||||
|
||||
@ -249,7 +249,8 @@ impl AppService {
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|r| r.amount.to_f64().unwrap_or_default())
|
||||
.sum()
|
||||
.sum();
|
||||
-month_used
|
||||
}
|
||||
|
||||
/// Get email addresses for workspace owners and admins who have email notifications enabled.
|
||||
|
||||
@ -92,6 +92,7 @@ impl AppService {
|
||||
.into_iter()
|
||||
.map(|m| m.amount.to_f64().unwrap_or_default())
|
||||
.sum::<f64>();
|
||||
let month_used = -month_used;
|
||||
|
||||
Ok(WorkspaceBillingCurrentResponse {
|
||||
workspace_id: ws.id,
|
||||
@ -188,25 +189,22 @@ impl AppService {
|
||||
|
||||
let billing = self.ensure_workspace_billing(ws.id, Some(user_uid)).await?;
|
||||
let now_utc = Utc::now();
|
||||
let new_balance =
|
||||
Decimal::from_f64_retain(billing.balance.to_f64().unwrap_or_default() + params.amount)
|
||||
.unwrap_or(Decimal::ZERO);
|
||||
let amount_dec =
|
||||
Decimal::from_f64_retain(params.amount).unwrap_or(Decimal::ZERO);
|
||||
let new_balance = billing.balance + amount_dec;
|
||||
let currency = billing.currency.clone();
|
||||
|
||||
let _ = workspace_billing::ActiveModel {
|
||||
workspace_id: Unchanged(ws.id),
|
||||
balance: Set(new_balance),
|
||||
updated_at: Set(now_utc),
|
||||
..Default::default()
|
||||
}
|
||||
.update(&self.db)
|
||||
.await;
|
||||
let mut updated: workspace_billing::ActiveModel = billing.into();
|
||||
updated.balance = Set(new_balance);
|
||||
updated.updated_at = Set(now_utc);
|
||||
updated.update(&self.db).await?;
|
||||
|
||||
let _ = workspace_billing_history::ActiveModel {
|
||||
uid: Set(Uuid::now_v7()),
|
||||
workspace_id: Set(ws.id),
|
||||
user_id: Set(Some(user_uid)),
|
||||
amount: Set(Decimal::from_f64_retain(params.amount).unwrap_or(Decimal::ZERO)),
|
||||
currency: Set(billing.currency.clone()),
|
||||
currency: Set(currency),
|
||||
reason: Set(params.reason.unwrap_or_else(|| "credit_added".to_string())),
|
||||
extra: Set(None),
|
||||
created_at: Set(now_utc),
|
||||
|
||||
@ -97,7 +97,7 @@ impl AppService {
|
||||
.filter(workspace_membership::Column::Status.eq("active"))
|
||||
.all(&self.db)
|
||||
.await?;
|
||||
let initial_balance = if existing_workspaces.len() <= 1 {
|
||||
let initial_balance = if existing_workspaces.is_empty() {
|
||||
Decimal::from_f64_retain(30.0).unwrap_or(Decimal::ZERO)
|
||||
} else {
|
||||
Decimal::ZERO
|
||||
|
||||
@ -464,7 +464,10 @@ export const RoomSettingsPanel = memo(function RoomSettingsPanel({
|
||||
style={{ background: 'var(--room-bg)', borderColor: 'var(--room-border)', color: 'var(--room-text)' }}
|
||||
>
|
||||
<SelectValue>
|
||||
{agentType === 'react' ? 'ReAct (multi-step reasoning)' : 'Chat (simple)'}
|
||||
{agentType === 'react' ? 'ReAct' :
|
||||
agentType === 'cot' ? 'CoT' :
|
||||
agentType === 'rewoo' ? 'ReWOO' :
|
||||
agentType === 'reflexion' ? 'Reflexion' : 'Chat'}
|
||||
</SelectValue>
|
||||
</SelectTrigger>
|
||||
<SelectContent style={{ background: 'var(--room-bg)', border: '1px solid var(--room-border)' }}>
|
||||
@ -476,6 +479,18 @@ export const RoomSettingsPanel = memo(function RoomSettingsPanel({
|
||||
<span className="font-medium">ReAct</span>
|
||||
<span className="text-xs ml-2" style={{ color: 'var(--room-text-muted)' }}>Multi-step + tools</span>
|
||||
</SelectItem>
|
||||
<SelectItem value="cot">
|
||||
<span className="font-medium">CoT</span>
|
||||
<span className="text-xs ml-2" style={{ color: 'var(--room-text-muted)' }}>Chain-of-Thought</span>
|
||||
</SelectItem>
|
||||
<SelectItem value="rewoo">
|
||||
<span className="font-medium">ReWOO</span>
|
||||
<span className="text-xs ml-2" style={{ color: 'var(--room-text-muted)' }}>Plan → Execute → Synthesize</span>
|
||||
</SelectItem>
|
||||
<SelectItem value="reflexion">
|
||||
<span className="font-medium">Reflexion</span>
|
||||
<span className="text-xs ml-2" style={{ color: 'var(--room-text-muted)' }}>Generate → Critique → Revise</span>
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
@ -704,12 +719,12 @@ export const RoomSettingsPanel = memo(function RoomSettingsPanel({
|
||||
think
|
||||
</span>
|
||||
)}
|
||||
{config.agent_type === 'react' && (
|
||||
{config.agent_type && ['react', 'cot', 'rewoo', 'reflexion'].includes(config.agent_type) && (
|
||||
<span
|
||||
className="rounded px-1 py-0.5 text-[10px] shrink-0"
|
||||
style={{ background: 'rgba(168,85,247,0.15)', color: '#c084fc' }}
|
||||
>
|
||||
react
|
||||
{config.agent_type}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@ -12,6 +12,9 @@ import { useCallback, useEffect, useState } from 'react';
|
||||
import {
|
||||
applyPaletteToDOM,
|
||||
clearCustomPalette,
|
||||
encodeThemeToken,
|
||||
decodeThemeToken,
|
||||
isThemeToken,
|
||||
loadActivePresetId,
|
||||
loadCustomPalette,
|
||||
resetDOMFromPalette,
|
||||
@ -24,7 +27,7 @@ import { Button } from '@/components/ui/button';
|
||||
import { Sheet, SheetContent, SheetHeader, SheetTitle } from '@/components/ui/sheet';
|
||||
import { cn } from '@/lib/utils';
|
||||
import { useTheme } from '@/contexts';
|
||||
import { Check, RotateCcw, Sliders } from 'lucide-react';
|
||||
import { Check, Copy, Download, RotateCcw, Sliders } from 'lucide-react';
|
||||
import {LanguageSwitcher} from '@/components/shared/LanguageSwitcher';
|
||||
|
||||
// ─── Token definitions ───────────────────────────────────────────────────────
|
||||
@ -246,6 +249,10 @@ export function ThemeSwitcher({ open, onOpenChange }: ThemeSwitcherProps) {
|
||||
// Working copy being edited
|
||||
const [draft, setDraft] = useState<PaletteEntry | null>(null);
|
||||
const [isDirty, setIsDirty] = useState(false);
|
||||
// Token import
|
||||
const [tokenInput, setTokenInput] = useState('');
|
||||
const [tokenError, setTokenError] = useState('');
|
||||
const [copied, setCopied] = useState(false);
|
||||
|
||||
// Reset when panel opens
|
||||
useEffect(() => {
|
||||
@ -297,6 +304,53 @@ export function ThemeSwitcher({ open, onOpenChange }: ThemeSwitcherProps) {
|
||||
applyPreset('default');
|
||||
}, [applyPreset]);
|
||||
|
||||
// Get the current theme's token for sharing
|
||||
const currentToken = (() => {
|
||||
if (activePresetId === 'custom' && customPalette) {
|
||||
return encodeThemeToken(customPalette);
|
||||
}
|
||||
const preset = THEME_PRESETS.find((p) => p.id === activePresetId);
|
||||
if (preset?.palette) {
|
||||
return encodeThemeToken(preset.palette);
|
||||
}
|
||||
return null;
|
||||
})();
|
||||
|
||||
const handleCopyToken = useCallback(() => {
|
||||
if (!currentToken) return;
|
||||
navigator.clipboard.writeText(currentToken).then(() => {
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000);
|
||||
});
|
||||
}, [currentToken]);
|
||||
|
||||
const handleImportToken = useCallback(() => {
|
||||
const trimmed = tokenInput.trim();
|
||||
if (!trimmed) {
|
||||
setTokenError('Please paste a theme token');
|
||||
return;
|
||||
}
|
||||
if (!isThemeToken(trimmed)) {
|
||||
setTokenError('Invalid token format');
|
||||
return;
|
||||
}
|
||||
const palette = decodeThemeToken(trimmed);
|
||||
if (!palette) {
|
||||
setTokenError('Could not decode token');
|
||||
return;
|
||||
}
|
||||
setTokenError('');
|
||||
setTokenInput('');
|
||||
// Apply as custom palette
|
||||
saveCustomPalette(palette);
|
||||
saveActivePresetId('custom');
|
||||
applyPaletteToDOM(palette);
|
||||
setActivePresetId('custom');
|
||||
setCustomPalette(palette);
|
||||
setDraft({ ...palette });
|
||||
setIsDirty(false);
|
||||
}, [tokenInput]);
|
||||
|
||||
return (
|
||||
<Sheet open={open} onOpenChange={onOpenChange}>
|
||||
<SheetContent className="flex flex-col overflow-y-auto w-[360px] sm:max-w-[360px]">
|
||||
@ -371,6 +425,63 @@ export function ThemeSwitcher({ open, onOpenChange }: ThemeSwitcherProps) {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* ── Share / Import Token ────────────────────────────────────────── */}
|
||||
<div className="border-t pt-5 space-y-4">
|
||||
<p className="text-xs font-semibold uppercase tracking-wide text-muted-foreground">
|
||||
Share & Import
|
||||
</p>
|
||||
|
||||
{/* Current token display + copy */}
|
||||
{currentToken && (
|
||||
<div className="space-y-1.5">
|
||||
<p className="text-[10px] text-muted-foreground">Current theme token</p>
|
||||
<div className="flex gap-1.5">
|
||||
<input
|
||||
type="text"
|
||||
readOnly
|
||||
value={currentToken}
|
||||
className="flex-1 min-w-0 rounded border bg-muted/50 px-2 py-1.5 text-[10px] font-mono text-muted-foreground truncate focus:outline-none"
|
||||
/>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
className="h-8 gap-1 text-xs shrink-0"
|
||||
onClick={handleCopyToken}
|
||||
>
|
||||
{copied ? <Check className="h-3 w-3" /> : <Copy className="h-3 w-3" />}
|
||||
{copied ? 'Copied' : 'Copy'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Import token */}
|
||||
<div className="space-y-1.5">
|
||||
<p className="text-[10px] text-muted-foreground">Import a theme token</p>
|
||||
<div className="flex gap-1.5">
|
||||
<input
|
||||
type="text"
|
||||
value={tokenInput}
|
||||
onChange={(e) => { setTokenInput(e.target.value); setTokenError(''); }}
|
||||
placeholder="T1:eyJi..."
|
||||
className="flex-1 min-w-0 rounded border bg-background px-2 py-1.5 text-xs font-mono focus:outline-none focus:ring-1 focus:ring-ring"
|
||||
spellCheck={false}
|
||||
/>
|
||||
<Button
|
||||
size="sm"
|
||||
className="h-8 gap-1 text-xs shrink-0"
|
||||
onClick={handleImportToken}
|
||||
>
|
||||
<Download className="h-3 w-3" />
|
||||
Apply
|
||||
</Button>
|
||||
</div>
|
||||
{tokenError && (
|
||||
<p className="text-[10px] text-destructive">{tokenError}</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* ── Custom token editor ─────────────────────────────────────────── */}
|
||||
{activePresetId === 'custom' && draft && (
|
||||
<div className="border-t pt-5">
|
||||
|
||||
@ -45,7 +45,7 @@ export interface PaletteEntry {
|
||||
badgeRole: string; // tailwind classes for role badge
|
||||
}
|
||||
|
||||
export type ThemePresetId = 'default' | 'custom';
|
||||
export type ThemePresetId = 'default' | 'midnight' | 'forest' | 'sunset' | 'rose' | 'lavender' | 'arctic' | 'nord' | 'dracula' | 'light' | 'custom';
|
||||
|
||||
export interface ThemePreset {
|
||||
id: ThemePresetId;
|
||||
@ -64,6 +64,159 @@ export const THEME_PRESETS: ThemePreset[] = [
|
||||
description: 'Linear / Vercel inspired — neutral + single indigo accent',
|
||||
palette: null, // reads live from CSS vars
|
||||
},
|
||||
{
|
||||
id: 'midnight',
|
||||
label: 'Midnight',
|
||||
description: 'Deep blue-black with electric blue accents',
|
||||
palette: {
|
||||
bg: '#0a0e1a', bgSubtle: '#111827', bgHover: '#1e293b', bgActive: '#334155',
|
||||
border: '#1e293b', borderFocus: '#3b82f6', borderMuted: '#1e293b',
|
||||
text: '#e2e8f0', textMuted: '#94a3b8', textSubtle: '#64748b',
|
||||
accent: '#3b82f6', accentHover: '#60a5fa', accentText: '#ffffff',
|
||||
icon: '#94a3b8', iconHover: '#e2e8f0',
|
||||
surface: '#111827', surface2: '#1e293b',
|
||||
online: '#22c55e', away: '#eab308', offline: '#475569',
|
||||
mentionBg: 'rgba(59,130,246,0.15)', mentionText: '#60a5fa',
|
||||
msgBg: '#111827', msgOwnBg: 'rgba(59,130,246,0.12)',
|
||||
panelBg: '#0f172a', badgeAi: 'bg-blue-500/15 text-blue-400 font-medium', badgeRole: 'bg-slate-700 text-slate-300 font-medium',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'forest',
|
||||
label: 'Forest',
|
||||
description: 'Warm dark with emerald green accents',
|
||||
palette: {
|
||||
bg: '#0c1210', bgSubtle: '#131f1a', bgHover: '#1a2e26', bgActive: '#243d33',
|
||||
border: '#1a2e26', borderFocus: '#10b981', borderMuted: '#1a2e26',
|
||||
text: '#d1fae5', textMuted: '#6ee7b7', textSubtle: '#34d399',
|
||||
accent: '#10b981', accentHover: '#34d399', accentText: '#021a0f',
|
||||
icon: '#6ee7b7', iconHover: '#a7f3d0',
|
||||
surface: '#131f1a', surface2: '#1a2e26',
|
||||
online: '#10b981', away: '#f59e0b', offline: '#4b5563',
|
||||
mentionBg: 'rgba(16,185,129,0.15)', mentionText: '#34d399',
|
||||
msgBg: '#131f1a', msgOwnBg: 'rgba(16,185,129,0.1)',
|
||||
panelBg: '#0a1410', badgeAi: 'bg-emerald-500/15 text-emerald-400 font-medium', badgeRole: 'bg-emerald-900 text-emerald-300 font-medium',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'sunset',
|
||||
label: 'Sunset',
|
||||
description: 'Warm dark with orange-amber accents',
|
||||
palette: {
|
||||
bg: '#120e0a', bgSubtle: '#1c1612', bgHover: '#2a1f18', bgActive: '#3d2c20',
|
||||
border: '#2a1f18', borderFocus: '#f59e0b', borderMuted: '#2a1f18',
|
||||
text: '#fef3c7', textMuted: '#fbbf24', textSubtle: '#d97706',
|
||||
accent: '#f59e0b', accentHover: '#fbbf24', accentText: '#1c1207',
|
||||
icon: '#fbbf24', iconHover: '#fde68a',
|
||||
surface: '#1c1612', surface2: '#2a1f18',
|
||||
online: '#22c55e', away: '#f59e0b', offline: '#6b7280',
|
||||
mentionBg: 'rgba(245,158,11,0.15)', mentionText: '#fbbf24',
|
||||
msgBg: '#1c1612', msgOwnBg: 'rgba(245,158,11,0.1)',
|
||||
panelBg: '#0f0c09', badgeAi: 'bg-amber-500/15 text-amber-400 font-medium', badgeRole: 'bg-amber-900 text-amber-300 font-medium',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'rose',
|
||||
label: 'Rose',
|
||||
description: 'Soft dark with rose-pink accents',
|
||||
palette: {
|
||||
bg: '#130b11', bgSubtle: '#1e131a', bgHover: '#2d1c27', bgActive: '#3f2636',
|
||||
border: '#2d1c27', borderFocus: '#f43f5e', borderMuted: '#2d1c27',
|
||||
text: '#fce7f3', textMuted: '#f9a8d4', textSubtle: '#ec4899',
|
||||
accent: '#f43f5e', accentHover: '#fb7185', accentText: '#ffffff',
|
||||
icon: '#f9a8d4', iconHover: '#fbcfe8',
|
||||
surface: '#1e131a', surface2: '#2d1c27',
|
||||
online: '#22c55e', away: '#f59e0b', offline: '#6b7280',
|
||||
mentionBg: 'rgba(244,63,94,0.15)', mentionText: '#fb7185',
|
||||
msgBg: '#1e131a', msgOwnBg: 'rgba(244,63,94,0.1)',
|
||||
panelBg: '#100a0e', badgeAi: 'bg-rose-500/15 text-rose-400 font-medium', badgeRole: 'bg-rose-900 text-rose-300 font-medium',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'lavender',
|
||||
label: 'Lavender',
|
||||
description: 'Cool dark with purple-violet accents',
|
||||
palette: {
|
||||
bg: '#0e0b14', bgSubtle: '#16121e', bgHover: '#201a2e', bgActive: '#2e2540',
|
||||
border: '#201a2e', borderFocus: '#8b5cf6', borderMuted: '#201a2e',
|
||||
text: '#ede9fe', textMuted: '#c4b5fd', textSubtle: '#a78bfa',
|
||||
accent: '#8b5cf6', accentHover: '#a78bfa', accentText: '#ffffff',
|
||||
icon: '#c4b5fd', iconHover: '#ddd6fe',
|
||||
surface: '#16121e', surface2: '#201a2e',
|
||||
online: '#22c55e', away: '#f59e0b', offline: '#6b7280',
|
||||
mentionBg: 'rgba(139,92,246,0.15)', mentionText: '#a78bfa',
|
||||
msgBg: '#16121e', msgOwnBg: 'rgba(139,92,246,0.1)',
|
||||
panelBg: '#0b0810', badgeAi: 'bg-violet-500/15 text-violet-400 font-medium', badgeRole: 'bg-violet-900 text-violet-300 font-medium',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'arctic',
|
||||
label: 'Arctic',
|
||||
description: 'Clean light with cyan-teal accents',
|
||||
palette: {
|
||||
bg: '#f0fdfa', bgSubtle: '#e0f7f0', bgHover: '#ccfbf1', bgActive: '#99f6e4',
|
||||
border: '#99f6e4', borderFocus: '#14b8a6', borderMuted: '#ccfbf1',
|
||||
text: '#134e4a', textMuted: '#0f766e', textSubtle: '#5eead4',
|
||||
accent: '#14b8a6', accentHover: '#2dd4bf', accentText: '#ffffff',
|
||||
icon: '#0f766e', iconHover: '#115e59',
|
||||
surface: '#e0f7f0', surface2: '#ccfbf1',
|
||||
online: '#10b981', away: '#f59e0b', offline: '#94a3b8',
|
||||
mentionBg: 'rgba(20,184,166,0.12)', mentionText: '#0f766e',
|
||||
msgBg: '#e0f7f0', msgOwnBg: 'rgba(20,184,166,0.08)',
|
||||
panelBg: '#f0fdfa', badgeAi: 'bg-teal-500/15 text-teal-700 font-medium', badgeRole: 'bg-teal-100 text-teal-800 font-medium',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'nord',
|
||||
label: 'Nord',
|
||||
description: 'Arctic blue-gray inspired by Nord theme',
|
||||
palette: {
|
||||
bg: '#2e3440', bgSubtle: '#3b4252', bgHover: '#434c5e', bgActive: '#4c566a',
|
||||
border: '#434c5e', borderFocus: '#88c0d0', borderMuted: '#3b4252',
|
||||
text: '#eceff4', textMuted: '#d8dee9', textSubtle: '#81a1c1',
|
||||
accent: '#88c0d0', accentHover: '#8fbcbb', accentText: '#2e3440',
|
||||
icon: '#d8dee9', iconHover: '#eceff4',
|
||||
surface: '#3b4252', surface2: '#434c5e',
|
||||
online: '#a3be8c', away: '#ebcb8b', offline: '#4c566a',
|
||||
mentionBg: 'rgba(136,192,208,0.15)', mentionText: '#88c0d0',
|
||||
msgBg: '#3b4252', msgOwnBg: 'rgba(136,192,208,0.1)',
|
||||
panelBg: '#2e3440', badgeAi: 'bg-sky-500/15 text-sky-300 font-medium', badgeRole: 'bg-slate-700 text-slate-300 font-medium',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'dracula',
|
||||
label: 'Dracula',
|
||||
description: 'Classic Dracula theme with pink and purple',
|
||||
palette: {
|
||||
bg: '#282a36', bgSubtle: '#343746', bgHover: '#44475a', bgActive: '#565970',
|
||||
border: '#44475a', borderFocus: '#bd93f9', borderMuted: '#343746',
|
||||
text: '#f8f8f2', textMuted: '#6272a4', textSubtle: '#6272a4',
|
||||
accent: '#bd93f9', accentHover: '#caa9fa', accentText: '#282a36',
|
||||
icon: '#f8f8f2', iconHover: '#f1fa8c',
|
||||
surface: '#343746', surface2: '#44475a',
|
||||
online: '#50fa7b', away: '#f1fa8c', offline: '#6272a4',
|
||||
mentionBg: 'rgba(189,147,249,0.15)', mentionText: '#bd93f9',
|
||||
msgBg: '#343746', msgOwnBg: 'rgba(189,147,249,0.1)',
|
||||
panelBg: '#282a36', badgeAi: 'bg-purple-500/15 text-purple-400 font-medium', badgeRole: 'bg-purple-900 text-purple-300 font-medium',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'light',
|
||||
label: 'Clean Light',
|
||||
description: 'Minimal light theme with blue accents',
|
||||
palette: {
|
||||
bg: '#ffffff', bgSubtle: '#f8fafc', bgHover: '#f1f5f9', bgActive: '#e2e8f0',
|
||||
border: '#e2e8f0', borderFocus: '#3b82f6', borderMuted: '#f1f5f9',
|
||||
text: '#0f172a', textMuted: '#64748b', textSubtle: '#94a3b8',
|
||||
accent: '#3b82f6', accentHover: '#2563eb', accentText: '#ffffff',
|
||||
icon: '#64748b', iconHover: '#0f172a',
|
||||
surface: '#f8fafc', surface2: '#f1f5f9',
|
||||
online: '#22c55e', away: '#f59e0b', offline: '#cbd5e1',
|
||||
mentionBg: 'rgba(59,130,246,0.08)', mentionText: '#2563eb',
|
||||
msgBg: '#f8fafc', msgOwnBg: 'rgba(59,130,246,0.06)',
|
||||
panelBg: '#ffffff', badgeAi: 'bg-blue-500/10 text-blue-600 font-medium', badgeRole: 'bg-slate-100 text-slate-700 font-medium',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
/** Well-known CSS vars that map to PaletteEntry keys */
|
||||
@ -237,3 +390,81 @@ export function deactivateCustomPalette() {
|
||||
clearCustomPalette();
|
||||
resetDOMFromPalette();
|
||||
}
|
||||
|
||||
// ─── Theme Token ─────────────────────────────────────────────────────────────
|
||||
// Encode a palette into a compact shareable token, and decode it back.
|
||||
//
|
||||
// Token format: "T1:" + base64url(minified JSON)
|
||||
// Version byte (T1) allows future format changes.
|
||||
// Only the 16 core color fields are encoded (badge classes are derived).
|
||||
|
||||
const TOKEN_VERSION = 'T1:';
|
||||
const TOKEN_FIELDS: (keyof PaletteEntry)[] = [
|
||||
'bg', 'bgSubtle', 'bgHover', 'bgActive',
|
||||
'border', 'borderFocus', 'borderMuted',
|
||||
'text', 'textMuted', 'textSubtle',
|
||||
'accent', 'accentHover', 'accentText',
|
||||
'icon', 'iconHover',
|
||||
'surface', 'surface2',
|
||||
'online', 'away', 'offline',
|
||||
'mentionBg', 'mentionText',
|
||||
'msgBg', 'msgOwnBg',
|
||||
'panelBg',
|
||||
];
|
||||
|
||||
function base64urlEncode(str: string): string {
|
||||
return btoa(str).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, '');
|
||||
}
|
||||
|
||||
function base64urlDecode(str: string): string {
|
||||
let s = str.replace(/-/g, '+').replace(/_/g, '/');
|
||||
while (s.length % 4) s += '=';
|
||||
return atob(s);
|
||||
}
|
||||
|
||||
/** Encode a PaletteEntry into a shareable token string. */
|
||||
export function encodeThemeToken(palette: PaletteEntry): string {
|
||||
const compact: Record<string, string> = {};
|
||||
for (const key of TOKEN_FIELDS) {
|
||||
compact[key] = (palette as unknown as Record<string, string>)[key];
|
||||
}
|
||||
const json = JSON.stringify(compact);
|
||||
return TOKEN_VERSION + base64urlEncode(json);
|
||||
}
|
||||
|
||||
/** Decode a token string back into a PaletteEntry. Returns null if invalid. */
|
||||
export function decodeThemeToken(token: string): PaletteEntry | null {
|
||||
try {
|
||||
if (!token.startsWith(TOKEN_VERSION)) return null;
|
||||
const b64 = token.slice(TOKEN_VERSION.length);
|
||||
const json = base64urlDecode(b64);
|
||||
const compact = JSON.parse(json) as Record<string, string>;
|
||||
|
||||
// Validate that all required fields are present
|
||||
for (const key of TOKEN_FIELDS) {
|
||||
if (typeof compact[key] !== 'string') return null;
|
||||
}
|
||||
|
||||
// Build full palette with derived badge classes
|
||||
return {
|
||||
...compact,
|
||||
badgeAi: 'bg-accent/10 text-accent font-medium',
|
||||
badgeRole: 'bg-muted text-muted-foreground font-medium',
|
||||
} as PaletteEntry;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/** Check if a string looks like a valid theme token. */
|
||||
export function isThemeToken(str: string): boolean {
|
||||
return str.startsWith(TOKEN_VERSION) && str.length > TOKEN_VERSION.length;
|
||||
}
|
||||
|
||||
/** Apply a theme from a token string. Returns true if successful. */
|
||||
export function applyThemeFromToken(token: string): boolean {
|
||||
const palette = decodeThemeToken(token);
|
||||
if (!palette) return false;
|
||||
activateCustomPalette(palette);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -166,7 +166,7 @@ export const MessageInput = forwardRef<MessageInputHandle, MessageInputProps>(fu
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
const baseUrl = import.meta.env.VITE_API_BASE_URL ?? window.location.origin;
|
||||
const res = await fetch(`${baseUrl}/rooms/${activeRoomId}/upload`, {method: 'POST', body: formData});
|
||||
const res = await fetch(`${baseUrl}/api/rooms/${activeRoomId}/upload`, {method: 'POST', body: formData, credentials: 'include'});
|
||||
if (!res.ok) throw new Error('Upload failed');
|
||||
return res.json();
|
||||
};
|
||||
|
||||
@ -74,6 +74,8 @@ export type MessageWithMeta = RoomMessageResponse & {
|
||||
chunk_type?: string;
|
||||
/** Accumulated thinking/reasoning content from AI stream (collapsible) */
|
||||
thinking_content?: string;
|
||||
/** True when thinking_content is JSON chunk array, false for plain text */
|
||||
thinking_is_chunked?: boolean;
|
||||
};
|
||||
|
||||
export type RoomWithCategory = RoomResponse & {
|
||||
@ -83,6 +85,7 @@ export type RoomWithCategory = RoomResponse & {
|
||||
export type UiMessage = MessageWithMeta;
|
||||
|
||||
function wsMessageToUiMessage(wsMsg: RoomMessagePayload): MessageWithMeta {
|
||||
const thinkingIsChunked = wsMsg.thinking_content?.includes('__chunks__') ?? false;
|
||||
return {
|
||||
id: wsMsg.id,
|
||||
seq: wsMsg.seq,
|
||||
@ -99,6 +102,7 @@ function wsMessageToUiMessage(wsMsg: RoomMessagePayload): MessageWithMeta {
|
||||
is_streaming: false,
|
||||
reactions: wsMsg.reactions,
|
||||
thinking_content: wsMsg.thinking_content,
|
||||
thinking_is_chunked: thinkingIsChunked,
|
||||
};
|
||||
}
|
||||
|
||||
@ -162,6 +166,8 @@ interface RoomContextValue {
|
||||
streamingChunks: Map<string, Array<{ type: string; content: string }>>;
|
||||
/** Active AI stream info for typing indicator */
|
||||
activeAiStream: { message_id: string; display_name: string } | null;
|
||||
/** Cancel an active AI streaming session */
|
||||
cancelAiStream: () => Promise<boolean>;
|
||||
|
||||
/** Project repositories for @repository: mention suggestions */
|
||||
projectRepos: ProjectRepositoryItem[];
|
||||
@ -416,7 +422,15 @@ export function RoomProvider({
|
||||
if (abortController.signal.aborted) return prev;
|
||||
if (isInitial) {
|
||||
setIsTransitioningRoom(false);
|
||||
return newMessages;
|
||||
// Merge: preserve any WS messages that arrived during loading
|
||||
const existingIds = new Set(newMessages.map((m) => m.id));
|
||||
const pending = prev.filter((m) => !existingIds.has(m.id));
|
||||
let merged = [...newMessages, ...pending];
|
||||
merged.sort((a, b) => a.seq - b.seq);
|
||||
if (merged.length > MAX_MESSAGES_IN_MEMORY) {
|
||||
merged = merged.slice(-MAX_MESSAGES_IN_MEMORY);
|
||||
}
|
||||
return merged;
|
||||
}
|
||||
const existingIds = new Set(prev.map((m) => m.id));
|
||||
const filtered = newMessages.filter((m) => !existingIds.has(m.id));
|
||||
@ -1197,6 +1211,12 @@ export function RoomProvider({
|
||||
[],
|
||||
);
|
||||
|
||||
const cancelAiStream = useCallback(async () => {
|
||||
const client = wsClientRef.current;
|
||||
if (!client) return false;
|
||||
return client.cancelAiStream(activeRoomIdRef.current ?? '');
|
||||
}, []);
|
||||
|
||||
const updateReadSeq = useCallback(
|
||||
async (seq: number) => {
|
||||
const client = wsClientRef.current;
|
||||
@ -1510,6 +1530,7 @@ export function RoomProvider({
|
||||
deleteRoom,
|
||||
streamingChunks,
|
||||
activeAiStream,
|
||||
cancelAiStream,
|
||||
projectRepos,
|
||||
reposLoading,
|
||||
roomAiConfigs,
|
||||
@ -1565,6 +1586,7 @@ export function RoomProvider({
|
||||
deleteRoom,
|
||||
streamingChunks,
|
||||
activeAiStream,
|
||||
cancelAiStream,
|
||||
projectRepos,
|
||||
reposLoading,
|
||||
roomAiConfigs,
|
||||
|
||||
76
src/hooks/use-ai-streaming.ts
Normal file
76
src/hooks/use-ai-streaming.ts
Normal file
@ -0,0 +1,76 @@
|
||||
import { useCallback, useRef, useState } from 'react';
|
||||
import type { RoomWsClient } from '@/lib/room-ws-client';
|
||||
|
||||
export interface AiStreamChunk {
|
||||
type: string;
|
||||
content: string;
|
||||
seq?: number;
|
||||
}
|
||||
|
||||
export interface ActiveAiStream {
|
||||
message_id: string;
|
||||
display_name: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook managing AI streaming state: streaming chunks, active stream indicator,
|
||||
* and stream cancellation. Separated from the main room context to reduce
|
||||
* the God component size (~1583 lines → ~300).
|
||||
*/
|
||||
export function useAiStreaming(clientRef: React.MutableRefObject<RoomWsClient | null>) {
|
||||
const [streamingChunks, setStreamingChunks] = useState<Map<string, AiStreamChunk[]>>(new Map());
|
||||
const [activeAiStream, setActiveAiStream] = useState<ActiveAiStream | null>(null);
|
||||
// Ref to latest chunks so done handler reads current state (setState is async)
|
||||
const chunksRef = useRef<Map<string, AiStreamChunk[]>>(new Map());
|
||||
|
||||
const clearStreamingState = useCallback((msgId: string) => {
|
||||
setStreamingChunks(prev => { prev.delete(msgId); return new Map(prev); });
|
||||
chunksRef.current.delete(msgId);
|
||||
setActiveAiStream(null);
|
||||
}, []);
|
||||
|
||||
const insertChunk = useCallback((
|
||||
messageId: string,
|
||||
chunkType: string | undefined,
|
||||
content: string,
|
||||
seq: number | undefined,
|
||||
) => {
|
||||
setStreamingChunks(prev => {
|
||||
const next = new Map(prev);
|
||||
const existing: AiStreamChunk[] = next.get(messageId) ?? [];
|
||||
const s = seq ?? existing.length;
|
||||
const newChunk: AiStreamChunk = { type: chunkType ?? 'answer', content, seq: s };
|
||||
const insertIdx = existing.findIndex(c => c.seq != null && c.seq > s);
|
||||
next.set(messageId,
|
||||
insertIdx === -1
|
||||
? [...existing, newChunk]
|
||||
: [...existing.slice(0, insertIdx), newChunk, ...existing.slice(insertIdx)]
|
||||
);
|
||||
chunksRef.current = new Map(next);
|
||||
return next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
const getOrderedChunks = useCallback((msgId: string): AiStreamChunk[] => {
|
||||
return chunksRef.current.get(msgId) ?? [];
|
||||
}, []);
|
||||
|
||||
const cancelAiStream = useCallback(async () => {
|
||||
const client = clientRef.current;
|
||||
if (!client) return false;
|
||||
const roomId = client.getSubscribedRooms().values().next().value;
|
||||
if (!roomId) return false;
|
||||
return client.cancelAiStream(roomId as string);
|
||||
}, [clientRef]);
|
||||
|
||||
return {
|
||||
streamingChunks,
|
||||
activeAiStream,
|
||||
setActiveAiStream,
|
||||
clearStreamingState,
|
||||
insertChunk,
|
||||
getOrderedChunks,
|
||||
cancelAiStream,
|
||||
chunksRef,
|
||||
};
|
||||
}
|
||||
118
src/hooks/use-room-messages.ts
Normal file
118
src/hooks/use-room-messages.ts
Normal file
@ -0,0 +1,118 @@
|
||||
import { useCallback, useState } from 'react';
|
||||
import type { RoomWsClient } from '@/lib/room-ws-client';
|
||||
import type { MessageWithMeta } from '@/contexts/room-context';
|
||||
|
||||
const MAX_MESSAGES_IN_MEMORY = 1000;
|
||||
|
||||
/**
|
||||
* Hook managing room messages state: list, send, edit, revoke.
|
||||
* Separated to reduce the main room context (~1583 lines).
|
||||
*/
|
||||
export function useRoomMessages(clientRef: React.MutableRefObject<RoomWsClient | null>) {
|
||||
const [messages, setMessages] = useState<MessageWithMeta[]>([]);
|
||||
const [messagesLoading, setMessagesLoading] = useState(false);
|
||||
const [isHistoryLoaded, setIsHistoryLoaded] = useState(false);
|
||||
const [isLoadingMore, setIsLoadingMore] = useState(false);
|
||||
const [isTransitioningRoom, setIsTransitioningRoom] = useState(false);
|
||||
const [nextCursor, setNextCursor] = useState<number | null>(null);
|
||||
|
||||
const appendMessage = useCallback((msg: MessageWithMeta) => {
|
||||
setMessages(prev => {
|
||||
const exists = prev.some(m => m.id === msg.id);
|
||||
if (exists) return prev.map(m => m.id === msg.id ? { ...m, ...msg } : m);
|
||||
const next = [...prev, msg];
|
||||
return next.length > MAX_MESSAGES_IN_MEMORY ? next.slice(-MAX_MESSAGES_IN_MEMORY) : next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
const updateMessage = useCallback((msgId: string, updater: (m: MessageWithMeta) => MessageWithMeta) => {
|
||||
setMessages(prev => prev.map(m => m.id === msgId ? updater(m) : m));
|
||||
}, []);
|
||||
|
||||
const removeMessage = useCallback((msgId: string) => {
|
||||
setMessages(prev => prev.filter(m => m.id !== msgId));
|
||||
}, []);
|
||||
|
||||
const clearMessages = useCallback(() => {
|
||||
setMessages([]);
|
||||
setIsHistoryLoaded(false);
|
||||
setNextCursor(null);
|
||||
}, []);
|
||||
|
||||
const sendMessage = useCallback(async (
|
||||
content: string,
|
||||
contentType?: string,
|
||||
inReplyTo?: string,
|
||||
attachmentIds?: string[],
|
||||
) => {
|
||||
const client = clientRef.current;
|
||||
if (!client) return;
|
||||
const roomId = client.getSubscribedRooms().values().next().value;
|
||||
if (!roomId) return;
|
||||
|
||||
const optimistic: MessageWithMeta = {
|
||||
id: crypto.randomUUID(),
|
||||
seq: 0,
|
||||
room: roomId as string,
|
||||
sender_type: 'member',
|
||||
content,
|
||||
content_type: contentType || 'text',
|
||||
send_at: new Date().toISOString(),
|
||||
display_content: content,
|
||||
isOptimistic: true,
|
||||
attachment_ids: attachmentIds,
|
||||
};
|
||||
appendMessage(optimistic);
|
||||
|
||||
try {
|
||||
const result = await client.messageCreate(roomId as string, content, {
|
||||
contentType,
|
||||
inReplyTo,
|
||||
attachmentIds,
|
||||
});
|
||||
setMessages(prev => prev.map(m => m.id === optimistic.id ? {
|
||||
...result,
|
||||
id: result.id,
|
||||
seq: result.seq,
|
||||
isOptimistic: false,
|
||||
} : m));
|
||||
} catch (err) {
|
||||
setMessages(prev => prev.map(m => m.id === optimistic.id ? {
|
||||
...m, isOptimistic: false, isOptimisticError: true,
|
||||
} : m));
|
||||
throw err;
|
||||
}
|
||||
}, [clientRef, appendMessage]);
|
||||
|
||||
const editMessage = useCallback(async (messageId: string, content: string) => {
|
||||
const client = clientRef.current;
|
||||
if (!client) return;
|
||||
await client.messageUpdate(messageId, content);
|
||||
setMessages(prev => prev.map(m => m.id === messageId ? { ...m, content, display_content: content } : m));
|
||||
}, [clientRef]);
|
||||
|
||||
const revokeMessage = useCallback(async (messageId: string) => {
|
||||
const client = clientRef.current;
|
||||
if (!client) return;
|
||||
let rollback: MessageWithMeta | null = null;
|
||||
setMessages(prev => {
|
||||
rollback = prev.find(m => m.id === messageId) ?? null;
|
||||
return prev.filter(m => m.id !== messageId);
|
||||
});
|
||||
try {
|
||||
await client.messageRevoke(messageId);
|
||||
} catch {
|
||||
if (rollback) setMessages(prev => [...prev, rollback!]);
|
||||
}
|
||||
}, [clientRef]);
|
||||
|
||||
return {
|
||||
messages, setMessages, appendMessage, updateMessage, removeMessage, clearMessages,
|
||||
messagesLoading, setMessagesLoading,
|
||||
isHistoryLoaded, setIsHistoryLoaded,
|
||||
isLoadingMore, setIsLoadingMore,
|
||||
isTransitioningRoom, setIsTransitioningRoom,
|
||||
nextCursor, setNextCursor,
|
||||
sendMessage, editMessage, revokeMessage,
|
||||
};
|
||||
}
|
||||
@ -453,6 +453,7 @@ export class RoomWsClient {
|
||||
case 'ai.list': return { path: '/rooms/{room_id}/ai', method: 'GET', pathParams: ['room_id'] };
|
||||
case 'ai.upsert': return { path: '/rooms/{room_id}/ai', method: 'PUT', pathParams: ['room_id'] };
|
||||
case 'ai.delete': return { path: '/rooms/{room_id}/ai/{model_id}', method: 'DELETE', pathParams: ['room_id', 'model_id'] };
|
||||
case 'ai.stop': return { path: '/rooms/{room_id}/ai/stop', method: 'POST', pathParams: ['room_id'] };
|
||||
case 'notification.list': return { path: '/me/notifications', method: 'GET', pathParams: [] };
|
||||
case 'notification.mark_read': return { path: '/me/notifications/{notification_id}/read', method: 'POST', pathParams: ['notification_id'] };
|
||||
case 'notification.mark_all_read': return { path: '/me/notifications/read-all', method: 'POST', pathParams: [] };
|
||||
@ -1203,6 +1204,20 @@ export class RoomWsClient {
|
||||
}
|
||||
}
|
||||
|
||||
/** Cancel an active AI streaming session for a room. */
|
||||
async cancelAiStream(roomId: string): Promise<boolean> {
|
||||
if (this.status !== 'open' || !this.ws) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
const data = await this.requestWs<boolean>('ai.stop', { room_id: roomId });
|
||||
return data === true;
|
||||
} catch (err) {
|
||||
console.warn('[RoomWs] cancelAiStream failed:', roomId, err);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private scheduleReconnect(): void {
|
||||
if (!this.shouldReconnect) return;
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ export type WsAction =
|
||||
| 'ai.list'
|
||||
| 'ai.upsert'
|
||||
| 'ai.delete'
|
||||
| 'ai.stop'
|
||||
| 'notification.list'
|
||||
| 'notification.mark_read'
|
||||
| 'notification.mark_all_read'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user