feat: 1.0

This commit is contained in:
zhenyi 2026-05-30 01:38:40 +08:00
parent e1330451a5
commit a835610737
1342 changed files with 111368 additions and 0 deletions

4
.clippy.toml Normal file
View File

@ -0,0 +1,4 @@
# Clippy configuration
doc-valid-idents = ["GitHub", "GitLab", "TypeScript", "WebSocket", "PostgreSQL", "Redis", "OpenAI"]
avoid-breaking-exported-api = true
disallowed-types = []

42
.editorconfig Normal file
View File

@ -0,0 +1,42 @@
root = true
[*]
charset = utf-8
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
[*.{js,ts,jsx,tsx,json,jsonc,md,yaml,yml,toml}]
indent_style = space
indent_size = 2
[*.py]
indent_style = space
indent_size = 4
[*.go]
indent_style = tab
indent_size = unset
tab_width = 8
[*.rs]
indent_style = space
indent_size = 4
[Makefile]
indent_style = tab
indent_size = unset
[Dockerfile]
indent_style = space
indent_size = 2
[*.md]
trim_trailing_whitespace = false
[*.{yml,yaml}]
indent_style = space
indent_size = 2
[*.toml]
indent_style = space
indent_size = 2

28
app/email/Cargo.toml Normal file
View File

@ -0,0 +1,28 @@
[package]
name = "app-email"
version.workspace = true
edition.workspace = true
authors.workspace = true
description.workspace = true
repository.workspace = true
readme.workspace = true
homepage.workspace = true
license.workspace = true
keywords.workspace = true
categories.workspace = true
documentation.workspace = true
[[bin]]
name = "email-service"
path = "src/main.rs"
[dependencies]
anyhow = { workspace = true }
config = { workspace = true }
email = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
[lints]
workspace = true

24
app/email/src/context.rs Normal file
View File

@ -0,0 +1,24 @@
use config::AppConfig;
use tracing_subscriber::EnvFilter;
pub struct AppContext {
pub config: AppConfig,
}
impl AppContext {
pub fn init() -> anyhow::Result<Self> {
let config = AppConfig::load();
init_tracing(&config)?;
Ok(Self { config })
}
}
fn init_tracing(config: &AppConfig) -> anyhow::Result<()> {
let level = config.log_level()?;
let filter = EnvFilter::try_new(&level)?;
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(false)
.init();
Ok(())
}

22
app/email/src/main.rs Normal file
View File

@ -0,0 +1,22 @@
mod context;
use context::AppContext;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let ctx = AppContext::init()?;
tracing::info!("email service starting");
tokio::select! {
result = email::EmailWorker::start(&ctx.config) => {
if let Err(e) = result {
tracing::error!("email worker exited with error: {}", e);
}
}
_ = tokio::signal::ctrl_c() => {
tracing::info!("shutdown signal received, stopping email service");
}
}
Ok(())
}

50
app/gitdata/Cargo.toml Normal file
View File

@ -0,0 +1,50 @@
[package]
name = "app-gitdata"
version.workspace = true
edition.workspace = true
authors.workspace = true
description.workspace = true
repository.workspace = true
readme.workspace = true
homepage.workspace = true
license.workspace = true
keywords.workspace = true
categories.workspace = true
documentation.workspace = true
[[bin]]
name = "gitdata"
path = "src/main.rs"
[[bin]]
name = "gen-openapi"
path = "src/bin/gen-openapi.rs"
[dependencies]
anyhow = { workspace = true }
config = { workspace = true }
cache = { workspace = true }
db = { workspace = true }
service = { workspace = true }
session = { workspace = true }
api = { workspace = true }
email = { workspace = true }
storage = { workspace = true }
git = { workspace = true }
model = { workspace = true }
channel = { workspace = true }
socketio = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
actix-web = { workspace = true, features = ["cookies", "secure-cookies"] }
actix-ws = { workspace = true }
tonic = { workspace = true, features = ["transport"] }
deadpool-redis = { workspace = true }
redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp", "connection-manager", "cluster"] }
sqlx = { workspace = true, features = ["postgres", "runtime-tokio"] }
serde_json = { workspace = true }
uuid = { workspace = true, features = ["v4", "v7", "serde"] }
[lints]
workspace = true

View File

@ -0,0 +1,7 @@
use std::fs;
fn main() {
let json = api::openapi::openapi_json();
fs::write("openapi.json", json).expect("Failed to write openapi.json");
println!("openapi.json generated successfully");
}

127
app/gitdata/src/context.rs Normal file
View File

@ -0,0 +1,127 @@
use actix_web::cookie::Key;
use cache::{AppCache, AppCacheConfig};
use config::AppConfig;
use db::database::AppDatabase;
use deadpool_redis::{
PoolConfig, Runtime, Timeouts,
cluster::{Config, Pool as RedisPool},
};
use email::AppEmail;
use service::AppService;
use session::storage::RedisClusterSessionStore;
use storage::{AppStorage, AppStorageConfig};
use tonic::transport::Channel;
use channel::{ChannelBus, ChannelBusConfig};
use socketio::SocketIo;
pub struct AppContext {
pub config: AppConfig,
pub service: AppService,
pub session_store: RedisClusterSessionStore,
pub session_key: Key,
pub channel_bus: ChannelBus,
}
impl AppContext {
pub async fn init() -> anyhow::Result<Self> {
let config = AppConfig::load();
init_tracing(&config)?;
tracing::info!("initializing database");
let db = AppDatabase::init(&config).await?;
tracing::info!("initializing cache");
let cache_config = AppCacheConfig::try_from(&config)?;
let cache = AppCache::init(cache_config).await?;
tracing::info!("initializing storage");
let storage_config = AppStorageConfig::try_from(&config)?;
let storage = AppStorage::init(storage_config).await?;
tracing::info!("initializing email");
let email = AppEmail::init(&config).await?;
tracing::info!("connecting to git RPC");
let rpc_addr = config.git_rpc_addr()?;
let rpc_port = config.git_rpc_port()?;
let git_channel =
Channel::from_shared(format!("http://{}:{}", rpc_addr, rpc_port))
.expect("invalid gRPC endpoint")
.connect()
.await?;
let service = AppService {
db,
cache,
email,
storage,
config: config.clone(),
git: git_channel,
redis_pool: init_redis_pool(&config)?,
};
tracing::info!("initializing session store");
let redis_urls = config.redis_urls()?;
let session_store = RedisClusterSessionStore::new(redis_urls).await?;
tracing::info!("initializing session key");
let secret = config.session_secret()?;
let session_key = Key::from(secret.as_bytes());
tracing::info!("initializing channel bus");
let io = SocketIo::new();
let channel_config = ChannelBusConfig {
namespace: "/channel".to_owned(),
signing_secret: Some(secret.clone()),
..Default::default()
};
let channel_bus = ChannelBus::new(
service.db.clone(),
service.cache.clone(),
io,
channel_config,
);
channel_bus.attach().await?;
Ok(Self {
config,
service,
session_store,
session_key,
channel_bus,
})
}
}
fn init_tracing(config: &AppConfig) -> anyhow::Result<()> {
let level = config.log_level()?;
let filter = tracing_subscriber::EnvFilter::try_new(&level)?;
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(false)
.init();
Ok(())
}
fn init_redis_pool(config: &AppConfig) -> anyhow::Result<RedisPool> {
let redis_urls = config.redis_urls()?;
let pool_size = config.redis_pool_size()?;
let connect_timeout = config.redis_connect_timeout()?;
let acquire_timeout = config.redis_acquire_timeout()?;
let mut pool_config = PoolConfig::new(pool_size as usize);
pool_config.timeouts = Timeouts {
wait: Some(std::time::Duration::from_secs(acquire_timeout)),
create: Some(std::time::Duration::from_secs(connect_timeout)),
recycle: Some(std::time::Duration::from_secs(connect_timeout)),
};
let cfg = Config {
urls: Some(redis_urls),
connections: None,
pool: Some(pool_config),
read_from_replicas: false,
};
Ok(cfg.create_pool(Some(Runtime::Tokio1))?)
}

111
app/gitdata/src/main.rs Normal file
View File

@ -0,0 +1,111 @@
mod context;
mod shutdown;
use std::time::Instant;
use actix_web::{App, dev::Service};
use context::AppContext;
use service::ai::sync::spawn_model_sync_loop;
const REQUEST_LOG_EXCLUDED_PATHS: &[&str] = &[
"/health",
"/live",
"/ready",
"/metrics",
"/favicon.ico",
"/robots.txt",
];
fn should_log_request(path: &str) -> bool {
!REQUEST_LOG_EXCLUDED_PATHS.contains(&path)
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let ctx = AppContext::init().await?;
let api_port = ctx.config.api_port()?;
tracing::info!("GitDataAI API service starting on 0.0.0.0:{}", api_port);
let service = ctx.service.clone();
let session_store = ctx.session_store.clone();
let session_key = ctx.session_key.clone();
let channel_bus = ctx.channel_bus.clone();
let srv = actix_web::HttpServer::new(move || {
let session_middleware = session::SessionMiddleware::builder(
session_store.clone(),
session_key.clone(),
)
.cookie_secure(false)
.cookie_name("id".to_string())
.session_lifecycle(
session::config::PersistentSession::default()
.session_ttl(actix_web::cookie::time::Duration::days(30)),
)
.build();
App::new()
.app_data(actix_web::web::Data::new(service.clone()))
.app_data(actix_web::web::Data::new(channel_bus.clone()))
.wrap_fn(|req, srv| {
let should_log = should_log_request(req.path());
let method = req.method().clone();
let path = req.path().to_owned();
let peer_addr =
req.connection_info().peer_addr().map(str::to_owned);
let started_at = Instant::now();
let fut = srv.call(req);
async move {
match fut.await {
Ok(res) => {
if should_log {
tracing::info!(
method = %method,
path = %path,
status = res.status().as_u16(),
elapsed_ms = started_at.elapsed().as_millis(),
peer_addr = peer_addr.as_deref().unwrap_or("-"),
"http request"
);
}
Ok(res)
}
Err(err) => {
if should_log {
tracing::warn!(
method = %method,
path = %path,
elapsed_ms = started_at.elapsed().as_millis(),
peer_addr = peer_addr.as_deref().unwrap_or("-"),
error = %err,
"http request failed"
);
}
Err(err)
}
}
}
})
.wrap(session_middleware)
.configure(|cfg| api::configure(cfg, channel_bus.clone()))
})
.bind(format!("0.0.0.0:{}", api_port))?;
spawn_model_sync_loop(ctx.service.clone());
let server = srv.run();
tracing::info!("API server is running");
tokio::select! {
_ = server => {
tracing::info!("API server stopped");
}
_ = shutdown::wait_for_shutdown_signal() => {
tracing::info!("shutdown signal received, stopping gitdata API service");
}
}
Ok(())
}

View File

@ -0,0 +1,25 @@
pub async fn wait_for_shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to listen for ctrl_c event");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(
tokio::signal::unix::SignalKind::terminate(),
)
.expect("failed to listen for SIGTERM")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
}

32
app/gitpod/Cargo.toml Normal file
View File

@ -0,0 +1,32 @@
[package]
name = "app-gitpod"
version.workspace = true
edition.workspace = true
authors.workspace = true
description.workspace = true
repository.workspace = true
readme.workspace = true
homepage.workspace = true
license.workspace = true
keywords.workspace = true
categories.workspace = true
documentation.workspace = true
[[bin]]
name = "gitpod"
path = "src/main.rs"
[dependencies]
anyhow = { workspace = true }
config = { workspace = true }
cache = { workspace = true }
db = { workspace = true }
git = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
deadpool-redis = { workspace = true }
redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp"] }
[lints]
workspace = true

65
app/gitpod/src/context.rs Normal file
View File

@ -0,0 +1,65 @@
use std::time::Duration;
use cache::{AppCache, AppCacheConfig};
use config::AppConfig;
use db::database::AppDatabase;
use deadpool_redis::{PoolConfig, Runtime, Timeouts, cluster::Config};
pub struct AppContext {
pub config: AppConfig,
pub db: AppDatabase,
pub cache: AppCache,
pub redis_pool: deadpool_redis::cluster::Pool,
}
impl AppContext {
pub async fn init() -> anyhow::Result<Self> {
let config = AppConfig::load();
init_tracing(&config)?;
tracing::info!("initializing database");
let db = AppDatabase::init(&config).await?;
tracing::info!("initializing cache");
let cache_config = AppCacheConfig::try_from(&config)?;
let cache = AppCache::init(cache_config).await?;
tracing::info!("initializing redis pool");
let redis_urls = config.redis_urls()?;
let pool_size = config.redis_pool_size()?;
let connect_timeout = config.redis_connect_timeout()?;
let acquire_timeout = config.redis_acquire_timeout()?;
let mut pool_config = PoolConfig::new(pool_size as usize);
pool_config.timeouts = Timeouts {
wait: Some(Duration::from_secs(acquire_timeout)),
create: Some(Duration::from_secs(connect_timeout)),
recycle: Some(Duration::from_secs(connect_timeout)),
};
let cfg = Config {
urls: Some(redis_urls),
connections: None,
pool: Some(pool_config),
read_from_replicas: false,
};
let redis_pool = cfg.create_pool(Some(Runtime::Tokio1))?;
Ok(Self {
config,
db,
cache,
redis_pool,
})
}
}
fn init_tracing(config: &AppConfig) -> anyhow::Result<()> {
let level = config.log_level()?;
let filter = tracing_subscriber::EnvFilter::try_new(&level)?;
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(false)
.init();
Ok(())
}

82
app/gitpod/src/main.rs Normal file
View File

@ -0,0 +1,82 @@
mod context;
mod shutdown;
use context::AppContext;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let ctx = AppContext::init().await?;
let http_port = ctx.config.git_http_port()?;
let ssh_port = ctx.config.ssh_port()?;
let rpc_addr = ctx.config.git_rpc_addr()?;
let rpc_port = ctx.config.git_rpc_port()?;
tracing::info!(
"gitpod service starting (HTTP:{} / SSH:{} / gRPC:{}:{})",
http_port,
ssh_port,
rpc_addr,
rpc_port
);
let http_task = tokio::spawn(git::http::run_http(
ctx.config.clone(),
ctx.db.clone(),
ctx.cache.clone(),
ctx.redis_pool.clone(),
));
let ssh_task = tokio::spawn(git::ssh::run_ssh(
ctx.config.clone(),
ctx.db.clone(),
ctx.cache.clone(),
ctx.redis_pool.clone(),
));
let rpc_addr_parsed =
format!("{}:{}", rpc_addr, rpc_port).parse::<std::net::SocketAddr>()?;
let sync_service =
git::sync::ReceiveSyncService::new(ctx.redis_pool.clone());
let git_server = git::rpc::server::GitServer::new(
rpc_addr_parsed,
ctx.db.clone(),
ctx.cache.clone(),
sync_service,
);
let rpc_task = tokio::spawn(async move {
git_server
.serve()
.await
.map_err(|e| anyhow::anyhow!("{}", e))
});
tokio::select! {
result = http_task => {
match result {
Ok(Ok(())) => tracing::info!("HTTP server stopped"),
Ok(Err(e)) => tracing::error!("HTTP server error: {}", e),
Err(e) => tracing::error!("HTTP task panicked: {}", e),
}
}
result = ssh_task => {
match result {
Ok(Ok(())) => tracing::info!("SSH server stopped"),
Ok(Err(e)) => tracing::error!("SSH server error: {}", e),
Err(e) => tracing::error!("SSH task panicked: {}", e),
}
}
result = rpc_task => {
match result {
Ok(Ok(())) => tracing::info!("gRPC server stopped"),
Ok(Err(e)) => tracing::error!("gRPC server error: {}", e),
Err(e) => tracing::error!("gRPC task panicked: {}", e),
}
}
_ = shutdown::wait_for_shutdown_signal() => {
tracing::info!("shutdown signal received, stopping gitpod service");
}
}
Ok(())
}

View File

@ -0,0 +1,25 @@
pub async fn wait_for_shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to listen for ctrl_c event");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(
tokio::signal::unix::SignalKind::terminate(),
)
.expect("failed to listen for SIGTERM")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
}

36
app/gitsync/Cargo.toml Normal file
View File

@ -0,0 +1,36 @@
[package]
name = "app-gitsync"
version.workspace = true
edition.workspace = true
authors.workspace = true
description.workspace = true
repository.workspace = true
readme.workspace = true
homepage.workspace = true
license.workspace = true
keywords.workspace = true
categories.workspace = true
documentation.workspace = true
[[bin]]
name = "gitsync"
path = "src/main.rs"
[dependencies]
anyhow = { workspace = true }
config = { workspace = true }
cache = { workspace = true }
db = { workspace = true }
git = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
actix-web = { workspace = true }
deadpool-redis = { workspace = true }
redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp"] }
uuid = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true, features = ["postgres", "runtime-tokio"] }
[lints]
workspace = true

View File

@ -0,0 +1,65 @@
use std::time::Duration;
use cache::{AppCache, AppCacheConfig};
use config::AppConfig;
use db::database::AppDatabase;
use deadpool_redis::{PoolConfig, Runtime, Timeouts, cluster::Config};
pub struct AppContext {
pub config: AppConfig,
pub db: AppDatabase,
pub cache: AppCache,
pub redis_pool: deadpool_redis::cluster::Pool,
}
impl AppContext {
pub async fn init() -> anyhow::Result<Self> {
let config = AppConfig::load();
init_tracing(&config)?;
tracing::info!("initializing database");
let db = AppDatabase::init(&config).await?;
tracing::info!("initializing cache");
let cache_config = AppCacheConfig::try_from(&config)?;
let cache = AppCache::init(cache_config).await?;
tracing::info!("initializing redis pool");
let redis_urls = config.redis_urls()?;
let pool_size = config.redis_pool_size()?;
let connect_timeout = config.redis_connect_timeout()?;
let acquire_timeout = config.redis_acquire_timeout()?;
let mut pool_config = PoolConfig::new(pool_size as usize);
pool_config.timeouts = Timeouts {
wait: Some(Duration::from_secs(acquire_timeout)),
create: Some(Duration::from_secs(connect_timeout)),
recycle: Some(Duration::from_secs(connect_timeout)),
};
let cfg = Config {
urls: Some(redis_urls),
connections: None,
pool: Some(pool_config),
read_from_replicas: false,
};
let redis_pool = cfg.create_pool(Some(Runtime::Tokio1))?;
Ok(Self {
config,
db,
cache,
redis_pool,
})
}
}
fn init_tracing(config: &AppConfig) -> anyhow::Result<()> {
let level = config.log_level()?;
let filter = tracing_subscriber::EnvFilter::try_new(&level)?;
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(false)
.init();
Ok(())
}

99
app/gitsync/src/health.rs Normal file
View File

@ -0,0 +1,99 @@
use std::time::Instant;
use actix_web::dev::Service;
use actix_web::{App, HttpResponse, HttpServer, dev::Server, web};
use cache::AppCache;
use db::database::AppDatabase;
const REQUEST_LOG_EXCLUDED_PATHS: &[&str] = &[
"/health",
"/live",
"/ready",
"/metrics",
"/favicon.ico",
"/robots.txt",
];
fn should_log_request(path: &str) -> bool {
!REQUEST_LOG_EXCLUDED_PATHS.contains(&path)
}
async fn health(
db: web::Data<AppDatabase>,
cache: web::Data<AppCache>,
) -> HttpResponse {
let db_ok = sqlx::query("SELECT 1").execute(db.reader()).await.is_ok();
let cache_ok = cache.ping_cluster().await.is_ok();
if db_ok && cache_ok {
HttpResponse::Ok().json(serde_json::json!({
"status": "ok",
"db": "ok",
"cache": "ok",
}))
} else {
HttpResponse::ServiceUnavailable().json(serde_json::json!({
"status": "unhealthy",
"db": if db_ok { "ok" } else { "error" },
"cache": if cache_ok { "ok" } else { "error" },
}))
}
}
pub fn start_health(
port: u16,
db: AppDatabase,
cache: AppCache,
) -> anyhow::Result<Server> {
tracing::info!("health endpoint starting on 0.0.0.0:{}", port);
let srv = HttpServer::new(move || {
App::new()
.app_data(web::Data::new(db.clone()))
.app_data(web::Data::new(cache.clone()))
.wrap_fn(|req, srv| {
let should_log = should_log_request(req.path());
let method = req.method().clone();
let path = req.path().to_owned();
let peer_addr =
req.connection_info().peer_addr().map(str::to_owned);
let started_at = Instant::now();
let fut = srv.call(req);
async move {
match fut.await {
Ok(res) => {
if should_log {
tracing::info!(
method = %method,
path = %path,
status = res.status().as_u16(),
elapsed_ms = started_at.elapsed().as_millis(),
peer_addr = peer_addr.as_deref().unwrap_or("-"),
"http request"
);
}
Ok(res)
}
Err(err) => {
if should_log {
tracing::warn!(
method = %method,
path = %path,
elapsed_ms = started_at.elapsed().as_millis(),
peer_addr = peer_addr.as_deref().unwrap_or("-"),
error = %err,
"http request failed"
);
}
Err(err)
}
}
}
})
.route("/health", web::get().to(health))
})
.bind(format!("0.0.0.0:{}", port))?;
Ok(srv.run())
}

51
app/gitsync/src/main.rs Normal file
View File

@ -0,0 +1,51 @@
mod context;
mod health;
mod shutdown;
use context::AppContext;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let ctx = AppContext::init().await?;
tracing::info!("gitsync service starting");
let health_port = ctx.config.gitsync_health_port();
let health_server =
health::start_health(health_port, ctx.db.clone(), ctx.cache.clone())?;
let health_handle = health_server.handle();
let health_task = tokio::spawn(health_server);
let sync_service =
git::sync::ReceiveSyncService::new(ctx.redis_pool.clone());
let consumer = git::sync::consumer::SyncConsumer::new(sync_service, 5);
let worker = git::sync::worker::SyncWorker::new(
consumer,
ctx.db.clone(),
ctx.cache.clone(),
ctx.redis_pool.clone(),
ctx.config.clone(),
format!("gitsync-{}", uuid::Uuid::new_v4()),
);
let worker_task = tokio::spawn(async move { worker.run().await });
tokio::select! {
result = health_task => {
match result {
Ok(Ok(())) => tracing::info!("health server stopped"),
Ok(Err(e)) => tracing::error!("health server error: {}", e),
Err(e) => tracing::error!("health task panicked: {}", e),
}
}
_ = worker_task => {
tracing::info!("sync worker stopped");
}
_ = shutdown::wait_for_shutdown_signal() => {
tracing::info!("shutdown signal received, stopping gitsync service");
health_handle.stop(true).await;
}
}
Ok(())
}

View File

@ -0,0 +1,25 @@
pub async fn wait_for_shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to listen for ctrl_c event");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(
tokio::signal::unix::SignalKind::terminate(),
)
.expect("failed to listen for SIGTERM")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
}

137
docker/README.md Normal file
View File

@ -0,0 +1,137 @@
# GitDataAI Docker 配置
## 文件说明
### Dockerfile 文件
| 文件名 | 服务 | 说明 |
|--------|------|------|
| `gitdata.Dockerfile` | GitData API | 主 API 服务 |
| `email.Dockerfile` | Email Service | 邮件发送服务 |
| `gitpod.Dockerfile` | GitPod Service | Git 服务 |
| `gitsync.Dockerfile` | GitSync Service | Git 同步服务 |
| `migrate.Dockerfile` | Database Migration | 数据库迁移工具 |
| `web.Dockerfile` | Web Frontend | React 前端应用 |
### 配置文件
| 文件名 | 说明 |
|--------|------|
| `docker-compose.yml` | 完整的开发环境配置 |
| `nginx.conf` | Nginx 反向代理配置 |
## 快速开始
### 1. 启动完整开发环境
```bash
# 进入 docker 目录
cd docker
# 启动所有服务
docker-compose up -d
# 查看服务状态
docker-compose ps
# 查看日志
docker-compose logs -f
```
### 2. 单独构建服务
```bash
# 构建 GitData API
docker build -f docker/gitdata.Dockerfile -t gitdata-api .
# 构建前端
docker build -f docker/web.Dockerfile -t gitdata-web .
```
### 3. 环境变量配置
创建 `.env` 文件配置环境变量:
```bash
# 数据库配置
POSTGRES_USER=gitdata
POSTGRES_PASSWORD=your_secure_password
POSTGRES_DB=app
# MinIO 配置
MINIO_ROOT_USER=admin
MINIO_ROOT_PASSWORD=your_secure_password
```
## 服务端口
| 服务 | 端口 | 说明 |
|------|------|------|
| Web Frontend | 80 | 前端访问入口 |
| GitData API | 8080 | 主 API 服务 |
| Git HTTP | 5023 | Git HTTP 访问 |
| Git RPC | 5030 | Git RPC 服务 |
| SSH | 5022 | SSH Git 访问 |
| GitPod | 5082 | GitPod 服务 |
| GitSync | 5083 | GitSync 健康检查 |
| PostgreSQL | 5432 | 数据库 |
| Redis | 6379 | 缓存 |
| Qdrant | 6333 | 向量数据库 |
| NATS | 4222 | 消息队列 |
| MinIO | 9000/9001 | 对象存储 |
## 生产环境部署
### 1. 修改环境变量
```bash
# 复制示例配置
cp .env.example .env
# 编辑配置文件,修改密码等敏感信息
vim .env
```
### 2. 启动服务
```bash
# 使用生产配置启动
docker-compose -f docker-compose.yml up -d
# 查看服务状态
docker-compose ps
```
### 3. 数据备份
```bash
# 备份 PostgreSQL
docker exec gitdata-postgres pg_dump -U gitdata app > backup.sql
# 备份 MinIO 数据
docker cp gitdata-minio:/data ./minio-backup
```
## 常见问题
### 1. 服务启动失败
检查日志:
```bash
docker-compose logs <service-name>
```
### 2. 数据库连接失败
确保 PostgreSQL 健康检查通过:
```bash
docker-compose ps postgres
```
### 3. 端口冲突
修改 `docker-compose.yml` 中的端口映射:
```yaml
ports:
- "8081:8080" # 修改宿主机端口
```

131
docker/build.sh Executable file
View File

@ -0,0 +1,131 @@
#!/bin/bash
# GitDataAI Docker Build Script
set -e
# Get version from Cargo.toml
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
CARGO_VERSION=$(grep -m1 'version' "${PROJECT_ROOT}/Cargo.toml" | sed 's/.*"\(.*\)".*/\1/')
# Configuration
REGISTRY=${REGISTRY:-""}
TAG=${TAG:-"${CARGO_VERSION:-latest}"}
PLATFORM=${PLATFORM:-"linux/amd64"}
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# Services to build
SERVICES=("gitdata" "email" "gitpod" "gitsync" "migrate" "web")
# Function to print colored output
log_info() {
echo -e "${GREEN}[INFO]${NC} $1"
}
log_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
}
log_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
# Function to build a service
build_service() {
local service=$1
local dockerfile="${PROJECT_ROOT}/docker/${service}.Dockerfile"
local image_name="gitdata-${service}"
# Add registry prefix if set
if [ -n "$REGISTRY" ]; then
image_name="${REGISTRY}/${image_name}"
fi
log_info "Building ${service}..."
if [ ! -f "$dockerfile" ]; then
log_error "Dockerfile not found: ${dockerfile}"
return 1
fi
docker build \
-f "$dockerfile" \
-t "${image_name}:${TAG}" \
--platform "$PLATFORM" \
"$PROJECT_ROOT"
log_info "Successfully built ${image_name}:${TAG}"
}
# Parse command line arguments
BUILD_SERVICES=()
BUILD_ALL=true
while [[ $# -gt 0 ]]; do
case $1 in
--tag|-t)
TAG="$2"
shift 2
;;
--registry|-r)
REGISTRY="$2"
shift 2
;;
--platform|-p)
PLATFORM="$2"
shift 2
;;
--help|-h)
echo "Usage: $0 [OPTIONS] [SERVICE...]"
echo ""
echo "Options:"
echo " -t, --tag TAG Docker image tag (default: latest)"
echo " -r, --registry REG Docker registry prefix"
echo " -p, --platform PLAT Target platform (default: linux/amd64)"
echo " -h, --help Show this help message"
echo ""
echo "Services:"
echo " gitdata Main API service"
echo " email Email service"
echo " gitpod GitPod service"
echo " gitsync GitSync service"
echo " migrate Database migration"
echo " web Web frontend"
echo ""
echo "Examples:"
echo " $0 # Build all services"
echo " $0 gitdata web # Build specific services"
echo " $0 -t v1.0.0 -r registry.com # Build with custom tag and registry"
exit 0
;;
*)
BUILD_SERVICES+=("$1")
BUILD_ALL=false
shift
;;
esac
done
# Build services
log_info "Starting Docker build..."
log_info "Registry: ${REGISTRY:-none}"
log_info "Tag: ${TAG}"
log_info "Platform: ${PLATFORM}"
if [ "$BUILD_ALL" = true ]; then
log_info "Building all services..."
for service in "${SERVICES[@]}"; do
build_service "$service"
done
else
log_info "Building specified services: ${BUILD_SERVICES[*]}"
for service in "${BUILD_SERVICES[@]}"; do
build_service "$service"
done
fi
log_info "Build completed successfully!"

203
docker/docker-compose.yml Normal file
View File

@ -0,0 +1,203 @@
# GitDataAI Docker Compose
# Full stack deployment configuration
services:
# PostgreSQL Database
postgres:
image: postgres:16-alpine
container_name: gitdata-postgres
environment:
POSTGRES_USER: ${POSTGRES_USER:-gitdata}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-gitdata123}
POSTGRES_DB: ${POSTGRES_DB:-app}
volumes:
- postgres_data:/var/lib/postgresql/data
ports:
- "5432:5432"
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-gitdata}"]
interval: 10s
timeout: 5s
retries: 5
restart: unless-stopped
# Redis Cluster
redis:
image: redis:7-alpine
container_name: gitdata-redis
ports:
- "6379:6379"
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
restart: unless-stopped
# Qdrant Vector Database
qdrant:
image: qdrant/qdrant:latest
container_name: gitdata-qdrant
ports:
- "6333:6333"
volumes:
- qdrant_data:/qdrant/storage
restart: unless-stopped
# NATS Message Queue
nats:
image: nats:alpine
container_name: gitdata-nats
ports:
- "4222:4222"
- "8222:8222"
command: "--jetstream"
restart: unless-stopped
# MinIO S3 Storage
minio:
image: minio/minio:latest
container_name: gitdata-minio
command: server /data --console-address ":9001"
environment:
MINIO_ROOT_USER: ${MINIO_ROOT_USER:-admin}
MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-mysecret123}
ports:
- "9000:9000"
- "9001:9001"
volumes:
- minio_data:/data
restart: unless-stopped
# Database Migration
migrate:
build:
context: ..
dockerfile: docker/migrate.Dockerfile
container_name: gitdata-migrate
environment:
DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app}
depends_on:
postgres:
condition: service_healthy
restart: "no"
# GitData Main API Service
gitdata:
build:
context: ..
dockerfile: docker/gitdata.Dockerfile
container_name: gitdata-api
environment:
APP_DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app}
APP_REDIS_URLS: redis://redis:6379
APP_QDRANT_URL: http://qdrant:6333/
NATS_URL: nats://nats:4222
APP_STORAGE_S3_ENDPOINT_URL: http://minio:9000
APP_STORAGE_S3_ACCESS_KEY_ID: ${MINIO_ROOT_USER:-admin}
APP_STORAGE_S3_SECRET_ACCESS_KEY: ${MINIO_ROOT_PASSWORD:-mysecret123}
ports:
- "8080:8080"
- "5023:5023"
- "5030:5030"
- "5022:5022"
volumes:
- gitdata_repos:/app/data/repos
- gitdata_files:/app/data/files
- gitdata_avatar:/app/data/avatar
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
qdrant:
condition: service_started
nats:
condition: service_started
minio:
condition: service_started
migrate:
condition: service_completed_successfully
restart: unless-stopped
# Email Service
email:
build:
context: ..
dockerfile: docker/email.Dockerfile
container_name: gitdata-email
environment:
APP_DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app}
APP_REDIS_URLS: redis://redis:6379
NATS_URL: nats://nats:4222
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
nats:
condition: service_started
restart: unless-stopped
# GitPod Service
gitpod:
build:
context: ..
dockerfile: docker/gitpod.Dockerfile
container_name: gitdata-gitpod
environment:
APP_DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app}
APP_REDIS_URLS: redis://redis:6379
ports:
- "5082:5082"
volumes:
- gitdata_repos:/app/data/repos
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
restart: unless-stopped
# GitSync Service
gitsync:
build:
context: ..
dockerfile: docker/gitsync.Dockerfile
container_name: gitdata-gitsync
environment:
APP_DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app}
APP_REDIS_URLS: redis://redis:6379
ports:
- "5083:5083"
volumes:
- gitdata_repos:/app/data/repos
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
restart: unless-stopped
# Web Frontend
web:
build:
context: ..
dockerfile: docker/web.Dockerfile
container_name: gitdata-web
ports:
- "80:80"
depends_on:
- gitdata
restart: unless-stopped
volumes:
postgres_data:
redis_data:
qdrant_data:
minio_data:
gitdata_repos:
gitdata_files:
gitdata_avatar:

