gitdataai/libs/api/chat/stream.rs

464 lines
20 KiB
Rust

use agent::chat::chat_execution;
use agent::chat::{normalize_thinking_content, AiChunkType, AiStreamChunk};
use agent::client::AiClientConfig;
use agent::client::types::ChatRequestMessage;
use agent::client::StreamChunkType;
use futures::StreamExt;
use models::ai::{ai_message, ai_conversation, AiMessage};
use queue::{ChatMessageEvent, ChatStreamChunkEvent};
use sea_orm::{EntityTrait, QueryFilter, ColumnTrait, QueryOrder, ActiveModelTrait, Set, PaginatorTrait};
use service::AppService;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid;
/// Create an SSE stream that executes AI chat with ReAct tool-calling.
///
/// Also publishes chat messages and stream chunks via NATS JetStream for
/// multi-viewer support. The requesting client receives SSE events, while
/// other viewers receive chunks via NATS → WebSocket broadcast.
pub fn create_chat_sse_stream(
service: AppService,
conversation_id: Uuid,
user_message_id: Uuid,
model_name: String,
user_id: Uuid,
) -> Pin<Box<dyn futures::Stream<Item = Result<actix_web::web::Bytes, actix_web::Error>> + Send>> {
let (tx, rx) = tokio::sync::mpsc::channel::<String>(100);
let cache = service.cache.clone();
tokio::spawn(async move {
// Check for active stream (SSE reconnect recovery) BEFORE starting a new one
// so the frontend can recover from a page refresh.
if let Some((msg_id, started_at)) = cache.get_chat_stream_active(conversation_id).await {
let _ = tx.send(format!(
"data: {{\"event\":\"recovery\",\"data\":{{\"message_id\":\"{}\",\"started_at\":{}}}}}\n\n",
msg_id,
started_at
)).await;
}
let queue = service.queue_producer.clone();
let chunk_seq = Arc::new(AtomicU64::new(0));
// Build messages from conversation history
let messages = match build_messages_from_history(&service, conversation_id).await {
Ok(msgs) => msgs,
Err(e) => {
let _ = tx.send(format!("data: {{\"event\":\"error\",\"data\":\"{}\"}}\n\n", e)).await;
return;
}
};
// Get AI config
let api_key = match service.config.ai_api_key() {
Ok(k) => k,
Err(_) => {
let _ = tx.send("data: {\"event\":\"error\",\"data\":\"AI not configured\"}\n\n".to_string()).await;
return;
}
};
let base_url = match service.config.ai_basic_url() {
Ok(u) => u,
Err(_) => {
let _ = tx.send("data: {\"event\":\"error\",\"data\":\"AI not configured\"}\n\n".to_string()).await;
return;
}
};
let config = AiClientConfig::new(api_key).with_base_url(&base_url);
// Get tools from ChatService if available
let (tools, tool_registry, embed_service) = match &service.chat_service {
Some(cs) => (
cs.tools(),
cs.tool_registry().cloned(),
service.embed_service.as_ref().map(|es| (**es).clone()),
),
None => (Vec::new(), None, None),
};
// Get project_id from conversation
let project_id = match service.find_conversation(conversation_id).await {
Ok(c) => c.project_id.unwrap_or(Uuid::nil()),
Err(_) => {
let _ = tx.send("data: {\"event\":\"error\",\"data\":\"conversation not found\"}\n\n".to_string()).await;
return;
}
};
// Pre-flight balance check: verify project + user can afford at least a minimal AI call
let balance_ok = agent::billing::check_balance(
&service.db, project_id, user_id, Uuid::nil(), 500, 250,
).await;
match balance_ok {
Ok(true) => {},
Ok(false) => {
tracing::warn!(project_id = %project_id, user_id = %user_id, "Insufficient balance for chat AI call");
let _ = agent::billing::persist_billing_error(
&service.db, "user", user_id, "insufficient_balance",
&format!("Insufficient balance. Your account does not have enough funds for this AI request."),
Some(serde_json::json!({
"user_id": user_id.to_string(),
"project_id": project_id.to_string(),
})),
).await;
let error_msg = "Insufficient balance. Your account does not have enough funds to process this AI request. Please add credits to continue.";
let _ = tx.send(format!("data: {{\"event\":\"billing_error\",\"data\":\"{}\"}}\n\n", error_msg)).await;
let _ = tx.send("data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n".to_string()).await;
return;
},
Err(e) => {
tracing::warn!(error = %e, "Balance check failed, proceeding without pre-flight check");
}
}
let max_tool_depth = 99;
// Determine conversation project_id for chat message event
let conv_project_id = match service.find_conversation(conversation_id).await {
Ok(c) => c.project_id,
Err(_) => None,
};
// Broadcast chat message start event via NATS
let chat_msg = ChatMessageEvent {
message_id: user_message_id,
conversation_id,
project_id: conv_project_id,
sender_id: Uuid::nil(),
role: "assistant".to_string(),
content: String::new(),
model: Some(model_name.clone()),
input_tokens: None,
output_tokens: None,
timestamp: chrono::Utc::now(),
};
let _ = queue.publish_chat_message(&chat_msg).await;
// Mark stream as active in Redis so page refresh can recover
let _ = cache.set_chat_stream_active(conversation_id, user_message_id).await;
let on_chunk_tx = tx.clone();
let on_chunk_queue = queue.clone();
let on_chunk_seq = chunk_seq.clone();
let on_chunk_conv_id = conversation_id;
let on_chunk_msg_id = user_message_id;
let on_chunk_model = model_name.clone();
let on_chunk: agent::chat::StreamCallback = Box::new(move |chunk: AiStreamChunk| {
let tx = on_chunk_tx.clone();
let queue = on_chunk_queue.clone();
let seq = on_chunk_seq.fetch_add(1, Ordering::Relaxed);
let conv_id = on_chunk_conv_id;
let msg_id = on_chunk_msg_id;
let model = on_chunk_model.clone();
Box::pin(async move {
let event = match chunk.chunk_type {
AiChunkType::Thinking => "thinking",
AiChunkType::Answer => "token",
AiChunkType::ToolCall => "tool_call",
AiChunkType::ToolResult => "tool_result",
};
let content = match chunk.chunk_type {
AiChunkType::Thinking => normalize_thinking_content(&chunk.content),
_ => chunk.content.clone(),
};
let sse = format!(
"data: {{\"event\":\"{}\",\"data\":{}}}\n\n",
event,
serde_json::to_string(&content).unwrap_or_default()
);
let _ = tx.send(sse).await;
// Also broadcast via NATS for other viewers
let natts_chunk = ChatStreamChunkEvent {
conversation_id: conv_id,
message_id: msg_id,
seq,
content,
done: false,
error: None,
chunk_type: Some(event.to_string()),
model_name: Some(model),
};
queue.publish_chat_chunk(&natts_chunk).await;
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
});
let result = chat_execution::execute_chat_stream(
messages,
tools,
&model_name,
&config,
0.7, // temperature
4096, // max_tokens
max_tool_depth,
tool_registry.as_ref(),
service.db.clone(),
service.cache.clone(),
service.config.clone(),
project_id,
Uuid::nil(), // sender_uid — unknown in Chat API context
embed_service,
on_chunk,
Some(conversation_id),
).await;
// Clear stream active state (streaming finished)
let _ = cache.clear_chat_stream_active(conversation_id).await;
match result {
Ok(stream_result) => {
// Build ordered content blocks from stream chunks, merging
// consecutive blocks of the same role (thinking/assistant).
let raw_blocks: Vec<(String, String)> = stream_result.chunks.iter()
.filter(|c| matches!(c.chunk_type, StreamChunkType::Thinking | StreamChunkType::Answer))
.map(|chunk| {
let role = match chunk.chunk_type {
StreamChunkType::Thinking => "thinking",
_ => "assistant",
};
(role.to_string(), chunk.content.clone())
})
.collect();
let merged_blocks = merge_consecutive_blocks(raw_blocks);
// Apply thinking normalization to the fully merged thinking
// blocks — per-token normalization is meaningless since each
// chunk is a single token.
let normalized_blocks: Vec<(String, String)> = merged_blocks.into_iter().map(|(role, content)| {
if role == "thinking" {
(role, normalize_thinking_content(&content))
} else {
(role, content)
}
}).collect();
let content_blocks: Vec<serde_json::Value> = normalized_blocks.iter()
.map(|(role, content)| serde_json::json!({ "role": role, "content": content }))
.collect();
let content_value = if content_blocks.is_empty() {
serde_json::json!([{ "role": "assistant", "content": stream_result.content }])
} else {
serde_json::json!(content_blocks)
};
// Persist assistant message
let assistant_msg_id = Uuid::now_v7();
let assistant_msg = ai_message::ActiveModel {
id: Set(assistant_msg_id),
conversation_id: Set(conversation_id),
parent_message_id: Set(Some(user_message_id)),
role: Set("assistant".to_string()),
content: Set(content_value),
model: Set(Some(model_name.clone())),
is_fork_origin: Set(false),
stop_reason: Set(Some("stop".to_string())),
input_tokens: Set(Some(stream_result.input_tokens as i32)),
output_tokens: Set(Some(stream_result.output_tokens as i32)),
latency_ms: Set(None),
metadata: Set(None),
room_id: Set(None),
version_group_id: Set(Some(assistant_msg_id)),
version_number: Set(1),
is_latest: Set(true),
created_at: Set(chrono::Utc::now()),
};
let saved = assistant_msg.insert(service.db.writer()).await;
if let Ok(msg) = &saved {
update_conversation_after_response(&service, conversation_id, msg).await;
// After AI response, check/update conversation title and emit via SSE
if let Ok(Some(conv)) = ai_conversation::Entity::find_by_id(conversation_id)
.one(service.db.reader()).await
{
let existing_title = conv.title.clone();
let needs_title = existing_title.as_deref().map(|t| t.is_empty() || t == "New Chat").unwrap_or(true);
if needs_title {
// Generate title from first user message
let first_user_msg = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.filter(ai_message::Column::Role.eq("user"))
.order_by_asc(ai_message::Column::CreatedAt)
.one(service.db.reader()).await.ok().flatten();
if let Some(user_msg) = first_user_msg {
let content = match &user_msg.content {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => {
arr.first()
.and_then(|f| f.get("content"))
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string()
}
other => other.to_string(),
};
// Simple title extraction: first meaningful words
let title = content
.split_whitespace()
.filter(|w| w.len() > 2)
.take(5)
.collect::<Vec<_>>()
.join(" ");
if !title.is_empty() {
let truncated: String = title.chars().take(40).collect();
// Save title to DB
let mut active: ai_conversation::ActiveModel = conv.into();
active.title = Set(Some(truncated.clone()));
active.updated_at = Set(chrono::Utc::now());
let _ = active.update(service.db.writer()).await;
// Emit title via SSE
let title_payload = serde_json::json!({"title": truncated}).to_string();
let _ = tx.send(format!("data: {{\"event\":\"title\",\"data\":{}}}\n\n", title_payload)).await;
}
}
} else if let Some(title) = &existing_title {
// Title already set (e.g. by AI tool) — emit it
let title_payload = serde_json::json!({"title": title}).to_string();
let _ = tx.send(format!("data: {{\"event\":\"title\",\"data\":{}}}\n\n", title_payload)).await;
}
}
}
// Broadcast final chat message with token usage
let final_msg = ChatMessageEvent {
message_id: user_message_id,
conversation_id,
project_id: conv_project_id,
sender_id: Uuid::nil(),
role: "assistant".to_string(),
content: stream_result.content.clone(),
model: Some(model_name.clone()),
input_tokens: Some(stream_result.input_tokens as i32),
output_tokens: Some(stream_result.output_tokens as i32),
timestamp: chrono::Utc::now(),
};
let _ = queue.publish_chat_message(&final_msg).await;
// Send final SSE done event
let _ = tx.send("data: {\"event\":\"done\",\"data\":\"ok\"}\n\n".to_string()).await;
}
Err(e) => {
let _ = tx.send(format!("data: {{\"event\":\"error\",\"data\":\"{}\"}}\n\n", e)).await;
}
}
});
Box::pin(ReceiverStream::new(rx).map(|msg| Ok(actix_web::web::Bytes::from(msg))))
}
/// Update conversation metadata after an AI assistant message is saved.
async fn update_conversation_after_response(
service: &AppService,
conversation_id: Uuid,
assistant_msg: &ai_message::Model,
) {
use models::ai::ai_conversation;
use sea_orm::EntityTrait;
if let Ok(Some(conv)) = ai_conversation::Entity::find_by_id(conversation_id)
.one(service.db.reader()).await
{
let input_tokens = assistant_msg.input_tokens.unwrap_or(0) as i64;
let output_tokens = assistant_msg.output_tokens.unwrap_or(0) as i64;
let total_tokens = input_tokens + output_tokens;
let mut active: ai_conversation::ActiveModel = conv.into();
if let Ok(count) = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.count(service.db.reader()).await
{
active.message_count = Set(count as i32);
}
active.token_usage_total = Set(Some(total_tokens as i32));
active.updated_at = Set(chrono::Utc::now());
let _ = active.update(service.db.writer()).await;
}
}
/// Build ChatRequestMessage list from ai_message conversation history.
async fn build_messages_from_history(
service: &AppService,
conversation_id: Uuid,
) -> Result<Vec<ChatRequestMessage>, String> {
let msgs = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.filter(ai_message::Column::IsLatest.eq(true))
.order_by_asc(ai_message::Column::CreatedAt)
.all(service.db.reader())
.await
.map_err(|e| format!("db error: {}", e))?;
let mut chat_messages = Vec::new();
for msg in &msgs {
let role = msg.role.as_str();
let content = match &msg.content {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => {
// Content is ordered blocks: [{role:"thinking",content:"..."}, {role:"assistant","content":"..."}, ...]
// For assistant messages: concatenate all "assistant" blocks
// For user/system messages: take the first block's content
if role == "assistant" {
arr.iter()
.filter(|item| item.get("role").and_then(|r| r.as_str()) != Some("thinking"))
.filter_map(|item| item.get("content").and_then(|c| c.as_str()))
.collect::<Vec<_>>()
.join("\n")
} else if let Some(first) = arr.first() {
first.get("content")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string()
} else {
String::new()
}
}
other => other.to_string(),
};
match role {
"user" => chat_messages.push(ChatRequestMessage::user(content)),
"assistant" => chat_messages.push(ChatRequestMessage::assistant(Some(content), None)),
"system" => chat_messages.push(ChatRequestMessage::system(content)),
_ => chat_messages.push(ChatRequestMessage::user(content)),
}
}
Ok(chat_messages)
}
/// Merge consecutive content blocks of the same role into single blocks.
/// This transforms many small per-chunk blocks into clean interleaved segments:
/// [thinking, thinking, assistant, assistant] → [thinking, assistant]
/// Per-token chunks are concatenated directly — the model sends \n inside
/// the token content where needed, not between tokens.
fn merge_consecutive_blocks(blocks: Vec<(String, String)>) -> Vec<(String, String)> {
let mut merged: Vec<(String, String)> = Vec::new();
for (role, content) in blocks {
if content.is_empty() { continue; }
if let Some(last) = merged.last_mut() {
if last.0 == role {
last.1.push_str(&content);
continue;
}
}
merged.push((role, content));
}
merged
}