1020 lines
42 KiB
Rust
1020 lines
42 KiB
Rust
use agent::chat::chat_execution;
|
|
use agent::chat::{AiChunkType, AiStreamChunk, normalize_thinking_content};
|
|
use agent::client::AiClientConfig;
|
|
use agent::client::types::ChatRequestMessage;
|
|
use agent::client::{StreamChunk, StreamChunkType};
|
|
use agent::react::PERSONAL_CONTEXT_PROMPT;
|
|
use futures::StreamExt;
|
|
use models::agents::{model, model_version};
|
|
use models::ai::{AiMessage, ai_conversation, ai_message};
|
|
use queue::{ChatMessageEvent, ChatStreamChunkEvent};
|
|
use sea_orm::{
|
|
ActiveModelTrait, ColumnTrait, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder, Set,
|
|
};
|
|
use service::AppService;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
use tokio_stream::wrappers::ReceiverStream;
|
|
use uuid::Uuid;
|
|
|
|
/// Create an SSE stream that executes AI chat with ReAct tool-calling.
|
|
///
|
|
/// Also publishes chat messages and stream chunks via NATS JetStream for
|
|
/// multi-viewer support. The requesting client receives SSE events, while
|
|
/// other viewers receive chunks via NATS -> WebSocket broadcast.
|
|
pub fn create_chat_sse_stream(
|
|
service: AppService,
|
|
conversation_id: Uuid,
|
|
user_message_id: Uuid,
|
|
model_name: String,
|
|
user_id: Uuid,
|
|
) -> Pin<Box<dyn futures::Stream<Item = Result<actix_web::web::Bytes, actix_web::Error>> + Send>> {
|
|
let (tx, rx) = tokio::sync::mpsc::channel::<String>(100);
|
|
|
|
let cache = service.cache.clone();
|
|
|
|
tokio::spawn(async move {
|
|
// Check for active stream (SSE reconnect recovery) BEFORE starting a new one
|
|
// so the frontend can recover from a page refresh.
|
|
if let Some((msg_id, started_at)) = cache.get_chat_stream_active(conversation_id).await {
|
|
let _ = tx.send(format!(
|
|
"data: {{\"event\":\"recovery\",\"data\":{{\"message_id\":\"{}\",\"started_at\":{}}}}}\n\n",
|
|
msg_id,
|
|
started_at
|
|
)).await;
|
|
let _ = tx
|
|
.send("data: {\"event\":\"done\",\"data\":\"recovery\"}\n\n".to_string())
|
|
.await;
|
|
return;
|
|
}
|
|
|
|
let queue = service.queue_producer.clone();
|
|
let chunk_seq = Arc::new(AtomicU64::new(0));
|
|
|
|
// Build messages from conversation history
|
|
let messages = match build_messages_from_history(&service, conversation_id).await {
|
|
Ok(msgs) => msgs,
|
|
Err(e) => {
|
|
let payload = serde_json::json!({"event":"error","data": e.to_string()});
|
|
let _ = tx.send(format!("data: {}\n\n", payload)).await;
|
|
return;
|
|
}
|
|
};
|
|
|
|
// Get AI config
|
|
let api_key = match service.config.ai_api_key() {
|
|
Ok(k) => k,
|
|
Err(_) => {
|
|
let _ = tx
|
|
.send(
|
|
"data: {\"event\":\"error\",\"data\":\"AI not configured\"}\n\n"
|
|
.to_string(),
|
|
)
|
|
.await;
|
|
return;
|
|
}
|
|
};
|
|
let base_url = match service.config.ai_basic_url() {
|
|
Ok(u) => u,
|
|
Err(_) => {
|
|
let _ = tx
|
|
.send(
|
|
"data: {\"event\":\"error\",\"data\":\"AI not configured\"}\n\n"
|
|
.to_string(),
|
|
)
|
|
.await;
|
|
return;
|
|
}
|
|
};
|
|
|
|
let config = AiClientConfig::new(api_key).with_base_url(&base_url);
|
|
|
|
// Get tools from ChatService if available
|
|
let (tools, tool_registry, embed_service) = match &service.chat_service {
|
|
Some(cs) => (
|
|
cs.tools(),
|
|
cs.tool_registry().cloned(),
|
|
service.embed_service.as_ref().map(|es| (**es).clone()),
|
|
),
|
|
None => (Vec::new(), None, None),
|
|
};
|
|
|
|
// Get project_id and scope from conversation
|
|
let (project_id, conv_project_id, is_personal) =
|
|
match service.find_conversation(conversation_id).await {
|
|
Ok(c) => {
|
|
let conv_project_id = c.project_id;
|
|
(
|
|
conv_project_id.unwrap_or(Uuid::nil()),
|
|
conv_project_id,
|
|
conv_project_id.is_none(),
|
|
)
|
|
}
|
|
Err(_) => {
|
|
let _ = tx
|
|
.send(
|
|
"data: {\"event\":\"error\",\"data\":\"conversation not found\"}\n\n"
|
|
.to_string(),
|
|
)
|
|
.await;
|
|
return;
|
|
}
|
|
};
|
|
|
|
// In personal scope: filter out project/git/repo tools and inject personal context prompt
|
|
let tools = if is_personal {
|
|
tools
|
|
.into_iter()
|
|
.filter(|t| {
|
|
let name = t
|
|
.get("function")
|
|
.and_then(|f| f.get("name"))
|
|
.and_then(|n| n.as_str())
|
|
.unwrap_or("");
|
|
!name.starts_with("project_")
|
|
&& !name.starts_with("git_")
|
|
&& !name.starts_with("repo_")
|
|
&& name != "send_message"
|
|
&& name != "retract_message"
|
|
})
|
|
.collect()
|
|
} else {
|
|
tools
|
|
};
|
|
|
|
// Inject personal context system prompt for non-project chats
|
|
let messages = if is_personal {
|
|
let mut msgs = messages;
|
|
msgs.insert(
|
|
0,
|
|
ChatRequestMessage::system(PERSONAL_CONTEXT_PROMPT.to_string()),
|
|
);
|
|
msgs
|
|
} else {
|
|
messages
|
|
};
|
|
|
|
let (model_record, billing_version_id) = match model::Entity::find()
|
|
.filter(model::Column::Name.eq(&model_name))
|
|
.one(service.db.reader())
|
|
.await
|
|
{
|
|
Ok(Some(m)) => {
|
|
let version_id = model_version::Entity::find()
|
|
.filter(model_version::Column::ModelId.eq(m.id))
|
|
.filter(model_version::Column::Status.eq("active"))
|
|
.order_by_desc(model_version::Column::IsDefault)
|
|
.order_by_desc(model_version::Column::ReleaseDate)
|
|
.one(service.db.reader())
|
|
.await
|
|
.ok()
|
|
.flatten()
|
|
.map(|v| v.id);
|
|
|
|
match version_id {
|
|
Some(version_id) => (m, version_id),
|
|
None => {
|
|
let error_msg = "AI model version is not configured. Please configure an active model version before using AI.";
|
|
let payload = serde_json::json!({"event":"billing_error","data":error_msg});
|
|
let _ = tx.send(format!("data: {}\n\n", payload)).await;
|
|
let _ = tx
|
|
.send(
|
|
"data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n"
|
|
.to_string(),
|
|
)
|
|
.await;
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
_ => {
|
|
let error_msg = "AI model is not configured. Please sync or configure the model before using AI.";
|
|
let payload = serde_json::json!({"event":"billing_error","data":error_msg});
|
|
let _ = tx.send(format!("data: {}\n\n", payload)).await;
|
|
let _ = tx
|
|
.send("data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n".to_string())
|
|
.await;
|
|
return;
|
|
}
|
|
};
|
|
|
|
// Pre-flight balance check: verify the selected account can afford a minimal AI call.
|
|
let balance_ok = if is_personal {
|
|
agent::billing::check_user_balance(&service.db, user_id, billing_version_id, 500, 250)
|
|
.await
|
|
} else {
|
|
agent::billing::check_balance(
|
|
&service.db,
|
|
project_id,
|
|
user_id,
|
|
billing_version_id,
|
|
500,
|
|
250,
|
|
)
|
|
.await
|
|
};
|
|
|
|
match balance_ok {
|
|
Ok(true) => {}
|
|
Ok(false) => {
|
|
tracing::warn!(project_id = %project_id, user_id = %user_id, personal = is_personal, "Insufficient balance for chat AI call");
|
|
|
|
let (scope, scope_id) = if is_personal {
|
|
("user", user_id)
|
|
} else {
|
|
("project", project_id)
|
|
};
|
|
let _ = agent::billing::persist_billing_error(
|
|
&service.db, scope, scope_id, "insufficient_balance",
|
|
"Insufficient balance. Your account does not have enough funds for this AI request.",
|
|
Some(serde_json::json!({
|
|
"user_id": user_id.to_string(),
|
|
"project_id": if is_personal { None } else { Some(project_id.to_string()) },
|
|
"model_version_id": billing_version_id.to_string(),
|
|
})),
|
|
).await;
|
|
|
|
let error_msg = "Insufficient balance. Your account does not have enough funds to process this AI request. Please add credits to continue.";
|
|
let payload = serde_json::json!({"event":"billing_error","data":error_msg});
|
|
let _ = tx.send(format!("data: {}\n\n", payload)).await;
|
|
let _ = tx
|
|
.send("data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n".to_string())
|
|
.await;
|
|
return;
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "Balance check failed");
|
|
let error_msg = format!("Billing check failed: {}", e);
|
|
let payload = serde_json::json!({"event":"billing_error","data":error_msg});
|
|
let _ = tx.send(format!("data: {}\n\n", payload)).await;
|
|
let _ = tx
|
|
.send("data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n".to_string())
|
|
.await;
|
|
return;
|
|
}
|
|
}
|
|
|
|
let max_tool_depth = 99;
|
|
let assistant_msg_id = Uuid::now_v7();
|
|
|
|
// Determine conversation project_id for chat message event
|
|
// Broadcast chat message start event via NATS
|
|
let chat_msg = ChatMessageEvent {
|
|
message_id: assistant_msg_id,
|
|
conversation_id,
|
|
project_id: conv_project_id,
|
|
sender_id: Uuid::nil(),
|
|
role: "assistant".to_string(),
|
|
content: String::new(),
|
|
model: Some(model_name.clone()),
|
|
input_tokens: None,
|
|
output_tokens: None,
|
|
timestamp: chrono::Utc::now(),
|
|
};
|
|
let _ = queue.publish_chat_message(&chat_msg).await;
|
|
|
|
// Mark stream as active in Redis so page refresh can recover
|
|
let _ = cache
|
|
.set_chat_stream_active(conversation_id, user_message_id)
|
|
.await;
|
|
|
|
// Clear any stale cancel flag before starting
|
|
let _ = cache.clear_chat_stream_cancelled(conversation_id).await;
|
|
|
|
// Cancellation token checked in on_chunk and by a periodic poller.
|
|
let cancelled = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
|
let cancelled_for_on_chunk = cancelled.clone();
|
|
let recorded_chunks = Arc::new(tokio::sync::Mutex::new(Vec::<StreamChunk>::new()));
|
|
|
|
let on_chunk_tx = tx.clone();
|
|
let on_chunk_queue = queue.clone();
|
|
let on_chunk_seq = chunk_seq.clone();
|
|
let on_chunk_conv_id = conversation_id;
|
|
let on_chunk_msg_id = user_message_id;
|
|
let on_chunk_model = model_name.clone();
|
|
let on_chunk_recorded = recorded_chunks.clone();
|
|
|
|
let on_chunk: agent::chat::StreamCallback = Box::new(move |chunk: AiStreamChunk| {
|
|
let tx = on_chunk_tx.clone();
|
|
let queue = on_chunk_queue.clone();
|
|
let seq = on_chunk_seq.fetch_add(1, Ordering::Relaxed);
|
|
let conv_id = on_chunk_conv_id;
|
|
let msg_id = on_chunk_msg_id;
|
|
let model = on_chunk_model.clone();
|
|
let cancelled = cancelled_for_on_chunk.clone();
|
|
let recorded = on_chunk_recorded.clone();
|
|
Box::pin(async move {
|
|
// Check if stream has been cancelled
|
|
if cancelled.load(Ordering::Acquire) {
|
|
return;
|
|
}
|
|
|
|
let chunk_type = chunk.chunk_type.clone();
|
|
let event = match &chunk_type {
|
|
AiChunkType::Thinking => "thinking",
|
|
AiChunkType::Answer => "token",
|
|
AiChunkType::ToolCall => "tool_call",
|
|
AiChunkType::ToolResult => "tool_result",
|
|
};
|
|
let content = match &chunk_type {
|
|
AiChunkType::Thinking => normalize_thinking_content(&chunk.content),
|
|
_ => chunk.content.clone(),
|
|
};
|
|
|
|
// Build structured data payload based on chunk type
|
|
let data_json = match &chunk_type {
|
|
AiChunkType::ToolCall | AiChunkType::ToolResult => {
|
|
// Use structured metadata if available
|
|
if let Some(meta) = chunk.metadata.clone() {
|
|
meta
|
|
} else {
|
|
// Fallback: wrap raw content as display text
|
|
serde_json::json!({"display": content})
|
|
}
|
|
}
|
|
_ => {
|
|
// thinking / answer: send plain text content
|
|
serde_json::Value::String(content.clone())
|
|
}
|
|
};
|
|
let persisted_content = match &chunk_type {
|
|
AiChunkType::ToolCall | AiChunkType::ToolResult => data_json.to_string(),
|
|
_ => content.clone(),
|
|
};
|
|
let persisted_type = match &chunk_type {
|
|
AiChunkType::Thinking => StreamChunkType::Thinking,
|
|
AiChunkType::Answer => StreamChunkType::Answer,
|
|
AiChunkType::ToolCall => StreamChunkType::ToolCall,
|
|
AiChunkType::ToolResult => StreamChunkType::ToolResult,
|
|
};
|
|
recorded.lock().await.push(StreamChunk {
|
|
chunk_type: persisted_type,
|
|
content: persisted_content,
|
|
});
|
|
|
|
let mut sse_json = serde_json::json!({
|
|
"event": event,
|
|
"data": data_json,
|
|
});
|
|
if let Some(children_id) = chunk.children_id {
|
|
sse_json.as_object_mut().unwrap().insert(
|
|
"children_id".to_string(),
|
|
serde_json::Value::String(children_id),
|
|
);
|
|
}
|
|
|
|
let sse = format!(
|
|
"data: {}\n\n",
|
|
serde_json::to_string(&sse_json).unwrap_or_default()
|
|
);
|
|
let _ = tx.send(sse).await;
|
|
|
|
// Also broadcast via NATS for other viewers
|
|
let natts_chunk = ChatStreamChunkEvent {
|
|
conversation_id: conv_id,
|
|
message_id: msg_id,
|
|
seq,
|
|
content: chunk.content,
|
|
done: false,
|
|
error: None,
|
|
chunk_type: Some(event.to_string()),
|
|
model_name: Some(model),
|
|
};
|
|
queue.publish_chat_chunk(&natts_chunk).await;
|
|
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
|
|
});
|
|
|
|
let cancel_wait = {
|
|
let cache_for_check = cache.clone();
|
|
let conv_id_for_check = conversation_id;
|
|
async move {
|
|
let mut interval = tokio::time::interval(std::time::Duration::from_millis(250));
|
|
loop {
|
|
interval.tick().await;
|
|
if cache_for_check
|
|
.is_chat_stream_cancelled(conv_id_for_check)
|
|
.await
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
// Resolve max_tokens from model config (unlimited if not set)
|
|
let max_tokens = model_record
|
|
.max_output_tokens
|
|
.map(|v| v as u32)
|
|
.unwrap_or(u32::MAX);
|
|
|
|
let execution = chat_execution::execute_chat_stream(
|
|
messages,
|
|
tools,
|
|
&model_name,
|
|
&config,
|
|
0.7, // temperature
|
|
max_tokens, // max_tokens from model config
|
|
max_tool_depth,
|
|
tool_registry.as_ref(),
|
|
service.db.clone(),
|
|
service.cache.clone(),
|
|
service.config.clone(),
|
|
project_id,
|
|
Uuid::nil(), // sender_uid 閳?unknown in Chat API context
|
|
embed_service,
|
|
on_chunk,
|
|
Some(conversation_id),
|
|
Some(service.queue_producer.clone()),
|
|
);
|
|
|
|
let result = tokio::select! {
|
|
result = execution => Some(result),
|
|
_ = cancel_wait => {
|
|
cancelled.store(true, Ordering::Release);
|
|
None
|
|
}
|
|
};
|
|
|
|
// Clear stream active state and cancel flag (streaming finished)
|
|
let _ = cache.clear_chat_stream_active(conversation_id).await;
|
|
let _ = cache.clear_chat_stream_cancelled(conversation_id).await;
|
|
let was_cancelled = cancelled.load(Ordering::Acquire);
|
|
|
|
match result {
|
|
Some(Ok(stream_result)) => {
|
|
if was_cancelled {
|
|
let partial_chunks = recorded_chunks.lock().await.clone();
|
|
if let Some(msg) = persist_assistant_message_from_chunks(
|
|
&service,
|
|
conversation_id,
|
|
user_message_id,
|
|
assistant_msg_id,
|
|
&model_name,
|
|
&partial_chunks,
|
|
&stream_result.content,
|
|
stream_result.input_tokens,
|
|
stream_result.output_tokens,
|
|
"cancelled",
|
|
)
|
|
.await
|
|
{
|
|
update_conversation_after_response(&service, conversation_id, &msg).await;
|
|
}
|
|
let _ = tx
|
|
.send("data: {\"event\":\"done\",\"data\":\"stopped\"}\n\n".to_string())
|
|
.await;
|
|
return;
|
|
}
|
|
// Build ordered content blocks from stream chunks, merging
|
|
// consecutive blocks of the same role (thinking/assistant/tool_call/tool_result).
|
|
let raw_blocks: Vec<(String, String)> = stream_result
|
|
.chunks
|
|
.iter()
|
|
.filter(|c| {
|
|
matches!(
|
|
c.chunk_type,
|
|
StreamChunkType::Thinking
|
|
| StreamChunkType::Answer
|
|
| StreamChunkType::ToolCall
|
|
| StreamChunkType::ToolResult
|
|
)
|
|
})
|
|
.map(|chunk| {
|
|
let role = match chunk.chunk_type {
|
|
StreamChunkType::Thinking => "thinking",
|
|
StreamChunkType::ToolCall => "tool_call",
|
|
StreamChunkType::ToolResult => "tool_result",
|
|
_ => "assistant",
|
|
};
|
|
(role.to_string(), chunk.content.clone())
|
|
})
|
|
.collect();
|
|
|
|
let merged_blocks = merge_consecutive_blocks(raw_blocks);
|
|
// Apply thinking normalization to the fully merged thinking
|
|
// blocks 閳?per-token normalization is meaningless since each
|
|
// chunk is a single token.
|
|
let normalized_blocks: Vec<(String, String)> = merged_blocks
|
|
.into_iter()
|
|
.map(|(role, content)| {
|
|
if role == "thinking" {
|
|
(role, normalize_thinking_content(&content))
|
|
} else {
|
|
(role, content)
|
|
}
|
|
})
|
|
.collect();
|
|
let content_blocks: Vec<serde_json::Value> = normalized_blocks
|
|
.iter()
|
|
.map(|(role, content)| serde_json::json!({ "role": role, "content": content }))
|
|
.collect();
|
|
let content_value = if content_blocks.is_empty() {
|
|
serde_json::json!([{ "role": "assistant", "content": stream_result.content }])
|
|
} else {
|
|
serde_json::json!(content_blocks)
|
|
};
|
|
|
|
// Persist assistant message
|
|
let assistant_msg = ai_message::ActiveModel {
|
|
id: Set(assistant_msg_id),
|
|
conversation_id: Set(conversation_id),
|
|
parent_message_id: Set(Some(user_message_id)),
|
|
role: Set("assistant".to_string()),
|
|
content: Set(content_value),
|
|
model: Set(Some(model_name.clone())),
|
|
is_fork_origin: Set(false),
|
|
stop_reason: Set(Some("stop".to_string())),
|
|
input_tokens: Set(Some(stream_result.input_tokens as i32)),
|
|
output_tokens: Set(Some(stream_result.output_tokens as i32)),
|
|
latency_ms: Set(None),
|
|
metadata: Set(None),
|
|
room_id: Set(None),
|
|
version_group_id: Set(Some(assistant_msg_id)),
|
|
version_number: Set(1),
|
|
is_latest: Set(true),
|
|
created_at: Set(chrono::Utc::now()),
|
|
};
|
|
|
|
let saved = assistant_msg.insert(service.db.writer()).await;
|
|
|
|
if let Ok(msg) = &saved {
|
|
update_conversation_after_response(&service, conversation_id, msg).await;
|
|
|
|
// After AI response, check/update conversation title and emit via SSE
|
|
if let Ok(Some(conv)) = ai_conversation::Entity::find_by_id(conversation_id)
|
|
.one(service.db.reader())
|
|
.await
|
|
{
|
|
let existing_title = conv.title.clone();
|
|
let needs_title = existing_title
|
|
.as_deref()
|
|
.map(|t| t.is_empty() || t == "New Chat")
|
|
.unwrap_or(true);
|
|
|
|
if needs_title {
|
|
// Generate title from first user message
|
|
let first_user_msg = AiMessage::find()
|
|
.filter(ai_message::Column::ConversationId.eq(conversation_id))
|
|
.filter(ai_message::Column::Role.eq("user"))
|
|
.order_by_asc(ai_message::Column::CreatedAt)
|
|
.one(service.db.reader())
|
|
.await
|
|
.ok()
|
|
.flatten();
|
|
|
|
if let Some(user_msg) = first_user_msg {
|
|
let content = match &user_msg.content {
|
|
serde_json::Value::String(s) => s.clone(),
|
|
serde_json::Value::Array(arr) => arr
|
|
.first()
|
|
.and_then(|f| f.get("content"))
|
|
.and_then(|c| c.as_str())
|
|
.unwrap_or("")
|
|
.to_string(),
|
|
other => other.to_string(),
|
|
};
|
|
|
|
// Simple title extraction: first meaningful words
|
|
let title = content
|
|
.split_whitespace()
|
|
.filter(|w| w.len() > 2)
|
|
.take(5)
|
|
.collect::<Vec<_>>()
|
|
.join(" ");
|
|
|
|
if !title.is_empty() {
|
|
let truncated: String = title.chars().take(40).collect();
|
|
|
|
// Save title to DB
|
|
let mut active: ai_conversation::ActiveModel = conv.into();
|
|
active.title = Set(Some(truncated.clone()));
|
|
active.updated_at = Set(chrono::Utc::now());
|
|
let _ = active.update(service.db.writer()).await;
|
|
|
|
// Emit title via SSE
|
|
let title_payload =
|
|
serde_json::json!({"title": truncated}).to_string();
|
|
let _ = tx
|
|
.send(format!(
|
|
"data: {{\"event\":\"title\",\"data\":{}}}\n\n",
|
|
title_payload
|
|
))
|
|
.await;
|
|
}
|
|
}
|
|
} else if let Some(title) = &existing_title {
|
|
// Title already set (e.g. by AI tool) 閳?emit it
|
|
let title_payload = serde_json::json!({"title": title}).to_string();
|
|
let _ = tx
|
|
.send(format!(
|
|
"data: {{\"event\":\"title\",\"data\":{}}}\n\n",
|
|
title_payload
|
|
))
|
|
.await;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Record billing after successful AI response
|
|
let billing_result = if is_personal {
|
|
agent::billing::record_user_ai_usage(
|
|
&service.db,
|
|
user_id,
|
|
billing_version_id,
|
|
stream_result.input_tokens,
|
|
stream_result.output_tokens,
|
|
)
|
|
.await
|
|
} else {
|
|
agent::billing::record_ai_usage(
|
|
&service.db,
|
|
project_id,
|
|
user_id,
|
|
billing_version_id,
|
|
stream_result.input_tokens,
|
|
stream_result.output_tokens,
|
|
)
|
|
.await
|
|
};
|
|
|
|
let mut billing_failed = false;
|
|
match billing_result {
|
|
Ok(agent::billing::BillingResult::Success(record)) => {
|
|
tracing::info!(
|
|
cost = record.cost,
|
|
deducted_from = record.deducted_from.as_str(),
|
|
personal = is_personal,
|
|
"chat_billing_deducted"
|
|
);
|
|
}
|
|
Ok(agent::billing::BillingResult::InsufficientBalance { message }) => {
|
|
billing_failed = true;
|
|
tracing::warn!(
|
|
project_id = %project_id,
|
|
user_id = %user_id,
|
|
personal = is_personal,
|
|
"chat_billing_insufficient_balance"
|
|
);
|
|
let payload = serde_json::json!({"event":"billing_error","data":message});
|
|
let _ = tx.send(format!("data: {}\n\n", payload)).await;
|
|
}
|
|
Err(e) => {
|
|
billing_failed = true;
|
|
tracing::error!(error = %e, "chat_billing_error");
|
|
let payload = serde_json::json!({
|
|
"event":"billing_error",
|
|
"data": format!("Billing failed: {}", e),
|
|
});
|
|
let _ = tx.send(format!("data: {}\n\n", payload)).await;
|
|
}
|
|
}
|
|
|
|
// Broadcast final chat message with token usage
|
|
let final_msg = ChatMessageEvent {
|
|
message_id: assistant_msg_id,
|
|
conversation_id,
|
|
project_id: conv_project_id,
|
|
sender_id: Uuid::nil(),
|
|
role: "assistant".to_string(),
|
|
content: stream_result.content.clone(),
|
|
model: Some(model_name.clone()),
|
|
input_tokens: Some(stream_result.input_tokens as i32),
|
|
output_tokens: Some(stream_result.output_tokens as i32),
|
|
timestamp: chrono::Utc::now(),
|
|
};
|
|
let _ = queue.publish_chat_message(&final_msg).await;
|
|
|
|
// Send final SSE done event
|
|
let done_data = if billing_failed {
|
|
"billing_error"
|
|
} else {
|
|
"ok"
|
|
};
|
|
let _ = tx
|
|
.send(format!(
|
|
"data: {{\"event\":\"done\",\"data\":\"{}\"}}\n\n",
|
|
done_data
|
|
))
|
|
.await;
|
|
}
|
|
None => {
|
|
let partial_chunks = recorded_chunks.lock().await.clone();
|
|
if let Some(msg) = persist_assistant_message_from_chunks(
|
|
&service,
|
|
conversation_id,
|
|
user_message_id,
|
|
assistant_msg_id,
|
|
&model_name,
|
|
&partial_chunks,
|
|
"",
|
|
0,
|
|
0,
|
|
"cancelled",
|
|
)
|
|
.await
|
|
{
|
|
update_conversation_after_response(&service, conversation_id, &msg).await;
|
|
let final_msg = ChatMessageEvent {
|
|
message_id: assistant_msg_id,
|
|
conversation_id,
|
|
project_id: conv_project_id,
|
|
sender_id: Uuid::nil(),
|
|
role: "assistant".to_string(),
|
|
content: assistant_plain_text(&msg.content),
|
|
model: Some(model_name.clone()),
|
|
input_tokens: msg.input_tokens,
|
|
output_tokens: msg.output_tokens,
|
|
timestamp: chrono::Utc::now(),
|
|
};
|
|
let _ = queue.publish_chat_message(&final_msg).await;
|
|
}
|
|
let _ = tx
|
|
.send("data: {\"event\":\"done\",\"data\":\"stopped\"}\n\n".to_string())
|
|
.await;
|
|
}
|
|
Some(Err(e)) => {
|
|
let partial_chunks = recorded_chunks.lock().await.clone();
|
|
if let Some(msg) = persist_assistant_message_from_chunks(
|
|
&service,
|
|
conversation_id,
|
|
user_message_id,
|
|
assistant_msg_id,
|
|
&model_name,
|
|
&partial_chunks,
|
|
"",
|
|
0,
|
|
0,
|
|
"error",
|
|
)
|
|
.await
|
|
{
|
|
update_conversation_after_response(&service, conversation_id, &msg).await;
|
|
let final_msg = ChatMessageEvent {
|
|
message_id: assistant_msg_id,
|
|
conversation_id,
|
|
project_id: conv_project_id,
|
|
sender_id: Uuid::nil(),
|
|
role: "assistant".to_string(),
|
|
content: assistant_plain_text(&msg.content),
|
|
model: Some(model_name.clone()),
|
|
input_tokens: msg.input_tokens,
|
|
output_tokens: msg.output_tokens,
|
|
timestamp: chrono::Utc::now(),
|
|
};
|
|
let _ = queue.publish_chat_message(&final_msg).await;
|
|
}
|
|
let payload = serde_json::json!({"event":"error","data": e.to_string()});
|
|
let _ = tx.send(format!("data: {}\n\n", payload)).await;
|
|
let _ = tx
|
|
.send("data: {\"event\":\"done\",\"data\":\"error\"}\n\n".to_string())
|
|
.await;
|
|
}
|
|
}
|
|
});
|
|
|
|
Box::pin(ReceiverStream::new(rx).map(|msg| Ok(actix_web::web::Bytes::from(msg))))
|
|
}
|
|
|
|
fn content_value_from_chunks(chunks: &[StreamChunk], fallback: &str) -> Option<serde_json::Value> {
|
|
let raw_blocks: Vec<(String, String)> = chunks
|
|
.iter()
|
|
.filter(|c| {
|
|
matches!(
|
|
c.chunk_type,
|
|
StreamChunkType::Thinking
|
|
| StreamChunkType::Answer
|
|
| StreamChunkType::ToolCall
|
|
| StreamChunkType::ToolResult
|
|
)
|
|
})
|
|
.map(|chunk| {
|
|
let role = match chunk.chunk_type {
|
|
StreamChunkType::Thinking => "thinking",
|
|
StreamChunkType::ToolCall => "tool_call",
|
|
StreamChunkType::ToolResult => "tool_result",
|
|
_ => "assistant",
|
|
};
|
|
(role.to_string(), chunk.content.clone())
|
|
})
|
|
.collect();
|
|
|
|
let merged_blocks = merge_consecutive_blocks(raw_blocks);
|
|
let normalized_blocks: Vec<(String, String)> = merged_blocks
|
|
.into_iter()
|
|
.map(|(role, content)| {
|
|
if role == "thinking" {
|
|
(role, normalize_thinking_content(&content))
|
|
} else {
|
|
(role, content)
|
|
}
|
|
})
|
|
.filter(|(_, content)| !content.is_empty())
|
|
.collect();
|
|
|
|
if normalized_blocks.is_empty() && fallback.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
let content_blocks: Vec<serde_json::Value> = normalized_blocks
|
|
.iter()
|
|
.map(|(role, content)| serde_json::json!({ "role": role, "content": content }))
|
|
.collect();
|
|
Some(if content_blocks.is_empty() {
|
|
serde_json::json!([{ "role": "assistant", "content": fallback }])
|
|
} else {
|
|
serde_json::json!(content_blocks)
|
|
})
|
|
}
|
|
|
|
fn assistant_plain_text(content: &serde_json::Value) -> String {
|
|
match content {
|
|
serde_json::Value::String(s) => s.clone(),
|
|
serde_json::Value::Array(arr) => arr
|
|
.iter()
|
|
.filter(|item| item.get("role").and_then(|r| r.as_str()) != Some("thinking"))
|
|
.filter_map(|item| item.get("content").and_then(|c| c.as_str()))
|
|
.collect::<Vec<_>>()
|
|
.join("\n"),
|
|
other => other.to_string(),
|
|
}
|
|
}
|
|
|
|
async fn persist_assistant_message_from_chunks(
|
|
service: &AppService,
|
|
conversation_id: Uuid,
|
|
user_message_id: Uuid,
|
|
assistant_msg_id: Uuid,
|
|
model_name: &str,
|
|
chunks: &[StreamChunk],
|
|
fallback: &str,
|
|
input_tokens: i64,
|
|
output_tokens: i64,
|
|
stop_reason: &str,
|
|
) -> Option<ai_message::Model> {
|
|
let content = content_value_from_chunks(chunks, fallback)?;
|
|
let assistant_msg = ai_message::ActiveModel {
|
|
id: Set(assistant_msg_id),
|
|
conversation_id: Set(conversation_id),
|
|
parent_message_id: Set(Some(user_message_id)),
|
|
role: Set("assistant".to_string()),
|
|
content: Set(content),
|
|
model: Set(Some(model_name.to_string())),
|
|
is_fork_origin: Set(false),
|
|
stop_reason: Set(Some(stop_reason.to_string())),
|
|
input_tokens: Set(Some(input_tokens as i32)),
|
|
output_tokens: Set(Some(output_tokens as i32)),
|
|
latency_ms: Set(None),
|
|
metadata: Set(None),
|
|
room_id: Set(None),
|
|
version_group_id: Set(Some(assistant_msg_id)),
|
|
version_number: Set(1),
|
|
is_latest: Set(true),
|
|
created_at: Set(chrono::Utc::now()),
|
|
};
|
|
|
|
match assistant_msg.insert(service.db.writer()).await {
|
|
Ok(msg) => Some(msg),
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, conversation_id = %conversation_id, "failed to persist partial assistant message");
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Update conversation metadata after an AI assistant message is saved.
|
|
async fn update_conversation_after_response(
|
|
service: &AppService,
|
|
conversation_id: Uuid,
|
|
assistant_msg: &ai_message::Model,
|
|
) {
|
|
use models::ai::ai_conversation;
|
|
use sea_orm::EntityTrait;
|
|
|
|
if let Ok(Some(conv)) = ai_conversation::Entity::find_by_id(conversation_id)
|
|
.one(service.db.reader())
|
|
.await
|
|
{
|
|
let input_tokens = assistant_msg.input_tokens.unwrap_or(0) as i64;
|
|
let output_tokens = assistant_msg.output_tokens.unwrap_or(0) as i64;
|
|
let total_tokens = input_tokens + output_tokens;
|
|
|
|
let previous_token_total = conv.token_usage_total.unwrap_or(0);
|
|
let mut active: ai_conversation::ActiveModel = conv.into();
|
|
if let Ok(count) = AiMessage::find()
|
|
.filter(ai_message::Column::ConversationId.eq(conversation_id))
|
|
.count(service.db.reader())
|
|
.await
|
|
{
|
|
active.message_count = Set(count as i32);
|
|
}
|
|
active.token_usage_total = Set(Some(previous_token_total + total_tokens as i32));
|
|
active.updated_at = Set(chrono::Utc::now());
|
|
let _ = active.update(service.db.writer()).await;
|
|
}
|
|
}
|
|
|
|
/// Build ChatRequestMessage list from ai_message conversation history.
|
|
async fn build_messages_from_history(
|
|
service: &AppService,
|
|
conversation_id: Uuid,
|
|
) -> Result<Vec<ChatRequestMessage>, String> {
|
|
let conversation = service
|
|
.find_conversation(conversation_id)
|
|
.await
|
|
.map_err(|e| format!("conversation lookup error: {}", e))?;
|
|
let project_id = conversation.project_id;
|
|
|
|
let msgs = AiMessage::find()
|
|
.filter(ai_message::Column::ConversationId.eq(conversation_id))
|
|
.filter(ai_message::Column::IsLatest.eq(true))
|
|
.order_by_asc(ai_message::Column::CreatedAt)
|
|
.all(service.db.reader())
|
|
.await
|
|
.map_err(|e| format!("db error: {}", e))?;
|
|
|
|
let mut chat_messages = Vec::new();
|
|
|
|
for msg in &msgs {
|
|
let role = msg.role.as_str();
|
|
let content = match &msg.content {
|
|
serde_json::Value::String(s) => s.clone(),
|
|
serde_json::Value::Array(arr) => {
|
|
// Content is ordered blocks: [{role:"thinking",content:"..."}, {role:"assistant","content":"..."}, ...]
|
|
// For assistant messages: concatenate all "assistant" blocks
|
|
// For user/system messages: take the first block's content
|
|
if role == "assistant" {
|
|
arr.iter()
|
|
.filter(|item| {
|
|
item.get("role").and_then(|r| r.as_str()) != Some("thinking")
|
|
})
|
|
.filter_map(|item| item.get("content").and_then(|c| c.as_str()))
|
|
.collect::<Vec<_>>()
|
|
.join("\n")
|
|
} else if let Some(first) = arr.first() {
|
|
first
|
|
.get("content")
|
|
.and_then(|c| c.as_str())
|
|
.unwrap_or("")
|
|
.to_string()
|
|
} else {
|
|
String::new()
|
|
}
|
|
}
|
|
other => other.to_string(),
|
|
};
|
|
|
|
if role == "user" {
|
|
match service
|
|
.build_message_context_prompts(project_id, msg.metadata.as_ref())
|
|
.await
|
|
{
|
|
Ok(prompts) => {
|
|
for prompt in prompts {
|
|
chat_messages.push(ChatRequestMessage::system(prompt));
|
|
}
|
|
}
|
|
Err(error) => {
|
|
tracing::warn!(
|
|
conversation_id = %conversation_id,
|
|
message_id = %msg.id,
|
|
error = %error,
|
|
"failed to build chat message context prompts"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
match role {
|
|
"user" => chat_messages.push(ChatRequestMessage::user(content)),
|
|
"assistant" => chat_messages.push(ChatRequestMessage::assistant(Some(content), None)),
|
|
"system" => chat_messages.push(ChatRequestMessage::system(content)),
|
|
_ => chat_messages.push(ChatRequestMessage::user(content)),
|
|
}
|
|
}
|
|
|
|
Ok(chat_messages)
|
|
}
|
|
|
|
/// Merge consecutive content blocks of the same role into single blocks.
|
|
/// This transforms many small per-chunk blocks into clean interleaved segments:
|
|
/// [thinking, thinking, assistant, assistant] -> [thinking, assistant]
|
|
/// Per-token chunks are concatenated directly; the model sends \n inside
|
|
/// the token content where needed, not between tokens.
|
|
fn merge_consecutive_blocks(blocks: Vec<(String, String)>) -> Vec<(String, String)> {
|
|
let mut merged: Vec<(String, String)> = Vec::new();
|
|
for (role, content) in blocks {
|
|
if content.is_empty() {
|
|
continue;
|
|
}
|
|
if let Some(last) = merged.last_mut() {
|
|
if last.0 == role && role != "tool_call" && role != "tool_result" {
|
|
last.1.push_str(&content);
|
|
continue;
|
|
}
|
|
}
|
|
merged.push((role, content));
|
|
}
|
|
merged
|
|
}
|