74
docker/gitdata.Dockerfile Normal file
View File

@ -0,0 +1,74 @@
# GitDataAI Backend - GitData Service
# Multi-stage build for Rust application
# Stage 1: Build the application
FROM rust:1.96-bookworm AS builder
# Install system dependencies
RUN apt-get update && apt-get install -y \
pkg-config \
libssl-dev \
libpq-dev \
cmake \
&& rm -rf /var/lib/apt/lists/*
# Create app directory
WORKDIR /app
# Copy workspace files
COPY Cargo.toml Cargo.lock ./
COPY app/ app/
COPY lib/ lib/
# Build the application in release mode
RUN cargo build --release --bin gitdata
# Stage 2: Create runtime image
FROM debian:bookworm-slim
# Install runtime dependencies
RUN apt-get update && apt-get install -y \
libssl3 \
libpq5 \
ca-certificates \
curl \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user
RUN useradd -r -s /bin/false appuser
# Create directories
RUN mkdir -p /app/data/repos \
/app/data/files \
/app/data/avatar \
/app/logs \
&& chown -R appuser:appuser /app
# Copy binary from builder
COPY --from=builder /app/target/release/gitdata /app/gitdata
# Set ownership
RUN chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Set working directory
WORKDIR /app
# Expose ports
# API port
EXPOSE 8080
# Git HTTP port
EXPOSE 5023
# Git RPC port
EXPOSE 5030
# SSH port
EXPOSE 5022
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8080/health || exit 1
# Run the application
CMD ["./gitdata"]

65
docker/gitpod.Dockerfile Normal file
View File

@ -0,0 +1,65 @@
# GitDataAI Backend - GitPod Service
# Multi-stage build for Rust application
# Stage 1: Build the application
FROM rust:1.96-bookworm AS builder
# Install system dependencies
RUN apt-get update && apt-get install -y \
pkg-config \
libssl-dev \
libpq-dev \
cmake \
&& rm -rf /var/lib/apt/lists/*
# Create app directory
WORKDIR /app
# Copy workspace files
COPY Cargo.toml Cargo.lock ./
COPY app/ app/
COPY lib/ lib/
# Build the application in release mode
RUN cargo build --release --bin gitpod
# Stage 2: Create runtime image
FROM debian:bookworm-slim
# Install runtime dependencies
RUN apt-get update && apt-get install -y \
libssl3 \
libpq5 \
ca-certificates \
curl \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user
RUN useradd -r -s /bin/false appuser
# Create directories
RUN mkdir -p /app/data/repos \
/app/logs \
&& chown -R appuser:appuser /app
# Copy binary from builder
COPY --from=builder /app/target/release/gitpod /app/gitpod
# Set ownership
RUN chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Set working directory
WORKDIR /app
# Expose port
EXPOSE 5082
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD curl -f http://localhost:5082/health || exit 1
# Run the application
CMD ["./gitpod"]

65
docker/gitsync.Dockerfile Normal file
View File

@ -0,0 +1,65 @@
# GitDataAI Backend - GitSync Service
# Multi-stage build for Rust application
# Stage 1: Build the application
FROM rust:1.96-bookworm AS builder
# Install system dependencies
RUN apt-get update && apt-get install -y \
pkg-config \
libssl-dev \
libpq-dev \
cmake \
&& rm -rf /var/lib/apt/lists/*
# Create app directory
WORKDIR /app
# Copy workspace files
COPY Cargo.toml Cargo.lock ./
COPY app/ app/
COPY lib/ lib/
# Build the application in release mode
RUN cargo build --release --bin gitsync
# Stage 2: Create runtime image
FROM debian:bookworm-slim
# Install runtime dependencies
RUN apt-get update && apt-get install -y \
libssl3 \
libpq5 \
ca-certificates \
curl \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user
RUN useradd -r -s /bin/false appuser
# Create directories
RUN mkdir -p /app/data/repos \
/app/logs \
&& chown -R appuser:appuser /app
# Copy binary from builder
COPY --from=builder /app/target/release/gitsync /app/gitsync
# Set ownership
RUN chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Set working directory
WORKDIR /app
# Expose health check port
EXPOSE 5083
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD curl -f http://localhost:5083/health || exit 1
# Run the application
CMD ["./gitsync"]

58
docker/migrate.Dockerfile Normal file
View File

@ -0,0 +1,58 @@
# GitDataAI Database Migration Dockerfile
# Multi-stage build for Rust migration tool
# Stage 1: Build the application
FROM rust:1.96-bookworm AS builder
# Install system dependencies
RUN apt-get update && apt-get install -y \
pkg-config \
libssl-dev \
libpq-dev \
cmake \
&& rm -rf /var/lib/apt/lists/*
# Create app directory
WORKDIR /app
# Copy workspace files
COPY Cargo.toml Cargo.lock ./
COPY app/ app/
COPY lib/ lib/
# Build the migration binary
RUN cargo build --release --bin migrate
# Stage 2: Create runtime image
FROM debian:bookworm-slim
# Install runtime dependencies
RUN apt-get update && apt-get install -y \
libssl3 \
libpq5 \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user
RUN useradd -r -s /bin/false appuser
# Create app directory
RUN mkdir -p /app && chown -R appuser:appuser /app
# Copy binary from builder
COPY --from=builder /app/target/release/migrate /app/migrate
# Copy migration files
COPY --from=builder /app/lib/migrate/sql /app/sql
# Set ownership
RUN chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Set working directory
WORKDIR /app
# Run migrations by default
CMD ["./migrate", "up"]

75
docker/nginx.conf Normal file
View File

@ -0,0 +1,75 @@
server {
listen 80;
server_name localhost;
# Gzip compression
gzip on;
gzip_vary on;
gzip_min_length 1024;
gzip_proxied any;
gzip_comp_level 6;
gzip_types
text/plain
text/css
text/xml
text/javascript
application/json
application/javascript
application/xml
application/rss+xml
image/svg+xml;
# Security headers
add_header X-Frame-Options "SAMEORIGIN" always;
add_header X-Content-Type-Options "nosniff" always;
add_header X-XSS-Protection "1; mode=block" always;
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
# Root directory
root /usr/share/nginx/html;
index index.html;
# Enable static asset caching
location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {
expires 1y;
add_header Cache-Control "public, immutable";
try_files $uri =404;
}
# API proxy (if needed)
location /api/ {
proxy_pass http://gitdata:8080/api/;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection 'upgrade';
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_cache_bypass $http_upgrade;
}
# Socket.IO proxy
location /socket.io/ {
proxy_pass http://gitdata:8080/socket.io/;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# SPA fallback
location / {
try_files $uri $uri/ /index.html;
}
# Health check endpoint
location /health {
access_log off;
return 200 'OK';
add_header Content-Type text/plain;
}
}

128
docker/push.sh Executable file
View File

@ -0,0 +1,128 @@
#!/bin/bash
# GitDataAI Docker Push Script
set -e
# Get version from Cargo.toml
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
CARGO_VERSION=$(grep -m1 'version' "${PROJECT_ROOT}/Cargo.toml" | sed 's/.*"\(.*\)".*/\1/')
# Configuration
REGISTRY=${REGISTRY:-""}
TAG=${TAG:-"${CARGO_VERSION:-latest}"}
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# Services to push
SERVICES=("gitdata" "email" "gitpod" "gitsync" "migrate" "web")
# Function to print colored output
log_info() {
echo -e "${GREEN}[INFO]${NC} $1"
}
log_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
}
log_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
# Function to push a service
push_service() {
local service=$1
local image_name="gitdata-${service}"
# Add registry prefix if set
if [ -n "$REGISTRY" ]; then
image_name="${REGISTRY}/${image_name}"
fi
log_info "Pushing ${service}..."
# Check if image exists locally
if ! docker image inspect "${image_name}:${TAG}" > /dev/null 2>&1; then
log_error "Image not found: ${image_name}:${TAG}"
log_error "Please build the image first with: ./build.sh ${service}"
return 1
fi
docker push "${image_name}:${TAG}"
log_info "Successfully pushed ${image_name}:${TAG}"
}
# Parse command line arguments
PUSH_SERVICES=()
PUSH_ALL=true
while [[ $# -gt 0 ]]; do
case $1 in
--tag|-t)
TAG="$2"
shift 2
;;
--registry|-r)
REGISTRY="$2"
shift 2
;;
--help|-h)
echo "Usage: $0 [OPTIONS] [SERVICE...]"
echo ""
echo "Options:"
echo " -t, --tag TAG Docker image tag (default: latest)"
echo " -r, --registry REG Docker registry prefix (required)"
echo " -h, --help Show this help message"
echo ""
echo "Services:"
echo " gitdata Main API service"
echo " email Email service"
echo " gitpod GitPod service"
echo " gitsync GitSync service"
echo " migrate Database migration"
echo " web Web frontend"
echo ""
echo "Examples:"
echo " $0 -r registry.com # Push all services"
echo " $0 -r registry.com gitdata web # Push specific services"
echo " $0 -r registry.com -t v1.0.0 # Push with custom tag"
exit 0
;;
*)
PUSH_SERVICES+=("$1")
PUSH_ALL=false
shift
;;
esac
done
# Validate registry
if [ -z "$REGISTRY" ]; then
log_error "Registry is required. Use -r or --registry to specify."
echo "Example: $0 -r registry.com"
exit 1
fi
# Push services
log_info "Starting Docker push..."
log_info "Registry: ${REGISTRY}"
log_info "Tag: ${TAG}"
if [ "$PUSH_ALL" = true ]; then
log_info "Pushing all services..."
for service in "${SERVICES[@]}"; do
push_service "$service"
done
else
log_info "Pushing specified services: ${PUSH_SERVICES[*]}"
for service in "${PUSH_SERVICES[@]}"; do
push_service "$service"
done
fi
log_info "Push completed successfully!"

62
docker/web.Dockerfile Normal file
View File

@ -0,0 +1,62 @@
# GitDataAI Frontend Dockerfile
# Multi-stage build for React application with Bun
# Stage 1: Build the application
FROM node:24-bookworm AS builder
# Install bun
RUN npm install -g bun
# Create app directory
WORKDIR /app
# Copy package files
COPY package.json bun.lock ./
# Install dependencies
RUN bun install --frozen-lockfile
# Copy source code
COPY src/ src/
COPY public/ public/
COPY index.html ./
COPY vite.config.ts ./
COPY tsconfig*.json ./
COPY eslint.config.js ./
COPY components.json ./
COPY orval.config.ts ./
# Build the application
RUN bun run build
# Stage 2: Create runtime image with Nginx
FROM nginx:alpine
# Copy custom nginx configuration
COPY docker/nginx.conf /etc/nginx/conf.d/default.conf
# Copy built assets from builder
COPY --from=builder /app/dist /usr/share/nginx/html
# Create non-root user
RUN adduser -D -S -h /var/cache/nginx -s /sbin/nologin -G nginx appuser
# Set ownership
RUN chown -R appuser:nginx /var/cache/nginx \
&& chown -R appuser:nginx /var/log/nginx \
&& chown -R appuser:nginx /etc/nginx/conf.d \
&& touch /var/run/nginx.pid \
&& chown -R appuser:nginx /var/run/nginx.pid
# Switch to non-root user
USER appuser
# Expose port
EXPOSE 80
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://localhost:80/ || exit 1
# Start Nginx
CMD ["nginx", "-g", "daemon off;"]

37
lib/ai/Cargo.toml Normal file
View File

@ -0,0 +1,37 @@
[package]
name = "ai"
version.workspace = true
edition.workspace = true
authors.workspace = true
description.workspace = true
repository.workspace = true
readme.workspace = true
homepage.workspace = true
license.workspace = true
keywords.workspace = true
categories.workspace = true
documentation.workspace = true
[lib]
path = "lib.rs"
name = "ai"
[dependencies]
rig-core = { workspace = true, features = ["derive"] }
tokio = { workspace = true, features = ["full"] }
tokio-util = { workspace = true }
tokio-stream = { workspace = true }
config = { workspace = true }
cache = { workspace = true }
db = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
serde_json = { workspace = true }
serde = { workspace = true, features = ["derive"] }
qdrant-client = { workspace = true, features = ["serde"] }
async-trait = { workspace = true }
redis = { workspace = true }
uuid = { workspace = true, features = ["v4", "v5", "serde"] }
reqwest = { workspace = true }
futures = { workspace = true }
[lints]
workspace = true

543
lib/ai/agent/agent.rs Normal file
View File

@ -0,0 +1,543 @@
use futures::StreamExt;
use rig::agent::AgentBuilder;
use rig::client::CompletionClient;
use rig::streaming::StreamingPrompt;
use rig::tool::ToolDyn;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use super::config::AgentConfig;
use super::helpers::{build_input_string, check_token_budget, estimate_tokens};
use super::hooks::{HookChain, HookLlmResponse, HookMessage, HookToolDef, ToolCallOutcome, ToolGuardrailDecision};
use super::persistence::ActiveAgentRun;
use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord};
use super::subagent::run_experts;
use super::RigStreamChunk;
use crate::client::AiClient;
use crate::error::{AiError, AiResult};
pub struct RigAgent {
pub client: AiClient,
pub config: AgentConfig,
pub hooks: HookChain,
}
impl RigAgent {
pub fn new(client: AiClient, config: AgentConfig) -> AiResult<Self> {
config.validate()?;
Ok(Self {
client,
config,
hooks: HookChain::empty(),
})
}
pub fn with_hooks(mut self, hooks: HookChain) -> Self {
self.hooks = hooks;
self
}
pub fn config(&self) -> &AgentConfig {
&self.config
}
pub async fn chat(
&self,
request: AgentRequest,
tools: Vec<Box<dyn ToolDyn>>,
) -> AiResult<String> {
let (mut rx, handle) = self.run(request, tools);
tokio::spawn(async move {
while rx.recv().await.is_some() {}
});
let result = handle.await.map_err(|_| {
AiError::Response("agent task panicked".to_string())
})?;
result.map(|r| r.output)
}
#[allow(clippy::too_many_lines)]
pub fn run(
&self,
request: AgentRequest,
tools: Vec<Box<dyn ToolDyn>>,
) -> (
tokio::sync::mpsc::Receiver<RigStreamChunk>,
tokio::task::JoinHandle<AiResult<AgentResult>>,
) {
let (tx, rx) = mpsc::channel::<RigStreamChunk>(256);
let model_name = self.config.model.clone();
let max_iterations = self.config.max_iterations;
let client = self.client.llm_client().clone();
let ai_client = self.client.clone();
let agent_config = self.config.clone();
let system_prompt = self.config.system_prompt.clone();
let temperature = self.config.temperature;
let max_completion_tokens = self.config.max_completion_tokens;
let max_total_tokens = self.config.max_total_tokens_per_run;
let cancellation = request.cancellation_token.clone();
let timeout = request.timeout;
let hooks = self.hooks.clone();
let filtered_tools: Vec<Box<dyn ToolDyn>> = tools
.into_iter()
.filter(|tool| self.config.is_tool_exposed(&tool.name()))
.collect();
let handle = tokio::spawn(async move {
execute_agent_run(
client,
model_name,
system_prompt,
request,
filtered_tools,
max_iterations,
ai_client,
agent_config,
temperature,
max_completion_tokens,
max_total_tokens,
cancellation,
timeout,
hooks,
tx,
)
.await
});
(rx, handle)
}
}
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
async fn execute_agent_run(
client: rig::providers::openai::Client,
model_name: String,
system_prompt: String,
request: AgentRequest,
tools: Vec<Box<dyn ToolDyn>>,
max_iterations: usize,
ai_client: AiClient,
agent_config: AgentConfig,
temperature: Option<f64>,
max_completion_tokens: Option<u64>,
max_total_tokens: Option<i64>,
cancellation: Option<CancellationToken>,
timeout: Option<std::time::Duration>,
hooks: HookChain,
tx: mpsc::Sender<RigStreamChunk>,
) -> AiResult<AgentResult> {
if let Some(ref ctx) = request.run_context {
let _ = hooks.run_session_start(ctx).await;
}
let model = client.completion_model(&model_name);
let mut agent_builder = AgentBuilder::new(model)
.preamble(&system_prompt)
.tools(tools)
.default_max_turns(max_iterations);
if let Some(temp) = temperature {
agent_builder = agent_builder.temperature(temp);
}
if let Some(mt) = max_completion_tokens {
agent_builder = agent_builder.max_tokens(mt);
}
let agent = agent_builder.build();
let mut input = build_input_string(&request);
// ---- SubAgent execution ----
let expert_outputs = if !request.experts.is_empty() {
let run = ActiveAgentRun {
conversation_id: request.run_context.as_ref().and_then(|c| c.conversation_id),
message_id: None,
invocation_id: request.run_context.as_ref().and_then(|c| c.invocation_id),
session_id: request.run_context.as_ref().and_then(|c| c.session_id),
user_id: request.run_context.as_ref().and_then(|c| c.user_id),
started_at: std::time::Instant::now(),
current_step: 0,
};
let realtime = request.run_context.as_ref().and_then(|c| c.realtime.as_ref());
// Notify frontend that subagents are starting.
for expert in &request.experts {
let _ = tx
.send(RigStreamChunk::SubagentStarted {
subagent_id: expert.id.clone(),
role: expert.role.clone(),
task: expert.task.clone(),
})
.await;
}
match run_experts(&ai_client, &agent_config, &request.experts, realtime, &run).await {
Ok(outputs) => {
for out in &outputs {
let _ = tx
.send(RigStreamChunk::SubagentCompleted {
subagent_id: out.id.clone(),
role: out.role.clone(),
task: out.task.clone(),
output: out.output.clone(),
})
.await;
input.push_str(&format!(
"\n--- Subagent: {} (role: {}) ---\nTask: {}\nResult: {}\n",
out.id, out.role, out.task, out.output
));
}
outputs
}
Err(e) => {
warn!(error = %e, "subagent execution failed, continuing without expert inputs");
let _ = tx
.send(RigStreamChunk::SubagentFailed {
error: e.to_string(),
})
.await;
Vec::new()
}
}
} else {
Vec::new()
};
let estimated_input_tokens = estimate_tokens(&input);
if let Some(limit) = max_total_tokens
&& estimated_input_tokens > limit as u64
{
return Err(AiError::TokenBudgetExceeded {
estimated: estimated_input_tokens,
limit,
});
}
if !hooks.is_empty() {
let hook_messages: Vec<HookMessage> = request
.messages
.iter()
.map(|m| HookMessage {
role: match m {
super::request::AgentMessage::User(_) => "user".to_string(),
super::request::AgentMessage::Assistant(_) => {
"assistant".to_string()
}
},
content: match m {
super::request::AgentMessage::User(c) => Some(c.clone()),
super::request::AgentMessage::Assistant(c) => {
Some(c.clone())
}
},
tool_calls: None,
tool_call_id: None,
})
.collect();
let hook_tools: Vec<HookToolDef> = Vec::new();
let _ = hooks.run_pre_llm_call(&hook_messages, &hook_tools).await;
}
let stream_future = agent
.stream_prompt(&input)
.with_history(Vec::<rig::completion::Message>::new())
.multi_turn(max_iterations);
let stream = if let Some(dur) = timeout {
match tokio::time::timeout(dur, stream_future).await {
Ok(stream) => stream,
Err(_elapsed) => {
let _ = tx
.send(RigStreamChunk::Failed {
error: format!("agent timed out after {}s", dur.as_secs()),
})
.await;
return Err(AiError::Timeout {
seconds: dur.as_secs(),
});
}
}
} else {
stream_future.await
};
tokio::pin!(stream);
let mut steps = Vec::new();
let mut delta_index = 0usize;
let mut current_step_tool_calls: Vec<ToolCallRecord> = Vec::new();
let mut current_step_assistant = String::new();
let mut current_step_reasoning = String::new();
let mut accumulated_output_chars: usize = 0;
while let Some(item) = stream.next().await {
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
let _ = tx
.send(RigStreamChunk::Failed {
error: "cancelled".to_string(),
})
.await;
return Err(AiError::Response("agent run cancelled".to_string()));
}
if let Some(limit) = max_total_tokens
&& check_token_budget(estimated_input_tokens, accumulated_output_chars, limit)
{
let _ = tx
.send(RigStreamChunk::Failed {
error: format!("token budget exceeded: limit {limit}"),
})
.await;
return Err(AiError::TokenBudgetExceeded {
estimated: estimated_input_tokens
+ (accumulated_output_chars as f64 / 2.5).ceil() as u64,
limit,
});
}
match item {
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::Text(text),
)) => {
accumulated_output_chars += text.text.chars().count();
current_step_assistant.push_str(&text.text);
let _ = tx
.send(RigStreamChunk::TextDelta {
index: delta_index,
content: text.text.clone(),
})
.await;
delta_index += 1;
}
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::Reasoning(reasoning),
)) => {
for part in &reasoning.content {
if let rig::completion::message::ReasoningContent::Text {
text, ..
} = part
{
accumulated_output_chars += text.chars().count();
current_step_reasoning.push_str(text);
let _ = tx
.send(RigStreamChunk::Thinking {
index: delta_index,
content: text.clone(),
})
.await;
delta_index += 1;
}
}
}
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::ReasoningDelta {
reasoning, ..
},
)) => {
accumulated_output_chars += reasoning.chars().count();
current_step_reasoning.push_str(&reasoning);
let _ = tx
.send(RigStreamChunk::Thinking {
index: delta_index,
content: reasoning.clone(),
})
.await;
delta_index += 1;
}
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::ToolCall {
tool_call,
internal_call_id: _,
},
)) => {
let args = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(),
v => serde_json::to_string(v).unwrap_or_default(),
};
accumulated_output_chars += args.chars().count();
let tool_name = tool_call.function.name.clone();
let tool_args: serde_json::Value =
serde_json::from_str(&args).unwrap_or_default();
if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await {
match decision {
ToolGuardrailDecision::Allow => {}
ToolGuardrailDecision::Block { reason } => {
let _ = tx
.send(RigStreamChunk::ToolCallFinished {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
output: format!("blocked: {reason}"),
error: Some(reason),
})
.await;
current_step_tool_calls.push(ToolCallRecord {
id: tool_call.id.clone(),
name: tool_name.clone(),
arguments: tool_args.clone(),
output: None,
error: Some("blocked by guardrail".to_string()),
elapsed_ms: None,
});
continue;
}
ToolGuardrailDecision::RequireApproval { message } => {
let _ = tx
.send(RigStreamChunk::ToolCallFinished {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
output: format!("awaiting approval: {message}"),
error: None,
})
.await;
current_step_tool_calls.push(ToolCallRecord {
id: tool_call.id.clone(),
name: tool_name.clone(),
arguments: tool_args.clone(),
output: None,
error: Some(format!("requires approval: {message}")),
elapsed_ms: None,
});
continue;
}
}
}
let _ = tx
.send(RigStreamChunk::ToolCallStarted {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
arguments: args.clone(),
})
.await;
current_step_tool_calls.push(ToolCallRecord {
id: tool_call.id.clone(),
name: tool_name.clone(),
arguments: tool_args.clone(),
output: None,
error: None,
elapsed_ms: None,
});
}
Ok(rig::agent::MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
)) => {
let content =
super::helpers::tool_result_content_to_string(&tool_result.content);
accumulated_output_chars += content.chars().count();
if let Some(last) = current_step_tool_calls.last_mut()
&& last.id == tool_result.id
{
last.output = Some(serde_json::from_str(&content).unwrap_or_default());
}
let tool_name = current_step_tool_calls
.last()
.map(|tc| tc.name.clone())
.unwrap_or_default();
let _ = tx
.send(RigStreamChunk::ToolCallFinished {
tool_call_id: tool_result.id.clone(),
tool_name,
output: content.clone(),
error: None,
})
.await;
if !hooks.is_empty() {
let outcome = ToolCallOutcome {
name: tool_result.id.clone(),
arguments: serde_json::Value::Null,
output: Some(serde_json::Value::String(content)),
error: None,
elapsed_ms: 0,
};
let _ = hooks.run_post_tool_call(&outcome).await;
}
}
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
if !current_step_tool_calls.is_empty() || !current_step_assistant.is_empty() {
let reasoning = (!current_step_reasoning.is_empty())
.then_some(std::mem::take(&mut current_step_reasoning));
steps.push(AgentStep {
index: steps.len(),
assistant: (!current_step_assistant.is_empty())
.then_some(std::mem::take(&mut current_step_assistant)),
reasoning_content: reasoning,
tool_calls: std::mem::take(&mut current_step_tool_calls),
reflection: None,
});
}
let output = steps
.last()
.and_then(|s| s.assistant.clone())
.unwrap_or_default();
if !hooks.is_empty() {
let hook_response = HookLlmResponse {
content: Some(output.clone()),
tool_calls: None,
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
finish_reason: None,
};
let _ = hooks.run_post_llm_call(&hook_response).await;
}
info!(
steps = steps.len(),
input_tokens = usage.input_tokens,
output_tokens = usage.output_tokens,
"agent run completed"
);
let _ = tx
.send(RigStreamChunk::Final {
content: output.clone(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
})
.await;
if let Some(ref ctx) = request.run_context {
let _ = hooks.run_session_end(ctx, true).await;
}
return Ok(AgentResult {
output,
steps,
expert_outputs,
input_tokens: usage.input_tokens as i64,
output_tokens: usage.output_tokens as i64,
});
}
Err(e) => {
let err = format!("{e}");
warn!(error = %err, "agent stream error");
let _ = tx.send(RigStreamChunk::Failed { error: err }).await;
if let Some(ref ctx) = request.run_context {
let _ = hooks.run_session_end(ctx, false).await;
}
return Err(AiError::Api(format!("{e}")));
}
_ => {}
}
}
Err(AiError::Response("agent stream ended without final response".to_string()))
}
impl Clone for HookChain {
fn clone(&self) -> Self {
HookChain::empty()
}
}

222
lib/ai/agent/compression.rs Normal file
View File

@ -0,0 +1,222 @@
use crate::error::AiResult;
/// Compression strategy controlling when and how context compaction occurs.
#[derive(Clone, Debug)]
pub struct CompressionStrategy {
/// Token threshold that triggers compaction.
pub threshold_tokens: i64,
/// Target token count after compaction.
pub target_tokens: i64,
/// Number of recent message pairs to always preserve.
pub preserve_last_n_pairs: usize,
/// Optional model override for the compaction LLM call.
pub summary_model: String,
/// Reserve this many tokens for the compaction prompt itself.
pub reserve_tokens: i64,
/// Whether to generate branch summaries when forking.
pub branch_summarization: bool,
/// Custom instructions appended to the compaction prompt.
pub custom_instructions: Option<String>,
/// Maximum word count for compaction summaries.
pub max_summary_words: usize,
}
impl Default for CompressionStrategy {
fn default() -> Self {
Self {
threshold_tokens: 64_000,
target_tokens: 32_000,
preserve_last_n_pairs: 4,
summary_model: String::new(),
reserve_tokens: 16_384,
branch_summarization: true,
custom_instructions: None,
max_summary_words: 1500,
}
}
}
impl CompressionStrategy {
pub fn new(threshold_tokens: i64, target_tokens: i64) -> Self {
Self {
threshold_tokens,
target_tokens,
..Default::default()
}
}
pub fn with_preserve_last(mut self, n: usize) -> Self {
self.preserve_last_n_pairs = n;
self
}
pub fn with_summary_model(mut self, model: impl Into<String>) -> Self {
self.summary_model = model.into();
self
}
pub fn with_reserve_tokens(mut self, tokens: i64) -> Self {
self.reserve_tokens = tokens;
self
}
pub fn with_branch_summarization(mut self, enabled: bool) -> Self {
self.branch_summarization = enabled;
self
}
pub fn with_custom_instructions(mut self, instructions: impl Into<String>) -> Self {
self.custom_instructions = Some(instructions.into());
self
}
pub fn with_max_summary_words(mut self, words: usize) -> Self {
self.max_summary_words = words;
self
}
/// Check whether compaction should be triggered based on current token count.
pub fn should_compact(&self, current_tokens: i64) -> bool {
current_tokens >= self.threshold_tokens
}
}
#[derive(Debug, Clone)]
pub struct CompactionResult {
pub summary: String,
pub messages_compacted: usize,
pub tokens_saved: i64,
/// Whether this was a branch summary (vs. standard compaction).
pub is_branch_summary: bool,
}
impl CompactionResult {
pub fn new(summary: String, messages_compacted: usize, tokens_saved: i64) -> Self {
Self {
summary,
messages_compacted,
tokens_saved,
is_branch_summary: false,
}
}
pub fn branch_summary(summary: String, entries_summarized: usize) -> Self {
Self {
summary,
messages_compacted: entries_summarized,
tokens_saved: 0,
is_branch_summary: true,
}
}
}
/// Build the compaction prompt for standard context compression.
pub fn build_compression_prompt(
existing_summary: Option<&str>,
messages_text: &str,
) -> String {
build_compression_prompt_with_options(existing_summary, messages_text, None, 1500)
}
/// Build the compaction prompt with custom instructions and word limit.
pub fn build_compression_prompt_with_options(
existing_summary: Option<&str>,
messages_text: &str,
custom_instructions: Option<&str>,
max_words: usize,
) -> String {
let custom = custom_instructions
.map(|ci| format!("\n\nAdditional instructions: {ci}"))
.unwrap_or_default();
if let Some(summary) = existing_summary {
format!(
"## Previous Summary\n{summary}\n\n## New Messages\n{messages_text}\n\n\
Combine the previous summary and the new messages into a concise, \
single-paragraph summary of the conversation. Preserve facts, \
decisions, code snippets, and anything essential for continuing \
work. Target up to {max_words} words.{custom} \
Output ONLY the summary text, no preamble.",
)
} else {
format!(
"## Conversation\n{messages_text}\n\n\
Summarise the conversation above into a concise, single-paragraph \
summary. Preserve facts, decisions, code snippets, and anything \
essential for continuing work. Target up to {max_words} words.{custom} \
Output ONLY the summary text, no preamble.",
)
}
}
/// Build a prompt for generating a branch summary.
///
/// Used when the user forks a conversation from a different point in the
/// session tree. Summarizes the divergent branch so context is preserved.
pub fn build_branch_summary_prompt(
branch_messages: &str,
custom_instructions: Option<&str>,
) -> String {
let custom = custom_instructions
.map(|ci| format!("\n\nAdditional instructions: {ci}"))
.unwrap_or_default();
format!(
"## Branch Conversation\n{branch_messages}\n\n\
Summarize the conversation branch above. This summary will be used \
to preserve context when the user navigates away from this branch. \
Focus on key decisions, unresolved questions, and important context.{custom} \
Output ONLY the summary text, no preamble.",
)
}
/// Calculate how many messages to truncate to reach the target token count.
pub fn estimate_truncation(
message_token_counts: &[i64],
current_total: i64,
target: i64,
preserve_last: usize,
) -> AiResult<(usize, i64)> {
let n = message_token_counts.len();
if n <= preserve_last {
return Ok((0, 0));
}
let excess = (current_total - target).max(0);
let mut truncated = 0;
let mut saved = 0i64;
let limit = n - preserve_last;
for i in 0..limit {
if saved >= excess {
break;
}
saved += message_token_counts[i];
truncated += 1;
}
Ok((truncated, saved.min(excess)))
}
/// Calculate compaction parameters for a given set of messages.
///
/// Returns `(messages_to_compact, tokens_saved)` where `messages_to_compact`
/// is the count of oldest messages to summarize, and `tokens_saved` is the
/// estimated token savings.
pub fn plan_compaction(
strategy: &CompressionStrategy,
message_token_counts: &[i64],
current_total: i64,
) -> AiResult<(usize, i64)> {
if !strategy.should_compact(current_total) {
return Ok((0, 0));
}
estimate_truncation(
message_token_counts,
current_total,
strategy.target_tokens,
strategy.preserve_last_n_pairs * 2, // pairs → individual messages
)
}

217
lib/ai/agent/config.rs Normal file
View File

