From 211cf0ee3e345503aeb1cd6e278aaed913cdcb8d Mon Sep 17 00:00:00 2001 From: ZhenYi <434836402@qq.com> Date: Tue, 28 Apr 2026 09:50:44 +0800 Subject: [PATCH] fix(agent): calculate and record cost in ai_session table - Add record_ai_session() helper calling billing::record_ai_usage() - Replace all Set(None) cost/currency with actual calculated values - Cost computed from model_pricing via Decimal precision --- libs/agent/chat/service.rs | 169 +++++++++++++++++++++---------------- 1 file changed, 94 insertions(+), 75 deletions(-) diff --git a/libs/agent/chat/service.rs b/libs/agent/chat/service.rs index b0e41e1..a3a31c4 100644 --- a/libs/agent/chat/service.rs +++ b/libs/agent/chat/service.rs @@ -11,7 +11,9 @@ use uuid::Uuid; use super::context::RoomMessageContext; use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback}; +use crate::billing; use crate::client::AiClientConfig; + use crate::client::types::{ChatRequestMessage, ToolCall}; use crate::client::{ StreamChunk, StreamChunkType, StreamedToolCall, call_stream, call_with_params, @@ -44,6 +46,48 @@ pub struct ProcessResult { pub output_tokens: i64, } +/// Record an AI session with cost calculation. +async fn record_ai_session( + db: &db::database::AppDatabase, + session_id: Uuid, + room_id: Uuid, + model_id: Uuid, + version_id: Uuid, + input_tokens: i64, + output_tokens: i64, + latency_ms: i64, +) { + let (cost, currency) = match billing::record_ai_usage( + db, + Uuid::nil(), // project_uid not needed for session cost + model_id, + input_tokens, + output_tokens, + ) + .await + { + Ok(record) => (Some(record.cost), Some(record.currency)), + Err(_) => (None, None), + }; + + let _ = models::ai::ai_session::ActiveModel { + id: Set(session_id), + room: Set(room_id), + model: Set(model_id), + version: Set(version_id), + token_input: Set(input_tokens), + token_output: Set(output_tokens), + latency_ms: Set(Some(latency_ms)), + cost: Set(cost), + currency: Set(currency), + error_message: Set(None), + error_code: Set(None), + created_at: Set(chrono::Utc::now()), + } + .insert(db) + .await; +} + /// Service for handling AI chat requests in rooms. pub struct ChatService { ai_base_url: Option, @@ -343,21 +387,16 @@ impl ChatService { text }; // Record session - let _ = models::ai::ai_session::ActiveModel { - id: Set(session_id), - room: Set(request.room.id), - model: Set(request.model.id), - version: Set(version_id.unwrap_or_default()), - token_input: Set(input_tokens), - token_output: Set(output_tokens), - latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)), - cost: Set(None), - currency: Set(None), - error_message: Set(None), - error_code: Set(None), - created_at: Set(chrono::Utc::now()), - } - .insert(&request.db) + record_ai_session( + &request.db, + session_id, + request.room.id, + request.model.id, + version_id.unwrap_or_default(), + input_tokens, + output_tokens, + session_start.elapsed().as_millis() as i64, + ) .await; return Ok(ProcessResult { content, @@ -369,21 +408,16 @@ impl ChatService { } // Record session - let _ = models::ai::ai_session::ActiveModel { - id: Set(session_id), - room: Set(request.room.id), - model: Set(request.model.id), - version: Set(version_id.unwrap_or_default()), - token_input: Set(input_tokens), - token_output: Set(output_tokens), - latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)), - cost: Set(None), - currency: Set(None), - error_message: Set(None), - error_code: Set(None), - created_at: Set(chrono::Utc::now()), - } - .insert(&request.db) + record_ai_session( + &request.db, + session_id, + request.room.id, + request.model.id, + version_id.unwrap_or_default(), + input_tokens, + output_tokens, + session_start.elapsed().as_millis() as i64, + ) .await; return Ok(ProcessResult { content: text, @@ -712,21 +746,16 @@ impl ChatService { content: max_depth_text, }); // Record session - let _ = models::ai::ai_session::ActiveModel { - id: Set(session_id), - room: Set(request.room.id), - model: Set(request.model.id), - version: Set(version_id.unwrap_or_default()), - token_input: Set(total_input_tokens), - token_output: Set(total_output_tokens), - latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)), - cost: Set(None), - currency: Set(None), - error_message: Set(None), - error_code: Set(None), - created_at: Set(chrono::Utc::now()), - } - .insert(&request.db) + record_ai_session( + &request.db, + session_id, + request.room.id, + request.model.id, + version_id.unwrap_or_default(), + total_input_tokens, + total_output_tokens, + session_start.elapsed().as_millis() as i64, + ) .await; return Ok(StreamResult { content: full_content, @@ -753,21 +782,16 @@ impl ChatService { content: response.content.clone(), }); // Record session - let _ = models::ai::ai_session::ActiveModel { - id: Set(session_id), - room: Set(request.room.id), - model: Set(request.model.id), - version: Set(version_id.unwrap_or_default()), - token_input: Set(total_input_tokens), - token_output: Set(total_output_tokens), - latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)), - cost: Set(None), - currency: Set(None), - error_message: Set(None), - error_code: Set(None), - created_at: Set(chrono::Utc::now()), - } - .insert(&request.db) + record_ai_session( + &request.db, + session_id, + request.room.id, + request.model.id, + version_id.unwrap_or_default(), + total_input_tokens, + total_output_tokens, + session_start.elapsed().as_millis() as i64, + ) .await; return Ok(StreamResult { content: full_content, @@ -1190,21 +1214,16 @@ impl ChatService { } let elapsed_ms = session_start.elapsed().as_millis() as i64; - let _ = models::ai::ai_session::ActiveModel { - id: Set(session_id), - room: Set(request.room.id), - model: Set(request.model.id), - version: Set(version_id.unwrap_or_default()), - token_input: Set(total_input_tokens), - token_output: Set(total_output_tokens), - latency_ms: Set(Some(elapsed_ms)), - cost: Set(None), - currency: Set(None), - error_message: Set(None), - error_code: Set(None), - created_at: Set(chrono::Utc::now()), - } - .insert(&request.db) + record_ai_session( + &request.db, + session_id, + request.room.id, + request.model.id, + version_id.unwrap_or_default(), + total_input_tokens, + total_output_tokens, + elapsed_ms, + ) .await; Ok(final_content)