fix(agent): calculate and record cost in ai_session table

- Add record_ai_session() helper calling billing::record_ai_usage()
- Replace all Set(None) cost/currency with actual calculated values
- Cost computed from model_pricing via Decimal precision
This commit is contained in:
ZhenYi 2026-04-28 09:50:44 +08:00
parent 7b43f55f41
commit 211cf0ee3e

View File

@ -11,7 +11,9 @@ use uuid::Uuid;
use super::context::RoomMessageContext; use super::context::RoomMessageContext;
use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback}; use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback};
use crate::billing;
use crate::client::AiClientConfig; use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall}; use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{ use crate::client::{
StreamChunk, StreamChunkType, StreamedToolCall, call_stream, call_with_params, StreamChunk, StreamChunkType, StreamedToolCall, call_stream, call_with_params,
@ -44,6 +46,48 @@ pub struct ProcessResult {
pub output_tokens: i64, pub output_tokens: i64,
} }
/// Record an AI session with cost calculation.
async fn record_ai_session(
db: &db::database::AppDatabase,
session_id: Uuid,
room_id: Uuid,
model_id: Uuid,
version_id: Uuid,
input_tokens: i64,
output_tokens: i64,
latency_ms: i64,
) {
let (cost, currency) = match billing::record_ai_usage(
db,
Uuid::nil(), // project_uid not needed for session cost
model_id,
input_tokens,
output_tokens,
)
.await
{
Ok(record) => (Some(record.cost), Some(record.currency)),
Err(_) => (None, None),
};
let _ = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(room_id),
model: Set(model_id),
version: Set(version_id),
token_input: Set(input_tokens),
token_output: Set(output_tokens),
latency_ms: Set(Some(latency_ms)),
cost: Set(cost),
currency: Set(currency),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(db)
.await;
}
/// Service for handling AI chat requests in rooms. /// Service for handling AI chat requests in rooms.
pub struct ChatService { pub struct ChatService {
ai_base_url: Option<String>, ai_base_url: Option<String>,
@ -343,21 +387,16 @@ impl ChatService {
text text
}; };
// Record session // Record session
let _ = models::ai::ai_session::ActiveModel { record_ai_session(
id: Set(session_id), &request.db,
room: Set(request.room.id), session_id,
model: Set(request.model.id), request.room.id,
version: Set(version_id.unwrap_or_default()), request.model.id,
token_input: Set(input_tokens), version_id.unwrap_or_default(),
token_output: Set(output_tokens), input_tokens,
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)), output_tokens,
cost: Set(None), session_start.elapsed().as_millis() as i64,
currency: Set(None), )
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await; .await;
return Ok(ProcessResult { return Ok(ProcessResult {
content, content,
@ -369,21 +408,16 @@ impl ChatService {
} }
// Record session // Record session
let _ = models::ai::ai_session::ActiveModel { record_ai_session(
id: Set(session_id), &request.db,
room: Set(request.room.id), session_id,
model: Set(request.model.id), request.room.id,
version: Set(version_id.unwrap_or_default()), request.model.id,
token_input: Set(input_tokens), version_id.unwrap_or_default(),
token_output: Set(output_tokens), input_tokens,
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)), output_tokens,
cost: Set(None), session_start.elapsed().as_millis() as i64,
currency: Set(None), )
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await; .await;
return Ok(ProcessResult { return Ok(ProcessResult {
content: text, content: text,
@ -712,21 +746,16 @@ impl ChatService {
content: max_depth_text, content: max_depth_text,
}); });
// Record session // Record session
let _ = models::ai::ai_session::ActiveModel { record_ai_session(
id: Set(session_id), &request.db,
room: Set(request.room.id), session_id,
model: Set(request.model.id), request.room.id,
version: Set(version_id.unwrap_or_default()), request.model.id,
token_input: Set(total_input_tokens), version_id.unwrap_or_default(),
token_output: Set(total_output_tokens), total_input_tokens,
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)), total_output_tokens,
cost: Set(None), session_start.elapsed().as_millis() as i64,
currency: Set(None), )
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await; .await;
return Ok(StreamResult { return Ok(StreamResult {
content: full_content, content: full_content,
@ -753,21 +782,16 @@ impl ChatService {
content: response.content.clone(), content: response.content.clone(),
}); });
// Record session // Record session
let _ = models::ai::ai_session::ActiveModel { record_ai_session(
id: Set(session_id), &request.db,
room: Set(request.room.id), session_id,
model: Set(request.model.id), request.room.id,
version: Set(version_id.unwrap_or_default()), request.model.id,
token_input: Set(total_input_tokens), version_id.unwrap_or_default(),
token_output: Set(total_output_tokens), total_input_tokens,
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)), total_output_tokens,
cost: Set(None), session_start.elapsed().as_millis() as i64,
currency: Set(None), )
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await; .await;
return Ok(StreamResult { return Ok(StreamResult {
content: full_content, content: full_content,
@ -1190,21 +1214,16 @@ impl ChatService {
} }
let elapsed_ms = session_start.elapsed().as_millis() as i64; let elapsed_ms = session_start.elapsed().as_millis() as i64;
let _ = models::ai::ai_session::ActiveModel { record_ai_session(
id: Set(session_id), &request.db,
room: Set(request.room.id), session_id,
model: Set(request.model.id), request.room.id,
version: Set(version_id.unwrap_or_default()), request.model.id,
token_input: Set(total_input_tokens), version_id.unwrap_or_default(),
token_output: Set(total_output_tokens), total_input_tokens,
latency_ms: Set(Some(elapsed_ms)), total_output_tokens,
cost: Set(None), elapsed_ms,
currency: Set(None), )
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await; .await;
Ok(final_content) Ok(final_content)