@ -0,0 +1,217 @@
use crate::error::{AiError, AiResult};
pub const DEFAULT_SYSTEM_PROMPT: &str = r#"You are a precise autonomous agent that executes tasks through tool calls.
## Core Principles
- Use tools when they can materially improve correctness or efficiency
- After each action, verify results and adjust approach if needed
- Keep reasoning concise and focus on actionable outcomes
- Return only the final useful answer to the user
## Workflow
1. Analyze the request and plan your approach
2. Execute actions using appropriate tools
3. Review observations and verify assumptions
4. Iterate until the task is complete
5. Provide a clear, concise final response
## Title Generation
If this is the first user message in a new conversation with a default title, you SHOULD call `set_conversation_title` as your first action to create a short, descriptive title (max 100 chars). This helps keep the conversation organized. Only do this once at the very beginning."#;
#[derive(Clone, Debug)]
pub struct AgentConfig {
pub model: String,
pub provider: String,
pub api_mode: String,
pub system_prompt: String,
pub max_iterations: usize,
pub iteration_budget: usize,
pub temperature: Option<f64>,
pub max_completion_tokens: Option<u64>,
pub max_total_tokens_per_run: Option<i64>,
pub enabled_toolsets: Vec<String>,
pub disabled_toolsets: Vec<String>,
pub allowed_tools: Vec<String>,
pub denied_tools: Vec<String>,
pub retry_max_attempts: usize,
pub retry_base_delay_ms: u64,
pub retry_jitter: bool,
pub fallback_model: Option<String>,
pub skip_memory: bool,
pub skip_context_files: bool,
pub skip_compression: bool,
pub quiet_mode: bool,
pub save_trajectories: bool,
pub reasoning_effort: Option<String>,
pub service_tier: Option<String>,
pub platform: Option<String>,
pub session_id: Option<uuid::Uuid>,
}
impl AgentConfig {
pub fn new(model: impl Into<String>) -> AiResult<Self> {
let config = Self {
model: model.into(),
provider: String::new(),
api_mode: String::from("chat_completions"),
system_prompt: DEFAULT_SYSTEM_PROMPT.to_string(),
max_iterations: 64,
iteration_budget: 90,
temperature: Some(0.2),
max_completion_tokens: None,
max_total_tokens_per_run: Some(128_000),
enabled_toolsets: Vec::new(),
disabled_toolsets: Vec::new(),
allowed_tools: Vec::new(),
denied_tools: Vec::new(),
retry_max_attempts: 3,
retry_base_delay_ms: 1_000,
retry_jitter: true,
fallback_model: None,
skip_memory: false,
skip_context_files: false,
skip_compression: false,
quiet_mode: false,
save_trajectories: false,
reasoning_effort: None,
service_tier: None,
platform: None,
session_id: None,
};
config.validate()?;
Ok(config)
}
pub fn validate(&self) -> AiResult<()> {
if self.model.trim().is_empty() {
return Err(AiError::Config("agent model is required".to_string()));
}
if self.max_iterations == 0 {
return Err(AiError::Config(
"agent max_iterations must be greater than 0".to_string(),
));
}
if let Some(tokens) = self.max_total_tokens_per_run
&& tokens <= 0
{
return Err(AiError::Config(
"agent max_total_tokens_per_run must be > 0".to_string(),
));
}
Ok(())
}
pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
self.provider = provider.into();
self
}
pub fn with_api_mode(mut self, mode: impl Into<String>) -> Self {
self.api_mode = mode.into();
self
}
pub fn with_max_iterations(mut self, max: usize) -> Self {
self.max_iterations = max;
self.iteration_budget = self.iteration_budget.max(max);
self
}
pub fn with_iteration_budget(mut self, budget: usize) -> Self {
self.iteration_budget = budget;
self
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn with_temperature(mut self, temperature: Option<f64>) -> Self {
self.temperature = temperature;
self
}
pub fn with_max_completion_tokens(mut self, max_completion_tokens: Option<u64>) -> Self {
self.max_completion_tokens = max_completion_tokens;
self
}
pub fn with_max_total_tokens(mut self, limit: Option<i64>) -> Self {
self.max_total_tokens_per_run = limit;
self
}
pub fn with_toolset_policy(mut self, enabled: Vec<String>, disabled: Vec<String>) -> Self {
self.enabled_toolsets = enabled;
self.disabled_toolsets = disabled;
self
}
pub fn with_tool_policy(mut self, allowed_tools: Vec<String>, denied_tools: Vec<String>) -> Self {
self.allowed_tools = allowed_tools;
self.denied_tools = denied_tools;
self
}
pub fn with_retry(mut self, max_attempts: usize, base_delay_ms: u64) -> Self {
self.retry_max_attempts = max_attempts;
self.retry_base_delay_ms = base_delay_ms;
self
}
pub fn with_retry_jitter(mut self, jitter: bool) -> Self {
self.retry_jitter = jitter;
self
}
pub fn with_fallback_model(mut self, fallback_model: impl Into<String>) -> Self {
self.fallback_model = Some(fallback_model.into());
self
}
pub fn with_skip_memory(mut self, skip: bool) -> Self {
self.skip_memory = skip;
self
}
pub fn with_skip_compression(mut self, skip: bool) -> Self {
self.skip_compression = skip;
self
}
pub fn with_quiet_mode(mut self, quiet: bool) -> Self {
self.quiet_mode = quiet;
self
}
pub fn with_platform(mut self, platform: impl Into<String>) -> Self {
self.platform = Some(platform.into());
self
}
pub fn with_session_id(mut self, session_id: uuid::Uuid) -> Self {
self.session_id = Some(session_id);
self
}
pub fn with_reasoning_effort(mut self, effort: impl Into<String>) -> Self {
self.reasoning_effort = Some(effort.into());
self
}
pub fn is_tool_exposed(&self, name: &str) -> bool {
let denied = self.denied_tools.iter().any(|tool| tool == name);
if denied {
return false;
}
if self.allowed_tools.is_empty() {
return true;
}
self.allowed_tools.iter().any(|tool| tool == name)
}
}
pub fn default_system_prompt() -> &'static str {
DEFAULT_SYSTEM_PROMPT
}

View File

@ -0,0 +1,239 @@
use std::time::Duration;
use crate::error::AiError;
/// Categorized error for deciding retry/fallback/fatal strategy.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ErrorCategory {
/// Transient error, safe to retry with backoff.
Retryable { reason: String },
/// Authentication or quota error, switch to fallback model.
FallbackModel { reason: String },
/// Non-recoverable error, do not retry.
Fatal { reason: String },
/// Token budget exceeded for this run.
TokenBudgetExceeded,
/// Request timed out.
Timeout,
/// Request was cancelled by the caller.
Cancelled,
/// Provider is overloaded or at capacity, retry with longer delay.
Overloaded { reason: String },
/// Context window exceeded, needs compaction before retry.
ContextWindowExceeded { reason: String },
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: usize,
pub base_delay: Duration,
pub jitter: bool,
pub exponential: bool,
pub switch_to_fallback: bool,
}
impl RetryPolicy {
pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
let ms = if self.exponential {
self.base_delay.as_millis() as u64 * (1u64 << attempt.min(6))
} else {
self.base_delay.as_millis() as u64
};
let ms = if self.jitter {
let half = (ms as f64 * 0.25) as u64;
let lo = ms.saturating_sub(half);
let hi = ms.saturating_add(half);
let mix = ((attempt as u64).wrapping_mul(1_103_515_245)) % (hi - lo + 1);
lo + mix
} else {
ms
};
Duration::from_millis(ms.max(100))
}
}
/// Classify an error into a category for retry/fallback decisions.
///
/// Inspects both the HTTP status code (when available) and the error message
/// content to determine the most appropriate category.
pub fn classify_error(error: &AiError, http_status: Option<u16>) -> ErrorCategory {
// HTTP status-based classification takes precedence
let from_status = match http_status {
Some(429) => Some(ErrorCategory::Retryable {
reason: "rate limited (HTTP 429)".to_string(),
}),
Some(401) | Some(403) => Some(ErrorCategory::FallbackModel {
reason: format!("authentication failed (HTTP {})", http_status.unwrap()),
}),
Some(502) | Some(503) => Some(ErrorCategory::Overloaded {
reason: format!("provider unavailable (HTTP {})", http_status.unwrap()),
}),
Some(504) => Some(ErrorCategory::Timeout),
Some(413) => Some(ErrorCategory::ContextWindowExceeded {
reason: "payload too large (HTTP 413)".to_string(),
}),
Some(s) if (400..500).contains(&s) => Some(ErrorCategory::Fatal {
reason: format!("client error (HTTP {})", s),
}),
Some(s) if (500..600).contains(&s) => Some(ErrorCategory::Retryable {
reason: format!("server error (HTTP {})", s),
}),
_ => None,
};
if let Some(cat) = from_status {
return cat;
}
// Message-based classification
match error {
AiError::Timeout { .. } => ErrorCategory::Timeout,
AiError::TokenBudgetExceeded { .. } => ErrorCategory::TokenBudgetExceeded,
AiError::Api(msg) => classify_api_message(msg),
AiError::Response(msg) => classify_response_message(msg),
AiError::ModelRetriesExhausted { .. } => ErrorCategory::Fatal {
reason: error.to_string(),
},
_ => ErrorCategory::Fatal {
reason: error.to_string(),
},
}
}
/// Classify API error messages by keyword patterns.
fn classify_api_message(msg: &str) -> ErrorCategory {
let lower = msg.to_lowercase();
// Rate limiting
if lower.contains("rate") || lower.contains("too many requests") || lower.contains("throttl") {
return ErrorCategory::Retryable {
reason: msg.to_string(),
};
}
// Overloaded / capacity
if lower.contains("overloaded")
|| lower.contains("capacity")
|| lower.contains("too busy")
|| lower.contains("service unavailable")
{
return ErrorCategory::Overloaded {
reason: msg.to_string(),
};
}
// Authentication / quota
if lower.contains("unauthorized")
|| lower.contains("invalid api key")
|| lower.contains("api key")
|| lower.contains("forbidden")
|| lower.contains("quota exceeded")
|| lower.contains("insufficient")
|| lower.contains("billing")
{
return ErrorCategory::FallbackModel {
reason: msg.to_string(),
};
}
// Context window exceeded
if lower.contains("context length")
|| lower.contains("context window")
|| lower.contains("maximum context")
|| lower.contains("too many tokens")
|| lower.contains("max_tokens")
{
return ErrorCategory::ContextWindowExceeded {
reason: msg.to_string(),
};
}
ErrorCategory::Fatal {
reason: msg.to_string(),
}
}
/// Classify response error messages by keyword patterns.
fn classify_response_message(msg: &str) -> ErrorCategory {
let lower = msg.to_lowercase();
if lower.contains("cancelled") || lower.contains("canceled") {
return ErrorCategory::Cancelled;
}
if lower.contains("timeout") || lower.contains("timed out") {
return ErrorCategory::Timeout;
}
ErrorCategory::Fatal {
reason: msg.to_string(),
}
}
/// Get the recommended retry policy for an error category.
pub fn retry_policy_for(
category: &ErrorCategory,
max_attempts: usize,
base_delay_ms: u64,
) -> RetryPolicy {
match category {
ErrorCategory::Retryable { .. } => RetryPolicy {
max_attempts,
base_delay: Duration::from_millis(base_delay_ms),
jitter: true,
exponential: true,
switch_to_fallback: false,
},
ErrorCategory::Overloaded { .. } => RetryPolicy {
max_attempts: max_attempts.min(5),
base_delay: Duration::from_millis(base_delay_ms.max(5_000)),
jitter: true,
exponential: true,
switch_to_fallback: true,
},
ErrorCategory::FallbackModel { .. } => RetryPolicy {
max_attempts: 1,
base_delay: Duration::from_millis(500),
jitter: false,
exponential: false,
switch_to_fallback: true,
},
ErrorCategory::ContextWindowExceeded { .. } => RetryPolicy {
max_attempts: 1,
base_delay: Duration::from_millis(0),
jitter: false,
exponential: false,
switch_to_fallback: false,
},
ErrorCategory::Timeout => RetryPolicy {
max_attempts: max_attempts.min(2),
base_delay: Duration::from_millis(base_delay_ms.max(2_000)),
jitter: true,
exponential: false,
switch_to_fallback: false,
},
ErrorCategory::TokenBudgetExceeded | ErrorCategory::Cancelled | ErrorCategory::Fatal { .. } => {
RetryPolicy {
max_attempts: 0,
base_delay: Duration::from_millis(0),
jitter: false,
exponential: false,
switch_to_fallback: false,
}
}
}
}
/// Determine whether the error warrants switching to a fallback model.
pub fn should_switch_to_fallback(category: &ErrorCategory) -> bool {
matches!(
category,
ErrorCategory::FallbackModel { .. } | ErrorCategory::Overloaded { .. }
)
}
/// Determine whether compaction should be attempted before retry.
pub fn should_compact_before_retry(category: &ErrorCategory) -> bool {
matches!(category, ErrorCategory::ContextWindowExceeded { .. })
}

179
lib/ai/agent/events.rs Normal file
View File

@ -0,0 +1,179 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
/// Fine-grained agent lifecycle events, inspired by pi's event system.
///
/// Covers the full agent execution lifecycle with enough granularity
/// for UI rendering, telemetry, and extension hooks.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AgentEvent {
// === Agent lifecycle ===
AgentStart,
AgentEnd {
messages: Vec<AgentEventMessage>,
total_input_tokens: u64,
total_output_tokens: u64,
},
// === Turn lifecycle ===
TurnStart {
turn_index: usize,
},
TurnEnd {
turn_index: usize,
assistant_text: Option<String>,
tool_call_count: usize,
},
// === Message lifecycle ===
MessageStart {
role: MessageRole,
},
MessageTextDelta {
index: usize,
delta: String,
},
MessageThinkingDelta {
index: usize,
delta: String,
},
MessageEnd {
role: MessageRole,
},
// === Tool execution lifecycle ===
ToolExecutionStart {
tool_call_id: String,
tool_name: String,
arguments: Value,
},
ToolExecutionUpdate {
tool_call_id: String,
tool_name: String,
partial_output: String,
},
ToolExecutionEnd {
tool_call_id: String,
tool_name: String,
output: Option<Value>,
error: Option<String>,
elapsed_ms: i64,
},
// === Steering / follow-up ===
SteeringMessagesInjected {
count: usize,
},
FollowUpMessagesInjected {
count: usize,
},
// === Context management ===
ContextCompacted {
messages_compacted: usize,
tokens_saved: i64,
},
BranchSummaryCreated {
entry_count: usize,
summary_length: usize,
},
// === Model switching ===
ModelSwitched {
from_model: String,
to_model: String,
reason: String,
},
// === Error and retry ===
ErrorClassified {
category: String,
message: String,
will_retry: bool,
retry_delay_ms: Option<u64>,
},
RetryAttempt {
attempt: usize,
max_attempts: usize,
delay_ms: u64,
},
}
/// Simplified message role for event display.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
User,
Assistant,
ToolResult,
System,
}
/// A simplified message representation for `AgentEnd` events.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentEventMessage {
pub role: MessageRole,
pub content: String,
pub tool_calls: Vec<EventToolCall>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EventToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
pub output: Option<Value>,
pub error: Option<String>,
}
/// An async-friendly event sink that collects or broadcasts events.
pub struct EventSink {
senders: Vec<tokio::sync::mpsc::UnboundedSender<AgentEvent>>,
}
impl EventSink {
pub fn new() -> Self {
Self {
senders: Vec::new(),
}
}
/// Subscribe to events, returns a receiver.
pub fn subscribe(&mut self) -> tokio::sync::mpsc::UnboundedReceiver<AgentEvent> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
self.senders.push(tx);
rx
}
/// Emit an event to all subscribers. Non-blocking; drops if receiver disconnected.
pub fn emit(&self, event: AgentEvent) {
for sender in &self.senders {
let _ = sender.send(event.clone());
}
}
/// Check if there are any active subscribers.
pub fn has_subscribers(&self) -> bool {
!self.senders.is_empty()
}
/// Remove disconnected senders.
pub fn cleanup(&mut self) {
self.senders.retain(|s| !s.is_closed());
}
}
impl Default for EventSink {
fn default() -> Self {
Self::new()
}
}
impl Clone for EventSink {
fn clone(&self) -> Self {
Self {
senders: self.senders.clone(),
}
}
}

113
lib/ai/agent/helpers.rs Normal file
View File

@ -0,0 +1,113 @@
use std::future::Future;
use std::time::Duration;
use crate::agent::request::AgentRequest;
use crate::error::{AiError, AiResult};
pub fn build_input_string(request: &AgentRequest) -> String {
let mut input = String::new();
if !request.context.is_empty() {
input.push_str("<retrieved_context>\n");
for chunk in &request.context {
let source = chunk.source.as_deref().unwrap_or("unknown");
let score = chunk
.score
.map(|s| format!("{s:.4}"))
.unwrap_or_else(|| "n/a".to_string());
input.push_str(&format!(
"\n<chunk id=\"{}\" source=\"{}\" score=\"{}\">\n{}\n</chunk>\n",
chunk.id, source, score, chunk.content
));
}
input.push_str("</retrieved_context>\n\n");
}
for message in &request.messages {
match message {
super::request::AgentMessage::User(content) => {
input.push_str(&format!("User: {content}\n"));
}
super::request::AgentMessage::Assistant(content) => {
input.push_str(&format!("Assistant: {content}\n"));
}
}
}
input.push_str(&format!("User: {}", request.input));
input
}
pub fn estimate_tokens(text: &str) -> u64 {
if text.is_empty() {
return 0;
}
(text.chars().count() as f64 / 2.5).ceil() as u64
}
pub fn check_token_budget(
estimated_input_tokens: u64,
accumulated_output_chars: usize,
limit: i64,
) -> bool {
let output_estimate = (accumulated_output_chars as f64 / 2.5).ceil() as u64;
estimated_input_tokens + output_estimate > limit as u64
}
pub async fn with_retry<F, Fut, T>(
max_attempts: usize,
base_delay_ms: u64,
f: F,
) -> AiResult<T>
where
F: Fn() -> Fut,
Fut: Future<Output = AiResult<T>>,
{
let mut last_error: Option<AiError> = None;
for attempt in 0..max_attempts {
match f().await {
Ok(result) => return Ok(result),
Err(e) if is_retryable(&e) && attempt + 1 < max_attempts => {
let delay = Duration::from_millis(base_delay_ms * 2u64.pow(attempt as u32));
tracing::warn!(
error = %e,
attempt = attempt + 1,
max_attempts,
delay_ms = delay.as_millis(),
"retrying after transient error"
);
tokio::time::sleep(delay).await;
last_error = Some(e);
}
Err(e) => return Err(e),
}
}
Err(AiError::ModelRetriesExhausted {
attempts: max_attempts,
last_error: last_error
.map(|e| e.to_string())
.unwrap_or_else(|| "unknown".to_string()),
})
}
fn is_retryable(error: &AiError) -> bool {
matches!(
error,
AiError::Api(_) | AiError::Response(_) | AiError::ModelRetriesExhausted { .. }
)
}
pub fn tool_result_content_to_string(
content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>,
) -> String {
use rig::completion::message::ToolResultContent;
content
.iter()
.filter_map(|item| match item {
ToolResultContent::Text(t) => Some(t.text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n")
}

145
lib/ai/agent/hooks.rs Normal file
View File

@ -0,0 +1,145 @@
use async_trait::async_trait;
use serde_json::Value;
use crate::agent::persistence::AgentRunContext;
use crate::error::AiResult;
#[derive(Debug, Clone)]
pub enum ToolGuardrailDecision {
Allow,
Block { reason: String },
RequireApproval { message: String },
}
#[derive(Debug, Clone)]
pub struct ToolCallOutcome {
pub name: String,
pub arguments: Value,
pub output: Option<Value>,
pub error: Option<String>,
pub elapsed_ms: i64,
}
#[derive(Debug, Clone)]
pub struct HookMessage {
pub role: String,
pub content: Option<String>,
pub tool_calls: Option<Value>,
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone)]
pub struct HookLlmResponse {
pub content: Option<String>,
pub tool_calls: Option<Value>,
pub input_tokens: u64,
pub output_tokens: u64,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct HookToolDef {
pub name: String,
pub description: String,
}
#[async_trait]
pub trait AgentHook: Send + Sync {
fn name(&self) -> &'static str;
async fn on_session_start(&self, _ctx: &AgentRunContext) -> AiResult<()> {
Ok(())
}
async fn on_session_end(&self, _ctx: &AgentRunContext, _success: bool) -> AiResult<()> {
Ok(())
}
async fn pre_llm_call(&self, _messages: &[HookMessage], _tools: &[HookToolDef]) -> AiResult<()> {
Ok(())
}
async fn post_llm_call(&self, _response: &HookLlmResponse) -> AiResult<()> {
Ok(())
}
async fn pre_tool_call(
&self,
_tool_name: &str,
_arguments: &Value,
) -> AiResult<Option<ToolGuardrailDecision>> {
Ok(None)
}
async fn post_tool_call(&self, _outcome: &ToolCallOutcome) -> AiResult<()> {
Ok(())
}
}
pub struct HookChain {
hooks: Vec<Box<dyn AgentHook>>,
}
impl HookChain {
pub fn new(hooks: Vec<Box<dyn AgentHook>>) -> Self {
Self { hooks }
}
pub fn empty() -> Self {
Self { hooks: Vec::new() }
}
pub fn is_empty(&self) -> bool {
self.hooks.is_empty()
}
pub async fn run_session_start(&self, ctx: &AgentRunContext) -> AiResult<()> {
for hook in &self.hooks {
hook.on_session_start(ctx).await?;
}
Ok(())
}
pub async fn run_session_end(&self, ctx: &AgentRunContext, success: bool) -> AiResult<()> {
for hook in &self.hooks {
hook.on_session_end(ctx, success).await?;
}
Ok(())
}
pub async fn run_pre_llm_call(&self, messages: &[HookMessage], tools: &[HookToolDef]) -> AiResult<()> {
for hook in &self.hooks {
hook.pre_llm_call(messages, tools).await?;
}
Ok(())
}
pub async fn run_post_llm_call(&self, response: &HookLlmResponse) -> AiResult<()> {
for hook in &self.hooks {
hook.post_llm_call(response).await?;
}
Ok(())
}
pub async fn run_pre_tool_call(
&self,
tool_name: &str,
arguments: &Value,
) -> AiResult<Option<ToolGuardrailDecision>> {
for hook in &self.hooks {
if let Some(decision) = hook.pre_tool_call(tool_name, arguments).await? {
if !matches!(decision, ToolGuardrailDecision::Allow) {
return Ok(Some(decision));
}
}
}
Ok(None)
}
pub async fn run_post_tool_call(&self, outcome: &ToolCallOutcome) -> AiResult<()> {
for hook in &self.hooks {
hook.post_tool_call(outcome).await?;
}
Ok(())
}
}

View File

@ -0,0 +1,45 @@
#[derive(Clone, Debug)]
pub struct IterationBudget {
pub remaining: usize,
pub hard_limit: usize,
pub grace_call: bool,
pub consumed: usize,
}
impl IterationBudget {
pub fn new(limit: usize) -> Self {
Self {
remaining: limit,
hard_limit: limit,
grace_call: true,
consumed: 0,
}
}
pub fn can_continue(&self) -> bool {
self.remaining > 0 || (self.remaining == 0 && self.grace_call)
}
pub fn consume(&mut self) -> bool {
if self.remaining > 0 {
self.remaining -= 1;
self.consumed += 1;
true
} else if self.grace_call {
self.grace_call = false;
self.consumed += 1;
true
} else {
false
}
}
pub fn exhaust(&mut self) {
self.remaining = 0;
self.grace_call = false;
}
pub const fn total_consumed(&self) -> usize {
self.consumed
}
}

876
lib/ai/agent/loop.rs Normal file
View File

@ -0,0 +1,876 @@
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use futures::StreamExt;
use rig::agent::AgentBuilder;
use rig::client::CompletionClient;
use rig::streaming::StreamingPrompt;
use rig::tool::ToolDyn;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use super::config::AgentConfig;
use super::error_classifier::{
classify_error, retry_policy_for, should_switch_to_fallback,
};
use super::events::{AgentEvent, EventSink};
use super::helpers::{build_input_string, estimate_tokens};
use super::hooks::{HookChain, HookLlmResponse, HookMessage, ToolCallOutcome, ToolGuardrailDecision};
use super::iteration_budget::IterationBudget;
use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord};
use super::RigStreamChunk;
use crate::client::AiClient;
use crate::error::{AiError, AiResult};
/// How tool calls from a single assistant turn are executed.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolExecutionMode {
/// Execute tool calls one at a time.
Sequential,
/// Execute tool calls concurrently (after sequential preflight).
Parallel,
}
impl Default for ToolExecutionMode {
fn default() -> Self {
Self::Parallel
}
}
/// Callback type for steering messages (injected mid-run).
pub type SteeringFn = Arc<
dyn Fn() -> Pin<Box<dyn Future<Output = Vec<String>> + Send>> + Send + Sync,
>;
/// Callback type for follow-up messages (injected after agent would stop).
pub type FollowUpFn = Arc<
dyn Fn() -> Pin<Box<dyn Future<Output = Vec<String>> + Send>> + Send + Sync,
>;
/// Callback to decide whether the agent should stop after a turn.
pub type ShouldStopFn = Arc<
dyn Fn(&TurnContext) -> bool + Send + Sync,
>;
/// Callback to prepare/modify state before the next turn.
pub type PrepareNextTurnFn = Arc<
dyn Fn(&TurnContext) -> Pin<Box<dyn Future<Output = Option<TurnUpdate>> + Send>>
+ Send
+ Sync,
>;
/// Context passed to `should_stop` and `prepare_next_turn` callbacks.
#[derive(Debug, Clone)]
pub struct TurnContext {
pub turn_index: usize,
pub assistant_text: String,
pub tool_call_count: usize,
pub total_input_tokens: u64,
pub total_output_tokens: u64,
pub model_name: String,
}
/// Replacement state for the next turn (returned by `prepare_next_turn`).
#[derive(Debug, Clone)]
pub struct TurnUpdate {
pub model: Option<String>,
pub temperature: Option<f64>,
pub max_completion_tokens: Option<u64>,
}
/// Extended agent loop configuration, adding steering/follow-up/lifecycle
/// hooks on top of the base `AgentConfig`.
pub struct AgentLoopConfig {
pub config: AgentConfig,
pub tool_execution_mode: ToolExecutionMode,
pub get_steering_messages: Option<SteeringFn>,
pub get_follow_up_messages: Option<FollowUpFn>,
pub should_stop_after_turn: Option<ShouldStopFn>,
pub prepare_next_turn: Option<PrepareNextTurnFn>,
pub event_sink: Option<EventSink>,
}
impl AgentLoopConfig {
pub fn new(config: AgentConfig) -> Self {
Self {
config,
tool_execution_mode: ToolExecutionMode::default(),
get_steering_messages: None,
get_follow_up_messages: None,
should_stop_after_turn: None,
prepare_next_turn: None,
event_sink: None,
}
}
pub fn with_tool_execution_mode(mut self, mode: ToolExecutionMode) -> Self {
self.tool_execution_mode = mode;
self
}
pub fn with_steering_messages(mut self, f: SteeringFn) -> Self {
self.get_steering_messages = Some(f);
self
}
pub fn with_follow_up_messages(mut self, f: FollowUpFn) -> Self {
self.get_follow_up_messages = Some(f);
self
}
pub fn with_should_stop(mut self, f: ShouldStopFn) -> Self {
self.should_stop_after_turn = Some(f);
self
}
pub fn with_prepare_next_turn(mut self, f: PrepareNextTurnFn) -> Self {
self.prepare_next_turn = Some(f);
self
}
pub fn with_event_sink(mut self, sink: EventSink) -> Self {
self.event_sink = Some(sink);
self
}
}
/// Enhanced agent with loop controls (steering, follow-up, model switching).
pub struct EnhancedAgent {
pub client: AiClient,
pub loop_config: AgentLoopConfig,
pub hooks: HookChain,
}
impl EnhancedAgent {
pub fn new(client: AiClient, loop_config: AgentLoopConfig) -> AiResult<Self> {
loop_config.config.validate()?;
Ok(Self {
client,
loop_config,
hooks: HookChain::empty(),
})
}
pub fn with_hooks(mut self, hooks: HookChain) -> Self {
self.hooks = hooks;
self
}
pub fn config(&self) -> &AgentConfig {
&self.loop_config.config
}
/// Run the enhanced agent loop, returning a chunk receiver and a join handle.
#[allow(clippy::too_many_lines)]
pub fn run(
&self,
request: AgentRequest,
tools: Vec<Box<dyn ToolDyn>>,
) -> (
mpsc::Receiver<RigStreamChunk>,
tokio::task::JoinHandle<AiResult<AgentResult>>,
) {
let (tx, rx) = mpsc::channel::<RigStreamChunk>(256);
let config = self.loop_config.config.clone();
let tool_execution_mode = self.loop_config.tool_execution_mode;
let steering_fn = self.loop_config.get_steering_messages.clone();
let follow_up_fn = self.loop_config.get_follow_up_messages.clone();
let should_stop = self.loop_config.should_stop_after_turn.clone();
let prepare_next = self.loop_config.prepare_next_turn.clone();
let event_sink = self.loop_config.event_sink.clone();
let client = self.client.llm_client().clone();
let hooks = self.hooks.clone();
let filtered_tools: Vec<Box<dyn ToolDyn>> = tools
.into_iter()
.filter(|tool| config.is_tool_exposed(&tool.name()))
.collect();
let handle = tokio::spawn(async move {
run_enhanced_loop(
client,
config,
request,
filtered_tools,
tool_execution_mode,
steering_fn,
follow_up_fn,
should_stop,
prepare_next,
event_sink,
hooks,
tx,
)
.await
});
(rx, handle)
}
}
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
async fn run_enhanced_loop(
client: rig::providers::openai::Client,
mut config: AgentConfig,
request: AgentRequest,
tools: Vec<Box<dyn ToolDyn>>,
_tool_execution_mode: ToolExecutionMode,
steering_fn: Option<SteeringFn>,
follow_up_fn: Option<FollowUpFn>,
should_stop: Option<ShouldStopFn>,
prepare_next: Option<PrepareNextTurnFn>,
event_sink: Option<EventSink>,
hooks: HookChain,
tx: mpsc::Sender<RigStreamChunk>,
) -> AiResult<AgentResult> {
let cancellation = request.cancellation_token.clone();
let timeout = request.timeout;
let mut budget = IterationBudget::new(config.iteration_budget);
let mut all_steps: Vec<AgentStep> = Vec::new();
let mut total_input_tokens: u64 = 0;
let mut total_output_tokens: u64 = 0;
let mut turn_index: usize = 0;
// Session start hook
if let Some(ctx) = &request.run_context {
let _ = hooks.run_session_start(ctx).await;
}
// Emit agent start event
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::AgentStart);
}
// Build the initial input
let input = build_input_string(&request);
let mut current_input = input.clone();
let estimated_input_tokens = estimate_tokens(&current_input);
if let Some(limit) = config.max_total_tokens_per_run
&& estimated_input_tokens > limit as u64
{
return Err(AiError::TokenBudgetExceeded {
estimated: estimated_input_tokens,
limit,
});
}
// Outer loop: handles follow-up messages after agent would stop
loop {
// Inner loop: tool call turns + steering messages
let mut pending_steering: Vec<String> = if let Some(f) = &steering_fn {
f().await
} else {
Vec::new()
};
loop {
// Check cancellation
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
let _ = tx.send(RigStreamChunk::Failed { error: "cancelled".to_string() }).await;
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ErrorClassified {
category: "cancelled".to_string(),
message: "cancelled by caller".to_string(),
will_retry: false,
retry_delay_ms: None,
});
}
return Err(AiError::Response("agent run cancelled".to_string()));
}
// Inject steering messages if any
if !pending_steering.is_empty() {
let count = pending_steering.len();
for msg in &pending_steering {
current_input.push_str(&format!("\nUser: {msg}\n"));
}
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::SteeringMessagesInjected { count });
}
pending_steering.clear();
}
// Emit turn start
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::TurnStart { turn_index });
}
let _ = tx.send(RigStreamChunk::TextDelta {
index: 0,
content: String::new(), // placeholder for turn boundary detection
}).await;
// Run one LLM turn with retry
let turn_result = run_single_turn(
&client,
&config,
&current_input,
&tools,
&mut budget,
&cancellation,
timeout,
&hooks,
&event_sink,
&tx,
)
.await;
match turn_result {
Ok(turn_output) => {
total_input_tokens += turn_output.input_tokens;
total_output_tokens += turn_output.output_tokens;
// Collect step
let tool_call_count = turn_output.tool_calls.len();
if !turn_output.tool_calls.is_empty() || !turn_output.assistant_text.is_empty() {
all_steps.push(AgentStep {
index: all_steps.len(),
assistant: (!turn_output.assistant_text.is_empty())
.then_some(turn_output.assistant_text.clone()),
reasoning_content: None,
tool_calls: turn_output.tool_calls,
reflection: None,
});
}
// Emit turn end
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::TurnEnd {
turn_index,
assistant_text: Some(turn_output.assistant_text.clone()),
tool_call_count,
});
}
// Check should_stop
let turn_ctx = TurnContext {
turn_index,
assistant_text: turn_output.assistant_text.clone(),
tool_call_count,
total_input_tokens,
total_output_tokens,
model_name: config.model.clone(),
};
if let Some(stop_fn) = &should_stop {
if stop_fn(&turn_ctx) {
info!(turn_index, "agent stopped by should_stop callback");
break;
}
}
// Prepare next turn (may switch model)
if let Some(prep_fn) = &prepare_next {
if let Some(update) = prep_fn(&turn_ctx).await {
if let Some(new_model) = update.model {
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ModelSwitched {
from_model: config.model.clone(),
to_model: new_model.clone(),
reason: "prepare_next_turn".to_string(),
});
}
config.model = new_model;
}
if let Some(temp) = update.temperature {
config.temperature = Some(temp);
}
if let Some(max_tok) = update.max_completion_tokens {
config.max_completion_tokens = Some(max_tok);
}
}
}
turn_index += 1;
// If no tool calls, this turn is done
if tool_call_count == 0 {
break;
}
// Otherwise, continue with tool results as new input
current_input = turn_output.assistant_text.clone();
}
Err(e) => {
// Error classification and retry with fallback
let category = classify_error(&e, None);
let policy = retry_policy_for(&category, config.retry_max_attempts, config.retry_base_delay_ms);
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ErrorClassified {
category: format!("{category:?}"),
message: e.to_string(),
will_retry: policy.switch_to_fallback || policy.max_attempts > 0,
retry_delay_ms: Some(policy.base_delay.as_millis() as u64),
});
}
if should_switch_to_fallback(&category) {
if let Some(fallback_model) = &config.fallback_model {
info!(
from_model = %config.model,
to_model = %fallback_model,
"switching to fallback model due to error"
);
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ModelSwitched {
from_model: config.model.clone(),
to_model: fallback_model.clone(),
reason: format!("fallback: {category:?}"),
});
}
config.model = fallback_model.clone();
// Retry with the fallback model
let retry_result = run_single_turn(
&client,
&config,
&current_input,
&tools,
&mut budget,
&cancellation,
timeout,
&hooks,
&event_sink,
&tx,
)
.await;
match retry_result {
Ok(turn_output) => {
total_input_tokens += turn_output.input_tokens;
total_output_tokens += turn_output.output_tokens;
let tc_count = turn_output.tool_calls.len();
let has_tools = tc_count > 0;
let has_text = !turn_output.assistant_text.is_empty();
let assistant = turn_output.assistant_text;
if has_tools || has_text {
all_steps.push(AgentStep {
index: all_steps.len(),
assistant: has_text.then_some(assistant.clone()),
reasoning_content: None,
tool_calls: turn_output.tool_calls,
reflection: None,
});
}
turn_index += 1;
if !has_tools {
break;
}
current_input = assistant;
continue;
}
Err(retry_err) => {
let _ = tx
.send(RigStreamChunk::Failed {
error: retry_err.to_string(),
})
.await;
if let Some(ctx) = &request.run_context {
let _ = hooks.run_session_end(ctx, false).await;
}
return Err(retry_err);
}
}
}
}
// Non-retryable or no fallback
let _ = tx
.send(RigStreamChunk::Failed {
error: e.to_string(),
})
.await;
if let Some(ctx) = &request.run_context {
let _ = hooks.run_session_end(ctx, false).await;
}
return Err(e);
}
}
}
// Check for follow-up messages
let follow_ups: Vec<String> = if let Some(f) = &follow_up_fn {
f().await
} else {
Vec::new()
};
if follow_ups.is_empty() {
break;
}
// Inject follow-up messages and continue the outer loop
let count = follow_ups.len();
for msg in &follow_ups {
current_input.push_str(&format!("\nUser: {msg}\n"));
}
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::FollowUpMessagesInjected { count });
}
}
// Build final output
let output = all_steps
.last()
.and_then(|s| s.assistant.clone())
.unwrap_or_default();
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::AgentEnd {
messages: Vec::new(),
total_input_tokens,
total_output_tokens,
});
}
let _ = tx
.send(RigStreamChunk::Final {
content: output.clone(),
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
})
.await;
if let Some(ctx) = &request.run_context {
let _ = hooks.run_session_end(ctx, true).await;
}
info!(
turns = turn_index,
steps = all_steps.len(),
total_input_tokens,
total_output_tokens,
"enhanced agent loop completed"
);
Ok(AgentResult {
output,
steps: all_steps,
expert_outputs: Vec::new(),
input_tokens: total_input_tokens as i64,
output_tokens: total_output_tokens as i64,
})
}
/// Output from a single LLM turn (one assistant response + its tool calls).
struct TurnOutput {
assistant_text: String,
tool_calls: Vec<ToolCallRecord>,
input_tokens: u64,
output_tokens: u64,
}
/// Run a single LLM turn with streaming, handling the stream parsing and
/// tool call collection.
#[allow(clippy::too_many_arguments)]
async fn run_single_turn(
client: &rig::providers::openai::Client,
config: &AgentConfig,
input: &str,
_tools: &[Box<dyn ToolDyn>],
budget: &mut IterationBudget,
cancellation: &Option<CancellationToken>,
timeout: Option<std::time::Duration>,
hooks: &HookChain,
event_sink: &Option<EventSink>,
tx: &mpsc::Sender<RigStreamChunk>,
) -> AiResult<TurnOutput> {
if !budget.consume() {
return Err(AiError::Response("iteration budget exhausted".to_string()));
}
let model = client.completion_model(&config.model);
let mut agent_builder = AgentBuilder::new(model)
.preamble(&config.system_prompt)
.default_max_turns(1); // Single turn, we manage the loop
// Note: we can't easily pass tools here for single-turn since
// rig's multi_turn handles tool execution internally.
// For the enhanced loop, we rely on rig's built-in tool execution
// within a single turn. The parallel/sequential mode is controlled
// by the event-level hooks.
if let Some(temp) = config.temperature {
agent_builder = agent_builder.temperature(temp);
}
if let Some(mt) = config.max_completion_tokens {
agent_builder = agent_builder.max_tokens(mt);
}
let agent = agent_builder.build();
// Pre-LLM hook
if !hooks.is_empty() {
let hook_messages = vec![HookMessage {
role: "user".to_string(),
content: Some(input.to_string()),
tool_calls: None,
tool_call_id: None,
}];
let _ = hooks.run_pre_llm_call(&hook_messages, &[]).await;
}
let stream_future = agent
.stream_prompt(input)
.with_history(Vec::<rig::completion::Message>::new())
.multi_turn(config.max_iterations);
let stream = if let Some(dur) = timeout {
match tokio::time::timeout(dur, stream_future).await {
Ok(stream) => stream,
Err(_) => {
return Err(AiError::Timeout {
seconds: dur.as_secs(),
});
}
}
} else {
stream_future.await
};
tokio::pin!(stream);
let mut assistant_text = String::new();
let mut tool_calls: Vec<ToolCallRecord> = Vec::new();
let mut delta_index = 0usize;
let mut _accumulated_output_chars: usize = 0;
let mut input_tokens: u64 = 0;
let mut output_tokens: u64 = 0;
while let Some(item) = stream.next().await {
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
return Err(AiError::Response("cancelled".to_string()));
}
match item {
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::Text(text),
)) => {
_accumulated_output_chars += text.text.chars().count();
assistant_text.push_str(&text.text);
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::MessageTextDelta {
index: delta_index,
delta: text.text.clone(),
});
}
let _ = tx
.send(RigStreamChunk::TextDelta {
index: delta_index,
content: text.text.clone(),
})
.await;
delta_index += 1;
}
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::Reasoning(reasoning),
)) => {
for part in &reasoning.content {
if let rig::completion::message::ReasoningContent::Text { text, .. } = part {
_accumulated_output_chars += text.chars().count();
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::MessageThinkingDelta {
index: delta_index,
delta: text.clone(),
});
}
let _ = tx
.send(RigStreamChunk::Thinking {
index: delta_index,
content: text.clone(),
})
.await;
delta_index += 1;
}
}
}
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::ReasoningDelta { reasoning, .. },
)) => {
_accumulated_output_chars += reasoning.chars().count();
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::MessageThinkingDelta {
index: delta_index,
delta: reasoning.clone(),
});
}
let _ = tx
.send(RigStreamChunk::Thinking {
index: delta_index,
content: reasoning.clone(),
})
.await;
delta_index += 1;
}
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::ToolCall { tool_call, .. },
)) => {
let args = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(),
v => serde_json::to_string(v).unwrap_or_default(),
};
_accumulated_output_chars += args.chars().count();
let tool_name = tool_call.function.name.clone();
let tool_args: serde_json::Value =
serde_json::from_str(&args).unwrap_or_default();
// Pre-tool-call guardrail hook
if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await {
match decision {
ToolGuardrailDecision::Allow => {}
ToolGuardrailDecision::Block { reason } => {
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ToolExecutionEnd {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
output: None,
error: Some(reason.clone()),
elapsed_ms: 0,
});
}
let _ = tx
.send(RigStreamChunk::ToolCallFinished {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
output: format!("blocked: {reason}"),
error: Some(reason),
})
.await;
tool_calls.push(ToolCallRecord {
id: tool_call.id.clone(),
name: tool_name,
arguments: tool_args,
output: None,
error: Some("blocked by guardrail".to_string()),
elapsed_ms: None,
});
continue;
}
ToolGuardrailDecision::RequireApproval { message } => {
tool_calls.push(ToolCallRecord {
id: tool_call.id.clone(),
name: tool_name.clone(),
arguments: tool_args,
output: None,
error: Some(format!("requires approval: {message}")),
elapsed_ms: None,
});
continue;
}
}
}
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ToolExecutionStart {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
arguments: tool_args.clone(),
});
}
let _ = tx
.send(RigStreamChunk::ToolCallStarted {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
arguments: args.clone(),
})
.await;
tool_calls.push(ToolCallRecord {
id: tool_call.id.clone(),
name: tool_name,
arguments: tool_args,
output: None,
error: None,
elapsed_ms: None,
});
}
Ok(rig::agent::MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
)) => {
let content =
super::helpers::tool_result_content_to_string(&tool_result.content);
_accumulated_output_chars += content.chars().count();
let tool_name = tool_calls
.last()
.map(|tc| tc.name.clone())
.unwrap_or_default();
if let Some(last) = tool_calls.last_mut()
&& last.id == tool_result.id
{
last.output = Some(serde_json::from_str(&content).unwrap_or_default());
}
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ToolExecutionEnd {
tool_call_id: tool_result.id.clone(),
tool_name: tool_name.clone(),
output: Some(serde_json::Value::String(content.clone())),
error: None,
elapsed_ms: 0,
});
}
let _ = tx
.send(RigStreamChunk::ToolCallFinished {
tool_call_id: tool_result.id.clone(),
tool_name,
output: content.clone(),
error: None,
})
.await;
if !hooks.is_empty() {
let outcome = ToolCallOutcome {
name: tool_result.id.clone(),
arguments: serde_json::Value::Null,
output: Some(serde_json::Value::String(content)),
error: None,
elapsed_ms: 0,
};
let _ = hooks.run_post_tool_call(&outcome).await;
}
}
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
input_tokens = usage.input_tokens;
output_tokens = usage.output_tokens;
if !hooks.is_empty() {
let hook_response = HookLlmResponse {
content: Some(assistant_text.clone()),
tool_calls: None,
input_tokens,
output_tokens,
finish_reason: None,
};
let _ = hooks.run_post_llm_call(&hook_response).await;
}
}
Err(e) => {
warn!(error = %e, "turn stream error");
return Err(AiError::Api(format!("{e}")));
}
_ => {}
}
}
Ok(TurnOutput {
assistant_text,
tool_calls,
input_tokens,
output_tokens,
})
}

