160 lines
5.6 KiB
Rust
160 lines
5.6 KiB
Rust
use chrono::Utc;
|
|
use uuid::Uuid;
|
|
|
|
use crate::event::{AgentInfo, RoomInfo, ai};
|
|
use crate::{ChannelBus, ChannelError, ChannelResult};
|
|
|
|
use super::WsOutEvent;
|
|
use super::WsHandler;
|
|
|
|
impl WsHandler {
|
|
pub(super) async fn ai_list(
|
|
bus: &ChannelBus,
|
|
user_id: Uuid,
|
|
room: Uuid,
|
|
) -> ChannelResult<Option<WsOutEvent>> {
|
|
Self::ensure_room_access(bus, user_id, room).await?;
|
|
let rows = db::sqlx::query_as::<_, (Uuid, Option<String>, Option<String>, Option<Uuid>, bool, bool)>(
|
|
"SELECT ra.agent_session, s.name, s.agent_kind, s.model_version, ra.enabled, ra.auto_reply \
|
|
FROM room_ai ra \
|
|
LEFT JOIN agent_session s ON s.id = ra.agent_session AND s.deleted_at IS NULL \
|
|
WHERE ra.room = $1",
|
|
)
|
|
.bind(room)
|
|
.fetch_all(bus.inner.db.reader())
|
|
.await?;
|
|
|
|
let agents = rows
|
|
.into_iter()
|
|
.filter_map(|(agent_session, name, agent_kind, model_version, enabled, auto_reply)| {
|
|
name.map(|n| ai::RoomAiEntry {
|
|
agent_session,
|
|
name: n,
|
|
agent_kind: agent_kind.unwrap_or_default(),
|
|
model_version,
|
|
enabled,
|
|
auto_reply,
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
let ai_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
|
|
Ok(Some(WsOutEvent::AiAgentList {
|
|
room: ai_room.clone(),
|
|
data: ai::RoomAiListService {
|
|
room: ai_room,
|
|
agents,
|
|
},
|
|
}))
|
|
}
|
|
|
|
pub(super) async fn ai_upsert(
|
|
bus: &ChannelBus,
|
|
user_id: Uuid,
|
|
room: Uuid,
|
|
model: Uuid,
|
|
) -> ChannelResult<Option<WsOutEvent>> {
|
|
Self::ensure_room_access(bus, user_id, room).await?;
|
|
let session = db::sqlx::query_as::<_, model::agent::AgentSessionModel>(
|
|
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
|
|
system_prompt, temperature, max_output_tokens, enabled, created_by, \
|
|
created_at, updated_at, deleted_at \
|
|
FROM agent_session WHERE id = $1 AND deleted_at IS NULL",
|
|
)
|
|
.bind(model)
|
|
.fetch_one(bus.inner.db.reader())
|
|
.await
|
|
.map_err(|e| match e {
|
|
db::sqlx::Error::RowNotFound => ChannelError::RoomNotFound,
|
|
other => ChannelError::Database(other),
|
|
})?;
|
|
db::sqlx::query_as::<_, model::room::RoomAiModel>(
|
|
"INSERT INTO room_ai (room, agent_session, enabled, auto_reply, created_by, created_at, updated_at) \
|
|
VALUES ($1, $2, true, false, $3, now(), now()) \
|
|
ON CONFLICT (room, agent_session) DO UPDATE SET enabled = true, updated_at = now() \
|
|
RETURNING room, agent_session, enabled, auto_reply, created_by, created_at, updated_at",
|
|
)
|
|
.bind(room)
|
|
.bind(model)
|
|
.bind(user_id)
|
|
.fetch_one(bus.inner.db.writer())
|
|
.await?;
|
|
let ai_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
|
|
let data = ai::AiAgentJoinedService {
|
|
room: ai_room,
|
|
agent: AgentInfo {
|
|
id: model,
|
|
name: session.name.clone(),
|
|
agent_type: session.agent_kind.clone(),
|
|
model_name: None,
|
|
},
|
|
joined_at: Utc::now(),
|
|
};
|
|
bus.publish_room_event(room, "ai.agent_joined", &data)
|
|
.await?;
|
|
|
|
Ok(Some(WsOutEvent::AiAgentJoined { room: data.room.clone(), data }))
|
|
}
|
|
|
|
pub(super) async fn ai_delete(
|
|
bus: &ChannelBus,
|
|
user_id: Uuid,
|
|
room: Uuid,
|
|
agent_id: Uuid,
|
|
) -> ChannelResult<Option<WsOutEvent>> {
|
|
Self::ensure_room_access(bus, user_id, room).await?;
|
|
let session = db::sqlx::query_as::<_, model::agent::AgentSessionModel>(
|
|
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
|
|
system_prompt, temperature, max_output_tokens, enabled, created_by, \
|
|
created_at, updated_at, deleted_at \
|
|
FROM agent_session WHERE id = $1 AND deleted_at IS NULL",
|
|
)
|
|
.bind(agent_id)
|
|
.fetch_optional(bus.inner.db.reader())
|
|
.await?;
|
|
|
|
let result = db::sqlx::query(
|
|
"DELETE FROM room_ai WHERE room = $1 AND agent_session = $2",
|
|
)
|
|
.bind(room)
|
|
.bind(agent_id)
|
|
.execute(bus.inner.db.writer())
|
|
.await?;
|
|
|
|
if result.rows_affected() == 0 {
|
|
return Err(ChannelError::RoomNotFound);
|
|
}
|
|
let ai_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
|
|
let agent_info = session.map(|s| AgentInfo {
|
|
id: s.id,
|
|
name: s.name,
|
|
agent_type: s.agent_kind,
|
|
model_name: None,
|
|
}).unwrap_or_else(|| AgentInfo::unknown(agent_id));
|
|
|
|
let data = ai::AiAgentLeftService {
|
|
room: ai_room,
|
|
agent: agent_info,
|
|
left_at: Utc::now(),
|
|
};
|
|
bus.publish_room_event(room, "ai.agent_left", &data).await?;
|
|
|
|
Ok(Some(WsOutEvent::AiAgentLeft { room: data.room.clone(), data }))
|
|
}
|
|
|
|
pub(super) async fn ai_stop(
|
|
bus: &ChannelBus,
|
|
user_id: Uuid,
|
|
room: Uuid,
|
|
) -> ChannelResult<Option<WsOutEvent>> {
|
|
Self::ensure_room_access(bus, user_id, room).await?;
|
|
bus.publish_room_event(
|
|
room,
|
|
"ai.stop",
|
|
&serde_json::json!({"stopped_by": user_id}),
|
|
)
|
|
.await?;
|
|
Ok(None)
|
|
}
|
|
}
|