gitdataai/lib/service/agent/memory_provider.rs
2026-05-30 01:38:40 +08:00

169 lines
4.6 KiB
Rust

use ai::error::AiResult;
use ai::memory::{MemoryEntry, MemoryProvider};
use async_trait::async_trait;
use chrono::Utc;
use db::sqlx;
use uuid::Uuid;
#[derive(Clone)]
pub struct SimpleMemoryProvider {
db: db::AppDatabase,
}
impl SimpleMemoryProvider {
pub fn new(db: db::AppDatabase) -> Self {
Self { db }
}
}
#[async_trait]
impl MemoryProvider for SimpleMemoryProvider {
fn name(&self) -> &'static str {
"simple"
}
async fn save(
&self,
session_id: Uuid,
key: &str,
value: &str,
importance: i32,
) -> AiResult<()> {
let now = Utc::now();
sqlx::query(
"INSERT INTO agent_long_term_memory \
(id, session, key, value, importance, last_used_at, created_at, updated_at) \
VALUES ($1, $2, $3, $4, $5, $6, $6, $6) \
ON CONFLICT ON CONSTRAINT idx_agent_ltm_session_key \
DO UPDATE SET value = $4, importance = $5, last_used_at = $6, updated_at = $6",
)
.bind(Uuid::now_v7())
.bind(session_id)
.bind(key)
.bind(value)
.bind(importance)
.bind(now)
.execute(self.db.writer())
.await
.map_err(|e| {
ai::error::AiError::Response(format!("memory save error: {e}"))
})?;
Ok(())
}
async fn recall(
&self,
session_id: Uuid,
query: &str,
limit: usize,
) -> AiResult<Vec<MemoryEntry>> {
use db::sqlx::FromRow;
#[derive(Debug, FromRow)]
struct Row {
key: String,
value: String,
importance: i32,
last_used_at: Option<chrono::DateTime<Utc>>,
}
let rows: Vec<Row> = sqlx::query_as(
"SELECT key, value, importance, last_used_at \
FROM agent_long_term_memory \
WHERE session = $1 \
AND deleted_at IS NULL \
AND (value ILIKE $2 OR key ILIKE $2) \
ORDER BY importance DESC, last_used_at DESC NULLS LAST \
LIMIT $3",
)
.bind(session_id)
.bind(format!("%{query}%"))
.bind(limit as i64)
.fetch_all(self.db.reader())
.await
.map_err(|e| {
ai::error::AiError::Response(format!("memory recall error: {e}"))
})?;
let entries: Vec<MemoryEntry> = rows
.into_iter()
.map(|r| MemoryEntry {
key: r.key,
value: r.value,
importance: r.importance,
last_used_at: r.last_used_at.map(|dt| dt.to_rfc3339()),
})
.collect();
if !entries.is_empty() {
let now = Utc::now();
let _ = sqlx::query(
"UPDATE agent_long_term_memory \
SET last_used_at = $1, updated_at = $1 \
WHERE session = $2 AND value ILIKE $3 AND deleted_at IS NULL",
)
.bind(now)
.bind(session_id)
.bind(format!("%{query}%"))
.execute(self.db.writer())
.await;
}
Ok(entries)
}
async fn forget(&self, session_id: Uuid, key: &str) -> AiResult<()> {
let now = Utc::now();
sqlx::query(
"UPDATE agent_long_term_memory \
SET deleted_at = $1, updated_at = $1 \
WHERE session = $2 AND key = $3 AND deleted_at IS NULL",
)
.bind(now)
.bind(session_id)
.bind(key)
.execute(self.db.writer())
.await
.map_err(|e| {
ai::error::AiError::Response(format!("memory forget error: {e}"))
})?;
Ok(())
}
async fn build_context_block(&self, session_id: Uuid) -> AiResult<String> {
use db::sqlx::FromRow;
#[derive(Debug, FromRow)]
struct Entry {
key: String,
value: String,
}
let rows: Vec<Entry> = sqlx::query_as(
"SELECT key, value \
FROM agent_long_term_memory \
WHERE session = $1 AND deleted_at IS NULL \
ORDER BY importance DESC, last_used_at DESC NULLS LAST \
LIMIT 20",
)
.bind(session_id)
.fetch_all(self.db.reader())
.await
.map_err(|e| {
ai::error::AiError::Response(format!("memory context error: {e}"))
})?;
if rows.is_empty() {
return Ok(String::new());
}
let mut block = String::from("<user_memories>\n");
for row in &rows {
block.push_str(&format!("- {}: {}\n", row.key, row.value));
}
block.push_str("</user_memories>");
Ok(block)
}
}