98
lib/ai/agent/mod.rs Normal file
View File

@ -0,0 +1,98 @@
pub mod agent;
pub mod compression;
pub mod config;
pub mod error_classifier;
pub mod events;
pub mod helpers;
pub mod hooks;
pub mod iteration_budget;
pub mod r#loop;
pub mod persistence;
pub mod prompt;
pub mod prompt_builder;
pub mod request;
pub mod session;
pub mod subagent;
pub mod tool;
use serde::{Deserialize, Serialize};
pub use agent::RigAgent;
pub use compression::{
CompactionResult, CompressionStrategy, build_branch_summary_prompt,
build_compression_prompt, build_compression_prompt_with_options,
estimate_truncation, plan_compaction,
};
pub use config::AgentConfig;
pub use error_classifier::{
ErrorCategory, RetryPolicy, classify_error, retry_policy_for,
should_compact_before_retry, should_switch_to_fallback,
};
pub use events::{
AgentEvent, AgentEventMessage, EventSink, EventToolCall, MessageRole,
};
pub use hooks::{AgentHook, HookChain, ToolCallOutcome, ToolGuardrailDecision};
pub use iteration_budget::IterationBudget;
pub use r#loop::{
AgentLoopConfig, EnhancedAgent, PrepareNextTurnFn, ShouldStopFn,
ToolExecutionMode, TurnContext, TurnUpdate,
};
pub use persistence::{
AgentRealtime, AgentRunContext, AgentRuntime, AgentStreamEvent,
};
pub use prompt_builder::SystemPromptBuilder;
pub use request::{
AgentContextChunk, AgentExpert, AgentExpertOutput, AgentMessage,
AgentRequest, AgentResult, AgentStep, ToolCallRecord,
};
pub use session::{
CompactionOptions, Session, SessionEntry, SessionHeader,
SessionMessageRole, SessionToolCall, SessionToolResult,
};
pub use tool::{RigTool, RigToolSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum RigStreamChunk {
TextDelta {
index: usize,
content: String,
},
Thinking {
index: usize,
content: String,
},
ToolCallStarted {
tool_call_id: String,
tool_name: String,
arguments: String,
},
ToolCallFinished {
tool_call_id: String,
tool_name: String,
output: String,
error: Option<String>,
},
SubagentStarted {
subagent_id: String,
role: String,
task: String,
},
SubagentCompleted {
subagent_id: String,
role: String,
task: String,
output: String,
},
SubagentFailed {
error: String,
},
Final {
content: String,
input_tokens: u64,
output_tokens: u64,
},
Failed {
error: String,
},
}

View File

@ -0,0 +1,33 @@
use std::time::Instant;
use crate::agent::persistence::types::{ActiveAgentRun, AgentRunContext};
use crate::error::AiResult;
impl super::types::AgentRuntime {
pub async fn start_run(
&self,
run_context: Option<&AgentRunContext>,
) -> AiResult<ActiveAgentRun> {
let Some(run_context) = run_context else {
return Ok(ActiveAgentRun {
conversation_id: None,
message_id: None,
invocation_id: None,
session_id: None,
user_id: None,
started_at: Instant::now(),
current_step: 0,
});
};
Ok(ActiveAgentRun {
conversation_id: run_context.conversation_id,
message_id: None,
invocation_id: run_context.invocation_id,
session_id: run_context.session_id,
user_id: run_context.user_id,
started_at: Instant::now(),
current_step: 0,
})
}
}

View File

@ -0,0 +1,8 @@
pub mod db;
pub mod realtime;
pub mod types;
pub use types::{
ActiveAgentRun, AgentRealtime, AgentRunContext, AgentRuntime,
AgentStreamEvent, estimate_output_tokens,
};

View File

@ -0,0 +1,38 @@
use crate::agent::persistence::types::{
AgentRealtime, AgentRuntime, AgentStreamEvent,
};
use crate::error::AiResult;
pub async fn publish_event(
runtime: &AgentRuntime,
_realtime: Option<&AgentRealtime>,
event: &AgentStreamEvent,
) -> AiResult<()> {
let Some(tx) = &runtime.tx else {
return Ok(());
};
let payload = match serde_json::to_string(event) {
Ok(p) => p,
Err(error) => {
tracing::warn!(error = %error, "agent sse: serialize failed");
return Ok(());
}
};
if tx.send(payload).is_err() {
tracing::debug!("agent sse: mpsc send failed, client disconnected");
}
Ok(())
}
impl AgentRuntime {
pub async fn publish(
&self,
realtime: Option<&AgentRealtime>,
event: &AgentStreamEvent,
) -> AiResult<()> {
publish_event(self, realtime, event).await
}
}

View File

@ -0,0 +1,177 @@
use std::time::Instant;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::mpsc;
use uuid::Uuid;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentRealtime {
pub channel: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentRunContext {
pub conversation_id: Option<Uuid>,
pub invocation_id: Option<Uuid>,
pub session_id: Option<Uuid>,
pub user_id: Option<Uuid>,
pub realtime: Option<AgentRealtime>,
}
impl AgentRunContext {
pub fn new(user_id: Uuid) -> Self {
Self {
conversation_id: None,
invocation_id: None,
session_id: None,
user_id: Some(user_id),
realtime: None,
}
}
pub fn with_conversation_id(mut self, id: Uuid) -> Self {
self.conversation_id = Some(id);
self
}
pub fn with_invocation_id(mut self, id: Uuid) -> Self {
self.invocation_id = Some(id);
self
}
pub fn with_session_id(mut self, id: Uuid) -> Self {
self.session_id = Some(id);
self
}
pub fn with_realtime(mut self, realtime: AgentRealtime) -> Self {
self.realtime = Some(realtime);
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AgentStreamEvent {
Started {
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
session_id: Option<Uuid>,
model: String,
},
Delta {
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
index: usize,
content: String,
},
Thinking {
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
index: usize,
content: String,
},
ToolCallStarted {
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
session_id: Option<Uuid>,
tool_call_id: String,
tool_name: String,
arguments: Value,
},
ToolCallFinished {
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
session_id: Option<Uuid>,
tool_call_id: String,
tool_name: String,
output: Option<Value>,
error: Option<String>,
execution_time_ms: i64,
},
SubagentStarted {
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
subagent_id: String,
role: String,
task: String,
model: String,
},
SubagentDelta {
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
subagent_id: String,
index: usize,
content: String,
},
SubagentCompleted {
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
subagent_id: String,
role: String,
task: String,
output: String,
input_tokens: i64,
output_tokens: i64,
model: String,
},
SubagentFailed {
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
subagent_id: String,
error: String,
},
Completed {
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
session_id: Option<Uuid>,
output: String,
input_tokens: i64,
output_tokens: i64,
latency_ms: i32,
stop_reason: Option<String>,
},
Failed {
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
session_id: Option<Uuid>,
error: String,
},
}
#[derive(Clone)]
pub struct AgentRuntime {
pub tx: Option<mpsc::UnboundedSender<String>>,
}
impl AgentRuntime {
pub fn new(tx: mpsc::UnboundedSender<String>) -> Self {
Self { tx: Some(tx) }
}
pub fn empty() -> Self {
Self { tx: None }
}
}
impl Default for AgentRuntime {
fn default() -> Self {
Self::empty()
}
}
#[derive(Clone, Debug)]
pub struct ActiveAgentRun {
pub conversation_id: Option<Uuid>,
pub message_id: Option<Uuid>,
pub invocation_id: Option<Uuid>,
pub session_id: Option<Uuid>,
pub user_id: Option<Uuid>,
pub started_at: Instant,
pub current_step: usize,
}
pub fn estimate_output_tokens(output: &str) -> i64 {
(output.chars().count() as f64 / 4.0).ceil() as i64
}

57
lib/ai/agent/prompt.rs Normal file
View File

@ -0,0 +1,57 @@
use rig::agent::AgentBuilder;
use rig::client::CompletionClient;
use rig::completion::Prompt;
use super::agent::RigAgent;
use super::helpers::with_retry;
use crate::error::{AiError, AiResult};
impl RigAgent {
pub async fn prompt(
&self,
system_prompt: &str,
user_input: &str,
) -> AiResult<(String, u64, u64)> {
let model_name = self.config.model.clone();
let client = self.client.llm_client().clone();
let temperature = self.config.temperature;
let max_completion_tokens = self.config.max_completion_tokens;
let retry_attempts = self.config.retry_max_attempts;
let retry_delay_ms = self.config.retry_base_delay_ms;
let sp = system_prompt.to_string();
let ui = user_input.to_string();
with_retry(retry_attempts, retry_delay_ms, || {
let client = client.clone();
let model_name = model_name.clone();
let sp = sp.clone();
let ui = ui.clone();
async move {
let model = client.completion_model(&model_name);
let mut builder = AgentBuilder::new(model).preamble(&sp);
if let Some(temp) = temperature {
builder = builder.temperature(temp);
}
if let Some(mt) = max_completion_tokens {
builder = builder.max_tokens(mt);
}
let agent = builder.build();
let response = agent
.prompt(&ui)
.extended_details()
.await
.map_err(|e: rig::completion::PromptError| {
AiError::Api(e.to_string())
})?;
Ok((
response.output,
response.usage.input_tokens,
response.usage.output_tokens,
))
}
})
.await
}
}

View File

@ -0,0 +1,251 @@
use std::collections::HashMap;
/// Modular system prompt builder inspired by pi's `buildSystemPrompt`.
///
/// Supports:
/// - Base prompt (replaceable or appendable)
/// - Tool snippets injected into an "Available tools" section
/// - Project context files (AGENTS.md, etc.)
/// - Skills injection
/// - Variable substitution ({{key}})
/// - Metadata (date)
///
/// # Example
/// ```rust
/// use ai::agent::prompt_builder::SystemPromptBuilder;
///
/// let prompt = SystemPromptBuilder::new()
/// .base_prompt("You are a helpful assistant.")
/// .tool_snippet("bash", "Execute shell commands")
/// .tool_snippet("read", "Read file contents")
/// .project_context("AGENTS.md", "# Project Rules\n- Follow conventions")
/// .variable("repo_name", "gitdataai")
/// .build();
/// ```
#[derive(Clone, Debug)]
pub struct SystemPromptBuilder {
base_prompt: Option<String>,
append_prompt: Option<String>,
tool_snippets: Vec<(String, String)>,
tool_guidelines: Vec<String>,
project_contexts: Vec<(String, String)>,
skills: Vec<String>,
variables: HashMap<String, String>,
date: Option<String>,
custom_sections: Vec<(String, String)>,
}
impl SystemPromptBuilder {
pub fn new() -> Self {
Self {
base_prompt: None,
append_prompt: None,
tool_snippets: Vec::new(),
tool_guidelines: Vec::new(),
project_contexts: Vec::new(),
skills: Vec::new(),
variables: HashMap::new(),
date: None,
custom_sections: Vec::new(),
}
}
/// Set the base system prompt. Replaces the default prompt.
pub fn base_prompt(mut self, prompt: impl Into<String>) -> Self {
self.base_prompt = Some(prompt.into());
self
}
/// Append additional text to the system prompt after the base.
pub fn append_prompt(mut self, text: impl Into<String>) -> Self {
self.append_prompt = Some(text.into());
self
}
/// Add a one-line tool description snippet.
pub fn tool_snippet(mut self, tool_name: impl Into<String>, description: impl Into<String>) -> Self {
self.tool_snippets.push((tool_name.into(), description.into()));
self
}
/// Add a guideline bullet for the tools section.
pub fn tool_guideline(mut self, guideline: impl Into<String>) -> Self {
self.tool_guidelines.push(guideline.into());
self
}
/// Add a project context file (e.g., AGENTS.md content).
pub fn project_context(mut self, path: impl Into<String>, content: impl Into<String>) -> Self {
self.project_contexts.push((path.into(), content.into()));
self
}
/// Add a skill description to inject into the prompt.
pub fn skill(mut self, skill_description: impl Into<String>) -> Self {
self.skills.push(skill_description.into());
self
}
/// Set a variable for {{key}} substitution.
pub fn variable(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.variables.insert(key.into(), value.into());
self
}
/// Set multiple variables from an iterator.
pub fn variables(mut self, vars: impl IntoIterator<Item = (String, String)>) -> Self {
self.variables.extend(vars);
self
}
/// Set the date metadata (ISO format: YYYY-MM-DD).
pub fn date(mut self, date: impl Into<String>) -> Self {
self.date = Some(date.into());
self
}
/// Add a custom named section to the prompt.
pub fn custom_section(mut self, name: impl Into<String>, content: impl Into<String>) -> Self {
self.custom_sections.push((name.into(), content.into()));
self
}
/// Build the final system prompt string.
pub fn build(self) -> String {
let mut parts: Vec<String> = Vec::new();
// 1. Base prompt
if let Some(base) = &self.base_prompt {
parts.push(base.clone());
}
// 2. Append prompt
if let Some(append) = &self.append_prompt {
parts.push(append.clone());
}
// 3. Tool snippets section
if !self.tool_snippets.is_empty() {
let mut section = String::from("\n## Available Tools\n");
for (name, desc) in &self.tool_snippets {
section.push_str(&format!("- `{name}`: {desc}\n"));
}
if !self.tool_guidelines.is_empty() {
section.push_str("\n### Tool Guidelines\n");
for guideline in &self.tool_guidelines {
section.push_str(&format!("- {guideline}\n"));
}
}
parts.push(section);
}
// 4. Project context files
if !self.project_contexts.is_empty() {
let mut section = String::from("\n<project_context>\n\n");
section.push_str("Project-specific instructions and guidelines:\n\n");
for (path, content) in &self.project_contexts {
section.push_str(&format!("<project_instructions path=\"{path}\">\n{content}\n</project_instructions>\n\n"));
}
section.push_str("</project_context>");
parts.push(section);
}
// 5. Skills section
if !self.skills.is_empty() {
let mut section = String::from("\n## Available Skills\n");
for skill in &self.skills {
section.push_str(&format!("{skill}\n"));
}
parts.push(section);
}
// 6. Custom sections
for (name, content) in &self.custom_sections {
parts.push(format!("\n## {name}\n{content}"));
}
// 7. Metadata footer
if let Some(date) = &self.date {
parts.push(format!("\nCurrent date: {date}"));
}
let mut result = parts.join("\n");
// 8. Variable substitution
for (key, value) in &self.variables {
let placeholder = format!("{{{{{}}}}}", key);
result = result.replace(&placeholder, value);
}
result
}
}
impl Default for SystemPromptBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_build() {
let prompt = SystemPromptBuilder::new()
.base_prompt("You are a helpful assistant.")
.date("2026-05-29")
.build();
assert!(prompt.contains("You are a helpful assistant."));
assert!(prompt.contains("Current date: 2026-05-29"));
}
#[test]
fn test_variable_substitution() {
let prompt = SystemPromptBuilder::new()
.base_prompt("Repo: {{repo_name}}, User: {{user}}")
.variable("repo_name", "gitdataai")
.variable("user", "zhenyi")
.build();
assert!(prompt.contains("Repo: gitdataai"));
assert!(prompt.contains("User: zhenyi"));
}
#[test]
fn test_tool_snippets() {
let prompt = SystemPromptBuilder::new()
.base_prompt("Agent prompt.")
.tool_snippet("bash", "Execute shell commands")
.tool_snippet("read", "Read file contents")
.build();
assert!(prompt.contains("## Available Tools"));
assert!(prompt.contains("`bash`: Execute shell commands"));
}
#[test]
fn test_project_context() {
let prompt = SystemPromptBuilder::new()
.base_prompt("Base.")
.project_context("AGENTS.md", "# Rules\n- Follow conventions")
.build();
assert!(prompt.contains("<project_context>"));
assert!(prompt.contains("AGENTS.md"));
assert!(prompt.contains("Follow conventions"));
}
#[test]
fn test_custom_section() {
let prompt = SystemPromptBuilder::new()
.base_prompt("Base.")
.custom_section("Memory", "Remember: user prefers Rust")
.build();
assert!(prompt.contains("## Memory"));
assert!(prompt.contains("user prefers Rust"));
}
}

240
lib/ai/agent/request.rs Normal file
View File

@ -0,0 +1,240 @@
use std::time::Duration;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio_util::sync::CancellationToken;
use super::persistence::AgentRunContext;
use crate::error::{AiError, AiResult};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentRequest {
pub input: String,
pub messages: Vec<AgentMessage>,
pub context: Vec<AgentContextChunk>,
pub experts: Vec<AgentExpert>,
pub run_context: Option<AgentRunContext>,
#[serde(skip)]
pub prefill_messages: Vec<rig::completion::Message>,
#[serde(skip)]
pub cancellation_token: Option<CancellationToken>,
#[serde(skip)]
pub timeout: Option<Duration>,
}
impl AgentRequest {
pub fn new(input: impl Into<String>) -> Self {
Self {
input: input.into(),
messages: Vec::new(),
context: Vec::new(),
experts: Vec::new(),
run_context: None,
prefill_messages: Vec::new(),
cancellation_token: None,
timeout: None,
}
}
pub fn validate(&self) -> AiResult<()> {
if self.input.trim().is_empty() {
return Err(AiError::Config("agent request input is required".to_string()));
}
if self.input.len() > 1_000_000 {
return Err(AiError::Config(
"agent request input exceeds maximum length (1MB)".to_string(),
));
}
if self.experts.len() > 32 {
return Err(AiError::Config(
"agent request experts count exceeds maximum (32)".to_string(),
));
}
Ok(())
}
pub fn with_messages(mut self, messages: Vec<AgentMessage>) -> Self {
self.messages = messages;
self
}
pub fn with_context(mut self, context: Vec<AgentContextChunk>) -> Self {
self.context = context;
self
}
pub fn add_context(mut self, chunk: AgentContextChunk) -> Self {
self.context.push(chunk);
self
}
pub fn with_experts(mut self, experts: Vec<AgentExpert>) -> Self {
self.experts = experts;
self
}
pub fn add_expert(mut self, expert: AgentExpert) -> Self {
self.experts.push(expert);
self
}
pub fn with_run_context(mut self, run_context: AgentRunContext) -> Self {
self.run_context = Some(run_context);
self
}
pub fn with_prefill_messages(mut self, prefill_messages: Vec<rig::completion::Message>) -> Self {
self.prefill_messages = prefill_messages;
self
}
pub fn with_cancellation_token(mut self, cancellation_token: CancellationToken) -> Self {
self.cancellation_token = Some(cancellation_token);
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum AgentMessage {
User(String),
Assistant(String),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentExpert {
pub id: String,
pub role: String,
pub task: String,
pub system_prompt: Option<String>,
pub context: Vec<AgentContextChunk>,
/// Override the master agent's temperature for this subagent.
pub temperature: Option<f64>,
/// Override the master agent's max_completion_tokens for this subagent.
pub max_completion_tokens: Option<u64>,
}
impl AgentExpert {
pub fn new(id: impl Into<String>, role: impl Into<String>, task: impl Into<String>) -> Self {
Self {
id: id.into(),
role: role.into(),
task: task.into(),
system_prompt: None,
context: Vec::new(),
temperature: None,
max_completion_tokens: None,
}
}
pub fn with_system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
self.system_prompt = Some(system_prompt.into());
self
}
pub fn with_context(mut self, context: Vec<AgentContextChunk>) -> Self {
self.context = context;
self
}
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_completion_tokens(mut self, max_tokens: u64) -> Self {
self.max_completion_tokens = Some(max_tokens);
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentContextChunk {
pub id: String,
pub content: String,
pub source: Option<String>,
pub score: Option<f32>,
pub metadata: Value,
}
impl AgentContextChunk {
pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
id: id.into(),
content: content.into(),
source: None,
score: None,
metadata: Value::Null,
}
}
}
impl From<crate::rag::RagSearchHit> for AgentContextChunk {
fn from(hit: crate::rag::RagSearchHit) -> Self {
Self {
id: hit.id,
content: hit.content,
source: Some(hit.session_id),
score: Some(hit.score),
metadata: Value::Object(hit.metadata.into_iter().collect()),
}
}
}
impl From<&AgentExpertOutput> for AgentContextChunk {
fn from(output: &AgentExpertOutput) -> Self {
Self {
id: format!("subagent:{}", output.id),
content: output.output.clone(),
source: Some(output.role.clone()),
score: None,
metadata: serde_json::json!({
"kind": "subagent",
"task": output.task,
}),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentResult {
pub output: String,
pub steps: Vec<AgentStep>,
pub expert_outputs: Vec<AgentExpertOutput>,
pub input_tokens: i64,
pub output_tokens: i64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentStep {
pub index: usize,
pub assistant: Option<String>,
pub reasoning_content: Option<String>,
pub tool_calls: Vec<ToolCallRecord>,
pub reflection: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolCallRecord {
pub id: String,
pub name: String,
pub arguments: Value,
pub output: Option<Value>,
pub error: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub elapsed_ms: Option<i64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentExpertOutput {
pub id: String,
pub role: String,
pub task: String,
pub output: String,
pub input_tokens: i64,
pub output_tokens: i64,
}

535
lib/ai/agent/session.rs Normal file
View File

@ -0,0 +1,535 @@
use std::time::SystemTime;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use uuid::Uuid;
use crate::error::{AiError, AiResult};
/// Current session file format version.
pub const SESSION_VERSION: u32 = 2;
/// Session metadata header.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionHeader {
pub version: u32,
pub id: Uuid,
pub created_at: String,
pub parent_session: Option<Uuid>,
pub name: Option<String>,
}
impl SessionHeader {
pub fn new() -> Self {
Self {
version: SESSION_VERSION,
id: Uuid::new_v4(),
created_at: iso_now(),
parent_session: None,
name: None,
}
}
pub fn with_parent(mut self, parent: Uuid) -> Self {
self.parent_session = Some(parent);
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
}
/// Typed session entry — each entry in a session transcript is one of these variants.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SessionEntry {
/// A user or assistant message.
Message {
id: Uuid,
parent_id: Option<Uuid>,
timestamp: String,
role: SessionMessageRole,
content: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<SessionToolCall>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
tool_result: Option<SessionToolResult>,
},
/// A context compaction event (older messages summarized).
Compaction {
id: Uuid,
parent_id: Option<Uuid>,
timestamp: String,
summary: String,
first_kept_entry_id: Uuid,
messages_compacted: usize,
tokens_saved: i64,
#[serde(default, skip_serializing_if = "Option::is_none")]
details: Option<Value>,
},
/// A branch summary (created when forking from a different point in the tree).
BranchSummary {
id: Uuid,
parent_id: Option<Uuid>,
timestamp: String,
from_entry_id: Uuid,
summary: String,
entries_summarized: usize,
#[serde(default, skip_serializing_if = "Option::is_none")]
label: Option<String>,
},
/// Model change during a session.
ModelChange {
id: Uuid,
parent_id: Option<Uuid>,
timestamp: String,
provider: String,
model_id: String,
},
/// Thinking level change during a session.
ThinkingLevelChange {
id: Uuid,
parent_id: Option<Uuid>,
timestamp: String,
level: String,
},
/// Custom extension data (not sent to LLM).
Custom {
id: Uuid,
parent_id: Option<Uuid>,
timestamp: String,
custom_type: String,
data: Option<Value>,
},
}
impl SessionEntry {
pub fn id(&self) -> Uuid {
match self {
Self::Message { id, .. }
| Self::Compaction { id, .. }
| Self::BranchSummary { id, .. }
| Self::ModelChange { id, .. }
| Self::ThinkingLevelChange { id, .. }
| Self::Custom { id, .. } => *id,
}
}
pub fn parent_id(&self) -> Option<Uuid> {
match self {
Self::Message { parent_id, .. }
| Self::Compaction { parent_id, .. }
| Self::BranchSummary { parent_id, .. }
| Self::ModelChange { parent_id, .. }
| Self::ThinkingLevelChange { parent_id, .. }
| Self::Custom { parent_id, .. } => *parent_id,
}
}
pub fn timestamp(&self) -> &str {
match self {
Self::Message { timestamp, .. }
| Self::Compaction { timestamp, .. }
| Self::BranchSummary { timestamp, .. }
| Self::ModelChange { timestamp, .. }
| Self::ThinkingLevelChange { timestamp, .. }
| Self::Custom { timestamp, .. } => timestamp,
}
}
/// Create a user message entry.
pub fn user_message(parent_id: Option<Uuid>, content: impl Into<String>) -> Self {
Self::Message {
id: Uuid::new_v4(),
parent_id,
timestamp: iso_now(),
role: SessionMessageRole::User,
content: content.into(),
tool_calls: None,
tool_result: None,
}
}
/// Create an assistant message entry.
pub fn assistant_message(
parent_id: Option<Uuid>,
content: impl Into<String>,
tool_calls: Option<Vec<SessionToolCall>>,
) -> Self {
Self::Message {
id: Uuid::new_v4(),
parent_id,
timestamp: iso_now(),
role: SessionMessageRole::Assistant,
content: content.into(),
tool_calls,
tool_result: None,
}
}
/// Create a compaction entry.
pub fn compaction(
parent_id: Option<Uuid>,
summary: impl Into<String>,
first_kept_entry_id: Uuid,
messages_compacted: usize,
tokens_saved: i64,
) -> Self {
Self::Compaction {
id: Uuid::new_v4(),
parent_id,
timestamp: iso_now(),
summary: summary.into(),
first_kept_entry_id,
messages_compacted,
tokens_saved,
details: None,
}
}
/// Create a branch summary entry.
pub fn branch_summary(
parent_id: Option<Uuid>,
from_entry_id: Uuid,
summary: impl Into<String>,
entries_summarized: usize,
label: Option<String>,
) -> Self {
Self::BranchSummary {
id: Uuid::new_v4(),
parent_id,
timestamp: iso_now(),
from_entry_id,
summary: summary.into(),
entries_summarized,
label,
}
}
/// Create a model change entry.
pub fn model_change(
parent_id: Option<Uuid>,
provider: impl Into<String>,
model_id: impl Into<String>,
) -> Self {
Self::ModelChange {
id: Uuid::new_v4(),
parent_id,
timestamp: iso_now(),
provider: provider.into(),
model_id: model_id.into(),
}
}
/// Create a custom extension entry.
pub fn custom(
parent_id: Option<Uuid>,
custom_type: impl Into<String>,
data: Option<Value>,
) -> Self {
Self::Custom {
id: Uuid::new_v4(),
parent_id,
timestamp: iso_now(),
custom_type: custom_type.into(),
data,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SessionMessageRole {
User,
Assistant,
ToolResult,
System,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionToolResult {
pub tool_call_id: String,
pub tool_name: String,
pub content: String,
pub is_error: bool,
}
/// A full session: header + ordered list of entries forming a tree.
///
/// The tree structure supports forking: entries share `parent_id` links,
/// and the "active branch" is determined by following from the leaf
/// back to the root.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub header: SessionHeader,
pub entries: Vec<SessionEntry>,
}
impl Session {
pub fn new() -> Self {
Self {
header: SessionHeader::new(),
entries: Vec::new(),
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.header = self.header.with_name(name);
self
}
/// Append an entry to the session.
pub fn push(&mut self, entry: SessionEntry) {
self.entries.push(entry);
}
/// Get the last entry's id (used as parent_id for the next entry).
pub fn last_entry_id(&self) -> Option<Uuid> {
self.entries.last().map(|e| e.id())
}
/// Get all entries on the active branch (from root to leaf).
pub fn active_branch(&self) -> Vec<&SessionEntry> {
if self.entries.is_empty() {
return Vec::new();
}
let mut branch = Vec::new();
let mut current_id = Some(self.entries.last().unwrap().id());
while let Some(id) = current_id {
if let Some(entry) = self.entries.iter().find(|e| e.id() == id) {
branch.push(entry);
current_id = entry.parent_id();
} else {
break;
}
}
branch.reverse();
branch
}
/// Get all message entries on the active branch (for LLM context).
pub fn active_messages(&self) -> Vec<&SessionEntry> {
self.active_branch()
.into_iter()
.filter(|e| matches!(e, SessionEntry::Message { .. } | SessionEntry::Compaction { .. }))
.collect()
}
/// Find all children of a given entry (for tree navigation).
pub fn children_of(&self, parent_id: Uuid) -> Vec<&SessionEntry> {
self.entries
.iter()
.filter(|e| e.parent_id() == Some(parent_id))
.collect()
}
/// Get all leaf entries (entries with no children).
pub fn leaves(&self) -> Vec<&SessionEntry> {
let parent_ids: std::collections::HashSet<Uuid> = self
.entries
.iter()
.filter_map(|e| e.parent_id())
.collect();
self.entries
.iter()
.filter(|e| !parent_ids.contains(&e.id()))
.collect()
}
/// Count total entries.
pub fn entry_count(&self) -> usize {
self.entries.len()
}
/// Fork from a specific entry, creating entries that belong to a new branch.
/// Returns the entries that should be in the new branch (from root to fork point).
pub fn fork_from(&self, fork_entry_id: Uuid) -> AiResult<Session> {
let fork_idx = self
.entries
.iter()
.position(|e| e.id() == fork_entry_id)
.ok_or_else(|| {
AiError::Config(format!("fork entry {fork_entry_id} not found in session"))
})?;
let mut new_session = Session::new();
new_session.header = new_session.header.with_parent(self.header.id);
// Copy entries up to and including the fork point
for entry in &self.entries[..=fork_idx] {
new_session.entries.push(entry.clone());
}
Ok(new_session)
}
/// Find the common ancestor of two entries.
pub fn common_ancestor(&self, id_a: Uuid, id_b: Uuid) -> Option<Uuid> {
let ancestors_a = self.ancestor_chain(id_a);
let ancestors_b: std::collections::HashSet<Uuid> =
self.ancestor_chain(id_b).into_iter().collect();
for ancestor in ancestors_a {
if ancestors_b.contains(&ancestor) {
return Some(ancestor);
}
}
None
}
/// Get the chain of ancestor IDs from an entry back to the root.
fn ancestor_chain(&self, entry_id: Uuid) -> Vec<Uuid> {
let mut chain = Vec::new();
let mut current_id = Some(entry_id);
while let Some(id) = current_id {
chain.push(id);
current_id = self
.entries
.iter()
.find(|e| e.id() == id)
.and_then(|e| e.parent_id());
}
chain
}
}
/// Options for session compaction.
#[derive(Debug, Clone)]
pub struct CompactionOptions {
/// Custom instructions for the compaction LLM call.
pub custom_instructions: Option<String>,
/// Reserve this many tokens for the prompt + LLM response.
pub reserve_tokens: i64,
/// Keep this many recent message pairs untouched.
pub keep_recent_pairs: usize,
/// Whether to generate branch summaries for forked branches.
pub branch_summarization: bool,
}
impl Default for CompactionOptions {
fn default() -> Self {
Self {
custom_instructions: None,
reserve_tokens: 16_384,
keep_recent_pairs: 4,
branch_summarization: true,
}
}
}
fn iso_now() -> String {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| {
let secs = d.as_secs();
// Simple ISO 8601 format (UTC)
let days = secs / 86400;
let years = (days * 400) / 146097;
let remaining_days = days - (years * 365 + years / 4 - years / 100 + years / 400);
let month_days = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
let is_leap = (years % 4 == 0 && years % 100 != 0) || years % 400 == 0;
let mut month = 0usize;
let mut day_acc = remaining_days as i64;
for (i, &md) in month_days.iter().enumerate() {
let md = if i == 1 && is_leap { md + 1 } else { md };
if day_acc < md as i64 {
month = i;
break;
}
day_acc -= md as i64;
}
let day = day_acc + 1;
let hour = (secs % 86400) / 3600;
let minute = (secs % 3600) / 60;
let second = secs % 60;
format!(
"{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
1970 + years,
month + 1,
day,
hour,
minute,
second,
)
})
.unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_basic() {
let mut session = Session::new();
let msg1 = SessionEntry::user_message(None, "Hello");
let msg1_id = msg1.id();
session.push(msg1);
let msg2 = SessionEntry::assistant_message(Some(msg1_id), "Hi there!", None);
session.push(msg2);
assert_eq!(session.entry_count(), 2);
assert_eq!(session.active_branch().len(), 2);
}
#[test]
fn test_session_fork() {
let mut session = Session::new();
let msg1 = SessionEntry::user_message(None, "First");
let msg1_id = msg1.id();
session.push(msg1);
let msg2 = SessionEntry::assistant_message(Some(msg1_id), "Reply 1", None);
let msg2_id = msg2.id();
session.push(msg2);
let msg3 = SessionEntry::user_message(Some(msg2_id), "Second");
session.push(msg3);
// Fork from msg2
let forked = session.fork_from(msg2_id).unwrap();
assert_eq!(forked.entry_count(), 2);
assert_eq!(forked.header.parent_session, Some(session.header.id));
}
#[test]
fn test_session_leaves() {
let mut session = Session::new();
let msg1 = SessionEntry::user_message(None, "Root");
let msg1_id = msg1.id();
session.push(msg1);
// Two children branching from root
let msg2a = SessionEntry::assistant_message(Some(msg1_id), "Branch A", None);
let msg2b = SessionEntry::assistant_message(Some(msg1_id), "Branch B", None);
session.push(msg2a);
session.push(msg2b);
let leaves = session.leaves();
assert_eq!(leaves.len(), 2);
}
}

203
lib/ai/agent/subagent.rs Normal file
View File

@ -0,0 +1,203 @@
use rig::agent::AgentBuilder;
use rig::client::CompletionClient;
use rig::completion::Prompt;
use tracing::{debug, warn};
use super::config::AgentConfig;
use super::helpers::with_retry;
use super::persistence::{
ActiveAgentRun, AgentRealtime, AgentRuntime, AgentStreamEvent,
estimate_output_tokens,
};
use super::request::{AgentExpert, AgentExpertOutput};
use crate::client::AiClient;
use crate::error::{AiError, AiResult};
pub async fn run_experts(
client: &AiClient,
config: &AgentConfig,
experts: &[AgentExpert],
realtime: Option<&AgentRealtime>,
run: &ActiveAgentRun,
) -> AiResult<Vec<AgentExpertOutput>> {
let mut outputs = Vec::with_capacity(experts.len());
let mut failed_count = 0;
for expert in experts {
match run_single(client, config, expert, realtime, run).await {
Ok(output) => {
debug!(subagent_id = %output.id, role = %output.role, "subagent completed");
outputs.push(output);
}
Err(error) => {
warn!(subagent_id = %expert.id, role = %expert.role, error = %error, "subagent failed");
let _ = publish_subagent_failed(realtime, run, expert, &error.to_string()).await;
failed_count += 1;
}
}
}
debug!(total = experts.len(), ok = outputs.len(), failed = failed_count, "experts done");
Ok(outputs)
}
async fn run_single(
client: &AiClient,
config: &AgentConfig,
expert: &AgentExpert,
realtime: Option<&AgentRealtime>,
run: &ActiveAgentRun,
) -> AiResult<AgentExpertOutput> {
publish_subagent_started(realtime, run, config, expert).await?;
let rig_client = client.llm_client().clone();
let model_name = config.model.clone();
let temperature = expert.temperature.or(config.temperature);
let max_completion_tokens = expert.max_completion_tokens.or(config.max_completion_tokens);
let retry_attempts = config.retry_max_attempts;
let retry_delay_ms = config.retry_base_delay_ms;
let prompt = expert.system_prompt.clone().unwrap_or_else(|| {
format!(
"You are a specialist subagent. Role: {}. Produce a concise expert answer for the parent chat agent.",
expert.role
)
});
let task = build_expert_task(expert);
let (output, input_tokens_usage, output_tokens_usage) = with_retry(
retry_attempts,
retry_delay_ms,
|| {
let rig_client = rig_client.clone();
let model_name = model_name.clone();
let prompt = prompt.clone();
let task = task.clone();
async move {
let model = rig_client.completion_model(&model_name);
let mut builder = AgentBuilder::new(model).preamble(&prompt);
if let Some(temp) = temperature {
builder = builder.temperature(temp);
}
if let Some(mt) = max_completion_tokens {
builder = builder.max_tokens(mt);
}
let agent = builder.build();
let response = agent
.prompt(&task)
.extended_details()
.await
.map_err(|e: rig::completion::PromptError| {
AiError::Api(e.to_string())
})?;
Ok((
response.output,
response.usage.input_tokens,
response.usage.output_tokens,
))
}
},
)
.await?;
let input_tokens = input_tokens_usage as i64;
let output_tokens = if output_tokens_usage > 0 {
output_tokens_usage as i64
} else {
estimate_output_tokens(&output)
};
let result = AgentExpertOutput {
id: expert.id.clone(),
role: expert.role.clone(),
task: expert.task.clone(),
output,
input_tokens,
output_tokens,
};
publish_subagent_completed(realtime, run, config, &result).await?;
Ok(result)
}
fn build_expert_task(expert: &AgentExpert) -> String {
let mut task = String::new();
if !expert.context.is_empty() {
task.push_str("Retrieved context for this specialist task:\n");
for (index, chunk) in expert.context.iter().enumerate() {
task.push_str(&format!(
"\n[{}] id={} source={}\n{}\n",
index + 1,
chunk.id,
chunk.source.as_deref().unwrap_or("unknown"),
chunk.content
));
}
task.push('\n');
}
task.push_str(&expert.task);
task
}
async fn publish_subagent_started(
realtime: Option<&AgentRealtime>,
run: &ActiveAgentRun,
config: &AgentConfig,
expert: &AgentExpert,
) -> AiResult<()> {
AgentRuntime::default().publish(
realtime,
&AgentStreamEvent::SubagentStarted {
conversation_id: run.conversation_id,
message_id: run.message_id,
subagent_id: expert.id.clone(),
role: expert.role.clone(),
task: expert.task.clone(),
model: config.model.clone(),
},
).await
}
async fn publish_subagent_completed(
realtime: Option<&AgentRealtime>,
run: &ActiveAgentRun,
config: &AgentConfig,
output: &AgentExpertOutput,
) -> AiResult<()> {
AgentRuntime::default().publish(
realtime,
&AgentStreamEvent::SubagentCompleted {
conversation_id: run.conversation_id,
message_id: run.message_id,
subagent_id: output.id.clone(),
role: output.role.clone(),
task: output.task.clone(),
output: output.output.clone(),
input_tokens: output.input_tokens,
output_tokens: output.output_tokens,
model: config.model.clone(),
},
).await
}
async fn publish_subagent_failed(
realtime: Option<&AgentRealtime>,
run: &ActiveAgentRun,
expert: &AgentExpert,
error: &str,
) -> AiResult<()> {
AgentRuntime::default().publish(
realtime,
&AgentStreamEvent::SubagentFailed {
conversation_id: run.conversation_id,
message_id: run.message_id,
subagent_id: expert.id.clone(),
error: error.to_string(),
},
).await
}

158
lib/ai/agent/tool.rs Normal file
View File

@ -0,0 +1,158 @@
use std::pin::Pin;
use std::sync::Arc;
use rig::completion::ToolDefinition as RigToolDefinition;
use rig::tool::ToolDyn;
use serde_json::Value;
use tokio::sync::Mutex;
use crate::tool::tools::FunctionCall;
pub struct RigTool<C>
where
C: Clone + Send + Sync + 'static,
{
context: Arc<Mutex<C>>,
tool: Arc<dyn FunctionCall<Context = C>>,
name: String,
description: String,
schema: Value,
}
impl<C> RigTool<C>
where
C: Clone + Send + Sync + 'static,
{
pub fn new(tool: Arc<dyn FunctionCall<Context = C>>, context: Arc<Mutex<C>>) -> Self {
let name = tool.name().to_string();
let description = tool.description().to_string();
let schema = tool.schema();
Self {
context,
tool,
name,
description,
schema,
}
}
}
impl<C> ToolDyn for RigTool<C>
where
C: Clone + Send + Sync + 'static,
{
fn name(&self) -> String {
self.name.clone()
}
fn definition<'a>(
&'a self,
_prompt: String,
) -> Pin<Box<dyn std::future::Future<Output = RigToolDefinition> + Send + 'a>> {
let name = self.name.clone();
let description = self.description.clone();
let params = self.schema.clone();
Box::pin(async move {
RigToolDefinition {
name,
description,
parameters: params,
}
})
}
fn call<'a>(
&'a self,
args: String,
) -> Pin<
Box<dyn std::future::Future<Output = Result<String, rig::tool::ToolError>> + Send + 'a>,
> {
let tool = self.tool.clone();
let context = self.context.clone();
Box::pin(async move {
let args_value: Value =
serde_json::from_str(&args).map_err(rig::tool::ToolError::JsonError)?;
let mut ctx = context.lock().await;
match tool.call(&mut *ctx, args_value).await {
Ok(value) => serde_json::to_string(&value)
.map_err(rig::tool::ToolError::JsonError),
Err(ai_err) => Err(rig::tool::ToolError::ToolCallError(Box::new(
std::io::Error::other(ai_err.to_string()),
))),
}
})
}
}
pub struct RigToolSet<C>
where
C: Clone + Send + Sync + 'static,
{
tools: Vec<Box<dyn ToolDyn + 'static>>,
context: Option<Arc<Mutex<C>>>,
}
impl<C> RigToolSet<C>
where
C: Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
tools: Vec::new(),
context: None,
}
}
pub fn from_register(
register: &crate::tool::register::ToolRegister<C>,
context: Arc<Mutex<C>>,
) -> Self {
let mut tools: Vec<Box<dyn ToolDyn + 'static>> = Vec::with_capacity(register.len());
for tool_arc in &register.tools {
tools.push(Box::new(RigTool::new(tool_arc.clone(), context.clone())));
}
Self {
tools,
context: Some(context),
}
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn context(&self) -> Option<&Arc<Mutex<C>>> {
self.context.as_ref()
}
pub fn take_tools(&mut self) -> Vec<Box<dyn ToolDyn + 'static>> {
std::mem::take(&mut self.tools)
}
pub fn into_context(mut self) -> C {
self.context
.take()
.and_then(|arc| Arc::try_unwrap(arc).ok().map(|m| m.into_inner()))
.unwrap_or_else(|| unreachable!("context must be available"))
}
}
impl<C> Default for RigToolSet<C>
where
C: Clone + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}

219
lib/ai/client.rs Normal file
View File

@ -0,0 +1,219 @@
use std::fmt;
use std::sync::Arc;
use config::AppConfig;
use rig::providers::openai;
use crate::error::{AiError, AiResult};
fn validate_required(scope: &str, field: &str, value: &str) -> AiResult<()> {
if value.trim().is_empty() {
return Err(AiError::Config(format!("{scope} {field} is required")));
}
Ok(())
}
fn config_error(error: impl fmt::Display) -> AiError {
AiError::Config(error.to_string())
}
#[derive(Clone)]
pub struct EndpointConfig {
pub base_url: String,
pub api_key: String,
}
impl EndpointConfig {
pub fn new(base_url: impl Into<String>, api_key: impl Into<String>) -> AiResult<Self> {
let config = Self {
base_url: base_url.into(),
api_key: api_key.into(),
};
config.validate("endpoint")?;
Ok(config)
}
fn validate(&self, scope: &str) -> AiResult<()> {
validate_required(scope, "base_url", &self.base_url)?;
validate_required(scope, "api_key", &self.api_key)?;
if !self.base_url.trim().starts_with("http://")
&& !self.base_url.trim().starts_with("https://")
{
return Err(AiError::Config(format!(
"{scope} base_url must start with http:// or https://"
)));
}
Ok(())
}
pub fn build_client(&self) -> AiResult<openai::Client> {
openai::Client::builder()
.api_key(&self.api_key)
.base_url(self.base_url.trim())
.build()
.map_err(|e| AiError::Config(format!("failed to build rig OpenAI client: {e}")))
}
}
impl fmt::Debug for EndpointConfig {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("EndpointConfig")
.field("base_url", &self.base_url)
.field("api_key", &"<redacted>")
.finish()
}
}
#[derive(Clone, Debug)]
pub struct EmbedConfig {
pub endpoint: EndpointConfig,
pub model: String,
pub dimensions: u64,
}
impl EmbedConfig {
pub fn new(
endpoint: EndpointConfig,
model: impl Into<String>,
dimensions: u64,
) -> AiResult<Self> {
let config = Self {
endpoint,
model: model.into(),
dimensions,
};
config.validate()?;
Ok(config)
}
fn validate(&self) -> AiResult<()> {
self.endpoint.validate("embed endpoint")?;
validate_required("embed", "model", &self.model)?;
if self.dimensions == 0 {
return Err(AiError::Config(
"embed dimensions must be greater than 0".to_string(),
));
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct AiClientConfig {
pub llm: EndpointConfig,
pub embed: EmbedConfig,
}
impl AiClientConfig {
pub fn new(llm: EndpointConfig, embed: EmbedConfig) -> AiResult<Self> {
let config = Self { llm, embed };
config.validate()?;
Ok(config)
}
pub fn validate(&self) -> AiResult<()> {
self.llm.validate("llm endpoint")?;
self.embed.validate()?;
Ok(())
}
}
impl TryFrom<&AppConfig> for AiClientConfig {
type Error = AiError;
fn try_from(config: &AppConfig) -> Result<Self, Self::Error> {
let llm = EndpointConfig::new(
config.ai_basic_url().map_err(config_error)?,
config.ai_api_key().map_err(config_error)?,
)?;
let embed_endpoint = EndpointConfig::new(
config.get_embed_model_base_url().map_err(config_error)?,
config.get_embed_model_api_key().map_err(config_error)?,
)?;
let embed = EmbedConfig::new(
embed_endpoint,
config.get_embed_model_name().map_err(config_error)?,
config.get_embed_model_dimensions().map_err(config_error)?,
)?;
Self::new(llm, embed)
}
}
#[derive(Clone, Debug)]
pub struct AiClient {
pub(super) llm_client: openai::Client,
pub(super) embed_client: openai::Client,
pub(super) config: Arc<AiClientConfig>,
}
impl AiClient {
pub fn new(config: AiClientConfig) -> AiResult<Self> {
config.validate()?;
Ok(Self {
llm_client: config.llm.build_client()?,
embed_client: config.embed.endpoint.build_client()?,
config: Arc::new(config),
})
}
pub fn from_app_config(config: &AppConfig) -> AiResult<Self> {
Self::new(AiClientConfig::try_from(config)?)
}
pub fn llm_client(&self) -> &openai::Client {
&self.llm_client
}
pub fn embed_client(&self) -> &openai::Client {
&self.embed_client
}
pub fn config(&self) -> &AiClientConfig {
&self.config
}
pub fn llm_config(&self) -> &EndpointConfig {
&self.config.llm
}
pub fn embed_config(&self) -> &EmbedConfig {
&self.config.embed
}
pub fn embed_model(&self) -> &str {
self.config.embed.model.as_str()
}
pub fn embed_dimensions(&self) -> u64 {
self.config.embed.dimensions
}
pub fn embed_dimensions_u32(&self) -> u32 {
u32::try_from(self.config.embed.dimensions).unwrap_or(u32::MAX)
}
}
pub fn build_http_client() -> Result<reqwest::Client, AiError> {
let mut builder = reqwest::Client::builder();
if let Ok(proxy_url) = std::env::var("HTTPS_PROXY")
.or_else(|_| std::env::var("https_proxy"))
.or_else(|_| std::env::var("HTTP_PROXY"))
.or_else(|_| std::env::var("http_proxy"))
{
let proxy_url = proxy_url.trim().trim_matches('"').trim_matches('\'');
let proxy = reqwest::Proxy::all(proxy_url).map_err(|e| {
AiError::Config(format!("Invalid proxy URL '{}': {}", proxy_url, e))
})?;
builder = builder.proxy(proxy);
}
builder.build().map_err(|e| {
AiError::Config(format!("Failed to build HTTP client: {}", e))
})
}

76
lib/ai/embed/client.rs Normal file
View File

@ -0,0 +1,76 @@
use rig::client::EmbeddingsClient;
use rig::embeddings::EmbeddingModel;
use crate::{client::AiClient, error::{AiError, AiResult}};
#[derive(Clone)]
pub struct EmbedClient {
model_name: String,
client: rig::providers::openai::Client,
}
impl EmbedClient {
pub fn new(ai_client: &AiClient) -> AiResult<Self> {
Ok(Self {
model_name: ai_client.embed_model().to_string(),
client: ai_client.embed_client().clone(),
})
}
fn embedding_model(&self) -> impl EmbeddingModel + '_ {
self.client.embedding_model(&self.model_name)
}
pub async fn embed_text(&self, text: String) -> AiResult<Vec<f32>> {
let model = self.embedding_model();
let mut embeddings = model.embed_texts(vec![text])
.await
.map_err(|e| AiError::Api(e.to_string()))?;
embeddings.pop()
.map(|e| e.vec.into_iter().map(|v| v as f32).collect())
.ok_or_else(|| AiError::Response("no embedding returned".to_string()))
}
pub async fn embed_texts(&self, texts: Vec<String>) -> AiResult<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let model = self.embedding_model();
let embeddings = model.embed_texts(texts)
.await
.map_err(|e| AiError::Api(e.to_string()))?;
Ok(embeddings.into_iter()
.map(|e| e.vec.into_iter().map(|v| v as f32).collect())
.collect())
}
pub async fn embed_texts_chunked(
&self,
texts: Vec<String>,
batch_size: usize,
) -> AiResult<Vec<Vec<f32>>> {
if batch_size == 0 {
return Err(AiError::Config("batch_size must be > 0".to_string()));
}
let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
for chunk in texts.chunks(batch_size) {
let model = self.embedding_model();
let chunk_embeddings = model.embed_texts(chunk.to_vec())
.await
.map_err(|e| AiError::Api(e.to_string()))?;
embeddings.extend(chunk_embeddings.into_iter()
.map(|e| e.vec.into_iter().map(|v| v as f32).collect()));
}
Ok(embeddings)
}
}
pub trait AiClientEmbedExt {
fn embedder(&self) -> AiResult<EmbedClient>;
}
impl AiClientEmbedExt for AiClient {
fn embedder(&self) -> AiResult<EmbedClient> {
EmbedClient::new(self)
}
}

3
lib/ai/embed/mod.rs Normal file
View File

@ -0,0 +1,3 @@
mod client;
pub use client::{AiClientEmbedExt, EmbedClient};

52
lib/ai/error.rs Normal file
View File

@ -0,0 +1,52 @@
pub type AiResult<T> = Result<T, AiError>;
#[derive(Debug, thiserror::Error)]
pub enum AiError {
#[error("ai config error: {0}")]
Config(String),
#[error("ai api error: {0}")]
Api(String),
#[error("qdrant error: {0}")]
Qdrant(Box<qdrant_client::QdrantError>),
#[error("database error: {0}")]
Database(#[from] db::sqlx::Error),
#[error("cache error: {0}")]
Cache(#[from] cache::CacheError),
#[error("redis error: {0}")]
Redis(#[from] redis::RedisError),
#[error("ai response error: {0}")]
Response(String),
#[error("model retries exhausted after {attempts} attempts: {last_error}")]
ModelRetriesExhausted {
attempts: usize,
last_error: String,
},
#[error("agent timeout after {seconds}s")]
Timeout { seconds: u64 },
#[error("tool not found: {tool}")]
ToolNotFound { tool: String },
#[error("tool execution failed: {cause}")]
ToolExecutionFailed { cause: String },
#[error("invalid input in '{field}': {reason}")]
InvalidInput { field: String, reason: String },
#[error("token budget exceeded: used ~{estimated} tokens, limit {limit}")]
TokenBudgetExceeded { estimated: u64, limit: i64 },
}
impl From<qdrant_client::QdrantError> for AiError {
fn from(e: qdrant_client::QdrantError) -> Self {
AiError::Qdrant(Box::new(e))
}
}

8
lib/ai/lib.rs Normal file
View File

@ -0,0 +1,8 @@
pub mod agent;
pub mod client;
pub mod embed;
pub mod error;
pub mod memory;
pub mod rag;
pub mod sync;
pub mod tool;

46
lib/ai/memory/mod.rs Normal file
View File

@ -0,0 +1,46 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::AiResult;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub key: String,
pub value: String,
pub importance: i32,
pub last_used_at: Option<String>,
}
#[async_trait]
pub trait MemoryProvider: Send + Sync {
fn name(&self) -> &'static str;
async fn save(
&self,
session_id: Uuid,
key: &str,
value: &str,
importance: i32,
) -> AiResult<()>;
async fn recall(
&self,
session_id: Uuid,
query: &str,
limit: usize,
) -> AiResult<Vec<MemoryEntry>>;
async fn forget(&self, session_id: Uuid, key: &str) -> AiResult<()>;
async fn prefetch(
&self,
_session_id: Uuid,
_query: &str,
) -> AiResult<Vec<MemoryEntry>> {
Ok(Vec::new())
}
async fn build_context_block(
&self,
_session_id: Uuid,
) -> AiResult<String> {
Ok(String::new())
}
async fn setup(&self) -> AiResult<()> {
Ok(())
}
}

263
lib/ai/rag/client.rs Normal file
View File

@ -0,0 +1,263 @@
use config::AppConfig;
use qdrant_client::qdrant::{
CreateCollectionBuilder, CreateFieldIndexCollectionBuilder,
DeletePointsBuilder, FieldType, PointStruct, QueryPointsBuilder,
SearchParamsBuilder, UpsertPointsBuilder, VectorParamsBuilder,
};
use qdrant_client::{Qdrant, QdrantError};
use super::{
config::RagConfig,
document::{RagDocument, RagSearchHit},
payload::{
SESSION_ID_KEY, document_payload, hit_from_scored_point, point_id,
},
search::RagSearchOptions,
session::{session_filter, validate_session_id},
};
use crate::{
client::AiClient,
embed::{AiClientEmbedExt, EmbedClient},
error::{AiError, AiResult},
};
#[derive(Clone)]
pub struct RagClient {
qdrant: Qdrant,
embedder: EmbedClient,
config: RagConfig,
}
impl RagClient {
pub fn new(
qdrant: Qdrant,
embedder: EmbedClient,
config: RagConfig,
) -> AiResult<Self> {
config.validate()?;
Ok(Self {
qdrant,
embedder,
config,
})
}
pub fn connect(
ai_client: &AiClient,
config: RagConfig,
) -> AiResult<Self> {
config.validate()?;
let mut builder =
Qdrant::from_url(config.url.trim()).timeout(config.timeout);
if let Some(api_key) = config
.api_key
.as_deref()
.filter(|api_key| !api_key.trim().is_empty())
{
builder = builder.api_key(api_key);
}
Self::new(builder.build()?, ai_client.embedder()?, config)
}
pub fn from_app_config(
ai_client: &AiClient,
config: &AppConfig,
collection_name: impl Into<String>,
) -> AiResult<Self> {
Self::connect(
ai_client,
RagConfig::from_app_config(config, collection_name)?,
)
}
pub fn qdrant(&self) -> &Qdrant {
&self.qdrant
}
pub fn embedder(&self) -> &EmbedClient {
&self.embedder
}
pub fn config(&self) -> &RagConfig {
&self.config
}
pub async fn ensure_collection(&self) -> AiResult<()> {
if !self
.qdrant
.collection_exists(&self.config.collection_name)
.await?
{
self.qdrant
.create_collection(
CreateCollectionBuilder::new(&self.config.collection_name)
.vectors_config(VectorParamsBuilder::new(
self.config.vector_size,
self.config.distance,
)),
)
.await?;
}
match self
.qdrant
.create_field_index(CreateFieldIndexCollectionBuilder::new(
&self.config.collection_name,
SESSION_ID_KEY,
FieldType::Keyword,
))
.await
{
Ok(_) => Ok(()),
Err(QdrantError::ResponseError { .. }) => Ok(()),
Err(error) => Err(error.into()),
}
}
pub async fn upsert_document(
&self,
session_id: impl AsRef<str>,
document: RagDocument,
) -> AiResult<()> {
self.upsert_documents(session_id, vec![document]).await
}
pub async fn upsert_documents(
&self,
session_id: impl AsRef<str>,
documents: Vec<RagDocument>,
) -> AiResult<()> {
let session_id = session_id.as_ref();
validate_session_id(session_id)?;
validate_documents(&documents)?;
let texts: Vec<String> = documents
.iter()
.map(|d| d.content.clone())
.collect();
let vectors = self
.embedder
.embed_texts_chunked(texts, self.config.upsert_batch_size)
.await?;
let points = documents
.iter()
.zip(vectors)
.map(|(document, vector)| {
Ok(PointStruct::new(
point_id(session_id, &document.id),
vector,
document_payload(session_id, document)?,
))
})
.collect::<AiResult<Vec<_>>>()?;
self.qdrant
.upsert_points(
UpsertPointsBuilder::new(&self.config.collection_name, points)
.wait(true),
)
.await?;
Ok(())
}
pub async fn search_session(
&self,
session_id: impl AsRef<str>,
query: impl Into<String>,
) -> AiResult<Vec<RagSearchHit>> {
let options = RagSearchOptions {
limit: self.config.default_search_limit,
exact: self.config.exact_session_search,
};
self.search_session_with_options(session_id, query, options)
.await
}
pub async fn search_session_with_options(
&self,
session_id: impl AsRef<str>,
query: impl Into<String>,
options: RagSearchOptions,
) -> AiResult<Vec<RagSearchHit>> {
let vector = self.embedder.embed_text(query.into()).await?;
self.search_session_by_vector(session_id, vector, options)
.await
}
pub async fn search_session_by_vector(
&self,
session_id: impl AsRef<str>,
vector: Vec<f32>,
options: RagSearchOptions,
) -> AiResult<Vec<RagSearchHit>> {
let session_id = session_id.as_ref();
validate_session_id(session_id)?;
if options.limit == 0 {
return Err(AiError::Config(
"rag search limit must be greater than 0".to_string(),
));
}
let response = self
.qdrant
.query(
QueryPointsBuilder::new(&self.config.collection_name)
.query(vector)
.limit(options.limit)
.filter(session_filter(session_id))
.with_payload(true)
.params(
SearchParamsBuilder::default().exact(options.exact),
),
)
.await?;
Ok(response
.result
.into_iter()
.map(hit_from_scored_point)
.collect())
}
pub async fn clear_session(
&self,
session_id: impl AsRef<str>,
) -> AiResult<()> {
let session_id = session_id.as_ref();
validate_session_id(session_id)?;
self.qdrant
.delete_points(
DeletePointsBuilder::new(&self.config.collection_name)
.points(session_filter(session_id))
.wait(true),
)
.await?;
Ok(())
}
}
fn validate_documents(documents: &[RagDocument]) -> AiResult<()> {
if documents.is_empty() {
return Err(AiError::Config("rag documents are required".to_string()));
}
for document in documents {
if document.id.trim().is_empty() {
return Err(AiError::Config(
"rag document id is required".to_string(),
));
}
if document.content.trim().is_empty() {
return Err(AiError::Config(
"rag document content is required".to_string(),
));
}
}
Ok(())
}

134
lib/ai/rag/config.rs Normal file
View File

@ -0,0 +1,134 @@
use std::time::Duration;
use config::AppConfig;
use qdrant_client::qdrant::Distance;
use crate::error::{AiError, AiResult};
#[derive(Clone, Debug)]
pub struct RagConfig {
pub url: String,
pub api_key: Option<String>,
pub collection_name: String,
pub vector_size: u64,
pub distance: Distance,
pub timeout: Duration,
pub upsert_batch_size: usize,
pub default_search_limit: u64,
pub exact_session_search: bool,
}
impl RagConfig {
pub fn new(
url: impl Into<String>,
collection_name: impl Into<String>,
vector_size: u64,
) -> AiResult<Self> {
let config = Self {
url: url.into(),
api_key: None,
collection_name: collection_name.into(),
vector_size,
distance: Distance::Cosine,
timeout: Duration::from_secs(10),
upsert_batch_size: 64,
default_search_limit: 8,
exact_session_search: true,
};
config.validate()?;
Ok(config)
}
pub fn with_api_key(mut self, api_key: Option<String>) -> Self {
self.api_key = api_key;
self
}
pub fn with_distance(mut self, distance: Distance) -> Self {
self.distance = distance;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_upsert_batch_size(mut self, upsert_batch_size: usize) -> Self {
self.upsert_batch_size = upsert_batch_size;
self
}
pub fn with_default_search_limit(
mut self,
default_search_limit: u64,
) -> Self {
self.default_search_limit = default_search_limit;
self
}
pub fn with_exact_session_search(
mut self,
exact_session_search: bool,
) -> Self {
self.exact_session_search = exact_session_search;
self
}
pub fn validate(&self) -> AiResult<()> {
if self.url.trim().is_empty() {
return Err(AiError::Config("qdrant url is required".to_string()));
}
if !self.url.trim().starts_with("http://")
&& !self.url.trim().starts_with("https://")
{
return Err(AiError::Config(
"qdrant url must start with http:// or https://".to_string(),
));
}
if self.collection_name.trim().is_empty() {
return Err(AiError::Config(
"qdrant collection_name is required".to_string(),
));
}
if self.vector_size == 0 {
return Err(AiError::Config(
"qdrant vector_size must be greater than 0".to_string(),
));
}
if self.upsert_batch_size == 0 {
return Err(AiError::Config(
"qdrant upsert_batch_size must be greater than 0".to_string(),
));
}
if self.default_search_limit == 0 {
return Err(AiError::Config(
"qdrant default_search_limit must be greater than 0"
.to_string(),
));
}
Ok(())
}
}
impl RagConfig {
pub fn from_app_config(
config: &AppConfig,
collection_name: impl Into<String>,
) -> AiResult<Self> {
Ok(Self::new(
config
.qdrant_url()
.map_err(|error| AiError::Config(error.to_string()))?,
collection_name,
config
.get_embed_model_dimensions()
.map_err(|error| AiError::Config(error.to_string()))?,
)?
.with_api_key(
config
.qdrant_api_key()
.map_err(|error| AiError::Config(error.to_string()))?,
))
}
}

44
lib/ai/rag/document.rs Normal file
View File

@ -0,0 +1,44 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RagDocument {
pub id: String,
pub content: String,
pub metadata: HashMap<String, Value>,
}
impl RagDocument {
pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
id: id.into(),
content: content.into(),
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
self.metadata = metadata;
self
}
pub fn metadata_value(
mut self,
key: impl Into<String>,
value: impl Into<Value>,
) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RagSearchHit {
pub id: String,
pub session_id: String,
pub score: f32,
pub content: String,
pub metadata: HashMap<String, Value>,
}

11
lib/ai/rag/mod.rs Normal file
View File

@ -0,0 +1,11 @@
mod client;
mod config;
mod document;
mod payload;
mod search;
mod session;
pub use client::RagClient;
pub use config::RagConfig;
pub use document::{RagDocument, RagSearchHit};
pub use search::RagSearchOptions;

110
lib/ai/rag/payload.rs Normal file
View File

@ -0,0 +1,110 @@
use std::collections::HashMap;
use qdrant_client::Payload;
use qdrant_client::qdrant::{
PointId, ScoredPoint, point_id::PointIdOptions, value::Kind,
};
use serde_json::{Map, Value, json};
use uuid::Uuid;
use super::document::{RagDocument, RagSearchHit};
use crate::error::{AiError, AiResult};
pub(super) const SESSION_ID_KEY: &str = "session_id";
pub(super) const DOCUMENT_ID_KEY: &str = "document_id";
pub(super) const CONTENT_KEY: &str = "content";
pub(super) const METADATA_KEY: &str = "metadata";
pub(super) fn point_id(session_id: &str, document_id: &str) -> u64 {
let ns = Uuid::NAMESPACE_DNS;
let key = format!("{session_id}:{document_id}");
let uuid = Uuid::new_v5(&ns, key.as_bytes());
let bytes = uuid.as_bytes();
u64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3],
bytes[4], bytes[5], bytes[6], bytes[7],
])
}
pub(super) fn document_payload(
session_id: &str,
document: &RagDocument,
) -> AiResult<Payload> {
Payload::try_from(json!({
SESSION_ID_KEY: session_id,
DOCUMENT_ID_KEY: document.id,
CONTENT_KEY: document.content,
METADATA_KEY: document.metadata,
}))
.map_err(|error| AiError::Config(error.to_string()))
}
pub(super) fn hit_from_scored_point(point: ScoredPoint) -> RagSearchHit {
let id = point_id_to_string(point.id);
let mut payload = qdrant_payload_to_json(point.payload);
let session_id = take_string(&mut payload, SESSION_ID_KEY);
let document_id = take_string(&mut payload, DOCUMENT_ID_KEY);
let content = take_string(&mut payload, CONTENT_KEY);
let metadata = payload
.remove(METADATA_KEY)
.and_then(|value| match value {
Value::Object(object) => Some(object.into_iter().collect()),
_ => None,
})
.unwrap_or_default();
RagSearchHit {
id: if document_id.is_empty() {
id
} else {
document_id
},
session_id,
score: point.score,
content,
metadata,
}
}
fn point_id_to_string(id: Option<PointId>) -> String {
match id.and_then(|id| id.point_id_options) {
Some(PointIdOptions::Num(id)) => id.to_string(),
Some(PointIdOptions::Uuid(id)) => id,
None => String::new(),
}
}
fn qdrant_payload_to_json(
payload: HashMap<String, qdrant_client::qdrant::Value>,
) -> Map<String, Value> {
payload
.into_iter()
.map(|(key, value)| (key, value_to_json(value)))
.collect()
}
fn value_to_json(value: qdrant_client::qdrant::Value) -> Value {
match value.kind {
Some(Kind::NullValue(_)) | None => Value::Null,
Some(Kind::DoubleValue(value)) => json!(value),
Some(Kind::IntegerValue(value)) => json!(value),
Some(Kind::StringValue(value)) => json!(value),
Some(Kind::BoolValue(value)) => json!(value),
Some(Kind::StructValue(value)) => Value::Object(
value
.fields
.into_iter()
.map(|(key, value)| (key, value_to_json(value)))
.collect(),
),
Some(Kind::ListValue(value)) => {
Value::Array(value.values.into_iter().map(value_to_json).collect())
}
}
}
fn take_string(payload: &mut Map<String, Value>, key: &str) -> String {
payload
.remove(key)
.and_then(|value| value.as_str().map(ToOwned::to_owned))
.unwrap_or_default()
}

16
lib/ai/rag/search.rs Normal file
View File

@ -0,0 +1,16 @@
#[derive(Clone, Debug)]
pub struct RagSearchOptions {
pub limit: u64,
pub exact: bool,
}
impl RagSearchOptions {
pub fn new(limit: u64) -> Self {
Self { limit, exact: true }
}
pub fn with_exact(mut self, exact: bool) -> Self {
self.exact = exact;
self
}
}

15
lib/ai/rag/session.rs Normal file
View File

@ -0,0 +1,15 @@
use qdrant_client::qdrant::{Condition, Filter};
use super::payload::SESSION_ID_KEY;
use crate::error::{AiError, AiResult};
pub(super) fn validate_session_id(session_id: &str) -> AiResult<()> {
if session_id.trim().is_empty() {
return Err(AiError::Config("rag session_id is required".to_string()));
}
Ok(())
}
pub(super) fn session_filter(session_id: &str) -> Filter {
Filter::all([Condition::matches(SESSION_ID_KEY, session_id.to_string())])
}

126
lib/ai/sync.rs Normal file
View File

@ -0,0 +1,126 @@
use std::error::Error;
use std::sync::LazyLock;
use tracing::{debug, warn};
use crate::{
client::EndpointConfig,
error::{AiError, AiResult},
};
#[derive(Debug, serde::Deserialize)]
struct ModelsListResponse {
data: Vec<UpstreamModel>,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct UpstreamModel {
pub id: String,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub owned_by: Option<String>,
#[serde(default)]
pub context_length: Option<i32>,
#[serde(default)]
pub max_output_tokens: Option<i32>,
#[serde(default)]
pub capabilities: Option<UpstreamCapabilities>,
#[serde(default)]
pub pricing: Option<UpstreamPricing>,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct UpstreamCapabilities {
#[serde(default)]
pub vision: Option<bool>,
#[serde(default)]
pub tool_call: Option<bool>,
#[serde(default)]
pub reasoning: Option<bool>,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct UpstreamPricing {
#[serde(default)]
pub prompt: Option<String>,
#[serde(default)]
pub completion: Option<String>,
#[serde(default)]
pub input: Option<f64>,
#[serde(default)]
pub output: Option<f64>,
#[serde(default)]
pub cache_read: Option<f64>,
#[serde(default)]
pub unit: Option<String>,
#[serde(default)]
pub currency: Option<String>,
}
static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
let mut builder = reqwest::Client::builder();
let proxy_url = std::env::var("HTTPS_PROXY")
.or_else(|_| std::env::var("https_proxy"))
.or_else(|_| std::env::var("HTTP_PROXY"))
.or_else(|_| std::env::var("http_proxy"))
.ok();
if let Some(raw) = &proxy_url {
let url = raw.trim().trim_matches('"').trim_matches('\'');
match reqwest::Proxy::all(url) {
Ok(proxy) => {
debug!(proxy_url = %url, "sync: using proxy");
builder = builder.proxy(proxy);
}
Err(e) => {
warn!(proxy_url = %url, error = %e, "sync: invalid proxy URL, skipping");
}
}
}
#[allow(clippy::expect_used)]
builder.build().expect("failed to build reqwest HTTP client — check system TLS configuration")
});
pub async fn list_models(
config: &EndpointConfig,
) -> AiResult<Vec<UpstreamModel>> {
let base = config.base_url.trim_end_matches('/');
let url = if base.ends_with("/v1") {
format!("{}/models", base)
} else {
format!("{}/v1/models", base)
};
debug!(url = %url, "listing models from upstream");
let resp = HTTP_CLIENT
.get(&url)
.header("Authorization", format!("Bearer {}", config.api_key.trim()))
.send()
.await
.map_err(|e| {
tracing::error!(
error = %e,
source = ?e.source(),
"list_models: request failed with full cause chain"
);
AiError::Response(format!("failed to list models: {}", e))
})?;
let body = resp
.text()
.await
.map_err(|e| AiError::Response(format!("failed to read models body: {}", e)))?;
if let Ok(parsed) = serde_json::from_str::<ModelsListResponse>(&body) {
debug!(count = parsed.data.len(), "parsed models in standard format");
return Ok(parsed.data);
}
if let Ok(parsed) = serde_json::from_str::<Vec<UpstreamModel>>(&body) {
debug!(count = parsed.len(), "parsed models in array format");
return Ok(parsed);
}
warn!(
body = %body.chars().take(500).collect::<String>(),
"list_models: unknown response format"
);
Err(AiError::Response(format!(
"unexpected /v1/models response format (first 200 chars): {}",
body.chars().take(200).collect::<String>()
)))
}

5
lib/ai/tool/mod.rs Normal file
View File

@ -0,0 +1,5 @@
pub mod register;
pub mod tools;
pub mod toolset;
pub use toolset::{Toolset, ToolsetRegistry, toolset_names};

65
lib/ai/tool/register.rs Normal file
View File

@ -0,0 +1,65 @@
use crate::tool::tools::FunctionCall;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone)]
pub struct ToolRegister<C>
where
C: Clone + Send + Sync + 'static,
{
pub tools: Vec<Arc<dyn FunctionCall<Context = C>>>,
index: HashMap<String, usize>,
}
impl<C> ToolRegister<C>
where
C: Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
ToolRegister {
tools: Vec::new(),
index: HashMap::new(),
}
}
pub fn register<T>(&mut self, tool: T)
where
T: FunctionCall<Context = C> + 'static,
{
let idx = self.tools.len();
self.index.insert(tool.name().to_string(), idx);
self.tools.push(Arc::new(tool));
}
pub fn with_tool<T>(mut self, tool: T) -> Self
where
T: FunctionCall<Context = C> + 'static,
{
self.register(tool);
self
}
pub fn get(
&self,
name: &str,
) -> Option<Arc<dyn FunctionCall<Context = C>>> {
self.index.get(name).map(|&idx| self.tools[idx].clone())
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn len(&self) -> usize {
self.tools.len()
}
}
impl<C> Default for ToolRegister<C>
where
C: Clone + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}

18
lib/ai/tool/tools.rs Normal file
View File

@ -0,0 +1,18 @@
use crate::error::AiResult;
use async_trait::async_trait;
use serde_json::Value;
#[async_trait]
pub trait FunctionCall: Send + Sync {
type Context;
fn name(&self) -> &'static str;
fn description(&self) -> &'static str {
""
}
fn schema(&self) -> Value;
async fn call(
&self,
context: &mut Self::Context,
args: Value,
) -> AiResult<Value>;
}

146
lib/ai/tool/toolset.rs Normal file
View File

@ -0,0 +1,146 @@
use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Toolset {
pub name: String,
pub description: String,
pub tools: Vec<String>,
pub requires_env: Vec<String>,
}
impl Toolset {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
) -> Self {
Self {
name: name.into(),
description: description.into(),
tools: Vec::new(),
requires_env: Vec::new(),
}
}
pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
self.tools.push(tool_name.into());
self
}
pub fn with_tools(mut self, tool_names: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.tools.extend(tool_names.into_iter().map(Into::into));
self
}
pub fn with_required_env(
mut self,
env_vars: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.requires_env.extend(env_vars.into_iter().map(Into::into));
self
}
pub fn is_available(&self) -> bool {
for env_var in &self.requires_env {
if std::env::var(env_var).is_err() {
return false;
}
}
true
}
pub fn contains(&self, tool_name: &str) -> bool {
self.tools.iter().any(|t| t == tool_name)
}
}
pub mod toolset_names {
pub const CORE: &str = "core";
pub const TERMINAL: &str = "terminal";
pub const WEB: &str = "web";
pub const FILE: &str = "file";
pub const MEMORY: &str = "memory";
pub const VISION: &str = "vision";
pub const SEARCH: &str = "search";
pub const BROWSER: &str = "browser";
pub const CODE_EXECUTION: &str = "code_execution";
pub const DELEGATION: &str = "delegation";
}
#[derive(Clone, Debug, Default)]
pub struct ToolsetRegistry {
toolsets: HashMap<String, Toolset>,
tool_index: HashMap<String, String>,
}
impl ToolsetRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, toolset: Toolset) {
let name = toolset.name.clone();
for tool in &toolset.tools {
self.tool_index.insert(tool.clone(), name.clone());
}
self.toolsets.insert(name, toolset);
}
pub fn get(&self, name: &str) -> Option<&Toolset> {
self.toolsets.get(name)
}
pub fn toolset_for(&self, tool_name: &str) -> Option<&str> {
self.tool_index.get(tool_name).map(String::as_str)
}
pub fn resolve_tool_names(
&self,
enabled: &[String],
disabled: &[String],
default_all: bool,
) -> Vec<String> {
let mut names = HashSet::new();
let mut denied = HashSet::new();
for ts_name in disabled {
if let Some(ts) = self.toolsets.get(ts_name) {
for tool in &ts.tools {
denied.insert(tool.clone());
}
}
}
if enabled.is_empty() && default_all {
for ts in self.toolsets.values() {
if !disabled.contains(&ts.name) && ts.is_available() {
for tool in &ts.tools {
if !denied.contains(tool) {
names.insert(tool.clone());
}
}
}
}
} else {
for ts_name in enabled {
if let Some(ts) = self.toolsets.get(ts_name) {
if ts.is_available() {
for tool in &ts.tools {
if !denied.contains(tool) {
names.insert(tool.clone());
}
}
}
}
}
}
let mut sorted: Vec<String> = names.into_iter().collect();
sorted.sort();
sorted
}
pub fn iter(&self) -> impl Iterator<Item = &Toolset> {
self.toolsets.values()
}
pub fn all_tool_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.tool_index.keys().cloned().collect();
names.sort();
names
}
}

