169 lines
4.6 KiB
Rust
169 lines
4.6 KiB
Rust
use ai::error::AiResult;
|
|
use ai::memory::{MemoryEntry, MemoryProvider};
|
|
use async_trait::async_trait;
|
|
use chrono::Utc;
|
|
use db::sqlx;
|
|
use uuid::Uuid;
|
|
#[derive(Clone)]
|
|
pub struct SimpleMemoryProvider {
|
|
db: db::AppDatabase,
|
|
}
|
|
|
|
impl SimpleMemoryProvider {
|
|
pub fn new(db: db::AppDatabase) -> Self {
|
|
Self { db }
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl MemoryProvider for SimpleMemoryProvider {
|
|
fn name(&self) -> &'static str {
|
|
"simple"
|
|
}
|
|
|
|
async fn save(
|
|
&self,
|
|
session_id: Uuid,
|
|
key: &str,
|
|
value: &str,
|
|
importance: i32,
|
|
) -> AiResult<()> {
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"INSERT INTO agent_long_term_memory \
|
|
(id, session, key, value, importance, last_used_at, created_at, updated_at) \
|
|
VALUES ($1, $2, $3, $4, $5, $6, $6, $6) \
|
|
ON CONFLICT ON CONSTRAINT idx_agent_ltm_session_key \
|
|
DO UPDATE SET value = $4, importance = $5, last_used_at = $6, updated_at = $6",
|
|
)
|
|
.bind(Uuid::now_v7())
|
|
.bind(session_id)
|
|
.bind(key)
|
|
.bind(value)
|
|
.bind(importance)
|
|
.bind(now)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| {
|
|
ai::error::AiError::Response(format!("memory save error: {e}"))
|
|
})?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn recall(
|
|
&self,
|
|
session_id: Uuid,
|
|
query: &str,
|
|
limit: usize,
|
|
) -> AiResult<Vec<MemoryEntry>> {
|
|
use db::sqlx::FromRow;
|
|
|
|
#[derive(Debug, FromRow)]
|
|
struct Row {
|
|
key: String,
|
|
value: String,
|
|
importance: i32,
|
|
last_used_at: Option<chrono::DateTime<Utc>>,
|
|
}
|
|
|
|
let rows: Vec<Row> = sqlx::query_as(
|
|
"SELECT key, value, importance, last_used_at \
|
|
FROM agent_long_term_memory \
|
|
WHERE session = $1 \
|
|
AND deleted_at IS NULL \
|
|
AND (value ILIKE $2 OR key ILIKE $2) \
|
|
ORDER BY importance DESC, last_used_at DESC NULLS LAST \
|
|
LIMIT $3",
|
|
)
|
|
.bind(session_id)
|
|
.bind(format!("%{query}%"))
|
|
.bind(limit as i64)
|
|
.fetch_all(self.db.reader())
|
|
.await
|
|
.map_err(|e| {
|
|
ai::error::AiError::Response(format!("memory recall error: {e}"))
|
|
})?;
|
|
|
|
let entries: Vec<MemoryEntry> = rows
|
|
.into_iter()
|
|
.map(|r| MemoryEntry {
|
|
key: r.key,
|
|
value: r.value,
|
|
importance: r.importance,
|
|
last_used_at: r.last_used_at.map(|dt| dt.to_rfc3339()),
|
|
})
|
|
.collect();
|
|
if !entries.is_empty() {
|
|
let now = Utc::now();
|
|
let _ = sqlx::query(
|
|
"UPDATE agent_long_term_memory \
|
|
SET last_used_at = $1, updated_at = $1 \
|
|
WHERE session = $2 AND value ILIKE $3 AND deleted_at IS NULL",
|
|
)
|
|
.bind(now)
|
|
.bind(session_id)
|
|
.bind(format!("%{query}%"))
|
|
.execute(self.db.writer())
|
|
.await;
|
|
}
|
|
|
|
Ok(entries)
|
|
}
|
|
|
|
async fn forget(&self, session_id: Uuid, key: &str) -> AiResult<()> {
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"UPDATE agent_long_term_memory \
|
|
SET deleted_at = $1, updated_at = $1 \
|
|
WHERE session = $2 AND key = $3 AND deleted_at IS NULL",
|
|
)
|
|
.bind(now)
|
|
.bind(session_id)
|
|
.bind(key)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| {
|
|
ai::error::AiError::Response(format!("memory forget error: {e}"))
|
|
})?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn build_context_block(&self, session_id: Uuid) -> AiResult<String> {
|
|
use db::sqlx::FromRow;
|
|
|
|
#[derive(Debug, FromRow)]
|
|
struct Entry {
|
|
key: String,
|
|
value: String,
|
|
}
|
|
|
|
let rows: Vec<Entry> = sqlx::query_as(
|
|
"SELECT key, value \
|
|
FROM agent_long_term_memory \
|
|
WHERE session = $1 AND deleted_at IS NULL \
|
|
ORDER BY importance DESC, last_used_at DESC NULLS LAST \
|
|
LIMIT 20",
|
|
)
|
|
.bind(session_id)
|
|
.fetch_all(self.db.reader())
|
|
.await
|
|
.map_err(|e| {
|
|
ai::error::AiError::Response(format!("memory context error: {e}"))
|
|
})?;
|
|
|
|
if rows.is_empty() {
|
|
return Ok(String::new());
|
|
}
|
|
|
|
let mut block = String::from("<user_memories>\n");
|
|
for row in &rows {
|
|
block.push_str(&format!("- {}: {}\n", row.key, row.value));
|
|
}
|
|
block.push_str("</user_memories>");
|
|
|
|
Ok(block)
|
|
}
|
|
}
|