138 lines
3.9 KiB
Rust
138 lines
3.9 KiB
Rust
use serde::{Deserialize, Serialize};
|
|
use uuid::Uuid;
|
|
|
|
use crate::ChannelResult;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SearchQuery {
|
|
pub query: String,
|
|
pub room_id: Option<Uuid>,
|
|
pub user_id: Option<Uuid>,
|
|
pub limit: u64,
|
|
pub offset: u64,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SearchResult {
|
|
pub total: u64,
|
|
pub hits: Vec<SearchHit>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SearchHit {
|
|
pub message_id: Uuid,
|
|
pub room_id: Uuid,
|
|
pub content: String,
|
|
pub highlighted: String,
|
|
pub sender_id: Uuid,
|
|
pub send_at: chrono::DateTime<chrono::Utc>,
|
|
pub score: f64,
|
|
}
|
|
|
|
pub struct SearchEngine {
|
|
db: db::AppDatabase,
|
|
}
|
|
|
|
impl SearchEngine {
|
|
pub fn new(db: db::AppDatabase) -> Self {
|
|
Self { db }
|
|
}
|
|
|
|
pub async fn search(
|
|
&self,
|
|
query: SearchQuery,
|
|
) -> ChannelResult<SearchResult> {
|
|
let search_term = format!("%{}%", escape_like(&query.query));
|
|
let room_filter = query.room_id;
|
|
let user_filter = query.user_id;
|
|
|
|
let count: (i64,) = db::sqlx::query_as(
|
|
"SELECT COUNT(*) FROM room_message \
|
|
WHERE ($1::uuid IS NULL OR room = $1) \
|
|
AND ($2::uuid IS NULL OR author = $2) \
|
|
AND content LIKE $3 ESCAPE '\\' \
|
|
AND deleted_at IS NULL",
|
|
)
|
|
.bind(room_filter)
|
|
.bind(user_filter)
|
|
.bind(&search_term)
|
|
.fetch_one(self.db.reader())
|
|
.await?;
|
|
|
|
let total = count.0 as u64;
|
|
|
|
let messages = db::sqlx::query_as::<_, model::room::RoomMessageModel>(
|
|
"SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \
|
|
system_type, metadata, edited_at, created_at, updated_at, deleted_at \
|
|
FROM room_message \
|
|
WHERE ($1::uuid IS NULL OR room = $1) \
|
|
AND ($2::uuid IS NULL OR author = $2) \
|
|
AND content LIKE $3 ESCAPE '\\' \
|
|
AND deleted_at IS NULL \
|
|
ORDER BY created_at DESC LIMIT $4 OFFSET $5"
|
|
)
|
|
.bind(room_filter)
|
|
.bind(user_filter)
|
|
.bind(&search_term)
|
|
.bind(query.limit as i64)
|
|
.bind(query.offset as i64)
|
|
.fetch_all(self.db.reader())
|
|
.await?;
|
|
|
|
let hits: Vec<SearchHit> = messages
|
|
.into_iter()
|
|
.map(|m| SearchHit {
|
|
message_id: m.id,
|
|
room_id: m.room,
|
|
content: m.content.clone(),
|
|
highlighted: highlight_text(&m.content, &query.query),
|
|
sender_id: m.author,
|
|
send_at: m.created_at,
|
|
score: 1.0,
|
|
})
|
|
.collect();
|
|
|
|
Ok(SearchResult { total, hits })
|
|
}
|
|
}
|
|
fn escape_like(input: &str) -> String {
|
|
let mut out = String::with_capacity(input.len());
|
|
for ch in input.chars() {
|
|
match ch {
|
|
'\\' => out.push_str("\\\\"),
|
|
'%' => out.push_str("\\%"),
|
|
'_' => out.push_str("\\_"),
|
|
_ => out.push(ch),
|
|
}
|
|
}
|
|
out
|
|
}
|
|
fn highlight_text(content: &str, query: &str) -> String {
|
|
let lower_content = content.to_lowercase();
|
|
let lower_query = query.to_lowercase();
|
|
let char_pos = match lower_content.find(&lower_query) {
|
|
Some(p) => p,
|
|
None => return content.to_string(),
|
|
};
|
|
let match_chars = lower_content[..char_pos].chars().count();
|
|
let query_chars = lower_query.chars().count();
|
|
|
|
let mut before = String::new();
|
|
let mut matched = String::new();
|
|
let mut after = String::new();
|
|
let mut char_idx = 0;
|
|
|
|
for ch in content.chars() {
|
|
if char_idx < match_chars {
|
|
before.push(ch);
|
|
} else if char_idx < match_chars + query_chars {
|
|
matched.push(ch);
|
|
} else {
|
|
after.push(ch);
|
|
}
|
|
char_idx += 1;
|
|
}
|
|
|
|
format!("{}<mark>{}</mark>{}", before, matched, after)
|
|
}
|