44
lib/api/Cargo.toml Normal file
View File

@ -0,0 +1,44 @@
[package]
name = "api"
version.workspace = true
edition.workspace = true
authors.workspace = true
description.workspace = true
repository.workspace = true
readme.workspace = true
homepage.workspace = true
license.workspace = true
keywords.workspace = true
categories.workspace = true
documentation.workspace = true
[lib]
path = "src/lib.rs"
name = "api"
[dependencies]
service = { workspace = true }
session = { workspace = true }
config = { workspace = true }
db = { workspace = true }
model = { workspace = true }
git = { workspace = true }
channel = { workspace = true }
socketio = { workspace = true }
actix-web = { workspace = true }
actix-ws = { workspace = true }
utoipa = { workspace = true, features = ["chrono", "uuid", "actix_extras"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
uuid = { workspace = true, features = ["v4", "v7", "serde"] }
chrono = { workspace = true }
async-stream = { workspace = true }
tokio-stream = { workspace = true }
base64 = { workspace = true }
comrak = { workspace = true }
redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp"] }
storage = { workspace = true }
[lints]
workspace = true

View File

@ -0,0 +1,296 @@
use actix_web::{HttpResponse, web, web::ServiceConfig};
use service::AppService;
use service::agent::conversation::{
ConversationResponse, ConversationWithSessionResponse, CreateConversation, MessageResponse, UpdateConversation,
};
use service::agent::types::{AgentRunRequest, AgentRunResponse};
use session::Session;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::UnboundedReceiverStream;
use uuid::Uuid;
use crate::error::{ApiError, ok_json};
pub fn configure(cfg: &mut ServiceConfig) {
cfg.service(
web::resource("/sessions/{session_id}/conversations")
.route(web::get().to(list_conversations))
.route(web::post().to(create_conversation)),
)
.service(
web::resource("/conversations")
.route(web::get().to(list_all_conversations)),
)
.service(
web::resource("/conversations/{id}")
.route(web::get().to(get_conversation))
.route(web::patch().to(update_conversation))
.route(web::delete().to(delete_conversation)),
)
.service(
web::resource("/conversations/{id}/messages")
.route(web::get().to(list_messages))
.route(web::post().to(send_message)),
)
.service(
web::resource("/conversations/{id}/stream")
.route(web::post().to(stream_agent)),
)
.service(
web::resource("/conversations/{id}/fork")
.route(web::post().to(fork_conversation)),
);
}
#[utoipa::path(
get, path = "/api/v1/agent/sessions/{session_id}/conversations",
params(("session_id" = Uuid, Path)),
responses((status = 200, body = Vec<ConversationResponse>)),
security(("session" = []))
)]
pub async fn list_conversations(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(service.agent_conversation_list(user_id, path.into_inner()).await?)
}
#[utoipa::path(
post, path = "/api/v1/agent/sessions/{session_id}/conversations",
params(("session_id" = Uuid, Path)),
request_body = CreateConversation,
responses((status = 200, body = ConversationResponse)),
security(("session" = []))
)]
pub async fn create_conversation(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
body: web::Json<CreateConversation>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_conversation_create(user_id, path.into_inner(), body.into_inner())
.await?,
)
}
#[derive(Debug, serde::Deserialize, utoipa::IntoParams)]
pub struct ListAllConversationsQuery {
pub wk: Option<String>,
}
#[utoipa::path(
get, path = "/api/v1/agent/conversations",
params(("wk" = Option<String>, Query, description = "Filter by workspace name")),
responses((status = 200, body = Vec<ConversationWithSessionResponse>)),
security(("session" = []))
)]
pub async fn list_all_conversations(
session: Session,
service: web::Data<AppService>,
query: web::Query<ListAllConversationsQuery>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_conversation_list_all(user_id, query.wk.as_deref())
.await?,
)
}
#[utoipa::path(
get, path = "/api/v1/agent/conversations/{id}",
params(("id" = Uuid, Path)),
responses((status = 200, body = ConversationResponse)),
security(("session" = []))
)]
pub async fn get_conversation(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(service.agent_conversation_get(user_id, path.into_inner()).await?)
}
#[utoipa::path(
patch, path = "/api/v1/agent/conversations/{id}",
params(("id" = Uuid, Path)),
request_body = UpdateConversation,
responses((status = 200, body = ConversationResponse)),
security(("session" = []))
)]
pub async fn update_conversation(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
body: web::Json<UpdateConversation>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_conversation_update(user_id, path.into_inner(), body.into_inner())
.await?,
)
}
#[utoipa::path(
delete, path = "/api/v1/agent/conversations/{id}",
params(("id" = Uuid, Path)),
responses((status = 200)),
security(("session" = []))
)]
pub async fn delete_conversation(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
service.agent_conversation_delete(user_id, path.into_inner()).await?;
Ok(HttpResponse::Ok().json(serde_json::json!({ "deleted": true })))
}
#[utoipa::path(
post, path = "/api/v1/agent/conversations/{id}/archive",
params(("id" = Uuid, Path)),
responses((status = 200, body = ConversationResponse)),
security(("session" = []))
)]
pub async fn archive_conversation(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(service.agent_conversation_archive(user_id, path.into_inner()).await?)
}
#[utoipa::path(
post, path = "/api/v1/agent/conversations/{id}/unarchive",
params(("id" = Uuid, Path)),
responses((status = 200, body = ConversationResponse)),
security(("session" = []))
)]
pub async fn unarchive_conversation(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(service.agent_conversation_unarchive(user_id, path.into_inner()).await?)
}
#[utoipa::path(
get, path = "/api/v1/agent/conversations/{id}/messages",
params(("id" = Uuid, Path), ("before" = Option<Uuid>, Query), ("limit" = Option<u32>, Query)),
responses((status = 200, body = Vec<MessageResponse>)),
security(("session" = []))
)]
pub async fn list_messages(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
query: web::Query<MessageListQuery>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_message_list(user_id, path.into_inner(), query.limit, query.before)
.await?,
)
}
#[derive(Debug, serde::Deserialize, utoipa::IntoParams)]
pub struct MessageListQuery {
pub limit: Option<u32>,
pub before: Option<Uuid>,
}
#[utoipa::path(
post, path = "/api/v1/agent/conversations/{id}/messages",
params(("id" = Uuid, Path)),
request_body = AgentRunRequest,
responses((status = 200, body = AgentRunResponse)),
security(("session" = []))
)]
pub async fn send_message(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
body: web::Json<AgentRunRequest>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
let conversation_id = path.into_inner();
let mut req = body.into_inner();
req.conversation_id = Some(conversation_id);
ok_json(service.agent_run(user_id, req).await?)
}
#[utoipa::path(
post, path = "/api/v1/agent/conversations/{id}/stream",
params(("id" = Uuid, Path)),
request_body = AgentRunRequest,
responses((status = 200, description = "SSE stream")),
security(("session" = []))
)]
pub async fn stream_agent(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
body: web::Json<AgentRunRequest>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
let conversation_id = path.into_inner();
let mut req = body.into_inner();
req.conversation_id = Some(conversation_id);
let rx = service.agent_run_streaming(user_id, req).await?;
let stream = UnboundedReceiverStream::new(rx).map(|payload| {
let frame = if payload.starts_with("data:") {
payload
} else {
format!("data: {}\n\n", payload)
};
Ok::<_, actix_web::Error>(actix_web::web::Bytes::from(frame))
});
Ok(HttpResponse::Ok()
.content_type("text/event-stream")
.insert_header(("Cache-Control", "no-cache"))
.insert_header(("Connection", "keep-alive"))
.insert_header(("X-Accel-Buffering", "no"))
.streaming(stream))
}
#[derive(Debug, serde::Deserialize, utoipa::ToSchema)]
pub struct ForkConversationRequest {
pub message_id: Option<Uuid>,
pub title: Option<String>,
}
#[utoipa::path(
post, path = "/api/v1/agent/conversations/{id}/fork",
params(("id" = Uuid, Path)),
request_body = ForkConversationRequest,
responses((status = 200, body = ConversationResponse)),
security(("session" = []))
)]
pub async fn fork_conversation(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
body: web::Json<ForkConversationRequest>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_conversation_fork(
user_id,
path.into_inner(),
body.message_id,
body.title.as_deref(),
)
.await?,
)
}

12
lib/api/src/agent/mod.rs Normal file
View File

@ -0,0 +1,12 @@
pub mod conversation;
pub mod session;
use actix_web::{web, web::ServiceConfig};
pub fn configure(cfg: &mut ServiceConfig) {
cfg.service(
web::scope("/agent")
.configure(session::configure)
.configure(conversation::configure),
);
}

View File

@ -0,0 +1,162 @@
use actix_web::{HttpResponse, web, web::ServiceConfig};
use service::AppService;
use service::agent::session::{
AgentSessionResponse, CreateAgentSession, UpdateAgentSession,
};
use session::Session;
use uuid::Uuid;
use crate::error::{ApiError, ok_json};
pub fn configure(cfg: &mut ServiceConfig) {
cfg.service(
web::resource("/sessions")
.route(web::get().to(list_sessions))
.route(web::post().to(create_session)),
)
.service(
web::resource("/sessions/search")
.route(web::get().to(search_sessions)),
)
.service(
web::resource("/sessions/{id}")
.route(web::get().to(get_session))
.route(web::patch().to(update_session))
.route(web::delete().to(delete_session)),
)
.service(
web::resource("/sessions/{id}/toolsets")
.route(web::patch().to(update_session_toolsets)),
);
}
#[utoipa::path(
get, path = "/api/v1/agent/sessions",
responses((status = 200, body = Vec<AgentSessionResponse>)),
security(("session" = []))
)]
pub async fn list_sessions(
session: Session,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(service.agent_session_list(user_id).await?)
}
#[utoipa::path(
post, path = "/api/v1/agent/sessions",
request_body = CreateAgentSession,
responses((status = 200, body = AgentSessionResponse)),
security(("session" = []))
)]
pub async fn create_session(
session: Session,
service: web::Data<AppService>,
body: web::Json<CreateAgentSession>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(service.agent_session_create(user_id, body.into_inner()).await?)
}
#[utoipa::path(
get, path = "/api/v1/agent/sessions/{id}",
params(("id" = Uuid, Path)),
responses((status = 200, body = AgentSessionResponse)),
security(("session" = []))
)]
pub async fn get_session(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(service.agent_session_get(user_id, path.into_inner()).await?)
}
#[utoipa::path(
patch, path = "/api/v1/agent/sessions/{id}",
params(("id" = Uuid, Path)),
request_body = UpdateAgentSession,
responses((status = 200, body = AgentSessionResponse)),
security(("session" = []))
)]
pub async fn update_session(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
body: web::Json<UpdateAgentSession>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(service.agent_session_update(user_id, path.into_inner(), body.into_inner()).await?)
}
#[utoipa::path(
delete, path = "/api/v1/agent/sessions/{id}",
params(("id" = Uuid, Path)),
responses((status = 200)),
security(("session" = []))
)]
pub async fn delete_session(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
service.agent_session_delete(user_id, path.into_inner()).await?;
Ok(HttpResponse::Ok().json(serde_json::json!({ "deleted": true })))
}
#[derive(Debug, serde::Deserialize, utoipa::IntoParams)]
pub struct SearchQuery {
pub q: String,
#[serde(default = "default_limit")]
pub limit: u32,
}
const fn default_limit() -> u32 {
20
}
#[utoipa::path(
get, path = "/api/v1/agent/sessions/search",
params(("q" = String, Query), ("limit" = Option<u32>, Query)),
responses((status = 200, body = Vec<AgentSessionResponse>)),
security(("session" = []))
)]
pub async fn search_sessions(
session: Session,
service: web::Data<AppService>,
query: web::Query<SearchQuery>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_session_search(user_id, &query.q, query.limit)
.await?,
)
}
#[derive(Debug, serde::Deserialize, utoipa::ToSchema)]
pub struct UpdateToolsetsRequest {
pub enabled: Option<Vec<String>>,
pub disabled: Option<Vec<String>>,
}
#[utoipa::path(
patch, path = "/api/v1/agent/sessions/{id}/toolsets",
params(("id" = Uuid, Path)),
request_body = UpdateToolsetsRequest,
responses((status = 200, body = AgentSessionResponse)),
security(("session" = []))
)]
pub async fn update_session_toolsets(
session: Session,
service: web::Data<AppService>,
path: web::Path<Uuid>,
body: web::Json<UpdateToolsetsRequest>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
ok_json(
service
.agent_session_update_toolsets(
user_id,
path.into_inner(),
body.enabled.clone(),
body.disabled.clone(),
)
.await?,
)
}

47
lib/api/src/ai/mod.rs Normal file
View File

@ -0,0 +1,47 @@
pub mod model;
pub mod provider;
use actix_web::web;
use actix_web::web::ServiceConfig;
pub fn configure(cfg: &mut ServiceConfig) {
cfg.service(
web::scope("/ai")
.service(
web::resource("/providers")
.route(web::get().to(provider::list_providers)),
)
.service(
web::resource("/providers/{id}")
.route(web::get().to(provider::get_provider)),
)
.service(
web::resource("/models")
.route(web::get().to(model::list_models)),
)
.service(
web::resource("/models/{id}")
.route(web::get().to(model::get_model)),
)
.service(
web::resource("/models/{id}/versions")
.route(web::get().to(model::list_versions)),
)
.service(
web::resource("/models/{id}/card")
.route(web::get().to(model::get_card)),
)
.service(
web::resource("/models/{id}/tags")
.route(web::get().to(model::list_tags)),
)
.service(
web::resource("/models/{id}/discussions")
.route(web::get().to(model::list_discussions)),
)
.service(
web::resource("/models/{id}/likes")
.route(web::get().to(model::list_likes)),
),
);
}

139
lib/api/src/ai/model.rs Normal file
View File

@ -0,0 +1,139 @@
use crate::error::{ApiError, ok_json};
use actix_web::{HttpResponse, web};
use serde::Deserialize;
use service::AppService;
use service::Pagination;
use service::ai::types::{
AiDiscussionResponse, AiLikeResponse, AiModelCardResponse, AiModelFilter,
AiModelListItem, AiModelResponse, AiModelVersionResponse,
};
use session::Session;
use uuid::Uuid;
#[derive(Deserialize, utoipa::IntoParams)]
pub struct ModelIdPath {
pub id: Uuid,
}
#[utoipa::path(
get, path = "/api/v1/ai/models",
params(AiModelFilter, Pagination),
responses((status = 200, body = Vec<AiModelListItem>), (status = 401, description = "Unauthorized")),
security(("session" = []))
)]
pub async fn list_models(
session: Session,
service: web::Data<AppService>,
filter: web::Query<AiModelFilter>,
pagination: web::Query<Pagination>,
) -> Result<HttpResponse, ApiError> {
ok_json(
service
.ai_model_list(
&session,
filter.into_inner(),
pagination.into_inner(),
)
.await?,
)
}
#[utoipa::path(
get, path = "/api/v1/ai/models/{id}",
params(("id" = String, Path)), responses((status = 200, body = AiModelResponse),
(status = 401, description = "Unauthorized"), (status = 404, description = "Not found")),
security(("session" = []))
)]
pub async fn get_model(
session: Session,
service: web::Data<AppService>,
path: web::Path<ModelIdPath>,
) -> Result<HttpResponse, ApiError> {
ok_json(service.ai_model_get(&session, path.into_inner().id).await?)
}
#[utoipa::path(
get, path = "/api/v1/ai/models/{id}/versions",
params(("id" = String, Path)), responses((status = 200, body = Vec<AiModelVersionResponse>),
(status = 401, description = "Unauthorized")),
security(("session" = []))
)]
pub async fn list_versions(
session: Session,
service: web::Data<AppService>,
path: web::Path<ModelIdPath>,
) -> Result<HttpResponse, ApiError> {
ok_json(
service
.ai_model_versions(&session, path.into_inner().id)
.await?,
)
}
#[utoipa::path(
get, path = "/api/v1/ai/models/{id}/card",
params(("id" = String, Path)), responses((status = 200, body = Option<AiModelCardResponse>),
(status = 401, description = "Unauthorized")),
security(("session" = []))
)]
pub async fn get_card(
session: Session,
service: web::Data<AppService>,
path: web::Path<ModelIdPath>,
) -> Result<HttpResponse, ApiError> {
ok_json(
service
.ai_model_card(&session, path.into_inner().id)
.await?,
)
}
#[utoipa::path(
get, path = "/api/v1/ai/models/{id}/tags",
params(("id" = String, Path)), responses((status = 200, body = Vec<String>)),
security(("session" = []))
)]
pub async fn list_tags(
session: Session,
service: web::Data<AppService>,
path: web::Path<ModelIdPath>,
) -> Result<HttpResponse, ApiError> {
ok_json(
service
.ai_model_tags(&session, path.into_inner().id)
.await?,
)
}
#[utoipa::path(
get, path = "/api/v1/ai/models/{id}/discussions",
params(("id" = String, Path), Pagination),
responses((status = 200, body = Vec<AiDiscussionResponse>)),
security(("session" = []))
)]
pub async fn list_discussions(
session: Session,
service: web::Data<AppService>,
path: web::Path<ModelIdPath>,
pagination: web::Query<Pagination>,
) -> Result<HttpResponse, ApiError> {
ok_json(
service
.ai_model_discussions(
&session,
path.into_inner().id,
pagination.into_inner(),
)
.await?,
)
}
#[utoipa::path(
get, path = "/api/v1/ai/models/{id}/likes",
params(("id" = String, Path)), responses((status = 200, body = Vec<AiLikeResponse>)),
security(("session" = []))
)]
pub async fn list_likes(
session: Session,
service: web::Data<AppService>,
path: web::Path<ModelIdPath>,
) -> Result<HttpResponse, ApiError> {
ok_json(
service
.ai_model_likes(&session, path.into_inner().id)
.await?,
)
}

View File

@ -0,0 +1,39 @@
use crate::error::{ApiError, ok_json};
use actix_web::{HttpResponse, web};
use serde::Deserialize;
use service::AppService;
use service::ai::types::AiProviderResponse;
use session::Session;
#[derive(Deserialize, utoipa::IntoParams)]
pub struct ProviderIdPath {
pub id: uuid::Uuid,
}
#[utoipa::path(
get, path = "/api/v1/ai/providers",
responses((status = 200, body = Vec<AiProviderResponse>), (status = 401, description = "Unauthorized")),
security(("session" = []))
)]
pub async fn list_providers(
session: Session,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
ok_json(service.ai_provider_list(&session).await?)
}
#[utoipa::path(
get, path = "/api/v1/ai/providers/{id}",
params(("id" = String, Path)), responses((status = 200, body = AiProviderResponse),
(status = 401, description = "Unauthorized"), (status = 404, description = "Not found")),
security(("session" = []))
)]
pub async fn get_provider(
session: Session,
service: web::Data<AppService>,
path: web::Path<ProviderIdPath>,
) -> Result<HttpResponse, ApiError> {
ok_json(
service
.ai_provider_get(&session, path.into_inner().id)
.await?,
)
}

View File

@ -0,0 +1,29 @@
use actix_web::{HttpResponse, web};
use serde::Serialize;
use service::{
AppService,
auth::captcha::{CaptchaQuery, CaptchaResponse},
};
use session::Session;
use crate::error::ApiError;
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(data))
}
#[utoipa::path(
get,
path = "/api/v1/auth/captcha",
params(CaptchaQuery),
responses((status = 200, body = CaptchaResponse)),
tag = "auth"
)]
pub async fn captcha(
session: Session,
query: web::Query<CaptchaQuery>,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
let result = service.auth_captcha(&session, query.into_inner()).await?;
ok_json(result)
}

64
lib/api/src/auth/email.rs Normal file
View File

@ -0,0 +1,64 @@
use actix_web::{HttpResponse, web};
use serde::Serialize;
use service::{
AppService,
auth::email::{EmailChangeRequest, EmailResponse, EmailVerifyRequest},
};
use session::Session;
use crate::error::ApiError;
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(data))
}
fn ok() -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().finish())
}
#[utoipa::path(
get,
path = "/api/v1/auth/email",
responses((status = 200, body = EmailResponse)),
tag = "auth"
)]
pub async fn get_email(
session: Session,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
let result = service.auth_get_email(&session).await?;
ok_json(result)
}
#[utoipa::path(
post,
path = "/api/v1/auth/email",
request_body = EmailChangeRequest,
responses((status = 200)),
tag = "auth"
)]
pub async fn email_change_request(
session: Session,
params: web::Json<EmailChangeRequest>,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
service
.auth_email_change_request(&session, params.into_inner())
.await?;
ok()
}
#[utoipa::path(
post,
path = "/api/v1/auth/email/verify",
request_body = EmailVerifyRequest,
responses((status = 200)),
tag = "auth"
)]
pub async fn email_verify(
params: web::Json<EmailVerifyRequest>,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
service.auth_email_verify(params.into_inner()).await?;
ok()
}

25
lib/api/src/auth/login.rs Normal file
View File

@ -0,0 +1,25 @@
use actix_web::{HttpResponse, web};
use service::{AppService, auth::login::LoginParams};
use session::Session;
use crate::error::ApiError;
fn ok() -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().finish())
}
#[utoipa::path(
post,
path = "/api/v1/auth/login",
request_body = LoginParams,
responses((status = 200)),
tag = "auth"
)]
pub async fn login(
session: Session,
params: web::Json<LoginParams>,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
service.auth_login(params.into_inner(), session).await?;
ok()
}

View File

@ -0,0 +1,23 @@
use actix_web::{HttpResponse, web};
use service::AppService;
use session::Session;
use crate::error::ApiError;
fn ok() -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().finish())
}
#[utoipa::path(
post,
path = "/api/v1/auth/logout",
responses((status = 200)),
tag = "auth"
)]
pub async fn logout(
session: Session,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
service.auth_logout(&session).await?;
ok()
}

24
lib/api/src/auth/me.rs Normal file
View File

@ -0,0 +1,24 @@
use actix_web::{HttpResponse, web};
use serde::Serialize;
use service::{AppService, auth::me::ContextMe};
use session::Session;
use crate::error::ApiError;
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(data))
}
#[utoipa::path(
get,
path = "/api/v1/auth/me",
responses((status = 200, body = ContextMe)),
tag = "auth"
)]
pub async fn me(
session: Session,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
let result = service.auth_me(session).await?;
ok_json(result)
}

75
lib/api/src/auth/mod.rs Normal file
View File

@ -0,0 +1,75 @@
pub mod captcha;
pub mod email;
pub mod login;
pub mod logout;
pub mod me;
pub mod register;
pub mod reset_pass;
pub mod rsa;
pub mod totp;
use actix_web::{web, web::ServiceConfig};
pub fn configure(cfg: &mut ServiceConfig) {
cfg.service(
web::scope("/auth")
.service(
web::resource("/captcha")
.route(web::get().to(captcha::captcha)),
)
.service(
web::resource("/login").route(web::post().to(login::login)),
)
.service(
web::resource("/logout").route(web::post().to(logout::logout)),
)
.service(web::resource("/me").route(web::get().to(me::me)))
.service(
web::resource("/register")
.route(web::post().to(register::register)),
)
.service(
web::scope("/reset-password")
.service(web::resource("/request").route(
web::post().to(reset_pass::reset_password_request),
))
.service(web::resource("/verify").route(
web::post().to(reset_pass::reset_password_verify),
)),
)
.service(web::resource("/public-key").route(web::get().to(rsa::rsa)))
.service(
web::scope("/2fa")
.service(
web::resource("/enable")
.route(web::post().to(totp::enable_2fa)),
)
.service(
web::resource("/verify")
.route(web::post().to(totp::verify_2fa)),
)
.service(
web::resource("")
.route(web::get().to(totp::status_2fa))
.route(web::delete().to(totp::disable_2fa)),
)
.service(
web::resource("/backup-codes").route(
web::post().to(totp::regenerate_backup_codes),
),
),
)
.service(
web::scope("/email")
.service(
web::resource("")
.route(web::get().to(email::get_email))
.route(web::put().to(email::email_change_request)),
)
.service(
web::resource("/verify")
.route(web::post().to(email::email_verify)),
),
),
);
}

View File

@ -0,0 +1,26 @@
use actix_web::{HttpResponse, web};
use serde::Serialize;
use service::{AppService, auth::register::RegisterParams};
use session::Session;
use crate::error::ApiError;
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(data))
}
#[utoipa::path(
post,
path = "/api/v1/auth/register",
request_body = RegisterParams,
responses((status = 200)),
tag = "auth"
)]
pub async fn register(
session: Session,
params: web::Json<RegisterParams>,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
let result = service.auth_register(params.into_inner(), &session).await?;
ok_json(result)
}

View File

@ -0,0 +1,47 @@
use actix_web::{HttpResponse, web};
use service::{
AppService,
auth::reset_pass::{ResetPasswordRequest, ResetPasswordVerifyParams},
};
use session::Session;
use crate::error::ApiError;
fn ok() -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().finish())
}
#[utoipa::path(
post,
path = "/api/v1/auth/reset-password/request",
request_body = ResetPasswordRequest,
responses((status = 200)),
tag = "auth"
)]
pub async fn reset_password_request(
params: web::Json<ResetPasswordRequest>,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
service
.auth_reset_password_request(params.into_inner())
.await?;
ok()
}
#[utoipa::path(
post,
path = "/api/v1/auth/reset-password/verify",
request_body = ResetPasswordVerifyParams,
responses((status = 200)),
tag = "auth"
)]
pub async fn reset_password_verify(
session: Session,
params: web::Json<ResetPasswordVerifyParams>,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
service
.auth_reset_password_verify(&session, params.into_inner())
.await?;
ok()
}

24
lib/api/src/auth/rsa.rs Normal file
View File

@ -0,0 +1,24 @@
use actix_web::{HttpResponse, web};
use serde::Serialize;
use service::{AppService, auth::rsa::RsaResponse};
use session::Session;
use crate::error::ApiError;
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(data))
}
#[utoipa::path(
get,
path = "/api/v1/auth/public-key",
responses((status = 200, body = RsaResponse)),
tag = "auth"
)]
pub async fn rsa(
session: Session,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
let result = service.auth_rsa(&session).await?;
ok_json(result)
}

102
lib/api/src/auth/totp.rs Normal file
View File

@ -0,0 +1,102 @@
use actix_web::{HttpResponse, web};
use serde::Serialize;
use service::{
AppService,
auth::totp::{
Disable2FAParams, Enable2FAResponse, Get2FAStatusResponse,
Verify2FAParams,
},
};
use session::Session;
use crate::error::ApiError;
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(data))
}
fn ok() -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().finish())
}
#[utoipa::path(
post,
path = "/api/v1/auth/2fa/enable",
responses((status = 200, body = Enable2FAResponse)),
tag = "auth"
)]
pub async fn enable_2fa(
session: Session,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
let result = service.auth_2fa_enable(&session).await?;
ok_json(result)
}
#[utoipa::path(
post,
path = "/api/v1/auth/2fa/verify",
request_body = Verify2FAParams,
responses((status = 200)),
tag = "auth"
)]
pub async fn verify_2fa(
session: Session,
params: web::Json<Verify2FAParams>,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
service
.auth_2fa_verify_and_enable(&session, params.into_inner())
.await?;
ok()
}
#[utoipa::path(
delete,
path = "/api/v1/auth/2fa",
request_body = Disable2FAParams,
responses((status = 200)),
tag = "auth"
)]
pub async fn disable_2fa(
session: Session,
params: web::Json<Disable2FAParams>,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
service
.auth_2fa_disable(&session, params.into_inner())
.await?;
ok()
}
#[utoipa::path(
get,
path = "/api/v1/auth/2fa",
responses((status = 200, body = Get2FAStatusResponse)),
tag = "auth"
)]
pub async fn status_2fa(
session: Session,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
let result = service.auth_2fa_status(&session).await?;
ok_json(result)
}
#[utoipa::path(
post,
path = "/api/v1/auth/2fa/backup-codes",
request_body = String,
responses((status = 200)),
tag = "auth"
)]
pub async fn regenerate_backup_codes(
session: Session,
params: web::Json<String>,
service: web::Data<AppService>,
) -> Result<HttpResponse, ApiError> {
let result = service
.auth_2fa_regenerate_backup_codes(&session, params.into_inner())
.await?;
ok_json(result)
}

207
lib/api/src/channel/mod.rs Normal file
View File

@ -0,0 +1,207 @@
pub mod rest;
pub mod rest_ai;
pub mod rest_interact;
pub mod rest_member;
pub mod rest_message;
pub mod rest_room;
pub mod rest_voice;
pub mod token;
pub use channel::ChannelBus;
use actix_web::web::ServiceConfig;
pub fn configure(cfg: &mut ServiceConfig, bus: ChannelBus) {
socketio::configure_at(cfg, "/socket.io", bus.io().clone());
socketio::configure_at(cfg, "/socket.io/", bus.io().clone());
cfg.service(
actix_web::web::resource("/ping")
.route(actix_web::web::get().to(rest::ping)),
)
.service(
actix_web::web::resource("/csrf")
.route(actix_web::web::get().to(rest::csrf_token)),
);
cfg.service(
actix_web::web::resource("/rooms/{room_id}/messages")
.route(actix_web::web::get().to(rest_message::list_messages))
.route(actix_web::web::post().to(rest_message::create_message)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/messages/around")
.route(actix_web::web::get().to(rest_message::messages_around)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/messages/missed")
.route(actix_web::web::get().to(rest_message::missed_messages)),
)
.service(
actix_web::web::resource("/messages/{message_id}")
.route(actix_web::web::patch().to(rest_message::update_message))
.route(actix_web::web::delete().to(rest_message::revoke_message)),
)
.service(
actix_web::web::resource("/search")
.route(actix_web::web::get().to(rest_message::search)),
);
cfg.service(
actix_web::web::resource("/rooms")
.route(actix_web::web::get().to(rest_room::list_rooms))
.route(actix_web::web::post().to(rest_room::room_create)),
)
.service(
actix_web::web::resource("/rooms/{room_id}")
.route(actix_web::web::get().to(rest_room::room_get))
.route(actix_web::web::patch().to(rest_room::room_update))
.route(actix_web::web::delete().to(rest_room::room_delete)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/subscribe")
.route(actix_web::web::post().to(rest_room::subscribe))
.route(actix_web::web::delete().to(rest_room::unsubscribe)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/members")
.route(actix_web::web::post().to(rest_room::access_grant)),
)
.service(
actix_web::web::resource("/workspaces/{workspace_id}/members")
.route(actix_web::web::get().to(rest_member::list_workspace_members)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/members/{user_id}")
.route(actix_web::web::delete().to(rest_room::access_revoke)),
)
.service(
actix_web::web::resource("/workspaces/{workspace_id}/categories")
.route(actix_web::web::post().to(rest_room::category_create)),
)
.service(
actix_web::web::resource("/categories/{category_id}")
.route(actix_web::web::patch().to(rest_room::category_update))
.route(actix_web::web::delete().to(rest_room::category_delete)),
);
cfg.service(
actix_web::web::resource("/rooms/{room_id}/reactions")
.route(actix_web::web::post().to(rest_interact::reaction_add))
.route(actix_web::web::delete().to(rest_interact::reaction_remove)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/threads")
.route(actix_web::web::post().to(rest_interact::thread_create)),
)
.service(
actix_web::web::resource("/threads/{thread_id}/resolve")
.route(actix_web::web::patch().to(rest_interact::thread_resolve)),
)
.service(
actix_web::web::resource("/threads/{thread_id}/archive")
.route(actix_web::web::patch().to(rest_interact::thread_archive)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/pins")
.route(actix_web::web::post().to(rest_interact::pin_add))
.route(actix_web::web::delete().to(rest_interact::pin_remove)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/drafts")
.route(actix_web::web::put().to(rest_interact::draft_save))
.route(actix_web::web::delete().to(rest_interact::draft_clear)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/typing")
.route(actix_web::web::post().to(rest_interact::typing)),
);
cfg.service(
actix_web::web::resource("/rooms/{room_id}/read-receipt")
.route(actix_web::web::post().to(rest_member::read_receipt)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/dnd")
.route(actix_web::web::patch().to(rest_member::dnd_update)),
)
.service(
actix_web::web::resource("/notifications/{id}/read").route(
actix_web::web::patch().to(rest_member::notification_mark_read),
),
)
.service(actix_web::web::resource("/notifications/read-all").route(
actix_web::web::post().to(rest_member::notification_mark_all_read),
))
.service(
actix_web::web::resource("/notifications/{id}").route(
actix_web::web::delete().to(rest_member::notification_archive),
),
)
.service(
actix_web::web::resource("/presence")
.route(actix_web::web::post().to(rest_member::presence_update)),
)
.service(
actix_web::web::resource("/custom-status").route(
actix_web::web::post().to(rest_member::custom_status_update),
),
)
.service(
actix_web::web::resource("/invites")
.route(actix_web::web::post().to(rest_member::invite_create)),
)
.service(
actix_web::web::resource("/invites/accept")
.route(actix_web::web::post().to(rest_member::invite_accept)),
)
.service(
actix_web::web::resource("/invites/{id}")
.route(actix_web::web::delete().to(rest_member::invite_revoke)),
)
.service(
actix_web::web::resource("/workspaces/{workspace_id}/bans")
.route(actix_web::web::post().to(rest_member::ban_create)),
)
.service(
actix_web::web::resource("/workspaces/{workspace_id}/bans/{user_id}")
.route(actix_web::web::delete().to(rest_member::ban_remove)),
);
cfg.service(
actix_web::web::resource("/rooms/{room_id}/voice/join")
.route(actix_web::web::post().to(rest_voice::voice_join)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/voice/leave")
.route(actix_web::web::post().to(rest_voice::voice_leave)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/voice/mute")
.route(actix_web::web::post().to(rest_voice::voice_mute)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/voice/deaf")
.route(actix_web::web::post().to(rest_voice::voice_deaf)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/screen-share")
.route(actix_web::web::post().to(rest_voice::screen_share)),
);
cfg.service(
actix_web::web::resource("/rooms/{room_id}/ai/stop")
.route(actix_web::web::post().to(rest_ai::ai_stop)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/ai")
.route(actix_web::web::get().to(rest_ai::ai_list))
.route(actix_web::web::post().to(rest_ai::ai_add)),
)
.service(
actix_web::web::resource("/rooms/{room_id}/ai/{agent_session_id}")
.route(actix_web::web::delete().to(rest_ai::ai_remove)),
)
.service(
actix_web::web::resource("/users/summary/{username}")
.route(actix_web::web::get().to(rest_ai::user_summary)),
);
cfg.service(
actix_web::web::resource("/token")
.route(actix_web::web::post().to(token::generate_token)),
);
cfg.app_data(actix_web::web::Data::new(bus));
}

110
lib/api/src/channel/rest.rs Normal file
View File

@ -0,0 +1,110 @@
use actix_web::{HttpRequest, HttpResponse, web};
use channel::http::{WsHandler, WsInMessage, WsOutEvent};
use channel::{ChannelBus, ChannelError};
use session::SessionExt;
use uuid::Uuid;
use crate::error::ApiError;
pub(crate) fn extract_user(req: &HttpRequest) -> Result<Uuid, ApiError> {
req.get_session()
.user()
.ok_or_else(|| ApiError(service::error::AppError::Unauthorized))
}
pub(crate) fn channel_err(e: ChannelError) -> ApiError {
ApiError(match e {
ChannelError::Unauthorized | ChannelError::TokenInvalidOrExpired => {
service::error::AppError::Unauthorized
}
ChannelError::AccessDenied => {
service::error::AppError::PermissionDenied
}
ChannelError::Validation(msg) => {
service::error::AppError::BadRequest(msg)
}
ChannelError::RateLimitExceeded => {
service::error::AppError::BadRequest("rate limit exceeded".into())
}
ChannelError::RenewalLimitExceeded => {
service::error::AppError::BadRequest(
"renewal limit exceeded".into(),
)
}
ChannelError::RoomNotFound => {
service::error::AppError::NotFound("room not found".into())
}
ChannelError::UserNotFound => {
service::error::AppError::NotFound("user not found".into())
}
ChannelError::Internal(msg) => {
service::error::AppError::InternalServerError(msg)
}
ChannelError::Database(e) => {
service::error::AppError::InternalServerError(e.to_string())
}
ChannelError::Cache(e) => {
service::error::AppError::InternalServerError(e.to_string())
}
ChannelError::SocketIo(e) => {
service::error::AppError::InternalServerError(e.to_string())
}
ChannelError::Serialization(e) => {
service::error::AppError::InternalServerError(e.to_string())
}
ChannelError::Redis(e) => {
service::error::AppError::InternalServerError(e.to_string())
}
ChannelError::Storage(e) => {
service::error::AppError::InternalServerError(e.to_string())
}
})
}
pub(crate) fn ok_json(event: Option<WsOutEvent>) -> HttpResponse {
match event {
Some(e) => HttpResponse::Ok().json(e),
None => HttpResponse::NoContent().finish(),
}
}
pub(crate) fn created_json(event: Option<WsOutEvent>) -> HttpResponse {
match event {
Some(e) => HttpResponse::Created().json(e),
None => HttpResponse::NoContent().finish(),
}
}
#[utoipa::path(
get,
path = "/api/v1/ws/ping",
responses((status = 200, description = "Pong with protocol version")),
tag = "channel",
)]
pub async fn ping(
req: HttpRequest,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let result = WsHandler::handle(&bus, user_id, WsInMessage::Ping)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
get,
path = "/api/v1/ws/csrf",
responses((status = 200, description = "CSRF token")),
tag = "channel",
)]
pub async fn csrf_token(
req: HttpRequest,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let result = WsHandler::handle(&bus, user_id, WsInMessage::CsrfToken)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}

View File

@ -0,0 +1,120 @@
use actix_web::{HttpRequest, HttpResponse, web};
use channel::ChannelBus;
use channel::http::{WsHandler, WsInMessage};
use serde::Deserialize;
use uuid::Uuid;
use super::rest::{channel_err, created_json, extract_user, ok_json};
use crate::error::ApiError;
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct AiAddRequest {
pub agent_session: Uuid,
}
#[utoipa::path(
get,
path = "/api/v1/ws/rooms/{room_id}/ai",
responses((status = 200, description = "AI agents in room")),
tag = "channel",
)]
pub async fn ai_list(
req: HttpRequest,
room_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::AiList {
room: room_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/ai",
request_body = AiAddRequest,
responses((status = 201, description = "AI agent added to room")),
tag = "channel",
)]
pub async fn ai_add(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<AiAddRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::AiUpsert {
room: room_id.into_inner(),
model: body.agent_session,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(created_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/rooms/{room_id}/ai/{agent_session_id}",
responses((status = 200, description = "AI agent removed from room")),
tag = "channel",
)]
pub async fn ai_remove(
req: HttpRequest,
path: web::Path<(Uuid, Uuid)>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let (room, agent_id) = path.into_inner();
let msg = WsInMessage::AiDelete { room, agent_id };
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/ai/stop",
responses((status = 204, description = "AI agent stopped")),
tag = "channel",
)]
pub async fn ai_stop(
req: HttpRequest,
room_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::AiStop {
room: room_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
get,
path = "/api/v1/ws/users/summary/{username}",
responses((status = 200, description = "User summary")),
tag = "channel",
)]
pub async fn user_summary(
req: HttpRequest,
username: web::Path<String>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::UserSummary {
username: username.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}

View File

@ -0,0 +1,274 @@
use actix_web::{HttpRequest, HttpResponse, web};
use channel::ChannelBus;
use channel::http::{WsHandler, WsInMessage};
use serde::Deserialize;
use uuid::Uuid;
use super::rest::{channel_err, created_json, extract_user, ok_json};
use crate::error::ApiError;
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct ReactionRequest {
pub message: Uuid,
pub emoji: String,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct ThreadCreateRequest {
pub parent: i64,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct TypingRequest {
pub action: TypingAction,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub enum TypingAction {
Start,
Stop,
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/reactions",
request_body = ReactionRequest,
responses((status = 204, description = "Reaction added")),
tag = "channel",
)]
pub async fn reaction_add(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<ReactionRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::ReactionAdd {
room: room_id.into_inner(),
message: body.message,
emoji: body.emoji.clone(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/rooms/{room_id}/reactions",
request_body = ReactionRequest,
responses((status = 204, description = "Reaction removed")),
tag = "channel",
)]
pub async fn reaction_remove(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<ReactionRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::ReactionRemove {
room: room_id.into_inner(),
message: body.message,
emoji: body.emoji.clone(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/threads",
request_body = ThreadCreateRequest,
responses((status = 201, description = "Thread created")),
tag = "channel",
)]
pub async fn thread_create(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<ThreadCreateRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::ThreadCreate {
room: room_id.into_inner(),
parent: body.parent,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(created_json(result))
}
#[utoipa::path(
patch,
path = "/api/v1/ws/threads/{thread_id}/resolve",
responses((status = 200, description = "Thread resolved")),
tag = "channel",
)]
pub async fn thread_resolve(
req: HttpRequest,
thread_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::ThreadResolve {
thread_id: thread_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
patch,
path = "/api/v1/ws/threads/{thread_id}/archive",
responses((status = 200, description = "Thread archived")),
tag = "channel",
)]
pub async fn thread_archive(
req: HttpRequest,
thread_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::ThreadArchive {
thread_id: thread_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct PinRequest {
pub message: Uuid,
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/pins",
request_body = PinRequest,
responses((status = 204, description = "Message pinned")),
tag = "channel",
)]
pub async fn pin_add(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<PinRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::PinAdd {
room: room_id.into_inner(),
message: body.message,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/rooms/{room_id}/pins",
request_body = PinRequest,
responses((status = 204, description = "Pin removed")),
tag = "channel",
)]
pub async fn pin_remove(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<PinRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::PinRemove {
room: room_id.into_inner(),
message: body.message,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct DraftSaveRequest {
pub content: String,
}
#[utoipa::path(
put,
path = "/api/v1/ws/rooms/{room_id}/drafts",
request_body = DraftSaveRequest,
responses((status = 204, description = "Draft saved")),
tag = "channel",
)]
pub async fn draft_save(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<DraftSaveRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::DraftSave {
room: room_id.into_inner(),
content: body.content.clone(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/rooms/{room_id}/drafts",
responses((status = 204, description = "Draft cleared")),
tag = "channel",
)]
pub async fn draft_clear(
req: HttpRequest,
room_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::DraftClear {
room: room_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/typing",
request_body = TypingRequest,
responses((status = 204, description = "Typing indicator broadcasted")),
tag = "channel",
)]
pub async fn typing(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<TypingRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let room = room_id.into_inner();
let msg = match body.action {
TypingAction::Start => WsInMessage::TypingStart { room },
TypingAction::Stop => WsInMessage::TypingStop { room },
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}

View File

@ -0,0 +1,375 @@
use actix_web::{HttpRequest, HttpResponse, web};
use channel::ChannelBus;
use channel::http::{WsHandler, WsInMessage};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use super::rest::{channel_err, created_json, extract_user, ok_json};
use crate::error::ApiError;
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct ReadReceiptRequest {
pub last_read_seq: i64,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct DndRequest {
pub do_not_disturb: Option<bool>,
pub dnd_start_hour: Option<i16>,
pub dnd_end_hour: Option<i16>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct NotificationMarkAllReadRequest {
pub workspace_id: Option<Uuid>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct PresenceUpdateRequest {
#[schema(example = "online")]
pub status: String,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CustomStatusRequest {
pub emoji: Option<String>,
pub text: Option<String>,
pub expires_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct InviteCreateRequest {
pub workspace: Uuid,
pub room: Option<Uuid>,
pub max_uses: Option<i32>,
pub expires_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct InviteAcceptRequest {
pub code: String,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct BanCreateRequest {
pub user: Uuid,
pub reason: Option<String>,
pub expires_at: Option<DateTime<Utc>>,
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/read-receipt",
request_body = ReadReceiptRequest,
responses((status = 200, description = "Read receipt saved")),
tag = "channel",
)]
pub async fn read_receipt(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<ReadReceiptRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::ReadReceipt {
room: room_id.into_inner(),
last_read_seq: body.last_read_seq,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
patch,
path = "/api/v1/ws/rooms/{room_id}/dnd",
request_body = DndRequest,
responses((status = 204, description = "DND updated")),
tag = "channel",
)]
pub async fn dnd_update(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<DndRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::StateUpdateDnd {
room: room_id.into_inner(),
do_not_disturb: body.do_not_disturb,
dnd_start_hour: body.dnd_start_hour,
dnd_end_hour: body.dnd_end_hour,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
patch,
path = "/api/v1/ws/notifications/{id}/read",
responses((status = 204, description = "Notification marked read")),
tag = "channel",
)]
pub async fn notification_mark_read(
req: HttpRequest,
id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::NotificationMarkRead {
id: id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/notifications/read-all",
request_body = NotificationMarkAllReadRequest,
responses((status = 204, description = "All notifications marked read")),
tag = "channel",
)]
pub async fn notification_mark_all_read(
req: HttpRequest,
body: web::Json<NotificationMarkAllReadRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::NotificationMarkAllRead {
workspace_id: body.workspace_id,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/notifications/{id}",
responses((status = 204, description = "Notification archived")),
tag = "channel",
)]
pub async fn notification_archive(
req: HttpRequest,
id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::NotificationArchive {
id: id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/presence",
request_body = PresenceUpdateRequest,
responses((status = 204, description = "Presence updated")),
tag = "channel",
)]
pub async fn presence_update(
req: HttpRequest,
body: web::Json<PresenceUpdateRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let status: channel::event::presence::UserPresenceStatus =
serde_json::from_value(serde_json::Value::String(body.status.clone()))
.map_err(|e| {
ApiError(service::error::AppError::BadRequest(e.to_string()))
})?;
let msg = WsInMessage::PresenceUpdate { status };
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/custom-status",
request_body = CustomStatusRequest,
responses((status = 204, description = "Custom status updated")),
tag = "channel",
)]
pub async fn custom_status_update(
req: HttpRequest,
body: web::Json<CustomStatusRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::CustomStatusUpdate {
emoji: body.emoji.clone(),
text: body.text.clone(),
expires_at: body.expires_at,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/invites",
request_body = InviteCreateRequest,
responses((status = 201, description = "Invite created")),
tag = "channel",
)]
pub async fn invite_create(
req: HttpRequest,
body: web::Json<InviteCreateRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::InviteCreate {
workspace: body.workspace,
room: body.room,
max_uses: body.max_uses,
expires_at: body.expires_at,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(created_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/invites/accept",
request_body = InviteAcceptRequest,
responses((status = 200, description = "Invite accepted")),
tag = "channel",
)]
pub async fn invite_accept(
req: HttpRequest,
body: web::Json<InviteAcceptRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::InviteAccept {
code: body.code.clone(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/invites/{id}",
responses((status = 204, description = "Invite revoked")),
tag = "channel",
)]
pub async fn invite_revoke(
req: HttpRequest,
id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::InviteRevoke {
id: id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/workspaces/{workspace_id}/bans",
request_body = BanCreateRequest,
responses((status = 201, description = "User banned")),
tag = "channel",
)]
pub async fn ban_create(
req: HttpRequest,
workspace_id: web::Path<Uuid>,
body: web::Json<BanCreateRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::BanCreate {
workspace: workspace_id.into_inner(),
user: body.user,
reason: body.reason.clone(),
expires_at: body.expires_at,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/workspaces/{workspace_id}/bans/{user_id}",
responses((status = 204, description = "User unbanned")),
tag = "channel",
)]
pub async fn ban_remove(
req: HttpRequest,
path: web::Path<(Uuid, Uuid)>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let (workspace, target_user) = path.into_inner();
let msg = WsInMessage::BanRemove {
workspace,
user: target_user,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct RoomMember {
pub id: Uuid,
pub username: String,
pub display_name: String,
pub avatar_url: String,
}
#[utoipa::path(
get,
path = "/api/v1/ws/workspaces/{workspace_id}/members",
responses((status = 200, description = "Workspace members list")),
tag = "channel",
)]
pub async fn list_workspace_members(
req: HttpRequest,
workspace_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let _user_id = extract_user(&req)?;
let workspace = workspace_id.into_inner();
let members = bus.list_workspace_members(workspace).await.map_err(channel_err)?;
let result: Vec<RoomMember> = members
.into_iter()
.map(|(id, username, display_name, avatar_url)| RoomMember {
id,
username,
display_name,
avatar_url,
})
.collect();
Ok(HttpResponse::Ok().json(result))
}

View File

@ -0,0 +1,225 @@
use actix_web::{HttpRequest, HttpResponse, web};
use channel::ChannelBus;
use channel::http::{WsHandler, WsInMessage};
use serde::Deserialize;
use uuid::Uuid;
use super::rest::{channel_err, created_json, extract_user, ok_json};
use crate::error::ApiError;
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateMessageRequest {
pub content: String,
pub content_type: Option<String>,
pub thread: Option<Uuid>,
pub in_reply_to: Option<Uuid>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateMessageRequest {
pub content: String,
}
#[derive(Debug, Deserialize, utoipa::IntoParams)]
pub struct MessageListParams {
pub before_seq: Option<i64>,
pub after_seq: Option<i64>,
pub limit: Option<u64>,
}
#[derive(Debug, Deserialize, utoipa::IntoParams)]
pub struct MessageAroundParams {
pub seq: i64,
pub limit: Option<u64>,
}
#[derive(Debug, Deserialize, utoipa::IntoParams)]
pub struct MissedMessagesParams {
pub after_seq: i64,
pub limit: Option<i64>,
}
#[derive(Debug, Deserialize, utoipa::IntoParams)]
pub struct SearchParams {
pub q: String,
pub room: Option<Uuid>,
pub limit: Option<u64>,
pub offset: Option<u64>,
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/messages",
request_body = CreateMessageRequest,
responses((status = 201, description = "Message created")),
tag = "channel",
)]
pub async fn create_message(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<CreateMessageRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::MessageCreate {
room: room_id.into_inner(),
content: body.content.clone(),
content_type: body.content_type.clone(),
thread: body.thread,
in_reply_to: body.in_reply_to,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(created_json(result))
}
#[utoipa::path(
patch,
path = "/api/v1/ws/messages/{message_id}",
request_body = UpdateMessageRequest,
responses((status = 200, description = "Message updated")),
tag = "channel",
)]
pub async fn update_message(
req: HttpRequest,
message_id: web::Path<Uuid>,
body: web::Json<UpdateMessageRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::MessageUpdate {
message: message_id.into_inner(),
content: body.content.clone(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/messages/{message_id}",
responses((status = 200, description = "Message revoked")),
tag = "channel",
)]
pub async fn revoke_message(
req: HttpRequest,
message_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::MessageRevoke {
message: message_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
get,
path = "/api/v1/ws/rooms/{room_id}/messages",
params(MessageListParams),
responses((status = 200, description = "Message list")),
tag = "channel",
)]
pub async fn list_messages(
req: HttpRequest,
room_id: web::Path<Uuid>,
params: web::Query<MessageListParams>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::MessageList {
room: room_id.into_inner(),
before_seq: params.before_seq,
after_seq: params.after_seq,
limit: params.limit,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
get,
path = "/api/v1/ws/rooms/{room_id}/messages/around",
params(MessageAroundParams),
responses((status = 200, description = "Messages around seq")),
tag = "channel",
)]
pub async fn messages_around(
req: HttpRequest,
room_id: web::Path<Uuid>,
params: web::Query<MessageAroundParams>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::MessageAround {
room: room_id.into_inner(),
seq: params.seq,
limit: params.limit,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
get,
path = "/api/v1/ws/rooms/{room_id}/messages/missed",
params(MissedMessagesParams),
responses((status = 200, description = "Missed messages")),
tag = "channel",
)]
pub async fn missed_messages(
req: HttpRequest,
room_id: web::Path<Uuid>,
params: web::Query<MissedMessagesParams>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::MissedMessages {
room: room_id.into_inner(),
after_seq: params.after_seq,
limit: params.limit,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
get,
path = "/api/v1/ws/search",
params(SearchParams),
responses((status = 200, description = "Search results")),
tag = "channel",
)]
pub async fn search(
req: HttpRequest,
params: web::Query<SearchParams>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::Search {
q: params.q.clone(),
room: params.room,
start_time: None,
end_time: None,
sender_id: None,
content_type: None,
limit: params.limit,
offset: params.offset,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}

View File

@ -0,0 +1,321 @@
use actix_web::{HttpRequest, HttpResponse, web};
use channel::ChannelBus;
use channel::http::{WsHandler, WsInMessage};
use serde::Deserialize;
use uuid::Uuid;
use super::rest::{channel_err, created_json, extract_user, ok_json};
use crate::error::ApiError;
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct RoomCreateRequest {
pub workspace: Uuid,
pub room_name: String,
pub public: bool,
pub category: Option<Uuid>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct RoomUpdateRequest {
pub room_name: Option<String>,
pub public: Option<bool>,
pub category: Option<Uuid>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct AccessRequest {
pub user: Uuid,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CategoryCreateRequest {
pub name: String,
pub position: Option<i32>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CategoryUpdateRequest {
pub name: Option<String>,
pub position: Option<i32>,
}
#[utoipa::path(
get,
path = "/api/v1/ws/rooms",
responses((status = 200, description = "List of rooms")),
tag = "channel",
)]
pub async fn list_rooms(
req: HttpRequest,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let rooms = bus.list_user_rooms(user_id)
.await
.map_err(channel_err)?;
let categories = bus.list_user_categories(user_id)
.await
.map_err(channel_err)?;
let workspace_id = if let Some(r) = rooms.first() {
Some(r.workspace_id)
} else {
bus.first_workspace_id(user_id).await.unwrap_or(None)
};
Ok(HttpResponse::Ok().json(serde_json::json!({
"rooms": rooms,
"categories": categories,
"workspace_id": workspace_id,
})))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/subscribe",
responses((status = 204, description = "Subscribed, user room cache refreshed")),
tag = "channel",
)]
pub async fn subscribe(
req: HttpRequest,
room_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::Subscribe {
room: room_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/rooms/{room_id}/subscribe",
responses((status = 204, description = "Unsubscribed")),
tag = "channel",
)]
pub async fn unsubscribe(
req: HttpRequest,
room_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::Unsubscribe {
room: room_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
get,
path = "/api/v1/ws/rooms/{room_id}",
responses((status = 200, description = "Room info")),
tag = "channel",
)]
pub async fn room_get(
req: HttpRequest,
room_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::RoomGet {
room: room_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms",
request_body = RoomCreateRequest,
responses((status = 201, description = "Room created")),
tag = "channel",
)]
pub async fn room_create(
req: HttpRequest,
body: web::Json<RoomCreateRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::RoomCreate {
workspace: body.workspace,
room_name: body.room_name.clone(),
public: body.public,
category: body.category,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(created_json(result))
}
#[utoipa::path(
patch,
path = "/api/v1/ws/rooms/{room_id}",
request_body = RoomUpdateRequest,
responses((status = 200, description = "Room updated")),
tag = "channel",
)]
pub async fn room_update(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<RoomUpdateRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::RoomUpdate {
room: room_id.into_inner(),
room_name: body.room_name.clone(),
public: body.public,
category: body.category,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/rooms/{room_id}",
responses((status = 204, description = "Room deleted")),
tag = "channel",
)]
pub async fn room_delete(
req: HttpRequest,
room_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::RoomDelete {
room: room_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/members",
request_body = AccessRequest,
responses((status = 204, description = "Access granted")),
tag = "channel",
)]
pub async fn access_grant(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<AccessRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::AccessGrant {
room: room_id.into_inner(),
user: body.user,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/rooms/{room_id}/members/{user_id}",
responses((status = 204, description = "Access revoked")),
tag = "channel",
)]
pub async fn access_revoke(
req: HttpRequest,
path: web::Path<(Uuid, Uuid)>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let (room, target_user) = path.into_inner();
let msg = WsInMessage::AccessRevoke {
room,
user: target_user,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/workspaces/{workspace_id}/categories",
request_body = CategoryCreateRequest,
responses((status = 201, description = "Category created")),
tag = "channel",
)]
pub async fn category_create(
req: HttpRequest,
workspace_id: web::Path<Uuid>,
body: web::Json<CategoryCreateRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::CategoryCreate {
workspace: workspace_id.into_inner(),
name: body.name.clone(),
position: body.position,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(created_json(result))
}
#[utoipa::path(
patch,
path = "/api/v1/ws/categories/{category_id}",
request_body = CategoryUpdateRequest,
responses((status = 200, description = "Category updated")),
tag = "channel",
)]
pub async fn category_update(
req: HttpRequest,
category_id: web::Path<Uuid>,
body: web::Json<CategoryUpdateRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::CategoryUpdate {
id: category_id.into_inner(),
name: body.name.clone(),
position: body.position,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
delete,
path = "/api/v1/ws/categories/{category_id}",
responses((status = 204, description = "Category deleted")),
tag = "channel",
)]
pub async fn category_delete(
req: HttpRequest,
category_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::CategoryDelete {
id: category_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}

View File

@ -0,0 +1,137 @@
use actix_web::{HttpRequest, HttpResponse, web};
use channel::ChannelBus;
use channel::http::{WsHandler, WsInMessage};
use serde::Deserialize;
use uuid::Uuid;
use super::rest::{channel_err, extract_user, ok_json};
use crate::error::ApiError;
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct VoiceMuteRequest {
pub muted: bool,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct VoiceDeafRequest {
pub deafened: bool,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct ScreenShareRequest {
pub start: bool,
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/voice/join",
responses((status = 204, description = "Joined voice channel")),
tag = "channel",
)]
pub async fn voice_join(
req: HttpRequest,
room_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::VoiceJoin {
room: room_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/voice/leave",
responses((status = 204, description = "Left voice channel")),
tag = "channel",
)]
pub async fn voice_leave(
req: HttpRequest,
room_id: web::Path<Uuid>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::VoiceLeave {
room: room_id.into_inner(),
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/voice/mute",
request_body = VoiceMuteRequest,
responses((status = 204, description = "Mute toggled")),
tag = "channel",
)]
pub async fn voice_mute(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<VoiceMuteRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::VoiceMute {
room: room_id.into_inner(),
muted: body.muted,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/voice/deaf",
request_body = VoiceDeafRequest,
responses((status = 204, description = "Deaf toggled")),
tag = "channel",
)]
pub async fn voice_deaf(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<VoiceDeafRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::VoiceDeaf {
room: room_id.into_inner(),
deafened: body.deafened,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}
#[utoipa::path(
post,
path = "/api/v1/ws/rooms/{room_id}/screen-share",
request_body = ScreenShareRequest,
responses((status = 204, description = "Screen share toggled")),
tag = "channel",
)]
pub async fn screen_share(
req: HttpRequest,
room_id: web::Path<Uuid>,
body: web::Json<ScreenShareRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = extract_user(&req)?;
let msg = WsInMessage::ScreenShare {
room: room_id.into_inner(),
start: body.start,
};
let result = WsHandler::handle(&bus, user_id, msg)
.await
.map_err(channel_err)?;
Ok(ok_json(result))
}

View File

@ -0,0 +1,53 @@
use actix_web::{HttpRequest, HttpResponse, web};
use channel::{ChannelBus, ChannelTokenApply, TOKEN_TTL_SECS};
use serde::Deserialize;
use session::SessionExt;
use crate::error::ApiError;
use super::rest::channel_err;
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct TokenRequest {
pub device_id: String,
pub client_id: String,
}
#[derive(Debug, serde::Serialize, utoipa::ToSchema)]
pub struct TokenResponse {
pub access_token: String,
pub expires_in_secs: u64,
}
#[utoipa::path(
post,
path = "/api/v1/ws/token",
request_body = TokenRequest,
responses((status = 200, body = TokenResponse)),
tag = "channel"
)]
pub async fn generate_token(
req: HttpRequest,
body: web::Json<TokenRequest>,
bus: web::Data<ChannelBus>,
) -> Result<HttpResponse, ApiError> {
let user_id = req
.get_session()
.user()
.ok_or_else(|| ApiError(service::error::AppError::Unauthorized))?;
let apply = ChannelTokenApply {
device_id: body.device_id.clone(),
client_id: body.client_id.clone(),
};
let token = bus
.apply_access_token(user_id, apply)
.await
.map_err(channel_err)?;
Ok(HttpResponse::Ok().json(TokenResponse {
access_token: token.access_token,
expires_in_secs: TOKEN_TTL_SECS,
}))
}

91
lib/api/src/error.rs Normal file
View File

@ -0,0 +1,91 @@
use actix_web::{HttpResponse, error::ResponseError, http::StatusCode};
use serde::Serialize;
use service::error::AppError;
pub fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(data))
}
pub struct ApiError(pub AppError);
impl From<AppError> for ApiError {
fn from(err: AppError) -> Self {
ApiError(err)
}
}
impl std::fmt::Display for ApiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::fmt::Debug for ApiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl ResponseError for ApiError {
fn status_code(&self) -> StatusCode {
match &self.0 {
AppError::Unauthorized => StatusCode::UNAUTHORIZED,
AppError::UserNotFound => StatusCode::NOT_FOUND,
AppError::InvalidPassword => StatusCode::UNAUTHORIZED,
AppError::PasswordTooWeak => StatusCode::BAD_REQUEST,
AppError::CaptchaError => StatusCode::BAD_REQUEST,
AppError::TwoFactorRequired => {
StatusCode::from_u16(402).unwrap_or(StatusCode::BAD_REQUEST)
}
AppError::TwoFactorAlreadyEnabled => StatusCode::CONFLICT,
AppError::TwoFactorNotSetup => StatusCode::NOT_FOUND,
AppError::InvalidTwoFactorCode => StatusCode::BAD_REQUEST,
AppError::TwoFactorNotEnabled => StatusCode::NOT_FOUND,
AppError::RsaGenerationError => StatusCode::INTERNAL_SERVER_ERROR,
AppError::RsaDecodeError => StatusCode::BAD_REQUEST,
AppError::UserNameExists => StatusCode::CONFLICT,
AppError::EmailExists => StatusCode::CONFLICT,
AppError::AccountAlreadyExists => StatusCode::CONFLICT,
AppError::TxnError => StatusCode::INTERNAL_SERVER_ERROR,
AppError::PasswordHashError(_) => StatusCode::INTERNAL_SERVER_ERROR,
AppError::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR,
AppError::DoMainNotSet => StatusCode::INTERNAL_SERVER_ERROR,
AppError::InternalError => StatusCode::INTERNAL_SERVER_ERROR,
AppError::InternalServerError(_) => {
StatusCode::INTERNAL_SERVER_ERROR
}
AppError::PermissionDenied => StatusCode::FORBIDDEN,
AppError::ProjectNotFound => StatusCode::NOT_FOUND,
AppError::NoPower => StatusCode::FORBIDDEN,
AppError::RoleParseError => StatusCode::BAD_REQUEST,
AppError::ProjectNameAlreadyExists => StatusCode::CONFLICT,
AppError::RepoNameAlreadyExists => StatusCode::CONFLICT,
AppError::AvatarUploadError(_) => StatusCode::BAD_REQUEST,
AppError::RepoNotFound => StatusCode::NOT_FOUND,
AppError::RepoForBidAccess => StatusCode::FORBIDDEN,
AppError::SerdeError(_) => StatusCode::BAD_REQUEST,
AppError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
AppError::BadRequest(_) => StatusCode::BAD_REQUEST,
AppError::Forbidden(_) => StatusCode::FORBIDDEN,
AppError::Conflict(_) => StatusCode::CONFLICT,
AppError::NotFound(_) => StatusCode::NOT_FOUND,
AppError::InvalidResetToken => StatusCode::BAD_REQUEST,
AppError::ResetTokenExpired => StatusCode::BAD_REQUEST,
AppError::ResetTokenUsed => StatusCode::BAD_REQUEST,
AppError::IssueNotFound => StatusCode::NOT_FOUND,
AppError::LabelNotFound => StatusCode::NOT_FOUND,
AppError::MilestoneNotFound => StatusCode::NOT_FOUND,
AppError::PullRequestNotFound => StatusCode::NOT_FOUND,
AppError::CommentNotFound => StatusCode::NOT_FOUND,
AppError::GitRpcError(_) => StatusCode::INTERNAL_SERVER_ERROR,
AppError::AiError(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
fn error_response(&self) -> HttpResponse {
let status = self.status_code();
let message = self.0.to_string();
HttpResponse::build(status)
.json(serde_json::json!({ "error": message }))
}
}

View File

@ -0,0 +1,47 @@
use actix_web::{HttpResponse, web};
use serde::{Deserialize, Serialize};
use service::AppService;
use session::Session;
use crate::error::ApiError;
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(data))
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct WkRepoPath {
pub wk: String,
pub repo: String,
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct ArchiveQuery {
#[serde(default = "default_format")]
pub format: String,
pub tree: Option<String>,
pub prefix: Option<String>,
pub pathspec: Option<Vec<String>>,
}
fn default_format() -> String {
"tar".to_string()
}
#[utoipa::path(
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/archive",
params(WkRepoPath, ArchiveQuery),
responses((status = 200, description = "Archive download")),
security(("session" = []))
)]
pub async fn archive(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoPath>,
query: web::Query<ArchiveQuery>,
) -> Result<HttpResponse, ApiError> {
let WkRepoPath { wk, repo } = path.into_inner();
match query.format.as_str() {
"zip" => ok_json(service.git_archive_zip(&session, &wk, &repo, None).await?),
_ => ok_json(service.git_archive_tar(&session, &wk, &repo, None).await?),
}
}

62
lib/api/src/git/blame.rs Normal file
View File

@ -0,0 +1,62 @@
use actix_web::{HttpResponse, web};
use serde::{Deserialize, Serialize};
use service::AppService;
use session::Session;
use crate::error::ApiError;
use crate::git::dto;
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(data))
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct WkRepoPath {
pub wk: String,
pub repo: String,
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct BlameQuery {
pub path: String,
pub rev: Option<String>,
pub start_line: Option<u32>,
pub end_line: Option<u32>,
}
#[utoipa::path(
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/blame",
params(WkRepoPath, BlameQuery),
responses((status = 200, description = "Blame result", body = dto::BlameFileResponseDto)),
security(("session" = []))
)]
pub async fn blame_file(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoPath>,
query: web::Query<BlameQuery>,
) -> Result<HttpResponse, ApiError> {
let WkRepoPath { wk, repo } = path.into_inner();
match (query.start_line, query.end_line) {
(Some(start), Some(end)) => {
let data: dto::BlameFileResponseDto = service
.git_blame_hunk(
&session, &wk, &repo, query.path.clone(),
query.rev.clone(), start, end,
)
.await?
.into();
ok_json(data)
}
_ => {
let data: dto::BlameFileResponseDto = service
.git_blame_file(
&session, &wk, &repo, query.path.clone(),
query.rev.clone(), None,
)
.await?
.into();
ok_json(data)
}
}
}

97
lib/api/src/git/blob.rs Normal file
View File

@ -0,0 +1,97 @@
use actix_web::{HttpResponse, web};
use serde::{Deserialize, Serialize};
use service::AppService;
use session::Session;
use crate::error::ApiError;
use crate::git::dto;
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(data))
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct WkRepoPath {
pub wk: String,
pub repo: String,
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct WkRepoBlobPath {
pub wk: String,
pub repo: String,
pub oid: String,
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct BlobPathQuery {
pub path: Option<String>,
}
#[derive(Serialize, utoipa::ToSchema)]
pub struct BlobInfoResponse {
#[serde(flatten)]
pub load: dto::BlobLoadResponseDto,
pub size: u64,
pub is_binary: bool,
}
#[utoipa::path(
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/blobs/{oid}",
params(WkRepoBlobPath, BlobPathQuery),
responses((status = 200, description = "Blob info", body = BlobInfoResponse)),
security(("session" = []))
)]
pub async fn blob_info(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoBlobPath>,
query: web::Query<BlobPathQuery>,
) -> Result<HttpResponse, ApiError> {
let WkRepoBlobPath { wk, repo, oid } = path.into_inner();
let path_opt = query.path.clone().unwrap_or_default();
let load: dto::BlobLoadResponseDto = service
.git_blob_load(&session, &wk, &repo, oid.clone(), path_opt.clone())
.await?
.into();
let size_resp: dto::BlobSizeResponseDto = service
.git_blob_size(&session, &wk, &repo, oid.clone(), path_opt)
.await?
.into();
let binary_resp: dto::BlobIsBinaryResponseDto = service
.git_blob_is_binary(&session, &wk, &repo, oid)
.await?
.into();
ok_json(BlobInfoResponse {
load,
size: size_resp.size,
is_binary: binary_resp.is_binary,
})
}
#[derive(Deserialize, utoipa::ToSchema)]
pub struct BlobUploadBody {
pub path: String,
pub blob: Vec<u8>,
}
#[utoipa::path(
post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/blobs",
params(WkRepoPath),
request_body = BlobUploadBody,
responses((status = 200, description = "Upload result", body = dto::BlobUploadResponseDto)),
security(("session" = []))
)]
pub async fn blob_upload(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoPath>,
params: web::Json<BlobUploadBody>,
) -> Result<HttpResponse, ApiError> {
let WkRepoPath { wk, repo } = path.into_inner();
let p = params.into_inner();
let data: dto::BlobUploadResponseDto = service
.git_blob_upload(&session, &wk, &repo, p.path, p.blob)
.await?
.into();
ok_json(data)
}

205
lib/api/src/git/branch.rs Normal file
View File

@ -0,0 +1,205 @@
use actix_web::{HttpResponse, web};
use serde::{Deserialize, Serialize};
use service::{AppService, Pagination};
use session::Session;
use crate::error::ApiError;
use crate::git::dto;
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(data))
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct WkRepoPath {
pub wk: String,
pub repo: String,
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct WkRepoBranchPath {
pub wk: String,
pub repo: String,
pub name: String,
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct BranchDeleteQuery {
#[serde(default)]
pub force: bool,
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct BranchListQuery {
#[serde(default)]
pub summary: bool,
#[serde(default)]
pub default_only: bool,
}
#[derive(Deserialize, utoipa::ToSchema)]
pub struct RenameBranchBody {
pub new_branch: String,
#[serde(default)]
pub force: bool,
}
#[utoipa::path(
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches",
params(WkRepoPath, Pagination, BranchListQuery),
responses((status = 200, description = "Branch list or summary", body = dto::BranchListResponseDto)),
security(("session" = []))
)]
pub async fn list_branches(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoPath>,
pagination: web::Query<Pagination>,
query: web::Query<BranchListQuery>,
) -> Result<HttpResponse, ApiError> {
let WkRepoPath { wk, repo } = path.into_inner();
if query.summary {
let data: dto::BranchSummaryResponseDto = service
.git_branch_summary(&session, &wk, &repo)
.await?
.into();
return ok_json(data);
}
if query.default_only {
let data: dto::BranchHeadResponseDto = service
.git_branch_head(&session, &wk, &repo)
.await?
.into();
return ok_json(data);
}
let data: dto::BranchListResponseDto = service
.git_branch_list(&session, &wk, &repo, pagination.into_inner())
.await?
.into();
ok_json(data)
}
#[utoipa::path(
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}",
params(WkRepoBranchPath),
responses((status = 200, description = "Branch info", body = dto::BranchInfoResponseDto)),
security(("session" = []))
)]
pub async fn branch_info(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoBranchPath>,
) -> Result<HttpResponse, ApiError> {
let WkRepoBranchPath { wk, repo, name } = path.into_inner();
let data: dto::BranchInfoResponseDto = service
.git_branch_info(&session, &wk, &repo, name)
.await?
.into();
ok_json(data)
}
#[utoipa::path(
post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches",
params(WkRepoPath),
request_body = Object, description = "BranchForkParams { name, oid, force }",
responses((status = 200, description = "Branch created")),
security(("session" = []))
)]
pub async fn fork_branch(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoPath>,
params: web::Json<git::rpc::proto::BranchForkParams>,
) -> Result<HttpResponse, ApiError> {
let WkRepoPath { wk, repo } = path.into_inner();
let data = service
.git_branch_fork(&session, &wk, &repo, params.into_inner())
.await?;
ok_json(data)
}
#[utoipa::path(
patch, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}",
params(WkRepoBranchPath),
request_body = RenameBranchBody,
responses((status = 200, description = "Branch renamed")),
security(("session" = []))
)]
pub async fn rename_branch(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoBranchPath>,
body: web::Json<RenameBranchBody>,
) -> Result<HttpResponse, ApiError> {
let WkRepoBranchPath { wk, repo, name } = path.into_inner();
let body = body.into_inner();
let params = git::rpc::proto::BranchReNameParams {
old_branch: name,
new_branch: body.new_branch,
force: body.force,
};
let data = service
.git_branch_rename(&session, &wk, &repo, params)
.await?;
ok_json(data)
}
#[utoipa::path(
delete, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}",
params(WkRepoBranchPath, BranchDeleteQuery),
responses((status = 200, description = "Branch deleted")),
security(("session" = []))
)]
pub async fn delete_branch(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoBranchPath>,
query: web::Query<BranchDeleteQuery>,
) -> Result<HttpResponse, ApiError> {
let WkRepoBranchPath { wk, repo, name } = path.into_inner();
let params = git::rpc::proto::BranchDeleteParams {
name,
force: query.force,
};
let data = service
.git_branch_delete(&session, &wk, &repo, params)
.await?;
ok_json(data)
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct AheadBehindQuery {
pub remote_branch: String,
}
#[utoipa::path(
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}/ahead-behind",
params(WkRepoBranchPath, AheadBehindQuery),
responses((status = 200, description = "Ahead/behind counts", body = dto::BranchAheadBehindResponseDto)),
security(("session" = []))
)]
pub async fn ahead_behind(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoBranchPath>,
query: web::Query<AheadBehindQuery>,
) -> Result<HttpResponse, ApiError> {
let WkRepoBranchPath { wk, repo, name } = path.into_inner();
let data: dto::BranchAheadBehindResponseDto = service
.git_branch_ahead_behind(&session, &wk, &repo, name, query.remote_branch.clone())
.await?
.into();
ok_json(data)
}
#[utoipa::path(
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}/upstream",
params(WkRepoBranchPath),
responses((status = 200, description = "Upstream branch", body = dto::BranchUpstreamResponseDto)),
security(("session" = []))
)]
pub async fn branch_upstream(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoBranchPath>,
) -> Result<HttpResponse, ApiError> {
let WkRepoBranchPath { wk, repo, name } = path.into_inner();
let data: dto::BranchUpstreamResponseDto = service
.git_branch_upstream(&session, &wk, &repo, name)
.await?
.into();
ok_json(data)
}

169
lib/api/src/git/commit.rs Normal file
View File

@ -0,0 +1,169 @@
use actix_web::{HttpResponse, web};
use git::rpc::proto as p;
use serde::{Deserialize, Serialize};
use service::AppService;
use session::Session;
use crate::error::ApiError;
use crate::git::dto;
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(data))
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct WkRepoPath {
pub wk: String,
pub repo: String,
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct WkRepoCommitPath {
pub wk: String,
pub repo: String,
pub oid: String,
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct HistoryQuery {
pub limit: Option<u64>,
pub skip: Option<u64>,
pub sort: Option<i32>,
pub branch: Option<String>,
}
#[derive(Deserialize, utoipa::IntoParams)]
pub struct CommitListQuery {
#[serde(default)]
pub summary: bool,
#[serde(default)]
pub refs: bool,
pub prefix: Option<String>,
}
#[utoipa::path(
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits",
params(WkRepoPath, CommitListQuery),
responses(
(status = 200, description = "Commit list / summary / refs / prefix", body = dto::CommitHistoryResponseDto),
),
security(("session" = []))
)]
pub async fn list_commits(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoPath>,
query: web::Query<CommitListQuery>,
) -> Result<HttpResponse, ApiError> {
let WkRepoPath { wk, repo } = path.into_inner();
if let Some(prefix) = &query.prefix {
let data: dto::CommitPrefixResponseDto = service
.git_commit_prefix(&session, &wk, &repo, prefix.clone())
.await?
.into();
return ok_json(data);
}
if query.refs {
let data: dto::CommitRefsResponseDto = service
.git_commit_refs(&session, &wk, &repo)
.await?
.into();
return ok_json(data);
}
if query.summary {
let data: dto::CommitSummaryResponseDto = service
.git_commit_summary(&session, &wk, &repo)
.await?
.into();
return ok_json(data);
}
let data: dto::CommitHistoryResponseDto = service
.git_commit_history(&session, &wk, &repo, 20, 0, 0, None)
.await?
.into();
ok_json(data)
}
#[utoipa::path(
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/history",
params(WkRepoPath, HistoryQuery),
responses((status = 200, description = "Commit history", body = dto::CommitHistoryResponseDto)),
security(("session" = []))
)]
pub async fn commit_history(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoPath>,
query: web::Query<HistoryQuery>,
) -> Result<HttpResponse, ApiError> {
let WkRepoPath { wk, repo } = path.into_inner();
let data: dto::CommitHistoryResponseDto = service
.git_commit_history(
&session, &wk, &repo,
query.limit.unwrap_or(20),
query.skip.unwrap_or(0),
query.sort.unwrap_or(0),
query.branch.clone(),
)
.await?
.into();
ok_json(data)
}
#[utoipa::path(
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/{oid}",
params(WkRepoCommitPath),
responses((status = 200, description = "Commit info", body = dto::CommitInfoResponseDto)),
security(("session" = []))
)]
pub async fn commit_info(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoCommitPath>,
) -> Result<HttpResponse, ApiError> {
let WkRepoCommitPath { wk, repo, oid } = path.into_inner();
let data: dto::CommitInfoResponseDto = service
.git_commit_info(&session, &wk, &repo, oid)
.await?
.into();
ok_json(data)
}
#[utoipa::path(
post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/walk",
params(WkRepoPath),
request_body = Object, description = "CommitWalkParams",
responses((status = 200, description = "Walk result", body = dto::CommitHistoryResponseDto)),
security(("session" = []))
)]
pub async fn commit_walk(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoPath>,
params: web::Json<p::CommitWalkParams>,
) -> Result<HttpResponse, ApiError> {
let WkRepoPath { wk, repo } = path.into_inner();
let proto_resp = service
.git_commit_walk(&session, &wk, &repo, params.into_inner())
.await?;
ok_json(dto::CommitHistoryResponseDto {
commits: proto_resp.commits.into_iter().map(Into::into).collect(),
})
}
#[utoipa::path(
post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/cherry-pick",
params(WkRepoPath),
request_body = Object, description = "CommitCherryPickParams",
responses((status = 200, description = "Cherry-pick result", body = dto::CherryPickResponseDto)),
security(("session" = []))
)]
pub async fn cherry_pick(
session: Session,
service: web::Data<AppService>,
path: web::Path<WkRepoPath>,
params: web::Json<p::CommitCherryPickParams>,
) -> Result<HttpResponse, ApiError> {
let WkRepoPath { wk, repo } = path.into_inner();
let data: dto::CherryPickResponseDto = service
.git_cherry_pick(&session, &wk, &repo, params.into_inner())
.await?
.into();
ok_json(data)
}

Some files were not shown because too many files have changed in this diff Show More